Skip to content

Commit

Permalink
Added documentation for snapshot test
Browse files Browse the repository at this point in the history
  • Loading branch information
AKuederle committed Aug 30, 2023
1 parent 1ba1bad commit 35ccb2f
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 39 deletions.
1 change: 1 addition & 0 deletions docs/modules/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ This is the API Reference for tpcp.
Parameter Optimization <optimize>
Validation and Scoring <validate>
Parallel Helpers <parallel>
Testing <testing>
4 changes: 2 additions & 2 deletions docs/modules/testing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ Classes
.. currentmodule:: tpcp.testing

.. autosummary::
:toctree: generated/testing
:template: class.rst
:toctree: generated/testing
:template: class.rst

TestAlgorithmMixin
PyTestSnapshotTest
3 changes: 2 additions & 1 deletion tpcp/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Helper for testing of algorithms and pipelines implemented in tpcp."""
from tpcp.testing._algorithm_test_mixin import TestAlgorithmMixin
from tpcp.testing._regression_utils import PyTestSnapshotTest

__all__ = ["TestAlgorithmMixin"]
__all__ = ["TestAlgorithmMixin", "PyTestSnapshotTest"]
9 changes: 5 additions & 4 deletions tpcp/testing/_algorithm_test_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ class that should be tested.
In some very specific cases you might want to ignore some parameters in the docstring tests.
For this, set the _IGNORED_NAMES attribute to a tuple of parameter names that should be ignored.
Example
-------
Examples
--------
>>> class TestMyAlgorithm(TestAlgorithmMixin):
... ALGORITHM_CLASS = MyAlgorithm
Expand All @@ -66,7 +66,7 @@ def after_action_instance(self, **kwargs) -> BaseTpcpObjectObjT: # noqa: PT004
The returned algorithm instance should have the result attributes and the "other parameters" (i.e. the action
method inputs) set.
"""
raise NotImplementedError()
raise NotImplementedError

def test_is_algorithm(self):
"""Test that the class is actually an algorithm."""
Expand Down Expand Up @@ -203,5 +203,6 @@ def test_nested_algo_marked_default(self):
def test_passes_safe_action_checks(self, after_action_instance):
"""Test that the algorithm passes the safe action checks."""
return make_action_safe(get_action_method(after_action_instance))(
after_action_instance, **get_action_params(after_action_instance)
after_action_instance,
**get_action_params(after_action_instance),
)
110 changes: 78 additions & 32 deletions tpcp/testing/_regression_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
import re
from pathlib import Path
from typing import Optional, Union

import numpy as np
import pandas as pd
Expand All @@ -18,7 +19,7 @@ class SnapshotNotFoundError(Exception):
class PyTestSnapshotTest:
"""Perform snapshot tests in pytest.
This supports standard datatypes like numpy arrays and pandas DataFrames.
This supports standard datatypes and scientific datatypes like numpy arrays and pandas DataFrames.
To use this in your tests, add the following lines to your conftest.py:
Expand All @@ -42,39 +43,57 @@ def pytest_addoption(parser):
This will register the snapshot fixture that you can use in your tests.
Further, it will register the `--snapshot-update` commandline flag, which you can use to update the snapshots.
To use the fixture in your tests, simply add it as a parameter to your test function:
.. code-block:: python
def test_my_test(snapshot):
result = my_calculation()
snapshot.assert_match(result, "my_result_1")
This will store the result of `my_calculation()` in a snapshot file in a folder called `snapshot` in the same folder
as the test file.
The name of the snapshot file will be the name of the test function, suffixed with `_my_result_1`.
When the test is run again, the result will be compared to the stored snapshot.
To update a snapshot, either delete the snapshot file and manually run the test again or run pytest with the
`--snapshot-update` flag.
"""

def __init__(self, request=None):
curr_snapshot: str

def __init__(self, request=None) -> None:
self.request = request
self.curr_snapshot_number = 0
super().__init__()

@property
def update(self):
def _update(self):
return self.request.config.option.snapshot_update

@property
def module(self):
def _module(self):
return Path(self.request.node.fspath.strpath).parent

@property
def snapshot_folder(self):
return self.module / "snapshot"
def _snapshot_folder(self):
return self._module / "snapshot"

