Skip to content

Commit

Permalink
Add snapto validation tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sco1 committed Jun 18, 2024
1 parent 9ea760d commit e7af1bc
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 72 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ repos:
- id: isort
name: isort
- repo: https://github.com/pycqa/flake8
rev: 7.0.0
rev: 7.1.0
hooks:
- id: flake8
additional_dependencies:
Expand All @@ -35,6 +35,6 @@ repos:
- id: python-check-blanket-type-ignore
- id: python-use-type-annotations
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.7
rev: v0.4.9
hooks:
- id: ruff
3 changes: 3 additions & 0 deletions .ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,6 @@ ignore = [
"tests/test_*.py" = [
"D103",
]
"tests/conftest.py" = [
"D103",
]
22 changes: 15 additions & 7 deletions matplotlib_window/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,16 @@ def validate_snap_to(self, snap_to: Line2D | None) -> Line2D | None:
If `snap_to` is `None`, or is a plot object that contains x data, it is returned unchanged.
Otherwise an exception is raised.
"""
if snap_to is not None:
try:
snap_to.get_xdata()
except AttributeError as e:
raise ValueError("Cannot provide an empty lineseries to snapto") from e
if snap_to is None:
return None

try:
xydata = snap_to.get_xydata()
except AttributeError as e:
raise ValueError("Cannot provide an empty lineseries to snapto") from e

if len(xydata) == 0: # type: ignore[arg-type]
raise ValueError("Cannot provide an empty lineseries to snapto")

return snap_to

Expand Down Expand Up @@ -297,9 +302,9 @@ def location(self) -> NUMERIC_T:
"""Return the location of the `DragLine` along its relevant axis."""
pos: t.Sequence[NUMERIC_T]
if self.orientation == Orientation.VERTICAL:
pos = self.myobj.get_ydata() # type: ignore[assignment]
else:
pos = self.myobj.get_xdata() # type: ignore[assignment]
else:
pos = self.myobj.get_ydata() # type: ignore[assignment]

return pos[0] # Should be a (location, location) tuple

Expand Down Expand Up @@ -350,6 +355,9 @@ def __init__(
alpha: NUMERIC_T = 0.4,
**kwargs: t.Any,
) -> None:
if width <= 0:
raise ValueError(f"Width value must be greater than 1. Received: {width}")

# Rectangle patches are located from their bottom left corner; because we want to span the
# full y range, we need to translate the y position to the bottom of the axes
rect_params = transform_rect_params(ax, position)
Expand Down
127 changes: 65 additions & 62 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ mypy = "^1.0"
pre-commit = "^3.0"
pytest = "^8.0"
pytest-cov = "^5.0"
pytest-mock = "^3.14"
pytest-randomly = "^3.12"
ruff = "^0.4"
tox = "^4.4"
Expand Down
Empty file added tests/__init__.py
Empty file.
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import typing as t

import matplotlib.pyplot as plt
import pytest
from matplotlib.axes import Axes
from matplotlib.figure import Figure

PLOTOBJ_T: t.TypeAlias = tuple[Figure, Axes]


@pytest.fixture
def plotobj() -> tuple[Figure, Axes]:
fig, ax = plt.subplots()
return fig, ax
10 changes: 10 additions & 0 deletions tests/test_base_objs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import pytest

from matplotlib_window.base import DragRect
from tests.conftest import PLOTOBJ_T


def test_dragrect_invalid_width_raises(plotobj: PLOTOBJ_T) -> None:
_, ax = plotobj
with pytest.raises(ValueError, match="greater than 1"):
_ = DragRect(ax=ax, position=0, width=0)
98 changes: 98 additions & 0 deletions tests/test_snapto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import numpy as np
import pytest
from matplotlib.lines import Line2D

from matplotlib_window.base import DragLine, DragRect, Orientation, _DraggableObject
from tests.conftest import PLOTOBJ_T


def test_dragobj_snapto_none_passthrough() -> None:
do = _DraggableObject()
assert do.validate_snap_to(None) is None


def test_dragline_snapto_none_passthrough(plotobj: PLOTOBJ_T) -> None:
_, ax = plotobj
dl = DragLine(ax=ax, position=0)
assert dl.validate_snap_to(None) is None


def test_dragrect_snapto_none_passthrough(plotobj: PLOTOBJ_T) -> None:
_, ax = plotobj
dr = DragRect(ax=ax, position=0, width=1)
assert dr.validate_snap_to(None) is None


EMPTY_LINE = Line2D(xdata=np.array([]), ydata=np.array([]))


def test_dragobj_snapto_empty_data_raises() -> None:
do = _DraggableObject()
with pytest.raises(ValueError, match="empty"):
do.validate_snap_to(EMPTY_LINE)


def test_dragline_snapto_empty_data_raises(plotobj: PLOTOBJ_T) -> None:
_, ax = plotobj
dl = DragLine(ax=ax, position=0)
with pytest.raises(ValueError, match="empty"):
dl.validate_snap_to(EMPTY_LINE)


def test_dragrect_snapto_empty_data_raises(plotobj: PLOTOBJ_T) -> None:
_, ax = plotobj
dr = DragRect(ax=ax, position=0, width=1)
with pytest.raises(ValueError, match="empty"):
dr.validate_snap_to(EMPTY_LINE)


DUMMY_LINE = Line2D(xdata=np.array([0, 1, 2]), ydata=np.array([0, 1, 2]))


def test_vertical_dragline_snapto_out_of_bounds_raises(plotobj: PLOTOBJ_T) -> None:
_, ax = plotobj
dl = DragLine(ax, position=3, orientation=Orientation.VERTICAL)

with pytest.raises(ValueError, match="bounds"):
dl.validate_snap_to(DUMMY_LINE)


def test_horizontal_dragline_snapto_out_of_bounds_raises(plotobj: PLOTOBJ_T) -> None:
_, ax = plotobj
dl = DragLine(ax, position=3, orientation=Orientation.HORIZONTAL)

with pytest.raises(ValueError, match="bounds"):
dl.validate_snap_to(DUMMY_LINE)


def test_vertical_dragline_valid_snapto(plotobj: PLOTOBJ_T) -> None:
_, ax = plotobj
_ = DragLine(ax, position=1, orientation=Orientation.VERTICAL, snap_to=DUMMY_LINE)


def test_horizontal_dragline_valid_snapto(plotobj: PLOTOBJ_T) -> None:
_, ax = plotobj
_ = DragLine(ax, position=1, orientation=Orientation.HORIZONTAL, snap_to=DUMMY_LINE)


RECT_BOUNDS_TEST_CASES = (
(-1, 1), # Left edge out
(0, 3), # Right edge out
(-1, 4), # Both edges out
)


@pytest.mark.parametrize(("position", "width"), RECT_BOUNDS_TEST_CASES)
def test_dragrect_snapto_out_of_bounds_raises(
position: int, width: int, plotobj: PLOTOBJ_T
) -> None:
_, ax = plotobj
dr = DragRect(ax=ax, position=position, width=width)

with pytest.raises(ValueError, match="bounds"):
dr.validate_snap_to(DUMMY_LINE)


def test_dragrect_valid_snapto(plotobj: PLOTOBJ_T) -> None:
_, ax = plotobj
_ = DragRect(ax=ax, position=0, width=1, snap_to=DUMMY_LINE)

0 comments on commit e7af1bc

Please sign in to comment.