Skip to content

Commit

Permalink
fixed arrow_3d incorrect behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
Davide-sd committed Mar 19, 2024
1 parent d490c27 commit e9877bb
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 18 deletions.
8 changes: 4 additions & 4 deletions spb/backends/matplotlib/renderers/arrow2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Arrow2DRendererQuivers(MatplotlibRenderer):

def _draw_arrow_2d(renderer, data):
p, s = renderer.plot, renderer.series
xx, yy, uu, vv = data
x1, y1, x2, y2 = data
mpatches = p.matplotlib.patches

arrowstyle = mpatches.ArrowStyle(
Expand All @@ -23,7 +23,7 @@ def _draw_arrow_2d(renderer, data):
color=next(p._cl)
)
kw = p.merge({}, pkw, s.rendering_kw)
arrow = mpatches.FancyArrowPatch((xx, yy), (uu, vv), **kw)
arrow = mpatches.FancyArrowPatch((x1, y1), (x2, y2), **kw)
p._ax.add_patch(arrow)

if s.show_in_legend:
Expand All @@ -36,9 +36,9 @@ def _draw_arrow_2d(renderer, data):


def _update_arrow2d(renderer, data, handle):
xx, yy, uu, vv = data
x1, y1, x2, y2 = data
arrow = handle[0]
arrow.set_positions((xx, yy), (uu, vv))
arrow.set_positions((x1, y1), (x2, y2))


class Arrow2DRendererFancyArrowPatch(MatplotlibRenderer):
Expand Down
10 changes: 6 additions & 4 deletions spb/backends/matplotlib/renderers/arrow3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def do_3d_projection(self, renderer=None):

def _draw_arrow_3d(renderer, data):
p, s = renderer.plot, renderer.series
xx, yy, zz, uu, vv, ww = data
x1, y1, z1, x2, y2, z2 = data
dx, dy, dz = x2 - x1, y2 - y1, z2 - z1
mpatches = p.matplotlib.patches

arrowstyle = mpatches.ArrowStyle(
Expand All @@ -58,7 +59,7 @@ def _draw_arrow_3d(renderer, data):
color=next(p._cl)
)
kw = p.merge({}, pkw, s.rendering_kw)
arrow = Arrow3D(xx, yy, zz, uu, vv, ww, **kw)
arrow = Arrow3D(x1, y1, z1, dx, dy, dz, **kw)
p._ax.add_patch(arrow)

if s.show_in_legend:
Expand All @@ -77,10 +78,11 @@ def _update_arrow3d(renderer, data, handle):
arrow and add a new one.
"""
p = renderer.plot
xx, yy, zz, uu, vv, ww = data
x1, y1, z1, x2, y2, z2 = data
dx, dy, dz = x2 - x1, y2 - y1, z2 - z1
handle[0].remove()
kw = handle[1]
arrow = Arrow3D(xx, yy, zz, uu, vv, ww, **kw)
arrow = Arrow3D(x1, y1, z1, dx, dy, dz, **kw)
p._ax.add_patch(arrow)
handle[0] = arrow

Expand Down
18 changes: 8 additions & 10 deletions spb/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3695,8 +3695,11 @@ def __init__(self, start, direction, label="", **kwargs):
self.is_streamlines = kwargs.get("streamlines", False)

def __str__(self):
pre = "3D " if self.is_3D else "2D "
start = tuple(self.start)
end = tuple(s + d for s, d in zip(start, self.direction))
return self._str_helper(
f"2D arrow from {self.start} to {self.direction}"
pre + f"arrow from {start} to {end}"
)

def get_label(self, use_latex=False, wrapper="$%s$"):
Expand Down Expand Up @@ -3724,9 +3727,9 @@ def get_data(self):
Returns
=======
x, y : float
x1, y1, z1 [optional] : float
Coordinates of the start position.
u, v : float
x2, y2, z2 [optional] : float
Coordinates of the end position.
"""
np = import_module('numpy')
Expand All @@ -3742,8 +3745,8 @@ def get_data(self):
direction = np.array(
[t.evalf(subs=self.params) for t in direction], dtype=float)

direction += start
return self._apply_transform(*start, *direction)
end = start + direction
return self._apply_transform(*start, *end)


class Arrow3DSeries(Arrow2DSeries):
Expand All @@ -3765,11 +3768,6 @@ def get_data(self):
"""
return super().get_data()

def __str__(self):
return self._str_helper(
f"3D arrow from {self.start} to {self.direction}"
)


class GridBase:
"""
Expand Down
6 changes: 6 additions & 0 deletions tests/backends/test_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2423,11 +2423,17 @@ def test_arrow_3d():
assert len(p.ax.get_legend().legend_handles) == 1
assert p.ax.get_legend().legend_handles[0].get_label() == "$test$"
assert p.ax.get_legend().legend_handles[0].get_color() == "r"
# only way to test if it renders what it's supposed to
assert np.allclose(p.ax.patches[0]._xyz, [1, 2, 3])
assert np.allclose(p.ax.patches[0]._dxdydz, [4, 5, 6])

p = make_test_arrow_3d(MB, "test", {"color": "r"}, False)
p.fig
assert len(p.ax.patches) == 1
assert p.ax.get_legend() is None
# only way to test if it renders what it's supposed to
assert np.allclose(p.ax.patches[0]._xyz, [1, 2, 3])
assert np.allclose(p.ax.patches[0]._dxdydz, [4, 5, 6])


@pytest.mark.skipif(ct is None, reason="control is not installed")
Expand Down
5 changes: 5 additions & 0 deletions tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4124,14 +4124,17 @@ def test_arrow2dserie(start, direc, label, rkw, sil, params):
if not params:
assert s.get_label(False) == (
"(1.0, 2.0) -> (4.0, 6.0)" if not label else label)
assert str(s) == "2D arrow from (1.0, 2.0) to (4.0, 6.0)"
else:
assert s.get_label(False) == (
"(j, k) -> (j + 3, k + 4)" if not label else label)
assert str(s) == "interactive 2D arrow from (j, k) to (j + 3, k + 4) and parameters (j, k)"
assert s.rendering_kw == {} if not rkw else rkw
assert s.is_interactive == (len(s.params) > 0)
assert s.params == {} if not params else params



@pytest.mark.parametrize(
"start, direc, label, rkw, sil, params",
[
Expand All @@ -4155,9 +4158,11 @@ def test_arrow3dserie(start, direc, label, rkw, sil, params):
if not params:
assert s.get_label(False) == (
"(1.0, 2.0, 3.0) -> (5.0, 7.0, 9.0)" if not label else label)
assert str(s) == "3D arrow from (1.0, 2.0, 3.0) to (5.0, 7.0, 9.0)"
else:
assert s.get_label(False) == (
"(j, k, l) -> (j + 4, k + 5, l + 6)" if not label else label)
assert str(s) == "interactive 3D arrow from (j, k, l) to (j + 4, k + 5, l + 6) and parameters (j, k, l)"
assert s.rendering_kw == {} if not rkw else rkw
assert s.is_interactive == (len(s.params) > 0)
assert s.params == {} if not params else params
Expand Down

0 comments on commit e9877bb

Please sign in to comment.