diff --git a/spb/backends/matplotlib/renderers/arrow2d.py b/spb/backends/matplotlib/renderers/arrow2d.py index c5431b7..89af843 100644 --- a/spb/backends/matplotlib/renderers/arrow2d.py +++ b/spb/backends/matplotlib/renderers/arrow2d.py @@ -11,7 +11,7 @@ class Arrow2DRendererQuivers(MatplotlibRenderer): def _draw_arrow_2d(renderer, data): p, s = renderer.plot, renderer.series - xx, yy, uu, vv = data + x1, y1, x2, y2 = data mpatches = p.matplotlib.patches arrowstyle = mpatches.ArrowStyle( @@ -23,7 +23,7 @@ def _draw_arrow_2d(renderer, data): color=next(p._cl) ) kw = p.merge({}, pkw, s.rendering_kw) - arrow = mpatches.FancyArrowPatch((xx, yy), (uu, vv), **kw) + arrow = mpatches.FancyArrowPatch((x1, y1), (x2, y2), **kw) p._ax.add_patch(arrow) if s.show_in_legend: @@ -36,9 +36,9 @@ def _draw_arrow_2d(renderer, data): def _update_arrow2d(renderer, data, handle): - xx, yy, uu, vv = data + x1, y1, x2, y2 = data arrow = handle[0] - arrow.set_positions((xx, yy), (uu, vv)) + arrow.set_positions((x1, y1), (x2, y2)) class Arrow2DRendererFancyArrowPatch(MatplotlibRenderer): diff --git a/spb/backends/matplotlib/renderers/arrow3d.py b/spb/backends/matplotlib/renderers/arrow3d.py index fc910b4..822fd05 100644 --- a/spb/backends/matplotlib/renderers/arrow3d.py +++ b/spb/backends/matplotlib/renderers/arrow3d.py @@ -46,7 +46,8 @@ def do_3d_projection(self, renderer=None): def _draw_arrow_3d(renderer, data): p, s = renderer.plot, renderer.series - xx, yy, zz, uu, vv, ww = data + x1, y1, z1, x2, y2, z2 = data + dx, dy, dz = x2 - x1, y2 - y1, z2 - z1 mpatches = p.matplotlib.patches arrowstyle = mpatches.ArrowStyle( @@ -58,7 +59,7 @@ def _draw_arrow_3d(renderer, data): color=next(p._cl) ) kw = p.merge({}, pkw, s.rendering_kw) - arrow = Arrow3D(xx, yy, zz, uu, vv, ww, **kw) + arrow = Arrow3D(x1, y1, z1, dx, dy, dz, **kw) p._ax.add_patch(arrow) if s.show_in_legend: @@ -77,10 +78,11 @@ def _update_arrow3d(renderer, data, handle): arrow and add a new one. """ p = renderer.plot - xx, yy, zz, uu, vv, ww = data + x1, y1, z1, x2, y2, z2 = data + dx, dy, dz = x2 - x1, y2 - y1, z2 - z1 handle[0].remove() kw = handle[1] - arrow = Arrow3D(xx, yy, zz, uu, vv, ww, **kw) + arrow = Arrow3D(x1, y1, z1, dx, dy, dz, **kw) p._ax.add_patch(arrow) handle[0] = arrow diff --git a/spb/series.py b/spb/series.py index 08c3501..3a0c600 100644 --- a/spb/series.py +++ b/spb/series.py @@ -3695,8 +3695,11 @@ def __init__(self, start, direction, label="", **kwargs): self.is_streamlines = kwargs.get("streamlines", False) def __str__(self): + pre = "3D " if self.is_3D else "2D " + start = tuple(self.start) + end = tuple(s + d for s, d in zip(start, self.direction)) return self._str_helper( - f"2D arrow from {self.start} to {self.direction}" + pre + f"arrow from {start} to {end}" ) def get_label(self, use_latex=False, wrapper="$%s$"): @@ -3724,9 +3727,9 @@ def get_data(self): Returns ======= - x, y : float + x1, y1, z1 [optional] : float Coordinates of the start position. - u, v : float + x2, y2, z2 [optional] : float Coordinates of the end position. """ np = import_module('numpy') @@ -3742,8 +3745,8 @@ def get_data(self): direction = np.array( [t.evalf(subs=self.params) for t in direction], dtype=float) - direction += start - return self._apply_transform(*start, *direction) + end = start + direction + return self._apply_transform(*start, *end) class Arrow3DSeries(Arrow2DSeries): @@ -3765,11 +3768,6 @@ def get_data(self): """ return super().get_data() - def __str__(self): - return self._str_helper( - f"3D arrow from {self.start} to {self.direction}" - ) - class GridBase: """ diff --git a/tests/backends/test_matplotlib.py b/tests/backends/test_matplotlib.py index 101da0e..24cf793 100644 --- a/tests/backends/test_matplotlib.py +++ b/tests/backends/test_matplotlib.py @@ -2423,11 +2423,17 @@ def test_arrow_3d(): assert len(p.ax.get_legend().legend_handles) == 1 assert p.ax.get_legend().legend_handles[0].get_label() == "$test$" assert p.ax.get_legend().legend_handles[0].get_color() == "r" + # only way to test if it renders what it's supposed to + assert np.allclose(p.ax.patches[0]._xyz, [1, 2, 3]) + assert np.allclose(p.ax.patches[0]._dxdydz, [4, 5, 6]) p = make_test_arrow_3d(MB, "test", {"color": "r"}, False) p.fig assert len(p.ax.patches) == 1 assert p.ax.get_legend() is None + # only way to test if it renders what it's supposed to + assert np.allclose(p.ax.patches[0]._xyz, [1, 2, 3]) + assert np.allclose(p.ax.patches[0]._dxdydz, [4, 5, 6]) @pytest.mark.skipif(ct is None, reason="control is not installed") diff --git a/tests/test_series.py b/tests/test_series.py index c42b43d..1292ec0 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -4124,14 +4124,17 @@ def test_arrow2dserie(start, direc, label, rkw, sil, params): if not params: assert s.get_label(False) == ( "(1.0, 2.0) -> (4.0, 6.0)" if not label else label) + assert str(s) == "2D arrow from (1.0, 2.0) to (4.0, 6.0)" else: assert s.get_label(False) == ( "(j, k) -> (j + 3, k + 4)" if not label else label) + assert str(s) == "interactive 2D arrow from (j, k) to (j + 3, k + 4) and parameters (j, k)" assert s.rendering_kw == {} if not rkw else rkw assert s.is_interactive == (len(s.params) > 0) assert s.params == {} if not params else params + @pytest.mark.parametrize( "start, direc, label, rkw, sil, params", [ @@ -4155,9 +4158,11 @@ def test_arrow3dserie(start, direc, label, rkw, sil, params): if not params: assert s.get_label(False) == ( "(1.0, 2.0, 3.0) -> (5.0, 7.0, 9.0)" if not label else label) + assert str(s) == "3D arrow from (1.0, 2.0, 3.0) to (5.0, 7.0, 9.0)" else: assert s.get_label(False) == ( "(j, k, l) -> (j + 4, k + 5, l + 6)" if not label else label) + assert str(s) == "interactive 3D arrow from (j, k, l) to (j + 4, k + 5, l + 6) and parameters (j, k, l)" assert s.rendering_kw == {} if not rkw else rkw assert s.is_interactive == (len(s.params) > 0) assert s.params == {} if not params else params