@property
def file_name_json(self):
return self.snapshot_folder / f"{self.test_name}.json"
def _file_name_json(self):
return self._snapshot_folder / f"{self._test_name}.json"

@property
def file_name_csv(self):
return self.snapshot_folder / f"{self.test_name}.csv"
def _file_name_csv(self):
return self._snapshot_folder / f"{self._test_name}.csv"

@property
def file_name_txt(self):
return self.snapshot_folder / f"{self.test_name}.txt"
def _file_name_txt(self):
return self._snapshot_folder / f"{self._test_name}.txt"

@property
def test_name(self):
def _test_name(self):
cls_name = getattr(self.request.node.cls, "__name__", "")
flattened_node_name = re.sub(r"\s+", " ", self.request.node.name.replace(r"\n", " "))
return "{}{}_{}".format(f"{cls_name}." if cls_name else "", flattened_node_name, self.curr_snapshot)
Expand All @@ -85,48 +104,75 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
pass

def store(self, value):
self.snapshot_folder.mkdir(parents=True, exist_ok=True)
def _store(self, value):
self._snapshot_folder.mkdir(parents=True, exist_ok=True)
if isinstance(value, pd.DataFrame):
value.to_json(self.file_name_json, indent=4, orient="table")
value.to_json(self._file_name_json, indent=4, orient="table")
elif isinstance(value, np.ndarray):
np.savetxt(self.file_name_csv, value, delimiter=",")
np.savetxt(self._file_name_csv, value, delimiter=",")
elif isinstance(value, str):
with Path(self.file_name_txt).open("w") as f:
with Path(self._file_name_txt).open("w") as f:
f.write(value)
else:
raise TypeError(f"The dtype {type(value)} is not supported for snapshot testing")

def retrieve(self, dtype):
def _retrieve(self, dtype):
if dtype == pd.DataFrame:
filename = self.file_name_json
filename = self._file_name_json
if not filename.is_file():
raise SnapshotNotFoundError()
raise SnapshotNotFoundError
return pd.read_json(filename, orient="table")
if dtype == np.ndarray:
filename = self.file_name_csv
filename = self._file_name_csv
if not filename.is_file():
raise SnapshotNotFoundError()
raise SnapshotNotFoundError
return np.genfromtxt(filename, delimiter=",")
if dtype == str:
filename = self.file_name_txt
filename = self._file_name_txt
if not filename.is_file():
raise SnapshotNotFoundError()
with Path(self.file_name_txt).open() as f:
raise SnapshotNotFoundError
with Path(self._file_name_txt).open() as f:
value = f.read()
return value
raise ValueError(f"The dtype {dtype} is not supported for snapshot testing")

def assert_match(self, value, name="", **kwargs):
self.curr_snapshot = name or self.curr_snapshot_number
if self.update:
self.store(value)
def assert_match(self, value: Union[str, pd.DataFrame, np.ndarray], name: Optional[str] = None, **kwargs):
"""Assert that the value matches the snapshot.
This compares the value with a stored snapshot of the same name.
If no snapshot exists, it will be created.
The snapshot name is automatically generated from the test name and the `name` parameter passed to this
function.
If no name is passed, the name will be suffixed with a number, starting at 0.
If you have multiple snapshots in one test, we highly recommend to pass a name to this function.
Otherwise, changing the order of the snapshots will break your tests.
Parameters
----------
value
The value to compare with the snapshot.
We support strings, numpy arrays and pandas DataFrames.
For other datatypes like floats or short lists, we recommend to just use the standard pytest assertions
and hardcode the expected value.
name
Optional name suffix of the snapshot-file.
If not provided the name will be suffixed with a number, starting at 0.
kwargs
Additional keyword arguments passed to the comparison function.
This is only supported for DataFrames and numpy arrays.
There they will be passed to `assert_frame_equal` and `assert_array_almost_equal` respectively.
"""
self.curr_snapshot = name or str(self.curr_snapshot_number)
if self._update:
self._store(value)
else:
value_dtype = type(value)
try:
prev_snapshot = self.retrieve(value_dtype)
prev_snapshot = self._retrieve(value_dtype)
except SnapshotNotFoundError:
self.store(value) # first time this test has been seen
self._store(value) # first time this test has been seen
except:
raise
else:
Expand Down

0 comments on commit 35ccb2f

Please sign in to comment.