diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 58b25a9..2a60c35 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -35,10 +35,6 @@ jobs: os: macos-latest # macos builds sometimes get stuck starting "python -m tox". experimental: true - - python-version: 3.5 - # Latest os version that supports Python 3.5 - os: ubuntu-20.04 - experimental: false - python-version: 3.6 # Latest os version that supports Python 3.6 os: ubuntu-20.04 diff --git a/MANIFEST.in b/MANIFEST.in index 1ade348..a6e1af8 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,6 @@ include LICENSE include README.rst +include pytest_snapshot/py.typed recursive-exclude * __pycache__ recursive-exclude * *.py[co] diff --git a/pytest_snapshot/_utils.py b/pytest_snapshot/_utils.py index 1236317..68cef84 100644 --- a/pytest_snapshot/_utils.py +++ b/pytest_snapshot/_utils.py @@ -1,12 +1,17 @@ import os import re from pathlib import Path +from typing import Dict, List, Tuple, TypeVar, Union, cast import pytest SIMPLE_VERSION_REGEX = re.compile(r'([0-9]+)\.([0-9]+)\.([0-9]+)') ILLEGAL_FILENAME_CHARS = r'\/:*?"<>|' +_K = TypeVar("_K") +_V = TypeVar("_V") +_RecursiveDict = Dict[_K, Union["_RecursiveDict", _V]] + def shorten_path(path: Path) -> Path: """ @@ -44,7 +49,7 @@ def might_be_valid_filename(s: str) -> bool: ) -def simple_version_parse(version: str): +def simple_version_parse(version: str) -> Tuple[int, ...]: """ Returns a 3 tuple of the versions major, minor, and patch. Raises a value error if the version string is unsupported. @@ -75,7 +80,7 @@ def _pytest_expected_on_right() -> bool: return pytest_version >= (5, 4, 0) -def flatten_dict(d: dict): +def flatten_dict(d: _RecursiveDict[_K, _V]) -> List[Tuple[List[_K], _V]]: """ Returns the flattened dict representation of the given dict. @@ -91,22 +96,24 @@ def flatten_dict(d: dict): [(['a'], 1), (['b', 'c'], 2)] """ assert type(d) is dict - result = [] - _flatten_dict(d, result, []) + result: List[Tuple[List[_K], _V]] = [] + _flatten_dict(d, result, []) # type: ignore[misc] return result -def _flatten_dict(obj, result, prefix): - if type(obj) is dict: - for k, v in obj.items(): - prefix.append(k) - _flatten_dict(v, result, prefix) - prefix.pop() - else: - result.append((list(prefix), obj)) +def _flatten_dict( + obj: _RecursiveDict[_K, _V], result: List[Tuple[List[_K], _V]], prefix: List[_K] +) -> None: + for k, v in obj.items(): + prefix.append(k) + if type(v) is dict: + _flatten_dict(cast(_RecursiveDict[_K, _V], v), result, prefix) + else: + result.append((list(prefix), cast(_V, v))) + prefix.pop() -def flatten_filesystem_dict(d): +def flatten_filesystem_dict(d: _RecursiveDict[str, _V]) -> Dict[str, _V]: """ Returns the flattened dict of a nested dictionary structure describing a filesystem. diff --git a/pytest_snapshot/plugin.py b/pytest_snapshot/plugin.py index d9e6347..02b6eed 100644 --- a/pytest_snapshot/plugin.py +++ b/pytest_snapshot/plugin.py @@ -2,17 +2,27 @@ import os import re from pathlib import Path -from typing import Union +from typing import Any, AnyStr, Callable, Iterator, List, Optional, Tuple, Union import pytest -import _pytest.python -from pytest_snapshot._utils import shorten_path, get_valid_filename, _pytest_expected_on_right, flatten_filesystem_dict +try: + from pytest import Parser as _Parser +except ImportError: + from _pytest.config.argparsing import Parser as _Parser + +try: + from pytest import FixtureRequest as _FixtureRequest +except ImportError: + from _pytest.fixtures import FixtureRequest as _FixtureRequest + +from pytest_snapshot._utils import shorten_path, get_valid_filename, _pytest_expected_on_right +from pytest_snapshot._utils import flatten_filesystem_dict, _RecursiveDict PARAMETRIZED_TEST_REGEX = re.compile(r'^.*?\[(.*)]$') -def pytest_addoption(parser): +def pytest_addoption(parser: _Parser) -> None: group = parser.getgroup('snapshot') group.addoption( '--snapshot-update', @@ -27,7 +37,9 @@ def pytest_addoption(parser): @pytest.fixture -def snapshot(request): +def snapshot(request: _FixtureRequest) -> Iterator["Snapshot"]: + # FIXME Properly handle different node type + assert isinstance(request.node, pytest.Function) default_snapshot_dir = _get_default_snapshot_dir(request.node) with Snapshot(request.config.option.snapshot_update, @@ -36,7 +48,7 @@ def snapshot(request): yield snapshot -def _assert_equal(value, snapshot) -> None: +def _assert_equal(value: AnyStr, snapshot: AnyStr) -> None: if _pytest_expected_on_right(): assert value == snapshot else: @@ -68,12 +80,12 @@ def _file_decode(data: bytes) -> str: class Snapshot: - _snapshot_update = None # type: bool - _allow_snapshot_deletion = None # type: bool - _created_snapshots = None # type: List[Path] - _updated_snapshots = None # type: List[Path] - _snapshots_to_delete = None # type: List[Path] - _snapshot_dir = None # type: Path + _snapshot_update: bool + _allow_snapshot_deletion: bool + _created_snapshots: List[Path] + _updated_snapshots: List[Path] + _snapshots_to_delete: List[Path] + _snapshot_dir: Path def __init__(self, snapshot_update: bool, allow_snapshot_deletion: bool, snapshot_dir: Path): self._snapshot_update = snapshot_update @@ -83,10 +95,10 @@ def __init__(self, snapshot_update: bool, allow_snapshot_deletion: bool, snapsho self._updated_snapshots = [] self._snapshots_to_delete = [] - def __enter__(self): + def __enter__(self) -> "Snapshot": return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, *_: Any) -> None: if self._created_snapshots or self._updated_snapshots or self._snapshots_to_delete: message_lines = ['Snapshot directory was modified: {}'.format(shorten_path(self.snapshot_dir)), ' (verify that the changes are expected before committing them to version control)'] @@ -112,14 +124,14 @@ def __exit__(self, exc_type, exc_val, exc_tb): pytest.fail('\n'.join(message_lines), pytrace=False) @property - def snapshot_dir(self): + def snapshot_dir(self) -> Path: return self._snapshot_dir @snapshot_dir.setter - def snapshot_dir(self, value): + def snapshot_dir(self, value: Union[str, 'os.PathLike[str]']) -> None: self._snapshot_dir = Path(value).absolute() - def _snapshot_path(self, snapshot_name: Union[str, Path]) -> Path: + def _snapshot_path(self, snapshot_name: Union[str, 'os.PathLike[str]']) -> Path: """ Returns the absolute path to the given snapshot. """ @@ -135,7 +147,11 @@ def _snapshot_path(self, snapshot_name: Union[str, Path]) -> Path: return snapshot_path - def _get_compare_encode_decode(self, value: Union[str, bytes]): + def _get_compare_encode_decode(self, value: AnyStr) -> Tuple[ + Callable[[AnyStr, AnyStr], None], + Callable[[AnyStr], bytes], + Callable[[bytes], AnyStr] + ]: """ Returns a 3-tuple of a compare function, an encoding function, and a decoding function. @@ -147,11 +163,12 @@ def _get_compare_encode_decode(self, value: Union[str, bytes]): if isinstance(value, str): return _assert_equal, _file_encode, _file_decode elif isinstance(value, bytes): - return _assert_equal, lambda x: x, lambda x: x + noop: Callable[[bytes], bytes] = lambda x: x + return _assert_equal, noop, noop else: raise TypeError('value must be str or bytes') - def assert_match(self, value: Union[str, bytes], snapshot_name: Union[str, Path]): + def assert_match(self, value: AnyStr, snapshot_name: Union[str, 'os.PathLike[str]']) -> None: """ Asserts that ``value`` equals the current value of the snapshot with the given ``snapshot_name``. @@ -185,6 +202,7 @@ def assert_match(self, value: Union[str, bytes], snapshot_name: Union[str, Path] else: if encoded_expected_value is not None: expected_value = decode(encoded_expected_value) + snapshot_diff_msg: Optional[str] try: compare(value, expected_value) except AssertionError as e: @@ -202,7 +220,11 @@ def assert_match(self, value: Union[str, bytes], snapshot_name: Union[str, Path] "snapshot {} doesn't exist. (run pytest with --snapshot-update to create it)".format( shorten_path(snapshot_path))) - def assert_match_dir(self, dir_dict: dict, snapshot_dir_name: Union[str, Path]): + def assert_match_dir( + self, + dir_dict: _RecursiveDict[str, Union[bytes, str]], + snapshot_dir_name: Union[str, 'os.PathLike[str]'] + ) -> None: """ Asserts that the values in dir_dict equal the current values in the given snapshot directory. @@ -214,7 +236,7 @@ def assert_match_dir(self, dir_dict: dict, snapshot_dir_name: Union[str, Path]): raise TypeError('dir_dict must be a dictionary') snapshot_dir_path = self._snapshot_path(snapshot_dir_name) - values_by_filename = flatten_filesystem_dict(dir_dict) + values_by_filename = flatten_filesystem_dict(dir_dict) # type: ignore[misc] if snapshot_dir_path.is_dir(): existing_names = {p.relative_to(snapshot_dir_path).as_posix() for p in snapshot_dir_path.rglob('*') if p.is_file()} @@ -242,10 +264,10 @@ def assert_match_dir(self, dir_dict: dict, snapshot_dir_name: Union[str, Path]): # Call assert_match to add, update, or assert equality for all snapshot files in the directory. for name, value in values_by_filename.items(): - self.assert_match(value, snapshot_dir_path.joinpath(name)) + self.assert_match(value, snapshot_dir_path.joinpath(name)) # pyright: ignore -def _get_default_snapshot_dir(node: _pytest.python.Function) -> Path: +def _get_default_snapshot_dir(node: pytest.Function) -> Path: """ Returns the default snapshot directory for the pytest test. """ diff --git a/pytest_snapshot/py.typed b/pytest_snapshot/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/setup.cfg b/setup.cfg index f26e9b5..8bc2336 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,6 +30,8 @@ classifiers = [options] packages = pytest_snapshot +zip_safe = False +include_package_data = True python_requires = >=3.5 install_requires = pytest >= 3.0.0 diff --git a/tox.ini b/tox.ini index b4a1c95..f548c16 100644 --- a/tox.ini +++ b/tox.ini @@ -3,10 +3,12 @@ envlist = # Pytest <6.2.5 not supported on Python >=3.10 py{36,37,38,39}-pytest{3,4,5}-coverage - py{35,36,37,38,39,310,311,312,3}-pytest{6,}-coverage + py{36,37,38,39,310,311,312,3}-pytest{6,}-coverage # Coverage is slow in pypy pypy3-pytest{6,} flake8 + pyright + mypy [testenv] deps = @@ -32,18 +34,27 @@ skip_install = true deps = flake8 commands = flake8 pytest_snapshot setup.py tests +[testenv:pyright] +deps = pyright +commands = pyright --verifytypes pytest_snapshot --ignoreexternal + +[testenv:mypy] +deps = + mypy + py +commands = mypy -p pytest_snapshot + [flake8] max-line-length = 120 [gh-actions] python = - 3.5: py35 3.6: py36 3.7: py37 3.8: py38 3.9: py39 3.10: py310 3.11: py311 - 3.12: py312, flake8 + 3.12: py312, flake8, mypy, pyright 3: py3 pypy-3.10: pypy3