diff --git a/docs/modules/index.rst b/docs/modules/index.rst index e3a13499..5f9d0e0d 100644 --- a/docs/modules/index.rst +++ b/docs/modules/index.rst @@ -15,3 +15,4 @@ This is the API Reference for tpcp. Parameter Optimization Validation and Scoring Parallel Helpers + Testing diff --git a/docs/modules/testing.rst b/docs/modules/testing.rst index 7c7c4a09..677b21d0 100644 --- a/docs/modules/testing.rst +++ b/docs/modules/testing.rst @@ -11,8 +11,8 @@ Classes .. currentmodule:: tpcp.testing .. autosummary:: - :toctree: generated/testing - :template: class.rst + :toctree: generated/testing + :template: class.rst TestAlgorithmMixin PyTestSnapshotTest \ No newline at end of file diff --git a/tpcp/testing/__init__.py b/tpcp/testing/__init__.py index 40dee832..3b254b3a 100644 --- a/tpcp/testing/__init__.py +++ b/tpcp/testing/__init__.py @@ -1,4 +1,5 @@ """Helper for testing of algorithms and pipelines implemented in tpcp.""" from tpcp.testing._algorithm_test_mixin import TestAlgorithmMixin +from tpcp.testing._regression_utils import PyTestSnapshotTest -__all__ = ["TestAlgorithmMixin"] +__all__ = ["TestAlgorithmMixin", "PyTestSnapshotTest"] diff --git a/tpcp/testing/_algorithm_test_mixin.py b/tpcp/testing/_algorithm_test_mixin.py index c39289b3..0bcf28f6 100644 --- a/tpcp/testing/_algorithm_test_mixin.py +++ b/tpcp/testing/_algorithm_test_mixin.py @@ -38,8 +38,8 @@ class that should be tested. In some very specific cases you might want to ignore some parameters in the docstring tests. For this, set the _IGNORED_NAMES attribute to a tuple of parameter names that should be ignored. - Example - ------- + Examples + -------- >>> class TestMyAlgorithm(TestAlgorithmMixin): ... ALGORITHM_CLASS = MyAlgorithm @@ -66,7 +66,7 @@ def after_action_instance(self, **kwargs) -> BaseTpcpObjectObjT: # noqa: PT004 The returned algorithm instance should have the result attributes and the "other parameters" (i.e. the action method inputs) set. """ - raise NotImplementedError() + raise NotImplementedError def test_is_algorithm(self): """Test that the class is actually an algorithm.""" @@ -203,5 +203,6 @@ def test_nested_algo_marked_default(self): def test_passes_safe_action_checks(self, after_action_instance): """Test that the algorithm passes the safe action checks.""" return make_action_safe(get_action_method(after_action_instance))( - after_action_instance, **get_action_params(after_action_instance) + after_action_instance, + **get_action_params(after_action_instance), ) diff --git a/tpcp/testing/_regression_utils.py b/tpcp/testing/_regression_utils.py index 87814ca6..ad11d5ea 100644 --- a/tpcp/testing/_regression_utils.py +++ b/tpcp/testing/_regression_utils.py @@ -5,6 +5,7 @@ """ import re from pathlib import Path +from typing import Optional, Union import numpy as np import pandas as pd @@ -18,7 +19,7 @@ class SnapshotNotFoundError(Exception): class PyTestSnapshotTest: """Perform snapshot tests in pytest. - This supports standard datatypes like numpy arrays and pandas DataFrames. + This supports standard datatypes and scientific datatypes like numpy arrays and pandas DataFrames. To use this in your tests, add the following lines to your conftest.py: @@ -42,39 +43,57 @@ def pytest_addoption(parser): This will register the snapshot fixture that you can use in your tests. Further, it will register the `--snapshot-update` commandline flag, which you can use to update the snapshots. + To use the fixture in your tests, simply add it as a parameter to your test function: + + .. code-block:: python + + def test_my_test(snapshot): + result = my_calculation() + snapshot.assert_match(result, "my_result_1") + + This will store the result of `my_calculation()` in a snapshot file in a folder called `snapshot` in the same folder + as the test file. + The name of the snapshot file will be the name of the test function, suffixed with `_my_result_1`. + When the test is run again, the result will be compared to the stored snapshot. + + To update a snapshot, either delete the snapshot file and manually run the test again or run pytest with the + `--snapshot-update` flag. + """ - def __init__(self, request=None): + curr_snapshot: str + + def __init__(self, request=None) -> None: self.request = request self.curr_snapshot_number = 0 super().__init__() @property - def update(self): + def _update(self): return self.request.config.option.snapshot_update @property - def module(self): + def _module(self): return Path(self.request.node.fspath.strpath).parent @property - def snapshot_folder(self): - return self.module / "snapshot" + def _snapshot_folder(self): + return self._module / "snapshot" @property - def file_name_json(self): - return self.snapshot_folder / f"{self.test_name}.json" + def _file_name_json(self): + return self._snapshot_folder / f"{self._test_name}.json" @property - def file_name_csv(self): - return self.snapshot_folder / f"{self.test_name}.csv" + def _file_name_csv(self): + return self._snapshot_folder / f"{self._test_name}.csv" @property - def file_name_txt(self): - return self.snapshot_folder / f"{self.test_name}.txt" + def _file_name_txt(self): + return self._snapshot_folder / f"{self._test_name}.txt" @property - def test_name(self): + def _test_name(self): cls_name = getattr(self.request.node.cls, "__name__", "") flattened_node_name = re.sub(r"\s+", " ", self.request.node.name.replace(r"\n", " ")) return "{}{}_{}".format(f"{cls_name}." if cls_name else "", flattened_node_name, self.curr_snapshot) @@ -85,48 +104,75 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): pass - def store(self, value): - self.snapshot_folder.mkdir(parents=True, exist_ok=True) + def _store(self, value): + self._snapshot_folder.mkdir(parents=True, exist_ok=True) if isinstance(value, pd.DataFrame): - value.to_json(self.file_name_json, indent=4, orient="table") + value.to_json(self._file_name_json, indent=4, orient="table") elif isinstance(value, np.ndarray): - np.savetxt(self.file_name_csv, value, delimiter=",") + np.savetxt(self._file_name_csv, value, delimiter=",") elif isinstance(value, str): - with Path(self.file_name_txt).open("w") as f: + with Path(self._file_name_txt).open("w") as f: f.write(value) else: raise TypeError(f"The dtype {type(value)} is not supported for snapshot testing") - def retrieve(self, dtype): + def _retrieve(self, dtype): if dtype == pd.DataFrame: - filename = self.file_name_json + filename = self._file_name_json if not filename.is_file(): - raise SnapshotNotFoundError() + raise SnapshotNotFoundError return pd.read_json(filename, orient="table") if dtype == np.ndarray: - filename = self.file_name_csv + filename = self._file_name_csv if not filename.is_file(): - raise SnapshotNotFoundError() + raise SnapshotNotFoundError return np.genfromtxt(filename, delimiter=",") if dtype == str: - filename = self.file_name_txt + filename = self._file_name_txt if not filename.is_file(): - raise SnapshotNotFoundError() - with Path(self.file_name_txt).open() as f: + raise SnapshotNotFoundError + with Path(self._file_name_txt).open() as f: value = f.read() return value raise ValueError(f"The dtype {dtype} is not supported for snapshot testing") - def assert_match(self, value, name="", **kwargs): - self.curr_snapshot = name or self.curr_snapshot_number - if self.update: - self.store(value) + def assert_match(self, value: Union[str, pd.DataFrame, np.ndarray], name: Optional[str] = None, **kwargs): + """Assert that the value matches the snapshot. + + This compares the value with a stored snapshot of the same name. + If no snapshot exists, it will be created. + + The snapshot name is automatically generated from the test name and the `name` parameter passed to this + function. + If no name is passed, the name will be suffixed with a number, starting at 0. + If you have multiple snapshots in one test, we highly recommend to pass a name to this function. + Otherwise, changing the order of the snapshots will break your tests. + + Parameters + ---------- + value + The value to compare with the snapshot. + We support strings, numpy arrays and pandas DataFrames. + For other datatypes like floats or short lists, we recommend to just use the standard pytest assertions + and hardcode the expected value. + name + Optional name suffix of the snapshot-file. + If not provided the name will be suffixed with a number, starting at 0. + kwargs + Additional keyword arguments passed to the comparison function. + This is only supported for DataFrames and numpy arrays. + There they will be passed to `assert_frame_equal` and `assert_array_almost_equal` respectively. + + """ + self.curr_snapshot = name or str(self.curr_snapshot_number) + if self._update: + self._store(value) else: value_dtype = type(value) try: - prev_snapshot = self.retrieve(value_dtype) + prev_snapshot = self._retrieve(value_dtype) except SnapshotNotFoundError: - self.store(value) # first time this test has been seen + self._store(value) # first time this test has been seen except: raise else: