diff --git a/.bumpversion.cfg b/.bumpversion.cfg deleted file mode 100644 index 8c316abdb..000000000 --- a/.bumpversion.cfg +++ /dev/null @@ -1,7 +0,0 @@ -[bumpversion] -current_version = 0.1.1 -commit = True -tag = True -parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) -serialize = - {major}.{minor}.{patch} diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index ae21f6b53..d9a2cbf41 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -19,7 +19,7 @@ jobs: strategy: fail-fast: false matrix: - # TODO(michalk8): in the future, lint the docs + # TODO(michalk8): enable in the future lint-kind: [code] # , docs] steps: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index cf790bb84..c57a981c1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -19,22 +19,28 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python: [3.8] + python: ["3.8", "3.10"] + include: + - os: macos-latest + python: "3.9" + steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} - name: Install pip dependencies - # TODO(michalk8): remove tox-gh dependency, update tox.ini run: | python -m pip install --upgrade pip - pip install tox tox-gh-actions + pip install tox + - name: Test run: | - tox -vv + tox -e py-${{ matrix.python }} + env: + PYTEST_ADDOPTS: -vv -n 2 - name: Upload coverage uses: codecov/codecov-action@v3 diff --git a/.mypy.ini b/.mypy.ini deleted file mode 100644 index b1830176c..000000000 --- a/.mypy.ini +++ /dev/null @@ -1,35 +0,0 @@ -[mypy] -mypy_path = moscot -python_version = 3.9 -plugins = numpy.typing.mypy_plugin - -ignore_errors = False - -warn_redundant_casts = True -warn_unused_configs = True -warn_unused_ignores = True - -disallow_untyped_calls = False -disallow_untyped_defs = True -disallow_incomplete_defs = True -disallow_any_generics = True - -strict_optional = True -strict_equality = True -warn_return_any = False -warn_unreachable = False -check_untyped_defs = True -no_implicit_optional = True -no_implicit_reexport = True -no_warn_no_return = True - -show_error_codes = True -show_column_numbers = True -error_summary = True -ignore_missing_imports = True - -disable_error_code = assignment, comparison-overlap, no-untyped-def - -[mypy-tests.*] -ignore_errors = True -disable_error_code = assignment, comparison-overlap diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6ab1412f1..cefb6e97e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,20 +4,21 @@ default_language_version: default_stages: - commit - push -minimum_pre_commit_version: 2.14.0 +minimum_pre_commit_version: 3.0.0 repos: - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.991 + rev: v1.1.1 hooks: - id: mypy additional_dependencies: [numpy>=1.21.0, jax] + files: ^src - repo: https://github.com/psf/black rev: 23.1.0 hooks: - id: black additional_dependencies: [toml] - repo: https://github.com/pre-commit/mirrors-prettier - rev: v3.0.0-alpha.4 + rev: v3.0.0-alpha.6 hooks: - id: prettier language_version: system @@ -27,95 +28,42 @@ repos: - id: isort additional_dependencies: [toml] args: [--order-by-type] - - repo: https://github.com/asottile/yesqa - rev: v1.4.0 - hooks: - - id: yesqa - additional_dependencies: - [ - flake8-tidy-imports, - flake8-docstrings, - flake8-rst-docstrings, - flake8-comprehensions, - flake8-bugbear, - flake8-blind-except, - flake8-builtins, - flake8-pytest-style, - flake8-string-format, - ] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: - - id: detect-private-key - id: check-merge-conflict - id: check-ast - - id: check-symlinks - id: check-added-large-files - - id: check-executables-have-shebangs - - id: fix-encoding-pragma - args: [--remove] - id: end-of-file-fixer - id: mixed-line-ending args: [--fix=lf] - id: trailing-whitespace - exclude: ^.bumpversion.cfg$ - - id: check-case-conflict - id: check-docstring-first - id: check-yaml - id: check-toml - - id: requirements-txt-fixer - - repo: https://github.com/myint/autoflake - rev: v2.0.1 - hooks: - - id: autoflake - args: - [ - --in-place, - --remove-all-unused-imports, - --remove-unused-variable, - --ignore-init-module-imports, - ] - - repo: https://github.com/pycqa/flake8.git - rev: 6.0.0 - hooks: - - id: flake8 - additional_dependencies: - [ - flake8-tidy-imports, - flake8-docstrings, - flake8-rst-docstrings, - flake8-comprehensions, - flake8-bugbear, - flake8-blind-except, - flake8-builtins, - flake8-pytest-style, - flake8-string-format, - ] - - repo: https://github.com/jumanjihouse/pre-commit-hooks - rev: 3.0.0 + - repo: https://github.com/asottile/pyupgrade + rev: v3.3.1 hooks: - - id: script-must-have-extension - name: Check executable files use .sh extension - types: [shell, executable] + - id: pyupgrade + args: [--py3-plus, --py38-plus, --keep-runtime-typing] - repo: https://github.com/asottile/blacken-docs rev: 1.13.0 hooks: - id: blacken-docs additional_dependencies: [black==23.1.0] - - repo: https://github.com/asottile/pyupgrade - rev: v3.3.1 + - repo: https://github.com/rstcheck/rstcheck + rev: v6.1.1 hooks: - - id: pyupgrade - args: [--py3-plus, --py38-plus, --keep-runtime-typing] - - repo: https://github.com/pre-commit/pygrep-hooks - rev: v1.10.0 - hooks: - - id: python-no-eval - - id: python-check-blanket-noqa - - id: rst-backticks - - id: rst-directive-colons - - id: rst-inline-touching-normal + - id: rstcheck + additional_dependencies: [tomli] + args: [--config=pyproject.toml] - repo: https://github.com/PyCQA/doc8 rev: v1.1.1 hooks: - id: doc8 + - repo: https://github.com/charliermarsh/ruff-pre-commit + # Ruff version. + rev: v0.0.252 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] diff --git a/.prettierignore b/.prettierignore deleted file mode 100644 index a188e0692..000000000 --- a/.prettierignore +++ /dev/null @@ -1 +0,0 @@ -docs/* diff --git a/.rstcheck.cfg b/.rstcheck.cfg deleted file mode 100644 index e0e4d9ce4..000000000 --- a/.rstcheck.cfg +++ /dev/null @@ -1,3 +0,0 @@ -[rstcheck] -ignore_messages=Unknown target name:.*|No (directive|role) entry for "(auto)?(bibliography|nbgallery|class|method|meth||property|function|func|mod|module|attr|cite)" in module "docutils\.parsers\.rst\.languages\.en"\. -report=info diff --git a/MANIFEST.in b/MANIFEST.in index ec6ae01e5..d4dbf536d 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,6 +2,5 @@ include src/moscot/utils/data/allTFs_dmel.txt include src/moscot/utils/data/allTFs_hg38.txt include src/moscot/utils/data/allTFs_mm.txt prune docs -prune resources +prune tests prune .github -prune tests/data diff --git a/README.rst b/README.rst index f008442cf..8cd9cb596 100644 --- a/README.rst +++ b/README.rst @@ -11,7 +11,7 @@ single-cell genomics. It can be used for - prototyping of new OT models in single-cell genomics **moscot** is powered by -`OTT <https://ott-jax.readthedocs.io/en/latest/>`_ which is a JAX-based Optimal +`OTT <https://ott-jax.readthedocs.io>`_ which is a JAX-based Optimal Transport toolkit that supports just-in-time compilation, GPU acceleration, automatic differentiation and linear memory complexity for OT problems. @@ -39,7 +39,7 @@ If used with GPU, additionally run:: Resources --------- -Please have a look at our `documentation <https://moscot.readthedocs.io/en/latest/index.html/>`_ +Please have a look at our `documentation <https://moscot.readthedocs.io>`_ Reference --------- diff --git a/.github/.codecov.yml b/codecov.yml similarity index 52% rename from .github/.codecov.yml rename to codecov.yml index c729b3eed..fbe5cfa5b 100644 --- a/.github/.codecov.yml +++ b/codecov.yml @@ -1,18 +1,18 @@ codecov: - require_ci_to_pass: false + require_ci_to_pass: true strict_yaml_branch: main coverage: - range: 90..100 + range: "80...100" status: project: default: - target: 1 + target: 75% + threshold: 1% patch: off comment: - layout: reach, diff, files + layout: "reach, diff, files" behavior: default require_changes: true - branches: - - main + branches: [main] diff --git a/docs/Makefile b/docs/Makefile index b08c1b5bb..8948b67df 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -6,7 +6,7 @@ SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = . -BUILDDIR = build +BUILDDIR = _build # Put it first so that "make" without argument is like "make help". help: diff --git a/docs/_templates/autosummary/base.rst b/docs/_templates/autosummary/base.rst deleted file mode 100644 index e81ca6450..000000000 --- a/docs/_templates/autosummary/base.rst +++ /dev/null @@ -1,5 +0,0 @@ -:github_url: {{ fullname }} - -{% extends "!autosummary/base.rst" %} - -.. http://www.sphinx-doc.org/en/stable/ext/autosummary.html#customizing-templates diff --git a/docs/_templates/autosummary/class.rst b/docs/_templates/autosummary/class.rst index 8529158f5..5a4db15dd 100644 --- a/docs/_templates/autosummary/class.rst +++ b/docs/_templates/autosummary/class.rst @@ -1,33 +1,29 @@ -:github_url: {{ fullname }} - -{{ fullname | escape | underline}} +{{ fullname | escape | underline }} .. currentmodule:: {{ module }} .. autoclass:: {{ objname }} + {% block methods %} + {%- if methods %} + .. rubric:: {{ _('Methods') }} - {% block attributes %} - {% if attributes %} - .. rubric:: Attributes - - .. autosummary:: - :toctree: . - {% for item in attributes %} - ~{{ fullname }}.{{ item }} - {%- endfor %} - {% endif %} - {% endblock %} - - {% block methods %} - {% if methods %} - .. rubric:: Methods + .. autosummary:: + :toctree: . + {% for item in methods %} + {%- if item not in ['__init__', 'tree_flatten', 'tree_unflatten', 'bind'] %} + ~{{ name }}.{{ item }} + {%- endif %} + {%- endfor %} + {%- endif %} + {%- endblock %} + {% block attributes %} + {%- if attributes %} + .. rubric:: {{ _('Attributes') }} - .. autosummary:: - :toctree: . - {% for item in methods %} - {%- if item != '__init__' %} - ~{{ fullname }}.{{ item }} - {%- endif -%} - {%- endfor %} - {% endif %} - {% endblock %} + .. autosummary:: + :toctree: . + {% for item in attributes %} + ~{{ name }}.{{ item }} + {%- endfor %} + {%- endif %} + {% endblock %} diff --git a/docs/_templates/autosummary/function.rst b/docs/_templates/autosummary/function.rst deleted file mode 100644 index 097114364..000000000 --- a/docs/_templates/autosummary/function.rst +++ /dev/null @@ -1,5 +0,0 @@ -:github_url: {{ fullname }} - -{{ fullname | escape | underline}} - -.. autofunction:: {{ fullname }} diff --git a/docs/_templates/breadcrumbs.html b/docs/_templates/breadcrumbs.html deleted file mode 100644 index 4ecb013f8..000000000 --- a/docs/_templates/breadcrumbs.html +++ /dev/null @@ -1,4 +0,0 @@ -{%- extends "sphinx_rtd_theme/breadcrumbs.html" %} - -{% block breadcrumbs_aside %} -{% endblock %} diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 7b48e8428..861d2392d 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -1,9 +1,6 @@ Datasets ######## -Datasets -~~~~~~~~ - .. currentmodule:: moscot.datasets .. autosummary:: diff --git a/docs/api/developer.rst b/docs/api/developer.rst index 841da4bbb..9a2d8baa2 100644 --- a/docs/api/developer.rst +++ b/docs/api/developer.rst @@ -1,9 +1,8 @@ Developer ######### - OTT Backend -~~~~~~~~~~~~ +~~~~~~~~~~~ .. autosummary:: :toctree: genapi diff --git a/docs/api/plotting.rst b/docs/api/plotting.rst index 35a0f8579..41294334a 100644 --- a/docs/api/plotting.rst +++ b/docs/api/plotting.rst @@ -1,9 +1,6 @@ Plotting ######## -Plotting -~~~~~~~~ - .. currentmodule:: moscot.plotting .. autosummary:: diff --git a/pyproject.toml b/pyproject.toml index 588b1dbbd..f5e147671 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", "Topic :: Scientific/Engineering :: Bio-Informatics", "Topic :: Scientific/Engineering :: Mathematics" ] @@ -43,9 +42,7 @@ maintainers = [ {name = "Giovanni Palla", email = "giovanni.palla@helmholtz-muenchen.de"}, {name = "Michal Klein", email = "michal.klein@helmholtz-muenchen.de"} ] -urls.Documentation = "https://moscot.readthedocs.io/" -urls.Source = "https://github.com/theislab/moscot" -urls.Home-page = "https://github.com/theislab/moscot" + dependencies = [ "numpy>=1.20.0", @@ -53,9 +50,9 @@ dependencies = [ "pandas>=1.4.0", "networkx>=2.6.3", # https://github.com/scverse/scanpy/issues/2411 - "matplotlib>=3.4.0,<3.7", + "matplotlib>=3.5.0", "anndata>=0.8.0", - "scanpy>=1.9.0", + "scanpy>=1.9.3", "wrapt>=1.13.2", "docrep>=0.3.2", "ott-jax>=0.4.0", @@ -67,8 +64,15 @@ spatial = [ "squidpy>=1.2.3" ] dev = [ - "tox>=3.24.0", - "pre-commit>=2.14.0" + "pre-commit>=3.0.0", + "tox>=4", +] +test = [ + "pytest>=7", + "pytest-xdist>=3", + "pytest-mock>=3.5.0", + "pytest-cov>=4", + "coverage[toml]>=7", ] docs = [ "sphinx>=5.1.1", @@ -82,40 +86,83 @@ docs = [ "sphinx_design>=0.3.0", ] -test = [ - "pytest>=7.1.2", - "pytest-mock>=3.5.0", - "pytest-cov>=4.0.0", -] +[project.urls] +Homepage = "https://github.com/theislab/moscot" +Download = "https://moscot.readthedocs.io/en/latest/installation.html" +"Bug Tracker" = "https://github.com/theislab/moscot/issues" +Documentation = "https://moscot.readthedocs.io" +"Source Code" = "https://github.com/theislab/moscot" [tool.setuptools] package-dir = {"" = "src"} -packages = {find = {where = ["src"], namespaces = false}} +include-package-data = true [tool.setuptools_scm] +[tool.ruff] +exclude = [ + ".git", + "__pycache__", + "build", + "docs/_build", + "dist" +] +ignore = [ + # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient + "E731", + # allow I, O, l as variable names -> I is the identity matrix, i, j, k, l is reasonable indexing notation + "E741", + # Missing docstring in public package + "D104", + # Missing docstring in public module + "D100", + # Missing docstring in __init__ + "D107", + # Missing docstring in magic method + "D105", +] +line-length = 120 +select = [ + "D", # flake8-docstrings + # TODO(michalk8): enable this in https://github.com/theislab/moscot/issues/483 + # "I", # isort + "E", # pycodestyle + "F", # pyflakes + "W", # pycodestyle + "Q", # flake8-quotes + "SIM", # flake8-simplify + "NPY", # NumPy-specific rules + "PT", # flake8-pytest-style + "B", # flake8-bugbear + "UP", # pyupgrade + "C4", # flake8-comprehensions + "BLE", # flake8-blind-except + "T20", # flake8-print + "RET", # flake8-raise +] +unfixable = ["B", "UP", "C4", "BLE", "T20", "RET"] +target-version = "py38" +[tool.ruff.per-file-ignores] +"tests/*" = ["D"] +"*/__init__.py" = ["F401"] +"docs/*" = ["D"] +[tool.ruff.pydocstyle] +convention = "numpy" +[tool.ruff.pyupgrade] +# Preserve types, even if a file imports `from __future__ import annotations`. +keep-runtime-typing = true +[tool.ruff.flake8-tidy-imports] +# Disallow all relative imports. +ban-relative-imports = "parents" +[tool.ruff.flake8-quotes] +inline-quotes = "double" + [tool.black] line-length = 120 target-version = ['py38'] include = '\.pyi?$' -exclude = ''' -( - /( - \.eggs - | \.git - | \.hg - | \.mypy_cache - | \.tox - | \.venv - | _build - | buck-out - | build - | dist - )/ - -) -''' +# TODO(michalk8): simplify in https://github.com/theislab/moscot/issues/483 [tool.isort] profile = "black" py_version = "38" @@ -145,6 +192,147 @@ force_alphabetical_sort_within_sections = true lexicographical = true [tool.pytest.ini_options] -markers = [ - "fast: marks tests as fask", +markers = ["fast: marks tests as fask"] +xfail_strict = true +filterwarnings = [ + "ignore:X.dtype being converted:FutureWarning", + "ignore:No data for colormapping:UserWarning", + "ignore:jax\\.experimental\\.pjit\\.PartitionSpec:DeprecationWarning", +] + +[tool.coverage.run] +branch = true +parallel = true +source = ["src/"] +omit = [ + "*/__init__.py", + "*/_version.py", +] + +[tool.coverage.report] +exclude_lines = [ + '\#.*pragma:\s*no.?cover', + "^if __name__ == .__main__.:$", + '^\s*raise AssertionError\b', + '^\s*raise NotImplementedError\b', + '^\s*return NotImplemented\b', ] +precision = 2 +show_missing = true +skip_empty = true +sort = "Miss" + +[tool.rstcheck] +ignore_directives = [ + "toctree", + "currentmodule", + "autosummary", + "automodule", + "autoclass", + "bibliography", + "card", +] + +[tool.mypy] +mypy_path = "$MYPY_CONFIG_FILE_DIR/src" +python_version = "3.9" +plugins = "numpy.typing.mypy_plugin" + +ignore_errors = false + +warn_redundant_casts = true +warn_unused_configs = true +warn_unused_ignores = true + +disallow_untyped_calls = false +disallow_untyped_defs = true +disallow_incomplete_defs = true +disallow_any_generics = true + +strict_optional = true +strict_equality = true +warn_return_any = false +warn_unreachable = false +check_untyped_defs = true +no_implicit_optional = true +no_implicit_reexport = true +no_warn_no_return = true + +show_error_codes = true +show_column_numbers = true +error_summary = true +ignore_missing_imports = true + +disable_error_code = ["assignment", "comparison-overlap", "no-untyped-def"] + +[tool.doc8] +max_line_length = 120 + +[tool.tox] +legacy_tox_ini = """ +[tox] +min_version = 4.0 +env_list = lint-code,py{3.8,3.9,3.10,3.11} +skip_missing_interpreters = true + +[testenv] +extras = test +pass_env = PYTEST_*,CI +commands = + python -m pytest {tty:--color=yes} {posargs: \ + --cov={env_site_packages_dir}{/}moscot --cov-config={tox_root}{/}pyproject.toml \ + --no-cov-on-fail --cov-report=xml --cov-report=term-missing:skip-covered} + +[testenv:lint-code] +description = Lint the code. +deps = pre-commit>=3.0.0 +skip_install = true +commands = + pre-commit run --all-files --show-diff-on-failure + +[testenv:lint-docs] +description = Lint the documentation. +deps = +extras = docs +ignore_errors = true +allowlist_externals = make +pass_env = PYENCHANT_LIBRARY_PATH +set_env = SPHINXOPTS = -W -q --keep-going +changedir = {tox_root}{/}docs +commands = + make linkcheck {posargs} + make spelling {posargs} + +[testenv:clean-docs] +description = Remove the documentation. +skip_install = true +changedir = {tox_root}{/}docs +allowlist_externals = make +commands = + make clean + +[testenv:build-docs] +description = Build the documentation. +use_develop = true +deps = +extras = docs +allowlist_externals = make +changedir = {tox_root}{/}docs +commands = + make html {posargs} +commands_post = + python -c 'import pathlib; print("Documentation is under:", pathlib.Path("{tox_root}") / "docs" / "_build" / "html" / "index.html")' + +[testenv:build-package] +description = Build the package. +deps = + build + twine +allowlist_externals = rm +commands = + rm -rf {tox_root}{/}dist + python -m build --sdist --wheel --outdir {tox_root}{/}dist{/} {posargs:} + python -m twine check {tox_root}{/}dist{/}* +commands_post = + python -c 'import pathlib; print(f"Package is under:", pathlib.Path("{tox_root}") / "dist")' +""" diff --git a/src/moscot/__init__.py b/src/moscot/__init__.py index 2e2e8f776..fd672cdb5 100644 --- a/src/moscot/__init__.py +++ b/src/moscot/__init__.py @@ -9,8 +9,8 @@ import moscot.problems try: - __version__ = metadata.version(__name__) md = metadata.metadata(__name__) + __version__ = md.get("version", "") __author__ = md.get("Author", "") __maintainer__ = md.get("Maintainer-email", "") except ImportError: diff --git a/src/moscot/_constants/_enum.py b/src/moscot/_constants/_enum.py index 42b0e48d9..522ecc83f 100644 --- a/src/moscot/_constants/_enum.py +++ b/src/moscot/_constants/_enum.py @@ -26,7 +26,7 @@ def wrapper(*args: Any, **kwargs: Any) -> "ErrorFormatterABC": if not issubclass(cls, ErrorFormatterABC): raise TypeError(f"Class `{cls}` must be subtype of `ErrorFormatterABC`.") - elif not len(cls.__members__): # type: ignore[attr-defined] + if not len(cls.__members__): # type: ignore[attr-defined] # empty enum, for class hierarchy return func @@ -41,7 +41,7 @@ def __call__(cls, *args: Any, **kwargs: Any) -> Any: def __new__(cls, clsname: str, superclasses: Tuple[type], attributedict: Dict[str, Any]) -> "ABCEnumMeta": res = super().__new__(cls, clsname, superclasses, attributedict) # type: ignore[arg-type] - res.__new__ = _pretty_raise_enum(res, res.__new__) # type: ignore[assignment,arg-type] + res.__new__ = _pretty_raise_enum(res, res.__new__) # type: ignore[method-assign,arg-type] return res diff --git a/src/moscot/_constants/_key.py b/src/moscot/_constants/_key.py index e5869ed53..de44b270f 100644 --- a/src/moscot/_constants/_key.py +++ b/src/moscot/_constants/_key.py @@ -1,4 +1,5 @@ from typing import Any, Set, List, Callable, Optional +import contextlib import numpy as np @@ -56,8 +57,8 @@ def __init__(self, adata: AnnData, n: Optional[int] = None, where: str = "obs"): self._keys: List[str] = [] def _generate_random_keys(self): - def generator(): - return f"RNG_COL_{np.random.randint(2 ** 16)}" + def generator() -> str: + return f"RNG_COL_{np.random.RandomState().randint(2 ** 16)}" where = getattr(self._adata, self._where) names: List[str] = [] @@ -76,13 +77,9 @@ def __enter__(self): return self._keys def __exit__(self, exc_type, exc_val, exc_tb): - for key in self._keys: - try: - getattr(self._adata, self._where).drop(key, axis="columns", inplace=True) - except KeyError: - pass - if self._where == "obs": - try: + with contextlib.suppress(KeyError): + for key in self._keys: + df = getattr(self._adata, self._where) + df.drop(key, axis="columns", inplace=True) + if self._where == "obs": del self._adata.uns[f"{key}_colors"] - except KeyError: - pass diff --git a/src/moscot/_utils.py b/src/moscot/_utils.py index 29ea56068..317437461 100644 --- a/src/moscot/_utils.py +++ b/src/moscot/_utils.py @@ -58,7 +58,6 @@ def parallelize( ------- The result depending on ``callable``, ``extractor`` and ``as_array``. """ - if show_progress_bar: try: from tqdm.auto import tqdm @@ -186,7 +185,6 @@ def _np_apply_along_axis(func1d, axis: int, arr: ArrayLike) -> ArrayLike: ------- The reduced array. """ - assert arr.ndim == 2 assert axis in [0, 1] diff --git a/src/moscot/backends/ott/_output.py b/src/moscot/backends/ott/_output.py index 953ec2979..7755d5878 100644 --- a/src/moscot/backends/ott/_output.py +++ b/src/moscot/backends/ott/_output.py @@ -80,8 +80,7 @@ def plot_costs( if save is not None: fig.savefig(save) - if return_fig: - return fig + return fig if return_fig else None @d.dedent def plot_errors( @@ -123,8 +122,7 @@ def plot_errors( if save is not None: fig.savefig(save) - if return_fig: - return fig + return fig if return_fig else None def _plot_lines( self, @@ -179,7 +177,7 @@ def to(self, device: Optional[Device_t] = None) -> "OTTOutput": try: device = jax.devices(device)[idx] except IndexError: - raise IndexError(f"Unable to fetch the device with `id={idx}`.") + raise IndexError(f"Unable to fetch the device with `id={idx}`.") from None return OTTOutput(jax.device_put(self._output, device)) @@ -205,4 +203,4 @@ def rank(self) -> int: return len(lin_output.g) if isinstance(lin_output, OTTLRSinkhornOutput) else -1 def _ones(self, n: int) -> ArrayLike: - return jnp.ones((n,)) + return jnp.ones((n,)) # type: ignore[return-value] diff --git a/src/moscot/backends/ott/_solver.py b/src/moscot/backends/ott/_solver.py index d51a04982..d30dd0ef4 100644 --- a/src/moscot/backends/ott/_solver.py +++ b/src/moscot/backends/ott/_solver.py @@ -102,7 +102,7 @@ def _assert2d(arr: Optional[ArrayLike], *, allow_reshape: bool = True) -> Option return None arr: ArrayLike = jnp.asarray(arr.A if issparse(arr) else arr) # type: ignore[attr-defined, no-redef] if allow_reshape and arr.ndim == 1: - return jnp.reshape(arr, (-1, 1)) + return jnp.reshape(arr, (-1, 1)) # type: ignore[return-value] if arr.ndim != 2: raise ValueError(f"Expected array to have 2 dimensions, found `{arr.ndim}`.") return arr diff --git a/src/moscot/datasets/_datasets.py b/src/moscot/datasets/_datasets.py index 0125cd98d..b4c7f1afb 100644 --- a/src/moscot/datasets/_datasets.py +++ b/src/moscot/datasets/_datasets.py @@ -95,7 +95,9 @@ def hspc( path: PathLike = "~/.cache/moscot/hspc.h5ad", **kwargs: Any, ) -> AnnData: # pragma: no cover - """Subsampled and processed data from the `NeurIPS Multimodal Single-Cell Integration Challenge \ + """CD34+ hematopoietic stem and progenitor cells from 4 healthy human donors. + + From the `NeurIPS Multimodal Single-Cell Integration Challenge <https://www.kaggle.com/competitions/open-problems-multimodal/data>`_. 4000 cells were randomly selected after filtering the multiome training data of the donor `31800`. diff --git a/src/moscot/plotting/_plotting.py b/src/moscot/plotting/_plotting.py index e34b67f9e..bb7739a19 100644 --- a/src/moscot/plotting/_plotting.py +++ b/src/moscot/plotting/_plotting.py @@ -18,7 +18,7 @@ @d_plotting.dedent def cell_transition( - inp: Union[AnnData, Tuple[AnnData, AnnData], CompoundProblem], + inp: Union[AnnData, Tuple[AnnData, AnnData], CompoundProblem], # type: ignore[type-arg] uns_key: str = PlottingKeys.CELL_TRANSITION, row_labels: Optional[str] = None, col_labels: Optional[str] = None, @@ -62,7 +62,9 @@ def cell_transition( try: _ = adata1.uns[AdataKeys.UNS][PlottingKeys.CELL_TRANSITION][key] except KeyError: - raise KeyError(f"No data found in `adata.uns[{AdataKeys.UNS!r}][{PlottingKeys.CELL_TRANSITION!r}][{key!r}]`.") + raise KeyError( + f"No data found in `adata.uns[{AdataKeys.UNS!r}][{PlottingKeys.CELL_TRANSITION!r}][{key!r}]`." + ) from None data = adata1.uns[AdataKeys.UNS][PlottingKeys.CELL_TRANSITION][key] fig = _heatmap( @@ -87,8 +89,7 @@ def cell_transition( cbar_kwargs=cbar_kwargs, **kwargs, ) - if return_fig: - return fig + return fig if return_fig else None @d_plotting.dedent @@ -140,7 +141,7 @@ def sankey( try: _ = adata.uns[AdataKeys.UNS][PlottingKeys.SANKEY][key] except KeyError: - raise KeyError(f"No data found in `adata.uns[{AdataKeys.UNS!r}][{PlottingKeys.SANKEY!r}][{key!r}]`.") + raise KeyError(f"No data found in `adata.uns[{AdataKeys.UNS!r}][{PlottingKeys.SANKEY!r}][{key!r}]`.") from None data = adata.uns[AdataKeys.UNS][PlottingKeys.SANKEY][key] fig = _sankey( @@ -160,13 +161,12 @@ def sankey( ) if save: fig.figure.savefig(save) - if return_fig: - return fig + return fig if return_fig else None @d_plotting.dedent def push( - inp: Union[AnnData, TemporalProblem, LineageProblem, CompoundProblem], + inp: Union[AnnData, TemporalProblem, LineageProblem, CompoundProblem], # type: ignore[type-arg] uns_key: Optional[str] = None, time_points: Optional[Sequence[float]] = None, basis: str = "umap", @@ -245,13 +245,12 @@ def push( suptitle_fontsize=suptitle_fontsize, **kwargs, ) - if return_fig: - return fig.figure + return fig.figure if return_fig else None @d_plotting.dedent def pull( - inp: Union[AnnData, TemporalProblem, LineageProblem, CompoundProblem], + inp: Union[AnnData, TemporalProblem, LineageProblem, CompoundProblem], # type: ignore[type-arg] uns_key: Optional[str] = None, time_points: Optional[Sequence[float]] = None, basis: str = "umap", @@ -330,5 +329,4 @@ def pull( suptitle_fontsize=suptitle_fontsize, **kwargs, ) - if return_fig: - return fig.figure + return fig.figure if return_fig else None diff --git a/src/moscot/plotting/_utils.py b/src/moscot/plotting/_utils.py index 1c6a3332c..e3223058b 100644 --- a/src/moscot/plotting/_utils.py +++ b/src/moscot/plotting/_utils.py @@ -345,7 +345,9 @@ def _contrasting_color(r: int, g: int, b: int) -> str: return "#000000" if r * 0.299 + g * 0.587 + b * 0.114 > 186 else "#ffffff" -def _input_to_adatas(inp: Union[AnnData, Tuple[AnnData, AnnData], CompoundProblem]) -> Tuple[AnnData, AnnData]: +def _input_to_adatas( + inp: Union[AnnData, Tuple[AnnData, AnnData], CompoundProblem] # type: ignore[type-arg] +) -> Tuple[AnnData, AnnData]: if isinstance(inp, CompoundProblem): return inp.adata, inp.adata if isinstance(inp, AnnData): @@ -509,6 +511,4 @@ def _create_col_colors(adata: AnnData, obs_col: str, subset: Union[str, List[str h, _, v = mcolors.rgb_to_hsv(mcolors.to_rgb(color)) end_color = mcolors.hsv_to_rgb([h, 1, v]) - col_cmap = mcolors.LinearSegmentedColormap.from_list("category_cmap", ["darkgrey", end_color]) - - return col_cmap + return mcolors.LinearSegmentedColormap.from_list("category_cmap", ["darkgrey", end_color]) diff --git a/src/moscot/problems/_utils.py b/src/moscot/problems/_utils.py index 8417e1c9a..8350d5e0f 100644 --- a/src/moscot/problems/_utils.py +++ b/src/moscot/problems/_utils.py @@ -27,7 +27,10 @@ def require_solution( @wrapt.decorator def require_prepare( - wrapped: Callable[[Any], Any], instance: "BaseCompoundProblem", args: Tuple[Any, ...], kwargs: Mapping[str, Any] + wrapped: Callable[[Any], Any], + instance: "BaseCompoundProblem", # type: ignore[type-arg] + args: Tuple[Any, ...], + kwargs: Mapping[str, Any], ) -> Any: """Check whether problem has been prepared.""" if instance.problems is None: @@ -96,12 +99,15 @@ def handle_joint_attr( "y_key": joint_attr["key"], } return xy, kwargs - if joint_attr.get("tag", None) == "cost_matrix": # if this is True we have custom cost matrix or moscot cost - if len(joint_attr) == 2 or kwargs.get("attr", None) == "obsp": # in this case we have a custom cost matrix - joint_attr.setdefault("cost", "custom") - joint_attr.setdefault("attr", "obsp") - kwargs["xy_callback"] = "cost-matrix" - kwargs.setdefault("xy_callback_kwargs", {"key": joint_attr["key"]}) + + # if this is True we have custom cost matrix or moscot cost - in this case we have a custom cost matrix + if joint_attr.get("tag", None) == "cost_matrix" and ( + len(joint_attr) == 2 or kwargs.get("attr", None) == "obsp" + ): + joint_attr.setdefault("cost", "custom") + joint_attr.setdefault("attr", "obsp") + kwargs["xy_callback"] = "cost-matrix" + kwargs.setdefault("xy_callback_kwargs", {"key": joint_attr["key"]}) kwargs.setdefault("xy_callback_kwargs", {}) return joint_attr, kwargs raise TypeError(f"Expected `joint_attr` to be either `str` or `dict`, found `{type(joint_attr)}`.") @@ -138,3 +144,4 @@ def handle_cost( y = dict(y) y["cost"] = cost["y"] return xy, x, y + raise TypeError(type(cost)) diff --git a/src/moscot/problems/base/_utils.py b/src/moscot/problems/base/_utils.py index 62f116174..4c4d65f6a 100644 --- a/src/moscot/problems/base/_utils.py +++ b/src/moscot/problems/base/_utils.py @@ -153,19 +153,6 @@ def _get_categories_from_adata( return adata[adata.obs[key] == key_value].obs[annotation_key] -def _get_problem_key( - source: Optional[Any] = None, # TODO(@MUCDK) using `K` induces circular import, resolve - target: Optional[Any] = None, # TODO(@MUCDK) using `K` induces circular import, resolve -) -> Tuple[Any, Any]: # TODO(@MUCDK) using `K` induces circular import, resolve - if source is not None and target is not None: - return (source, target) - elif source is None and target is not None: - return ("src", target) # TODO(@MUCDK) make package constant - elif source is not None and target is None: - return (source, "ref") # TODO(@MUCDK) make package constant - return ("src", "ref") - - def _order_transition_matrix_helper( tm: pd.DataFrame, rows_verified: List[str], @@ -189,43 +176,42 @@ def _order_transition_matrix( target_annotations_ordered: Optional[List[str]], forward: bool, ) -> pd.DataFrame: + # TODO(michalk8): simplify if target_annotations_ordered is not None or source_annotations_ordered is not None: if forward: - tm = _order_transition_matrix_helper( + return _order_transition_matrix_helper( tm=tm, rows_verified=source_annotations_verified, cols_verified=target_annotations_verified, row_order=source_annotations_ordered, col_order=target_annotations_ordered, - ) - else: - tm = _order_transition_matrix_helper( - tm=tm, - rows_verified=target_annotations_verified, - cols_verified=source_annotations_verified, - row_order=target_annotations_ordered, - col_order=source_annotations_ordered, - ) - return tm.T if forward else tm - elif target_annotations_verified == source_annotations_verified: + ).T + return _order_transition_matrix_helper( + tm=tm, + rows_verified=target_annotations_verified, + cols_verified=source_annotations_verified, + row_order=target_annotations_ordered, + col_order=source_annotations_ordered, + ) + + if target_annotations_verified == source_annotations_verified: annotations_ordered = tm.columns.sort_values() if forward: - tm = _order_transition_matrix_helper( + return _order_transition_matrix_helper( tm=tm, rows_verified=source_annotations_verified, cols_verified=target_annotations_verified, row_order=annotations_ordered, col_order=annotations_ordered, - ) - else: - tm = _order_transition_matrix_helper( - tm=tm, - rows_verified=target_annotations_verified, - cols_verified=source_annotations_verified, - row_order=annotations_ordered, - col_order=annotations_ordered, - ) - return tm.T if forward else tm + ).T + return _order_transition_matrix_helper( + tm=tm, + rows_verified=target_annotations_verified, + cols_verified=source_annotations_verified, + row_order=annotations_ordered, + col_order=annotations_ordered, + ) + return tm if forward else tm.T @@ -284,7 +270,6 @@ def _correlation_test( - ``ci_low`` - lower bound of the ``confidence_level`` correlation confidence interval. - ``ci_high`` - upper bound of the ``confidence_level`` correlation confidence interval. """ - corr, pvals, ci_low, ci_high = _correlation_test_helper( X.T, Y.values, diff --git a/src/moscot/problems/generic/_generic.py b/src/moscot/problems/generic/_generic.py index 38d8586ac..a597cdf13 100644 --- a/src/moscot/problems/generic/_generic.py +++ b/src/moscot/problems/generic/_generic.py @@ -62,7 +62,7 @@ def prepare( -------- %(ex_prepare)s """ - self.batch_key = key + self.batch_key = key # type: ignore[misc] xy, kwargs = handle_joint_attr(joint_attr, kwargs) xy, _, _ = handle_cost(xy=xy, cost=cost) return super().prepare( @@ -129,7 +129,7 @@ def solve( -------- %(ex_solve_linear)s """ - return super().solve( + return super().solve( # type: ignore[return-value] epsilon=epsilon, tau_a=tau_a, tau_b=tau_b, @@ -155,7 +155,7 @@ def solve( @property def _base_problem_type(self) -> Type[B]: - return OTProblem + return OTProblem # type: ignore[return-value] @property def _valid_policies(self) -> Tuple[str, ...]: @@ -219,7 +219,7 @@ def prepare( -------- %(ex_prepare)s """ - self.batch_key = key + self.batch_key = key # type: ignore[misc] GW_updated: List[Dict[str, Any]] = [{}] * 2 for i, z in enumerate([GW_x, GW_y]): @@ -299,7 +299,7 @@ def solve( -------- %(ex_solve_quadratic)s """ - return super().solve( + return super().solve( # type: ignore[return-value] alpha=alpha, epsilon=epsilon, tau_a=tau_a, @@ -325,7 +325,7 @@ def solve( @property def _base_problem_type(self) -> Type[B]: - return OTProblem + return OTProblem # type: ignore[return-value] @property def _valid_policies(self) -> Tuple[str, ...]: diff --git a/src/moscot/problems/generic/_mixins.py b/src/moscot/problems/generic/_mixins.py index 80a46adb9..57566a404 100644 --- a/src/moscot/problems/generic/_mixins.py +++ b/src/moscot/problems/generic/_mixins.py @@ -143,8 +143,7 @@ def push( } self.adata.obs[key_added] = self._flatten(result, key=self.batch_key) Key.uns.set_plotting_vars(self.adata, PlottingKeys.PUSH, key_added, plot_vars) - if return_data: - return result + return result if return_data else None @d_mixins.dedent def pull( @@ -197,8 +196,7 @@ def pull( } self.adata.obs[key_added] = self._flatten(result, key=self.batch_key) Key.uns.set_plotting_vars(self.adata, PlottingKeys.PULL, key_added, plot_vars) - if return_data: - return result + return result if return_data else None @property def batch_key(self: GenericAnalysisMixinProtocol[K, B]) -> Optional[str]: diff --git a/src/moscot/problems/space/_mapping.py b/src/moscot/problems/space/_mapping.py index 9c82dcd2d..bc6abb866 100644 --- a/src/moscot/problems/space/_mapping.py +++ b/src/moscot/problems/space/_mapping.py @@ -229,7 +229,7 @@ def filtered_vars(self) -> Optional[Sequence[str]]: @filtered_vars.setter def filtered_vars(self, value: Optional[Sequence[str]]) -> None: - self._filtered_vars = self._filter_vars(var_names=value) + self._filtered_vars = self._filter_vars(var_names=value) # type: ignore[misc] @property def _base_problem_type(self) -> Type[B]: diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 19ecb0674..29f3bbc01 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -85,11 +85,7 @@ def _filter_vars( ) -> Optional[List[str]]: ... - def _cell_transition( - self: AnalysisMixinProtocol[K, B], - *args: Any, - **kwargs: Any, - ) -> pd.DataFrame: + def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ... @@ -101,14 +97,13 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._spatial_key: Optional[str] = None self._batch_key: Optional[str] = None - def _interpolate_scheme( + def _interpolate_scheme( # type: ignore[misc] self: SpatialAlignmentMixinProtocol[K, B], reference: K, mode: Literal["warp", "affine"], spatial_key: Optional[str] = None, ) -> Tuple[Dict[K, ArrayLike], Optional[Dict[K, Optional[ArrayLike]]]]: """Scheme for interpolation.""" - # get reference src = self._subset_spatial(reference, spatial_key=spatial_key) transport_maps: Dict[K, ArrayLike] = {reference: src} @@ -118,12 +113,9 @@ def _interpolate_scheme( transport_metadata = {reference: np.diag((1, 1))} # 2d data # get policy - if isinstance(reference, str): - reference_ = [reference] - else: - reference_ = reference + reference_ = [reference] if isinstance(reference, str) else reference full_steps = self._policy._graph - starts = set(chain.from_iterable(full_steps)) - set(reference_) # type: ignore[arg-type] + starts = set(chain.from_iterable(full_steps)) - set(reference_) # type: ignore[call-overload] if mode == AlignmentMode.AFFINE: _transport = self._affine @@ -149,7 +141,7 @@ def _interpolate_scheme( return transport_maps, (transport_metadata if mode == "affine" else None) @d.dedent - def align( + def align( # type: ignore[misc] self: SpatialAlignmentMixinProtocol[K, B], reference: K, mode: Literal["warp", "affine"] = "warp", @@ -193,10 +185,10 @@ def align( self.adata.uns[self.spatial_key]["alignment_metadata"] = aligned_metadata if not inplace: return aligned_basis - self.adata.obsm[f"{self.spatial_key}_{mode}"] = aligned_basis + self.adata.obsm[f"{self.spatial_key}_{mode}"] = aligned_basis # noqa: RET503 @d_mixins.dedent - def cell_transition( + def cell_transition( # type: ignore[misc] self: SpatialAlignmentMixinProtocol[K, B], source: K, target: K, @@ -255,7 +247,7 @@ def spatial_key(self) -> Optional[str]: return self._spatial_key @spatial_key.setter - def spatial_key(self: SpatialAlignmentMixinProtocol[K, B], key: Optional[str]) -> None: + def spatial_key(self: SpatialAlignmentMixinProtocol[K, B], key: Optional[str]) -> None: # type: ignore[misc] if key is not None and key not in self.adata.obsm: raise KeyError(f"Unable to find spatial data in `adata.obsm[{key!r}]`.") self._spatial_key = key @@ -267,11 +259,11 @@ def batch_key(self) -> Optional[str]: @batch_key.setter def batch_key(self, key: Optional[str]) -> None: - if key is not None and key not in self.adata.obs: + if key is not None and key not in self.adata.obs: # type: ignore[attr-defined] raise KeyError(f"Unable to find batch data in `adata.obs[{key!r}]`.") self._batch_key = key - def _subset_spatial( + def _subset_spatial( # type: ignore[misc] self: SpatialAlignmentMixinProtocol[K, B], k: K, spatial_key: Optional[str] = None ) -> ArrayLike: if spatial_key is None: @@ -308,7 +300,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._batch_key: Optional[str] = None self._spatial_key: Optional[str] = None - def _filter_vars( + def _filter_vars( # type: ignore[misc] self: SpatialMappingMixinProtocol[K, B], var_names: Optional[Sequence[str]] = None, ) -> Optional[List[str]]: @@ -330,7 +322,7 @@ def _filter_vars( raise ValueError("Some variable are missing in the single-cell or the spatial `AnnData`.") - def correlate( + def correlate( # type: ignore[misc] self: SpatialMappingMixinProtocol[K, B], var_names: Optional[List[str]] = None, corr_method: Literal["pearson", "spearman"] = "pearson", @@ -382,7 +374,7 @@ def correlate( return corrs - def impute( + def impute( # type: ignore[misc] self: SpatialMappingMixinProtocol[K, B], var_names: Optional[Sequence[Any]] = None, device: Optional[Device_t] = None, @@ -415,7 +407,7 @@ def impute( return adata_pred @d.dedent - def spatial_correspondence( + def spatial_correspondence( # type: ignore[misc] self: SpatialMappingMixinProtocol[K, B], interval: Union[ArrayLike, int] = 10, max_dist: Optional[int] = None, @@ -453,8 +445,7 @@ def _get_features( if key is not None: return getattr(adata, att)[key] - else: - return getattr(adata, att) + return getattr(adata, att) if self.batch_key is not None: out_list = [] @@ -480,14 +471,13 @@ def _get_features( out = pd.concat(out_list, axis=0) out[self.batch_key] = pd.Categorical(out[self.batch_key]) return out - else: - spatial = self.adata.obsm[self.spatial_key] - features = _get_features(self.adata, attr) - out = _compute_correspondence(spatial, features, interval, max_dist) - return out + + spatial = self.adata.obsm[self.spatial_key] + features = _get_features(self.adata, attr) + return _compute_correspondence(spatial, features, interval, max_dist) @d_mixins.dedent - def cell_transition( + def cell_transition( # type: ignore[misc] self: SpatialMappingMixinProtocol[K, B], source: K, target: Optional[K] = None, @@ -546,7 +536,7 @@ def batch_key(self) -> Optional[str]: @batch_key.setter def batch_key(self, key: Optional[str]) -> None: - if key is not None and key not in self.adata.obs: + if key is not None and key not in self.adata.obs: # type: ignore[attr-defined] raise KeyError(f"Unable to find batch data in `adata.obs[{key!r}]`.") self._batch_key = key @@ -556,7 +546,7 @@ def spatial_key(self) -> Optional[str]: return self._spatial_key @spatial_key.setter - def spatial_key(self: SpatialAlignmentMixinProtocol[K, B], key: Optional[str]) -> None: + def spatial_key(self: SpatialAlignmentMixinProtocol[K, B], key: Optional[str]) -> None: # type: ignore[misc] if key is not None and key not in self.adata.obsm: raise KeyError(f"Unable to find spatial data in `adata.obsm[{key!r}]`.") self._spatial_key = key @@ -570,7 +560,6 @@ def _compute_correspondence( ) -> pd.DataFrame: if isinstance(interval, int): # prepare support - spatial.shape[0] hull = ConvexHull(spatial) area = hull.volume if max_dist is None: @@ -582,6 +571,7 @@ def _compute_correspondence( def pdist(row_idx: ArrayLike, col_idx: float, feat: ArrayLike) -> Any: if len(row_idx) > 0: return pairwise_distances(feat[row_idx, :], feat[[col_idx], :]).mean() # type: ignore[index] + return np.nan vpdist = np.vectorize(pdist, excluded=["feat"]) features = features.A if sp.issparse(features) else features # type: ignore[attr-defined] diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 645532bea..2e1ffb99e 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -184,7 +184,6 @@ def cell_transition( ----- %(notes_cell_transition)s """ - if TYPE_CHECKING: assert isinstance(self.temporal_key, str) return self._cell_transition( @@ -294,8 +293,7 @@ def sankey( "captions": [str(t) for t in tuples], } Key.uns.set_plotting_vars(self.adata, PlottingKeys.SANKEY, key_added, plot_vars) - if return_data: - return cell_transitions_updated + return cell_transitions_updated if return_data else None @d_mixins.dedent def push( @@ -353,8 +351,7 @@ def push( } self.adata.obs[key_added] = self._flatten(result, key=self.temporal_key) Key.uns.set_plotting_vars(self.adata, PlottingKeys.PUSH, key_added, plot_vars) - if return_data: - return result + return result if return_data else None @d_mixins.dedent def pull( @@ -411,8 +408,7 @@ def pull( } self.adata.obs[key_added] = self._flatten(result, key=self.temporal_key) Key.uns.set_plotting_vars(self.adata, PlottingKeys.PULL, key_added, plot_vars) - if return_data: - return result + return result if return_data else None @property def prior_growth_rates(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFrame]: @@ -542,14 +538,14 @@ def _get_data( break else: raise ValueError(f"No data found for `{source}` time point.") - for src, tgt in self.problems.keys(): + for src, tgt in self.problems: if src == intermediate: intermediate_data = self.problems[src, tgt].xy.data_src # type: ignore[union-attr] intermediate_adata = self.problems[src, tgt].adata_src break else: raise ValueError(f"No data found for `{intermediate}` time point.") - for src, tgt in self.problems.keys(): + for src, tgt in self.problems: if tgt == target: target_data = self.problems[src, tgt].xy.data_tgt # type: ignore[union-attr] break @@ -841,12 +837,11 @@ def _interpolate_gex_randomly( else: row_probability = growth_rates ** (1 - interpolation_parameter) row_probability /= np.sum(row_probability) - result = ( + return ( source_data[rng.choice(len(source_data), size=number_cells, p=row_probability), :] * (1 - interpolation_parameter) + target_data[rng.choice(len(target_data), size=number_cells), :] * interpolation_parameter ) - return result @staticmethod def _get_interp_param( diff --git a/src/moscot/solvers/_base_solver.py b/src/moscot/solvers/_base_solver.py index 0a007ca8a..9e2dbbff5 100644 --- a/src/moscot/solvers/_base_solver.py +++ b/src/moscot/solvers/_base_solver.py @@ -42,9 +42,9 @@ def solver(self, *, backend: Literal["ott"] = "ott", **kwargs: Any) -> "BaseSolv from moscot.backends.ott import GWSolver, SinkhornSolver # type: ignore[attr-defined] if self == ProblemKind.LINEAR: - return SinkhornSolver(**kwargs) + return SinkhornSolver(**kwargs) # type: ignore[return-value] if self == ProblemKind.QUAD: - return GWSolver(**kwargs) + return GWSolver(**kwargs) # type: ignore[return-value] raise NotImplementedError(f"Unable to create solver for `{self}` problem.") raise NotImplementedError(f"Backend `{backend}` is not yet implemented.") diff --git a/src/moscot/utils/_subset_policy.py b/src/moscot/utils/_subset_policy.py index 5e7565a25..a42b2f3a3 100644 --- a/src/moscot/utils/_subset_policy.py +++ b/src/moscot/utils/_subset_policy.py @@ -2,6 +2,7 @@ from typing import Any, Set, Dict, List, Tuple, Union, Generic, Literal, TypeVar, Hashable, Iterable, Optional, Sequence from operator import gt, lt from itertools import product +import contextlib import pandas as pd import networkx as nx @@ -229,10 +230,8 @@ def remove_node(self, node: Tuple[K, K]) -> None: """ if self._graph is None: raise RuntimeError("Construct the policy graph first.") - try: + with contextlib.suppress(KeyError): self._graph.remove(node) - except KeyError: - pass class OrderedPolicy(SubsetPolicy[K], ABC): diff --git a/src/moscot/utils/_tagged_array.py b/src/moscot/utils/_tagged_array.py index 5ceb3a983..bbf6b94e9 100644 --- a/src/moscot/utils/_tagged_array.py +++ b/src/moscot/utils/_tagged_array.py @@ -60,9 +60,9 @@ def _extract_data( if key is not None: data = data[key] except KeyError: - raise KeyError(f"Unable to fetch data from `{modifier}`.") + raise KeyError(f"Unable to fetch data from `{modifier}`.") from None except IndexError: - raise IndexError(f"Unable to fetch data from `{modifier}`.") + raise IndexError(f"Unable to fetch data from `{modifier}`.") from None if sp.issparse(data): logger.warning(f"Densifying data in `{modifier}`") diff --git a/tests/_utils.py b/tests/_utils.py index 5df750c79..a9121214d 100644 --- a/tests/_utils.py +++ b/tests/_utils.py @@ -49,10 +49,9 @@ def _ones(self, n: int) -> ArrayLike: def _make_adata(grid: ArrayLike, n: int, seed) -> List[AnnData]: - rng = np.random.default_rng(seed) + rng = np.random.RandomState(seed) X = rng.normal(size=(100, 60)) - adatas = [AnnData(X=csr_matrix(X), obsm={"spatial": grid.copy()}, dtype=X.dtype) for _ in range(n)] - return adatas + return [AnnData(X=csr_matrix(X), obsm={"spatial": grid.copy()}, dtype=X.dtype) for _ in range(n)] def _adata_spatial_split(adata: AnnData) -> Tuple[AnnData, AnnData]: @@ -67,8 +66,7 @@ def _make_grid(grid_size: int) -> ArrayLike: x1s = np.linspace(*xlimits, num=grid_size) x2s = np.linspace(*ylimits, num=grid_size) X1, X2 = np.meshgrid(x1s, x2s) - X_orig_single = np.vstack([X1.ravel(), X2.ravel()]).T - return X_orig_single + return np.vstack([X1.ravel(), X2.ravel()]).T class Problem(CompoundProblem[Any, OTProblem]): diff --git a/tests/conftest.py b/tests/conftest.py index 8e19f3692..7895ce224 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,10 @@ from math import cos, sin from typing import Tuple, Optional -import random from scipy.sparse import csr_matrix import pandas as pd import pytest +import matplotlib.pyplot as plt from jax.config import config import numpy as np @@ -12,13 +12,13 @@ from anndata import AnnData import scanpy as sc -import anndata as ad from tests._utils import Geom_t, _make_grid, _make_adata ANGLES = (0, 30, 60) +# TODO(michalk8): consider passing this via env config.update("jax_enable_x64", True) @@ -30,6 +30,13 @@ def pytest_sessionstart() -> None: sc.set_figure_params(dpi=40, color_map="viridis") +@pytest.fixture(autouse=True) +def _close_figure(): + # prevent `RuntimeWarning: More than 20 figures have been opened.` + yield + plt.close() + + @pytest.fixture() def x() -> Geom_t: rng = np.random.RandomState(0) @@ -55,7 +62,7 @@ def y() -> Geom_t: @pytest.fixture() -def xy() -> Geom_t: +def xy() -> Tuple[Geom_t, Geom_t]: rng = np.random.RandomState(2) n = 20 # number of points in the first distribution n2 = 30 # number of points in the second distribution @@ -119,7 +126,7 @@ def adata_time() -> AnnData: adata.obs["batch"] = rng.choice((0, 1, 2), len(adata)) adata.obs["left_marginals"] = np.ones(len(adata)) adata.obs["right_marginals"] = np.ones(len(adata)) - adata.obs["celltype"] = np.random.choice(["A", "B", "C"], size=len(adata)) + adata.obs["celltype"] = rng.choice(["A", "B", "C"], size=len(adata)) # genes from mouse/human proliferation/apoptosis genes = ["ANLN", "ANP32E", "ATAD2", "Mcm4", "Smc4", "Gtse1", "ADD1", "AIFM3", "ANKH", "Ercc5", "Serpinb5", "Inhbb"] # genes which are transcription factors, 3 from drosophila, 2 from human, 1 from mouse @@ -145,21 +152,23 @@ def create_marginals(n: int, m: int, *, uniform: bool = False, seed: Optional[in @pytest.fixture() def gt_temporal_adata() -> AnnData: - return _gt_temporal_adata.copy() + adata = _gt_temporal_adata.copy() + adata.obs_names_make_unique() + return adata @pytest.fixture() def adata_space_rotate() -> AnnData: - seed = random.randint(0, 422) + rng = np.random.RandomState(31) grid = _make_grid(10) - adatas = _make_adata(grid, n=len(ANGLES), seed=seed) + adatas = _make_adata(grid, n=len(ANGLES), seed=32) for adata, angle in zip(adatas, ANGLES): theta = np.deg2rad(angle) rot = np.array([[cos(theta), -sin(theta)], [sin(theta), cos(theta)]]) adata.obsm["spatial"] = np.dot(adata.obsm["spatial"], rot) - adata = ad.concat(adatas, label="batch") - adata.obs["celltype"] = np.random.choice(["A", "B", "C"], size=len(adata)) + adata = adatas[0].concatenate(*adatas[1:], batch_key="batch") + adata.obs["celltype"] = rng.choice(["A", "B", "C"], size=len(adata)) adata.uns["spatial"] = {} adata.obs_names_make_unique() sc.pp.pca(adata) @@ -168,11 +177,10 @@ def adata_space_rotate() -> AnnData: @pytest.fixture() def adata_mapping() -> AnnData: - seed = random.randint(0, 422) grid = _make_grid(10) - adataref, adata1, adata2 = _make_adata(grid, n=3, seed=seed) + adataref, adata1, adata2 = _make_adata(grid, n=3, seed=17) sc.pp.pca(adataref, n_comps=30) - adata = ad.concat([adataref, adata1, adata2], label="batch", join="outer") + adata = adataref.concatenate(adata1, adata2, batch_key="batch", join="outer") adata.obs_names_make_unique() return adata diff --git a/tests/data/generate_gt_temporal_data.py b/tests/data/generate_gt_temporal_data.py index 41f8e73d2..a2fcb0852 100644 --- a/tests/data/generate_gt_temporal_data.py +++ b/tests/data/generate_gt_temporal_data.py @@ -7,7 +7,7 @@ raise ImportError( "Please install WOT from commit hash`ca5e94f05699997b01cf5ae13383f9810f0613f6`" + "with `pip install git+https://github.com/broadinstitute/wot.git@ca5e94f05699997b01cf5ae13383f9810f0613f6`" - ) + ) from None import os diff --git a/tests/data/regression_tests_spatial.py b/tests/data/regression_tests_spatial.py index 969a41c08..544b60a31 100644 --- a/tests/data/regression_tests_spatial.py +++ b/tests/data/regression_tests_spatial.py @@ -34,8 +34,7 @@ def adata_mapping() -> AnnData: adataref, adata1, adata2 = _make_adata(grid, n=3) sc.pp.pca(adataref) - adata = ad.concat([adataref, adata1, adata2], label="batch", join="outer") - return adata + return ad.concat([adataref, adata1, adata2], label="batch", join="outer") def _make_grid(grid_size: int) -> ArrayLike: @@ -43,8 +42,7 @@ def _make_grid(grid_size: int) -> ArrayLike: x1s = np.linspace(*xlimits, num=grid_size) # type: ignore [call-overload] x2s = np.linspace(*ylimits, num=grid_size) # type: ignore [call-overload] X1, X2 = np.meshgrid(x1s, x2s) - X_orig_single = np.vstack([X1.ravel(), X2.ravel()]).T - return X_orig_single + return np.vstack([X1.ravel(), X2.ravel()]).T def _make_adata(grid: ArrayLike, n: int) -> List[AnnData]: diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py index 315779488..d4be344ca 100644 --- a/tests/datasets/test_dataset.py +++ b/tests/datasets/test_dataset.py @@ -1,50 +1,14 @@ -from types import FunctionType from typing import Mapping, Optional -from pathlib import Path -from http.client import RemoteDisconnected -import warnings import pytest import networkx as nx import numpy as np -from anndata import AnnData, OldFormatWarning - from moscot.datasets import simulate_data -import moscot as mt - - -class TestDatasetsImports: - @pytest.mark.parametrize("func", mt.datasets._datasets.__all__) - def test_import(self, func): - assert hasattr(mt.datasets, func), dir(mt.datasets) - fn = getattr(mt.datasets, func) - - assert isinstance(fn, FunctionType) - - -# TODO(michalk8): parse the code and xfail iff server issue -class TestDatasetsDownload: - @pytest.mark.timeout(120) - def test_sim_align(self, tmp_path: Path): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=OldFormatWarning) - try: - adata = mt.datasets.sim_align(tmp_path / "foo") - - assert isinstance(adata, AnnData) - assert adata.shape == (1200, 500) - except RemoteDisconnected as e: - pytest.xfail(str(e)) class TestSimulateData: - @pytest.mark.fast() - def test_returns_adata(self): - result = simulate_data() - assert isinstance(result, AnnData) - @pytest.mark.fast() @pytest.mark.parametrize("n_distributions", [2, 4]) @pytest.mark.parametrize("key", ["batch", "day"]) diff --git a/tests/plotting/test_utils.py b/tests/plotting/test_utils.py index 97de9a580..de98e20e0 100644 --- a/tests/plotting/test_utils.py +++ b/tests/plotting/test_utils.py @@ -1,5 +1,4 @@ from typing import List, Optional -import os import pytest import matplotlib as mpl @@ -29,70 +28,47 @@ def test_input_to_adatas_adata(self, adata_time: AnnData): np.testing.assert_array_equal(adata1.X.A, adata_time.X.A) np.testing.assert_array_equal(adata2.X.A, adata_time.X.A) - @pytest.mark.parametrize("save", [None, "tests/data/test_plot.png"]) @pytest.mark.parametrize("return_fig", [True, False]) - def test_cell_transition(self, adata_pl_cell_transition: AnnData, return_fig: bool, save: Optional[str]): - if save: - if os.path.exists(save): - os.remove(save) - fig = msc.plotting.cell_transition(adata_pl_cell_transition, return_fig=return_fig, save=save) + def test_cell_transition(self, adata_pl_cell_transition: AnnData, return_fig: bool): + fig = msc.plotting.cell_transition(adata_pl_cell_transition, return_fig=return_fig) if return_fig: assert fig is not None assert isinstance(fig, mpl.figure.Figure) else: assert fig is None - if save: - assert os.path.exists(save) @pytest.mark.parametrize("time_points", [None, [0]]) @pytest.mark.parametrize("return_fig", [True, False]) - @pytest.mark.parametrize("save", [None, "tests/data/test_plot.png"]) - def test_push( - self, adata_pl_push: AnnData, time_points: Optional[List[int]], return_fig: bool, save: Optional[str] - ): - if save: - if os.path.exists(save): - os.remove(save) - - fig = msc.plotting.push(adata_pl_push, time_points=time_points, return_fig=return_fig, save=save) + def test_push(self, adata_pl_push: AnnData, time_points: Optional[List[int]], return_fig: bool): + fig = msc.plotting.push(adata_pl_push, time_points=time_points, return_fig=return_fig) if return_fig: assert fig is not None assert isinstance(fig, mpl.figure.Figure) else: assert fig is None - if save: - assert os.path.exists(save) @pytest.mark.parametrize("time_points", [None, [0]]) @pytest.mark.parametrize("return_fig", [True, False]) - @pytest.mark.parametrize("save", [None, "tests/data/test_plot.png"]) def test_pull( - self, adata_pl_pull: AnnData, time_points: Optional[List[int]], return_fig: bool, save: Optional[str] + self, + adata_pl_pull: AnnData, + time_points: Optional[List[int]], + return_fig: bool, ): - if save: - if os.path.exists(save): - os.remove(save) - fig = msc.plotting.pull(adata_pl_pull, time_points=time_points, return_fig=return_fig, save=save) + fig = msc.plotting.pull(adata_pl_pull, time_points=time_points, return_fig=return_fig) if return_fig: assert fig is not None assert isinstance(fig, mpl.figure.Figure) else: assert fig is None - if save: - assert os.path.exists(save) - @pytest.mark.parametrize("save", [None, "tests/data/test_plot.png"]) @pytest.mark.parametrize("return_fig", [True, False]) @pytest.mark.parametrize("interpolate_color", [True, False]) - def test_sankey(self, adata_pl_sankey: AnnData, return_fig: bool, save: Optional[str], interpolate_color: bool): - if save: - if os.path.exists(save): - os.remove(save) + def test_sankey(self, adata_pl_sankey: AnnData, return_fig: bool, interpolate_color: bool): fig = msc.plotting.sankey( adata_pl_sankey, return_fig=return_fig, - save=save, interpolate_color=interpolate_color, ) if return_fig: @@ -100,5 +76,3 @@ def test_sankey(self, adata_pl_sankey: AnnData, return_fig: bool, save: Optional assert isinstance(fig, mpl.figure.Figure) else: assert fig is None - if save: - assert os.path.exists(save) diff --git a/tests/problems/generic/test_fgw_problem.py b/tests/problems/generic/test_fgw_problem.py index a24a82317..b4c5cec4b 100644 --- a/tests/problems/generic/test_fgw_problem.py +++ b/tests/problems/generic/test_fgw_problem.py @@ -8,9 +8,9 @@ from anndata import AnnData -from moscot.problems.base import OTProblem # type:ignore[attr-defined] +from moscot.problems.base import OTProblem from moscot.solvers._output import BaseSolverOutput -from moscot.problems.generic import GWProblem # type:ignore[attr-defined] +from moscot.problems.generic import GWProblem from tests.problems.conftest import ( fgw_args_1, fgw_args_2, @@ -48,7 +48,7 @@ def test_prepare(self, adata_space_rotate: AnnData): assert key in expected_keys assert isinstance(problem[key], OTProblem) - def test_solve_balanced(self, adata_space_rotate: AnnData): # type: ignore[no-untyped-def] + def test_solve_balanced(self, adata_space_rotate: AnnData): eps = 0.5 adata_space_rotate = adata_space_rotate[adata_space_rotate.obs["batch"].isin(("0", "1"))].copy() expected_keys = [("0", "1"), ("1", "2")] diff --git a/tests/problems/generic/test_gw_problem.py b/tests/problems/generic/test_gw_problem.py index ee5611a90..c1f0a60a9 100644 --- a/tests/problems/generic/test_gw_problem.py +++ b/tests/problems/generic/test_gw_problem.py @@ -8,9 +8,9 @@ from anndata import AnnData -from moscot.problems.base import OTProblem # type:ignore[attr-defined] +from moscot.problems.base import OTProblem from moscot.solvers._output import BaseSolverOutput -from moscot.problems.generic import GWProblem # type:ignore[attr-defined] +from moscot.problems.generic import GWProblem from tests.problems.conftest import ( gw_args_1, gw_args_2, diff --git a/tests/problems/generic/test_mixins.py b/tests/problems/generic/test_mixins.py index 3d63538d1..892b100b3 100644 --- a/tests/problems/generic/test_mixins.py +++ b/tests/problems/generic/test_mixins.py @@ -29,7 +29,7 @@ def test_sample_from_tmap_pipeline( problem[10, 10.5]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) if interpolation_parameter is not None and not 0 <= interpolation_parameter <= 1: - with np.testing.assert_raises(ValueError): + with pytest.raises(ValueError, match=r"^Expected interpolation"): problem._sample_from_tmap( 10, 10.5, @@ -40,7 +40,7 @@ def test_sample_from_tmap_pipeline( interpolation_parameter=interpolation_parameter, ) elif interpolation_parameter is None and account_for_unbalancedness: - with np.testing.assert_raises(ValueError): + with pytest.raises(ValueError, match=r"^When accounting for unbalancedness"): problem._sample_from_tmap( 10, 10.5, @@ -118,8 +118,8 @@ def test_cell_transition_aggregation_cell_forward(self, gt_temporal_adata: AnnDa df_res = df_res.div(df_res.sum(axis=1), axis=0) - ctr_ordered = ctr.sort_index().sort_index(1) - df_res_ordered = df_res.sort_index().sort_index(1) + ctr_ordered = ctr.sort_index().sort_index(axis=1) + df_res_ordered = df_res.sort_index().sort_index(axis=1) np.testing.assert_allclose( ctr_ordered.values.astype(float), df_res_ordered.values.astype(float), rtol=RTOL, atol=ATOL ) @@ -159,8 +159,8 @@ def test_cell_transition_aggregation_cell_backward(self, gt_temporal_adata: AnnD df_res = df_res.div(df_res.sum(axis=0), axis=1) - ctr_ordered = ctr.sort_index().sort_index(1) - df_res_ordered = df_res.sort_index().sort_index(1) + ctr_ordered = ctr.sort_index().sort_index(axis=1) + df_res_ordered = df_res.sort_index().sort_index(axis=1) np.testing.assert_allclose( ctr_ordered.values.astype(float), df_res_ordered.values.astype(float), rtol=RTOL, atol=ATOL ) diff --git a/tests/problems/generic/test_sinkhorn_problem.py b/tests/problems/generic/test_sinkhorn_problem.py index 2af8ab263..f3ebed20f 100644 --- a/tests/problems/generic/test_sinkhorn_problem.py +++ b/tests/problems/generic/test_sinkhorn_problem.py @@ -44,7 +44,7 @@ def test_prepare(self, adata_time: AnnData): assert key in expected_keys assert isinstance(problem[key], OTProblem) - def test_solve_balanced(self, adata_time: AnnData): # type: ignore[no-untyped-def] + def test_solve_balanced(self, adata_time: AnnData): eps = 0.5 expected_keys = [(0, 1), (1, 2)] problem = SinkhornProblem(adata=adata_time) @@ -126,10 +126,7 @@ def test_pass_arguments(self, adata_time: AnnData, args_to_check: Mapping[str, A el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val) assert el == args_to_check[arg] - if args_to_check["rank"] == -1: - args = pointcloud_args - else: - args = lr_pointcloud_args + args = pointcloud_args if args_to_check["rank"] == -1 else lr_pointcloud_args for arg, val in args.items(): el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val) assert hasattr(geom, val) diff --git a/tests/problems/space/test_alignment_problem.py b/tests/problems/space/test_alignment_problem.py index 5ddb67539..ad5f22223 100644 --- a/tests/problems/space/test_alignment_problem.py +++ b/tests/problems/space/test_alignment_problem.py @@ -75,7 +75,7 @@ def test_solve_balanced( if initializer == "random": # kwargs["kwargs_init"] = {"key": 0} # kwargs["key"] = 0 - return 0 # TODO(@MUCDK) fix after refactoring + return # TODO(@MUCDK) fix after refactoring ap = ( AlignmentProblem(adata=adata_space_rotate) .prepare(batch_key="batch") @@ -86,9 +86,12 @@ def test_solve_balanced( if initializer != "random": # TODO: is this valid? assert ap[prob_key].solution.converged + # TODO(michalk8): use np.testing assert np.allclose(*(sol.cost for sol in ap.solutions.values())) assert np.all([sol.converged for sol in ap.solutions.values()]) - assert np.all([np.all(~np.isnan(sol.transport_matrix)) for sol in ap.solutions.values()]) + np.testing.assert_array_equal( + [np.all(np.isfinite(sol.transport_matrix)) for sol in ap.solutions.values()], True + ) def test_solve_unbalanced(self, adata_space_rotate: AnnData): tau_a, tau_b = [0.8, 1] diff --git a/tests/problems/space/test_mapping_problem.py b/tests/problems/space/test_mapping_problem.py index a2d33f745..fabefe352 100644 --- a/tests/problems/space/test_mapping_problem.py +++ b/tests/problems/space/test_mapping_problem.py @@ -21,6 +21,7 @@ ) from moscot.solvers._base_solver import ProblemKind +# TODO(michalk8): should be made relative to tests root SOLUTIONS_PATH = Path("./tests/data/mapping_solutions.pkl") # base is moscot @@ -97,7 +98,7 @@ def test_solve_balanced( if initializer == "random": # kwargs["kwargs_init"] = {"key": 0} # kwargs["key"] = 0 - return 0 # TODO(@MUCDK) fix after refactoring + return # TODO(@MUCDK) fix after refactoring mp = MappingProblem(adataref, adatasp) mp = mp.prepare(batch_key="batch", sc_attr=sc_attr, var_names=var_names) mp = mp.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs) diff --git a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py index c1c22ffc7..2911bc4a1 100644 --- a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py +++ b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py @@ -139,9 +139,9 @@ def test_apoptosis_key_pipeline(self, adata_spatio_temporal: AnnData): @pytest.mark.fast() @pytest.mark.parametrize("scaling", [0.1, 1, 4]) def test_proliferation_key_c_pipeline(self, adata_spatio_temporal: AnnData, scaling: float): - keys = np.sort(np.unique(adata_spatio_temporal.obs["time"].values)) - adata_spatio_temporal = adata_spatio_temporal[adata_spatio_temporal.obs["time"].isin([keys[0], keys[1]])] - delta = keys[1] - keys[0] + key0, key1, *_ = np.sort(np.unique(adata_spatio_temporal.obs["time"].values)) + adata_spatio_temporal = adata_spatio_temporal[adata_spatio_temporal.obs["time"].isin([key0, key1])].copy() + delta = key1 - key0 problem = SpatioTemporalProblem(adata_spatio_temporal) assert problem.proliferation_key is None @@ -149,12 +149,10 @@ def test_proliferation_key_c_pipeline(self, adata_spatio_temporal: AnnData, scal assert problem.proliferation_key == "proliferation" problem = problem.prepare(time_key="time", marginal_kwargs={"scaling": scaling}) - prolif = adata_spatio_temporal[adata_spatio_temporal.obs["time"] == keys[0]].obs["proliferation"] - apopt = adata_spatio_temporal[adata_spatio_temporal.obs["time"] == keys[0]].obs["apoptosis"] + prolif = adata_spatio_temporal[adata_spatio_temporal.obs["time"] == key0].obs["proliferation"] + apopt = adata_spatio_temporal[adata_spatio_temporal.obs["time"] == key0].obs["apoptosis"] expected_marginals = np.exp((prolif - apopt) * delta / scaling) - print("problem[keys[0], keys[1]]._prior_growth", problem[keys[0], keys[1]]._prior_growth) - print("expected_marginals", expected_marginals) - np.testing.assert_allclose(problem[keys[0], keys[1]]._prior_growth, expected_marginals, rtol=RTOL, atol=ATOL) + np.testing.assert_allclose(problem[key0, key1]._prior_growth, expected_marginals, rtol=RTOL, atol=ATOL) def test_growth_rates_pipeline(self, adata_spatio_temporal: AnnData): problem = SpatioTemporalProblem(adata=adata_spatio_temporal) diff --git a/tests/problems/time/test_lineage_problem.py b/tests/problems/time/test_lineage_problem.py index e0aa7ad8b..a1c234458 100644 --- a/tests/problems/time/test_lineage_problem.py +++ b/tests/problems/time/test_lineage_problem.py @@ -44,9 +44,8 @@ def test_prepare(self, adata_time_barcodes: AnnData): assert isinstance(problem[key], BirthDeathProblem) def test_solve_balanced(self, adata_time_barcodes: AnnData): - eps = 0.5 - expected_keys = [(0, 1)] - adata_time_barcodes = adata_time_barcodes[adata_time_barcodes.obs["time"].isin((0, 1))] + eps, key = 0.5, (0, 1) + adata_time_barcodes = adata_time_barcodes[adata_time_barcodes.obs["time"].isin(key)].copy() problem = LineageProblem(adata=adata_time_barcodes) problem = problem.prepare( time_key="time", @@ -57,7 +56,7 @@ def test_solve_balanced(self, adata_time_barcodes: AnnData): for key, subsol in problem.solutions.items(): assert isinstance(subsol, BaseSolverOutput) - assert key in expected_keys + assert key == key def test_solve_unbalanced(self, adata_time_barcodes: AnnData): taus = [9e-1, 1e-2] @@ -152,9 +151,9 @@ def test_apoptosis_key_pipeline(self, adata_time_barcodes: AnnData): @pytest.mark.fast() @pytest.mark.parametrize("scaling", [0.1, 1, 4]) def test_proliferation_key_c_pipeline(self, adata_time_barcodes: AnnData, scaling: float): - keys = np.sort(np.unique(adata_time_barcodes.obs["time"].values)) - adata_time_barcodes = adata_time_barcodes[adata_time_barcodes.obs["time"].isin([keys[0], keys[1]])] - delta = keys[1] - keys[0] + key0, key1, *_ = np.sort(np.unique(adata_time_barcodes.obs["time"].values)) + adata_time_barcodes = adata_time_barcodes[adata_time_barcodes.obs["time"].isin([key0, key1])].copy() + delta = key1 - key0 problem = LineageProblem(adata_time_barcodes) assert problem.proliferation_key is None @@ -167,10 +166,10 @@ def test_proliferation_key_c_pipeline(self, adata_time_barcodes: AnnData, scalin policy="sequential", marginal_kwargs={"scaling": scaling}, ) - prolif = adata_time_barcodes[adata_time_barcodes.obs["time"] == keys[0]].obs["proliferation"] - apopt = adata_time_barcodes[adata_time_barcodes.obs["time"] == keys[0]].obs["apoptosis"] + prolif = adata_time_barcodes[adata_time_barcodes.obs["time"] == key0].obs["proliferation"] + apopt = adata_time_barcodes[adata_time_barcodes.obs["time"] == key0].obs["apoptosis"] expected_marginals = np.exp((prolif - apopt) * delta / scaling) - np.testing.assert_allclose(problem[keys[0], keys[1]]._prior_growth, expected_marginals, rtol=RTOL, atol=ATOL) + np.testing.assert_allclose(problem[key0, key1]._prior_growth, expected_marginals, rtol=RTOL, atol=ATOL) @pytest.mark.fast() def test_barcodes_pipeline(self, adata_time_barcodes: AnnData): diff --git a/tests/problems/time/test_mixins.py b/tests/problems/time/test_mixins.py index c7fd0695c..61d1fc91c 100644 --- a/tests/problems/time/test_mixins.py +++ b/tests/problems/time/test_mixins.py @@ -89,7 +89,6 @@ def test_cell_transition_regression(self, gt_temporal_adata: AnnData, forward: b key_1 = config["key_1"] key_2 = config["key_2"] key_3 = config["key_3"] - set(gt_temporal_adata.obs["cell_type"].cat.categories) problem = TemporalProblem(gt_temporal_adata) problem = problem.prepare(key) assert set(problem.problems.keys()) == {(key_1, key_2), (key_2, key_3)} @@ -116,18 +115,17 @@ def test_cell_transition_regression(self, gt_temporal_adata: AnnData, forward: b assert result.shape == expected_shape marginal = result.sum(axis=forward == 1).values present_cell_type_marginal = marginal[marginal > 0] - np.testing.assert_almost_equal(present_cell_type_marginal, np.ones(len(present_cell_type_marginal)), decimal=5) + np.testing.assert_allclose(present_cell_type_marginal, 1.0) direction = "forward" if forward else "backward" gt = gt_temporal_adata.uns[f"cell_transition_10_105_{direction}"] gt = gt.sort_index() result = result.sort_index() result = result[gt.columns] - np.testing.assert_almost_equal(result.values, gt.values, decimal=4) + np.testing.assert_allclose(result.values, gt.values, rtol=1e-6, atol=1e-6) def test_compute_time_point_distances_pipeline(self, adata_time: AnnData): - problem = TemporalProblem(adata_time) - problem.prepare("time") + problem = TemporalProblem(adata_time).prepare("time") distance_source_intermediate, distance_intermediate_target = problem.compute_time_point_distances( source=0, intermediate=1, @@ -318,8 +316,7 @@ def test_cell_transition_regression_notparam( self, adata_time_with_tmap: AnnData, ): # TODO(MUCDK): please check. - problem = TemporalProblem(adata_time_with_tmap) - problem = problem.prepare("time") + problem = TemporalProblem(adata_time_with_tmap).prepare("time") problem[0, 1]._solution = MockSolverOutput(adata_time_with_tmap.uns["transport_matrix"]) result = problem.cell_transition( @@ -329,8 +326,8 @@ def test_cell_transition_regression_notparam( target_groups="cell_type", forward=True, ) - res = result.sort_index().sort_index(1) - df_expected = adata_time_with_tmap.uns["cell_transition_gt"].sort_index().sort_index(1) + res = result.sort_index().sort_index(axis=1) + df_expected = adata_time_with_tmap.uns["cell_transition_gt"].sort_index().sort_index(axis=1) np.testing.assert_almost_equal(res.values, df_expected.values, decimal=8) @pytest.mark.fast() diff --git a/tests/problems/time/test_temporal_problem.py b/tests/problems/time/test_temporal_problem.py index eab8738a6..cb2908fcd 100644 --- a/tests/problems/time/test_temporal_problem.py +++ b/tests/problems/time/test_temporal_problem.py @@ -139,9 +139,9 @@ def test_apoptosis_key_pipeline(self, adata_time: AnnData): @pytest.mark.fast() @pytest.mark.parametrize("scaling", [0.1, 1, 4]) def test_proliferation_key_c_pipeline(self, adata_time: AnnData, scaling: float): - keys = np.sort(np.unique(adata_time.obs["time"].values)) - adata_time = adata_time[adata_time.obs["time"].isin([keys[0], keys[1]])] - delta = keys[1] - keys[0] + key0, key1, *_ = np.sort(np.unique(adata_time.obs["time"].values)) + adata_time = adata_time[adata_time.obs["time"].isin([key0, key1])].copy() + delta = key1 - key0 problem = TemporalProblem(adata_time) assert problem.proliferation_key is None @@ -149,14 +149,13 @@ def test_proliferation_key_c_pipeline(self, adata_time: AnnData, scaling: float) assert problem.proliferation_key == "proliferation" problem = problem.prepare(time_key="time", marginal_kwargs={"scaling": scaling}) - prolif = adata_time[adata_time.obs["time"] == keys[0]].obs["proliferation"] - apopt = adata_time[adata_time.obs["time"] == keys[0]].obs["apoptosis"] + prolif = adata_time[adata_time.obs["time"] == key0].obs["proliferation"] + apopt = adata_time[adata_time.obs["time"] == key0].obs["apoptosis"] expected_marginals = np.exp((prolif - apopt) * delta / scaling) - np.testing.assert_allclose(problem[keys[0], keys[1]]._prior_growth, expected_marginals, rtol=RTOL, atol=ATOL) + np.testing.assert_allclose(problem[key0, key1]._prior_growth, expected_marginals, rtol=RTOL, atol=ATOL) def test_cell_costs_source_pipeline(self, adata_time: AnnData): - problem = TemporalProblem(adata=adata_time) - problem = problem.prepare("time") + problem = TemporalProblem(adata=adata_time).prepare("time") problem = problem.solve(max_iterations=2) cell_costs_source = problem.cell_costs_source @@ -272,10 +271,7 @@ def test_pass_arguments(self, adata_time: AnnData, args_to_check: Mapping[str, A el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val) assert el == args_to_check[arg] - if args_to_check["rank"] == -1: - args = pointcloud_args - else: - args = lr_pointcloud_args + args = pointcloud_args if args_to_check["rank"] == -1 else lr_pointcloud_args for arg, val in args.items(): el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val) assert hasattr(geom, val) diff --git a/tox.ini b/tox.ini deleted file mode 100644 index a68c586f5..000000000 --- a/tox.ini +++ /dev/null @@ -1,151 +0,0 @@ -[flake8] -per-file-ignores = - */__init__.py: D104, F401 - tests/*: D - docs/*: D,B,A - src/moscot/_docs.py: D - src/moscot/backends/ott/_solver.py: D101, D102 - src/moscot/utils/_subset_policy.py: D - src/moscot/solvers/_output.py: D101, D102, D105, D106, D107, A002, A003 - src/moscot/problems/_compound_problem.py: D101, D102, D105, D106, D107, A002, A003 - tests/data/*: D,B,E - src/moscot/costs/_costs.py: D - src/moscot/problems/_compound_problem.py: RST, D - src/moscot/solvers/_tagged_array.py: D - src/moscot/backends/ott/_output.py: D - src/moscot/_constants/_constants.py: D - src/moscot/_constants/_key.py: D - src/moscot/problems/base/_base_problem.py: D - src/moscot/problems/base/_problem_manager.py: D - src/moscot/_constants/_enum.py: D - src/moscot/_docs/*.py: D -# D104: Missing docstring in public package -# F401: <package> imported but unused -max_line_length = 120 -filename = *.py -# D202 No blank lines allowed after function docstring -# D107 Missing docstring in __init__ -# B008 Do not perform function calls in argument defaults -# W503 line break before binary operator -# D105 Missing docstring in magic method -# E203 whitespace before ':' -# F405 ... may be undefined, or defined from star imports: ... -# RST306 Unknown target name -# D106 Missing docstring in public nested class -ignore = D202,D107,B008,W503,D105,E203,F405,RST306,RST304,E741 -exclude = - .git - __pycache__ - .tox - build - dist - setup.py -ban-relative-imports = true - -[doc8] -max-line-length = 120 -ignore-path = .tox,build,dist -quiet = 1 - -[gh-actions] -python = - 3.8: py38 - 3.9: py39 - 3.10: py310 - 3.11: py311 - -[gh-actions:env] -PLATFORM = - ubuntu-latest: linux - macos-latest: macos - -[tox] -isolated_build = True -envlist = - lint-code - py{38,39,310,311}-{linux,macos} -skip_missing_interpreters = true - -[coverage:run] -branch = true -#TODO(michalk8): enable once using pytest-xdist -parallel = false -source = moscot -omit = */__init__.py - -[coverage:paths] -source = - moscot - */site-packages/moscot - -[coverage:report] -exclude_lines = - \#.*pragma:\s*no.?cover - - ^if __name__ == .__main__.:$ - - ^\s*raise AssertionError\b - ^\s*raise NotImplementedError\b - ^\s*return NotImplemented\b -show_missing = true -precision = 2 -skip_empty = True -sort = Miss - -[pytest] -addopts = --strict-markers -python_files = test_*.py -testpaths = tests/ -xfail_strict = true - -[testenv] -platform = - linux: linux - macos: (osx|darwin) -# TODO(michalk8): Cython+POT not necessary, just convenient for fixtures -deps = - pytest - pytest-mock - pytest-cov -usedevelop = true -passenv = TOXENV,CI,CODECOV_*,GITHUB_ACTIONS,PYTEST_FLAGS -commands = - python -m pytest --cov --cov-append --cov-report=term-missing --cov-config={toxinidir}/tox.ini {posargs:-vv} {env:PYTEST_FLAGS:} - -[testenv:lint-code] -description = Lint the code. -deps = pre-commit>=2.14.0 -skip_install = true -commands = - pre-commit run --all-files --show-diff-on-failure - -[testenv:clean-docs] -description = Remove the documentation. -deps = -skip_install = true -changedir = {tox_root}/docs -allowlist_externals = make -commands = - make clean - -[testenv:build-docs] -description = Build the documentation. -deps = -use_develop = true -extras = docs -allowlist_externals = make -commands = - make html -C {tox_root}{/}docs {posargs} -commands_post = - python -c 'import pathlib; print(f"Documentation is under:", pathlib.Path(f"{tox_root}") / "docs" / "build" / "html" / "index.html")' - -[testenv:build-package] -description = Build the package. -deps = - build - twine -commands = - python -m build --sdist --wheel --outdir {tox_root}{/}dist{/} {posargs:} - twine check {tox_root}{/}dist{/}* -commands_post = - python -c 'import pathlib; print(f"Package is under:", pathlib.Path("{tox_root}") / "dist")'