diff --git a/.bumpversion.cfg b/.bumpversion.cfg
deleted file mode 100644
index 8c316abdb..000000000
--- a/.bumpversion.cfg
+++ /dev/null
@@ -1,7 +0,0 @@
-[bumpversion]
-current_version = 0.1.1
-commit = True
-tag = True
-parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
-serialize = 
-	{major}.{minor}.{patch}
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index ae21f6b53..d9a2cbf41 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -19,7 +19,7 @@ jobs:
     strategy:
       fail-fast: false
       matrix:
-        # TODO(michalk8): in the future, lint the docs
+        # TODO(michalk8): enable in the future
         lint-kind: [code] # , docs]
 
     steps:
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index cf790bb84..c57a981c1 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -19,22 +19,28 @@ jobs:
       fail-fast: false
       matrix:
         os: [ubuntu-latest]
-        python: [3.8]
+        python: ["3.8", "3.10"]
+        include:
+          - os: macos-latest
+            python: "3.9"
+
     steps:
-      - uses: actions/checkout@v2
+      - uses: actions/checkout@v3
       - name: Set up Python ${{ matrix.python }}
-        uses: actions/setup-python@v2
+        uses: actions/setup-python@v4
         with:
           python-version: ${{ matrix.python }}
 
       - name: Install pip dependencies
-        # TODO(michalk8): remove tox-gh dependency, update tox.ini
         run: |
           python -m pip install --upgrade pip
-          pip install tox tox-gh-actions
+          pip install tox
+
       - name: Test
         run: |
-          tox -vv
+          tox -e py-${{ matrix.python }}
+        env:
+          PYTEST_ADDOPTS: -vv -n 2
 
       - name: Upload coverage
         uses: codecov/codecov-action@v3
diff --git a/.mypy.ini b/.mypy.ini
deleted file mode 100644
index b1830176c..000000000
--- a/.mypy.ini
+++ /dev/null
@@ -1,35 +0,0 @@
-[mypy]
-mypy_path = moscot
-python_version = 3.9
-plugins = numpy.typing.mypy_plugin
-
-ignore_errors = False
-
-warn_redundant_casts = True
-warn_unused_configs = True
-warn_unused_ignores = True
-
-disallow_untyped_calls = False
-disallow_untyped_defs = True
-disallow_incomplete_defs = True
-disallow_any_generics = True
-
-strict_optional = True
-strict_equality = True
-warn_return_any = False
-warn_unreachable = False
-check_untyped_defs = True
-no_implicit_optional = True
-no_implicit_reexport = True
-no_warn_no_return = True
-
-show_error_codes = True
-show_column_numbers = True
-error_summary = True
-ignore_missing_imports = True
-
-disable_error_code = assignment, comparison-overlap, no-untyped-def
-
-[mypy-tests.*]
-ignore_errors = True
-disable_error_code = assignment, comparison-overlap
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 6ab1412f1..cefb6e97e 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -4,20 +4,21 @@ default_language_version:
 default_stages:
   - commit
   - push
-minimum_pre_commit_version: 2.14.0
+minimum_pre_commit_version: 3.0.0
 repos:
   - repo: https://github.com/pre-commit/mirrors-mypy
-    rev: v0.991
+    rev: v1.1.1
     hooks:
       - id: mypy
         additional_dependencies: [numpy>=1.21.0, jax]
+        files: ^src
   - repo: https://github.com/psf/black
     rev: 23.1.0
     hooks:
       - id: black
         additional_dependencies: [toml]
   - repo: https://github.com/pre-commit/mirrors-prettier
-    rev: v3.0.0-alpha.4
+    rev: v3.0.0-alpha.6
     hooks:
       - id: prettier
         language_version: system
@@ -27,95 +28,42 @@ repos:
       - id: isort
         additional_dependencies: [toml]
         args: [--order-by-type]
-  - repo: https://github.com/asottile/yesqa
-    rev: v1.4.0
-    hooks:
-      - id: yesqa
-        additional_dependencies:
-          [
-            flake8-tidy-imports,
-            flake8-docstrings,
-            flake8-rst-docstrings,
-            flake8-comprehensions,
-            flake8-bugbear,
-            flake8-blind-except,
-            flake8-builtins,
-            flake8-pytest-style,
-            flake8-string-format,
-          ]
   - repo: https://github.com/pre-commit/pre-commit-hooks
     rev: v4.4.0
     hooks:
-      - id: detect-private-key
       - id: check-merge-conflict
       - id: check-ast
-      - id: check-symlinks
       - id: check-added-large-files
-      - id: check-executables-have-shebangs
-      - id: fix-encoding-pragma
-        args: [--remove]
       - id: end-of-file-fixer
       - id: mixed-line-ending
         args: [--fix=lf]
       - id: trailing-whitespace
-        exclude: ^.bumpversion.cfg$
-      - id: check-case-conflict
       - id: check-docstring-first
       - id: check-yaml
       - id: check-toml
-      - id: requirements-txt-fixer
-  - repo: https://github.com/myint/autoflake
-    rev: v2.0.1
-    hooks:
-      - id: autoflake
-        args:
-          [
-            --in-place,
-            --remove-all-unused-imports,
-            --remove-unused-variable,
-            --ignore-init-module-imports,
-          ]
-  - repo: https://github.com/pycqa/flake8.git
-    rev: 6.0.0
-    hooks:
-      - id: flake8
-        additional_dependencies:
-          [
-            flake8-tidy-imports,
-            flake8-docstrings,
-            flake8-rst-docstrings,
-            flake8-comprehensions,
-            flake8-bugbear,
-            flake8-blind-except,
-            flake8-builtins,
-            flake8-pytest-style,
-            flake8-string-format,
-          ]
-  - repo: https://github.com/jumanjihouse/pre-commit-hooks
-    rev: 3.0.0
+  - repo: https://github.com/asottile/pyupgrade
+    rev: v3.3.1
     hooks:
-      - id: script-must-have-extension
-        name: Check executable files use .sh extension
-        types: [shell, executable]
+      - id: pyupgrade
+        args: [--py3-plus, --py38-plus, --keep-runtime-typing]
   - repo: https://github.com/asottile/blacken-docs
     rev: 1.13.0
     hooks:
       - id: blacken-docs
         additional_dependencies: [black==23.1.0]
-  - repo: https://github.com/asottile/pyupgrade
-    rev: v3.3.1
+  - repo: https://github.com/rstcheck/rstcheck
+    rev: v6.1.1
     hooks:
-      - id: pyupgrade
-        args: [--py3-plus, --py38-plus, --keep-runtime-typing]
-  - repo: https://github.com/pre-commit/pygrep-hooks
-    rev: v1.10.0
-    hooks:
-      - id: python-no-eval
-      - id: python-check-blanket-noqa
-      - id: rst-backticks
-      - id: rst-directive-colons
-      - id: rst-inline-touching-normal
+      - id: rstcheck
+        additional_dependencies: [tomli]
+        args: [--config=pyproject.toml]
   - repo: https://github.com/PyCQA/doc8
     rev: v1.1.1
     hooks:
       - id: doc8
+  - repo: https://github.com/charliermarsh/ruff-pre-commit
+    # Ruff version.
+    rev: v0.0.252
+    hooks:
+      - id: ruff
+        args: [--fix, --exit-non-zero-on-fix]
diff --git a/.prettierignore b/.prettierignore
deleted file mode 100644
index a188e0692..000000000
--- a/.prettierignore
+++ /dev/null
@@ -1 +0,0 @@
-docs/*
diff --git a/.rstcheck.cfg b/.rstcheck.cfg
deleted file mode 100644
index e0e4d9ce4..000000000
--- a/.rstcheck.cfg
+++ /dev/null
@@ -1,3 +0,0 @@
-[rstcheck]
-ignore_messages=Unknown target name:.*|No (directive|role) entry for "(auto)?(bibliography|nbgallery|class|method|meth||property|function|func|mod|module|attr|cite)" in module "docutils\.parsers\.rst\.languages\.en"\.
-report=info
diff --git a/MANIFEST.in b/MANIFEST.in
index ec6ae01e5..d4dbf536d 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -2,6 +2,5 @@ include src/moscot/utils/data/allTFs_dmel.txt
 include src/moscot/utils/data/allTFs_hg38.txt
 include src/moscot/utils/data/allTFs_mm.txt
 prune docs
-prune resources
+prune tests
 prune .github
-prune tests/data
diff --git a/README.rst b/README.rst
index f008442cf..8cd9cb596 100644
--- a/README.rst
+++ b/README.rst
@@ -11,7 +11,7 @@ single-cell genomics. It can be used for
 - prototyping of new OT models in single-cell genomics
 
 **moscot** is powered by
-`OTT <https://ott-jax.readthedocs.io/en/latest/>`_ which is a JAX-based Optimal
+`OTT <https://ott-jax.readthedocs.io>`_ which is a JAX-based Optimal
 Transport toolkit that supports just-in-time compilation, GPU acceleration, automatic
 differentiation and linear memory complexity for OT problems.
 
@@ -39,7 +39,7 @@ If used with GPU, additionally run::
 Resources
 ---------
 
-Please have a look at our `documentation <https://moscot.readthedocs.io/en/latest/index.html/>`_
+Please have a look at our `documentation <https://moscot.readthedocs.io>`_
 
 Reference
 ---------
diff --git a/.github/.codecov.yml b/codecov.yml
similarity index 52%
rename from .github/.codecov.yml
rename to codecov.yml
index c729b3eed..fbe5cfa5b 100644
--- a/.github/.codecov.yml
+++ b/codecov.yml
@@ -1,18 +1,18 @@
 codecov:
-  require_ci_to_pass: false
+  require_ci_to_pass: true
   strict_yaml_branch: main
 
 coverage:
-  range: 90..100
+  range: "80...100"
   status:
     project:
       default:
-        target: 1
+        target: 75%
+        threshold: 1%
     patch: off
 
 comment:
-  layout: reach, diff, files
+  layout: "reach, diff, files"
   behavior: default
   require_changes: true
-  branches:
-    - main
+  branches: [main]
diff --git a/docs/Makefile b/docs/Makefile
index b08c1b5bb..8948b67df 100644
--- a/docs/Makefile
+++ b/docs/Makefile
@@ -6,7 +6,7 @@
 SPHINXOPTS    ?=
 SPHINXBUILD   ?= sphinx-build
 SOURCEDIR     = .
-BUILDDIR      = build
+BUILDDIR      = _build
 
 # Put it first so that "make" without argument is like "make help".
 help:
diff --git a/docs/_templates/autosummary/base.rst b/docs/_templates/autosummary/base.rst
deleted file mode 100644
index e81ca6450..000000000
--- a/docs/_templates/autosummary/base.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-:github_url: {{ fullname }}
-
-{% extends "!autosummary/base.rst" %}
-
-.. http://www.sphinx-doc.org/en/stable/ext/autosummary.html#customizing-templates
diff --git a/docs/_templates/autosummary/class.rst b/docs/_templates/autosummary/class.rst
index 8529158f5..5a4db15dd 100644
--- a/docs/_templates/autosummary/class.rst
+++ b/docs/_templates/autosummary/class.rst
@@ -1,33 +1,29 @@
-:github_url: {{ fullname }}
-
-{{ fullname | escape | underline}}
+{{ fullname | escape | underline }}
 
 .. currentmodule:: {{ module }}
 
 .. autoclass:: {{ objname }}
+    {% block methods %}
+    {%- if methods %}
+    .. rubric:: {{ _('Methods') }}
 
-   {% block attributes %}
-   {% if attributes %}
-   .. rubric:: Attributes
-
-   .. autosummary::
-      :toctree: .
-   {% for item in attributes %}
-      ~{{ fullname }}.{{ item }}
-   {%- endfor %}
-   {% endif %}
-   {% endblock %}
-
-   {% block methods %}
-   {% if methods %}
-   .. rubric:: Methods
+    .. autosummary::
+        :toctree: .
+    {% for item in methods %}
+    {%- if item not in ['__init__', 'tree_flatten', 'tree_unflatten', 'bind'] %}
+        ~{{ name }}.{{ item }}
+    {%- endif %}
+    {%- endfor %}
+    {%- endif %}
+    {%- endblock %}
+    {% block attributes %}
+    {%- if attributes %}
+    .. rubric:: {{ _('Attributes') }}
 
-   .. autosummary::
-      :toctree: .
-   {% for item in methods %}
-      {%- if item != '__init__' %}
-      ~{{ fullname }}.{{ item }}
-      {%- endif -%}
-   {%- endfor %}
-   {% endif %}
-   {% endblock %}
+    .. autosummary::
+        :toctree: .
+    {% for item in attributes %}
+        ~{{ name }}.{{ item }}
+    {%- endfor %}
+    {%- endif %}
+    {% endblock %}
diff --git a/docs/_templates/autosummary/function.rst b/docs/_templates/autosummary/function.rst
deleted file mode 100644
index 097114364..000000000
--- a/docs/_templates/autosummary/function.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-:github_url: {{ fullname }}
-
-{{ fullname | escape | underline}}
-
-.. autofunction:: {{ fullname }}
diff --git a/docs/_templates/breadcrumbs.html b/docs/_templates/breadcrumbs.html
deleted file mode 100644
index 4ecb013f8..000000000
--- a/docs/_templates/breadcrumbs.html
+++ /dev/null
@@ -1,4 +0,0 @@
-{%- extends "sphinx_rtd_theme/breadcrumbs.html" %}
-
-{% block breadcrumbs_aside %}
-{% endblock %}
diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst
index 7b48e8428..861d2392d 100644
--- a/docs/api/datasets.rst
+++ b/docs/api/datasets.rst
@@ -1,9 +1,6 @@
 Datasets
 ########
 
-Datasets
-~~~~~~~~
-
 .. currentmodule:: moscot.datasets
 
 .. autosummary::
diff --git a/docs/api/developer.rst b/docs/api/developer.rst
index 841da4bbb..9a2d8baa2 100644
--- a/docs/api/developer.rst
+++ b/docs/api/developer.rst
@@ -1,9 +1,8 @@
 Developer
 #########
 
-
 OTT Backend
-~~~~~~~~~~~~
+~~~~~~~~~~~
 
 .. autosummary::
     :toctree: genapi
diff --git a/docs/api/plotting.rst b/docs/api/plotting.rst
index 35a0f8579..41294334a 100644
--- a/docs/api/plotting.rst
+++ b/docs/api/plotting.rst
@@ -1,9 +1,6 @@
 Plotting
 ########
 
-Plotting
-~~~~~~~~
-
 .. currentmodule:: moscot.plotting
 
 .. autosummary::
diff --git a/pyproject.toml b/pyproject.toml
index 588b1dbbd..f5e147671 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -22,7 +22,6 @@ classifiers = [
     "Programming Language :: Python :: 3.8",
     "Programming Language :: Python :: 3.9",
     "Programming Language :: Python :: 3.10",
-    "Programming Language :: Python :: 3.11",
     "Topic :: Scientific/Engineering :: Bio-Informatics",
     "Topic :: Scientific/Engineering :: Mathematics"
 ]
@@ -43,9 +42,7 @@ maintainers = [
     {name = "Giovanni Palla", email = "giovanni.palla@helmholtz-muenchen.de"},
     {name = "Michal Klein", email = "michal.klein@helmholtz-muenchen.de"}
 ]
-urls.Documentation = "https://moscot.readthedocs.io/"
-urls.Source = "https://github.com/theislab/moscot"
-urls.Home-page = "https://github.com/theislab/moscot"
+
 
 dependencies = [
     "numpy>=1.20.0",
@@ -53,9 +50,9 @@ dependencies = [
     "pandas>=1.4.0",
     "networkx>=2.6.3",
     # https://github.com/scverse/scanpy/issues/2411
-    "matplotlib>=3.4.0,<3.7",
+    "matplotlib>=3.5.0",
     "anndata>=0.8.0",
-    "scanpy>=1.9.0",
+    "scanpy>=1.9.3",
     "wrapt>=1.13.2",
     "docrep>=0.3.2",
     "ott-jax>=0.4.0",
@@ -67,8 +64,15 @@ spatial = [
     "squidpy>=1.2.3"
 ]
 dev = [
-    "tox>=3.24.0",
-    "pre-commit>=2.14.0"
+    "pre-commit>=3.0.0",
+    "tox>=4",
+]
+test = [
+    "pytest>=7",
+    "pytest-xdist>=3",
+    "pytest-mock>=3.5.0",
+    "pytest-cov>=4",
+    "coverage[toml]>=7",
 ]
 docs = [
     "sphinx>=5.1.1",
@@ -82,40 +86,83 @@ docs = [
     "sphinx_design>=0.3.0",
 ]
 
-test = [
-    "pytest>=7.1.2",
-    "pytest-mock>=3.5.0",
-    "pytest-cov>=4.0.0",
-]
+[project.urls]
+Homepage = "https://github.com/theislab/moscot"
+Download = "https://moscot.readthedocs.io/en/latest/installation.html"
+"Bug Tracker" = "https://github.com/theislab/moscot/issues"
+Documentation = "https://moscot.readthedocs.io"
+"Source Code" = "https://github.com/theislab/moscot"
 
 [tool.setuptools]
 package-dir = {"" = "src"}
-packages = {find = {where = ["src"], namespaces = false}}
+include-package-data = true
 
 [tool.setuptools_scm]
 
+[tool.ruff]
+exclude = [
+    ".git",
+    "__pycache__",
+    "build",
+    "docs/_build",
+    "dist"
+]
+ignore = [
+    # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient
+    "E731",
+    # allow I, O, l as variable names -> I is the identity matrix, i, j, k, l is reasonable indexing notation
+    "E741",
+    # Missing docstring in public package
+    "D104",
+    # Missing docstring in public module
+    "D100",
+    # Missing docstring in __init__
+    "D107",
+    # Missing docstring in magic method
+    "D105",
+]
+line-length = 120
+select = [
+    "D", # flake8-docstrings
+    # TODO(michalk8): enable this in https://github.com/theislab/moscot/issues/483
+    # "I", # isort
+    "E", # pycodestyle
+    "F", # pyflakes
+    "W", # pycodestyle
+    "Q", # flake8-quotes
+    "SIM", # flake8-simplify
+    "NPY",  # NumPy-specific rules
+    "PT",  # flake8-pytest-style
+    "B", # flake8-bugbear
+    "UP", # pyupgrade
+    "C4", # flake8-comprehensions
+    "BLE", # flake8-blind-except
+    "T20",  # flake8-print
+    "RET", # flake8-raise
+]
+unfixable = ["B", "UP", "C4", "BLE", "T20", "RET"]
+target-version = "py38"
+[tool.ruff.per-file-ignores]
+"tests/*" = ["D"]
+"*/__init__.py" = ["F401"]
+"docs/*" = ["D"]
+[tool.ruff.pydocstyle]
+convention = "numpy"
+[tool.ruff.pyupgrade]
+# Preserve types, even if a file imports `from __future__ import annotations`.
+keep-runtime-typing = true
+[tool.ruff.flake8-tidy-imports]
+# Disallow all relative imports.
+ban-relative-imports = "parents"
+[tool.ruff.flake8-quotes]
+inline-quotes = "double"
+
 [tool.black]
 line-length = 120
 target-version = ['py38']
 include = '\.pyi?$'
-exclude = '''
-(
-  /(
-      \.eggs
-    | \.git
-    | \.hg
-    | \.mypy_cache
-    | \.tox
-    | \.venv
-    | _build
-    | buck-out
-    | build
-    | dist
-  )/
-
-)
-'''
 
+# TODO(michalk8): simplify in https://github.com/theislab/moscot/issues/483
 [tool.isort]
 profile = "black"
 py_version = "38"
@@ -145,6 +192,147 @@ force_alphabetical_sort_within_sections = true
 lexicographical = true
 
 [tool.pytest.ini_options]
-markers = [
-    "fast: marks tests as fask",
+markers = ["fast: marks tests as fask"]
+xfail_strict = true
+filterwarnings = [
+    "ignore:X.dtype being converted:FutureWarning",
+    "ignore:No data for colormapping:UserWarning",
+    "ignore:jax\\.experimental\\.pjit\\.PartitionSpec:DeprecationWarning",
+]
+
+[tool.coverage.run]
+branch = true
+parallel = true
+source = ["src/"]
+omit = [
+    "*/__init__.py",
+    "*/_version.py",
+]
+
+[tool.coverage.report]
+exclude_lines = [
+    '\#.*pragma:\s*no.?cover',
+    "^if __name__ == .__main__.:$",
+    '^\s*raise AssertionError\b',
+    '^\s*raise NotImplementedError\b',
+    '^\s*return NotImplemented\b',
 ]
+precision = 2
+show_missing = true
+skip_empty = true
+sort = "Miss"
+
+[tool.rstcheck]
+ignore_directives = [
+    "toctree",
+    "currentmodule",
+    "autosummary",
+    "automodule",
+    "autoclass",
+    "bibliography",
+    "card",
+]
+
+[tool.mypy]
+mypy_path = "$MYPY_CONFIG_FILE_DIR/src"
+python_version = "3.9"
+plugins = "numpy.typing.mypy_plugin"
+
+ignore_errors = false
+
+warn_redundant_casts = true
+warn_unused_configs = true
+warn_unused_ignores = true
+
+disallow_untyped_calls = false
+disallow_untyped_defs = true
+disallow_incomplete_defs = true
+disallow_any_generics = true
+
+strict_optional = true
+strict_equality = true
+warn_return_any = false
+warn_unreachable = false
+check_untyped_defs = true
+no_implicit_optional = true
+no_implicit_reexport = true
+no_warn_no_return = true
+
+show_error_codes = true
+show_column_numbers = true
+error_summary = true
+ignore_missing_imports = true
+
+disable_error_code = ["assignment", "comparison-overlap", "no-untyped-def"]
+
+[tool.doc8]
+max_line_length = 120
+
+[tool.tox]
+legacy_tox_ini = """
+[tox]
+min_version = 4.0
+env_list = lint-code,py{3.8,3.9,3.10,3.11}
+skip_missing_interpreters = true
+
+[testenv]
+extras = test
+pass_env = PYTEST_*,CI
+commands =
+    python -m pytest {tty:--color=yes} {posargs: \
+        --cov={env_site_packages_dir}{/}moscot --cov-config={tox_root}{/}pyproject.toml \
+        --no-cov-on-fail --cov-report=xml --cov-report=term-missing:skip-covered}
+
+[testenv:lint-code]
+description = Lint the code.
+deps = pre-commit>=3.0.0
+skip_install = true
+commands =
+    pre-commit run --all-files --show-diff-on-failure
+
+[testenv:lint-docs]
+description = Lint the documentation.
+deps =
+extras = docs
+ignore_errors = true
+allowlist_externals = make
+pass_env = PYENCHANT_LIBRARY_PATH
+set_env = SPHINXOPTS = -W -q --keep-going
+changedir = {tox_root}{/}docs
+commands =
+    make linkcheck {posargs}
+    make spelling {posargs}
+
+[testenv:clean-docs]
+description = Remove the documentation.
+skip_install = true
+changedir = {tox_root}{/}docs
+allowlist_externals = make
+commands =
+    make clean
+
+[testenv:build-docs]
+description = Build the documentation.
+use_develop = true
+deps =
+extras = docs
+allowlist_externals = make
+changedir = {tox_root}{/}docs
+commands =
+    make html {posargs}
+commands_post =
+    python -c 'import pathlib; print("Documentation is under:", pathlib.Path("{tox_root}") / "docs" / "_build" / "html" / "index.html")'
+
+[testenv:build-package]
+description = Build the package.
+deps =
+    build
+    twine
+allowlist_externals = rm
+commands =
+    rm -rf {tox_root}{/}dist
+    python -m build --sdist --wheel --outdir {tox_root}{/}dist{/} {posargs:}
+    python -m twine check {tox_root}{/}dist{/}*
+commands_post =
+    python -c 'import pathlib; print(f"Package is under:", pathlib.Path("{tox_root}") / "dist")'
+"""
diff --git a/src/moscot/__init__.py b/src/moscot/__init__.py
index 2e2e8f776..fd672cdb5 100644
--- a/src/moscot/__init__.py
+++ b/src/moscot/__init__.py
@@ -9,8 +9,8 @@
 import moscot.problems
 
 try:
-    __version__ = metadata.version(__name__)
     md = metadata.metadata(__name__)
+    __version__ = md.get("version", "")
     __author__ = md.get("Author", "")
     __maintainer__ = md.get("Maintainer-email", "")
 except ImportError:
diff --git a/src/moscot/_constants/_enum.py b/src/moscot/_constants/_enum.py
index 42b0e48d9..522ecc83f 100644
--- a/src/moscot/_constants/_enum.py
+++ b/src/moscot/_constants/_enum.py
@@ -26,7 +26,7 @@ def wrapper(*args: Any, **kwargs: Any) -> "ErrorFormatterABC":
 
     if not issubclass(cls, ErrorFormatterABC):
         raise TypeError(f"Class `{cls}` must be subtype of `ErrorFormatterABC`.")
-    elif not len(cls.__members__):  # type: ignore[attr-defined]
+    if not len(cls.__members__):  # type: ignore[attr-defined]
         # empty enum, for class hierarchy
         return func
 
@@ -41,7 +41,7 @@ def __call__(cls, *args: Any, **kwargs: Any) -> Any:
 
     def __new__(cls, clsname: str, superclasses: Tuple[type], attributedict: Dict[str, Any]) -> "ABCEnumMeta":
         res = super().__new__(cls, clsname, superclasses, attributedict)  # type: ignore[arg-type]
-        res.__new__ = _pretty_raise_enum(res, res.__new__)  # type: ignore[assignment,arg-type]
+        res.__new__ = _pretty_raise_enum(res, res.__new__)  # type: ignore[method-assign,arg-type]
         return res
 
 
diff --git a/src/moscot/_constants/_key.py b/src/moscot/_constants/_key.py
index e5869ed53..de44b270f 100644
--- a/src/moscot/_constants/_key.py
+++ b/src/moscot/_constants/_key.py
@@ -1,4 +1,5 @@
 from typing import Any, Set, List, Callable, Optional
+import contextlib
 
 import numpy as np
 
@@ -56,8 +57,8 @@ def __init__(self, adata: AnnData, n: Optional[int] = None, where: str = "obs"):
         self._keys: List[str] = []
 
     def _generate_random_keys(self):
-        def generator():
-            return f"RNG_COL_{np.random.randint(2 ** 16)}"
+        def generator() -> str:
+            return f"RNG_COL_{np.random.RandomState().randint(2 ** 16)}"
 
         where = getattr(self._adata, self._where)
         names: List[str] = []
@@ -76,13 +77,9 @@ def __enter__(self):
         return self._keys
 
     def __exit__(self, exc_type, exc_val, exc_tb):
-        for key in self._keys:
-            try:
-                getattr(self._adata, self._where).drop(key, axis="columns", inplace=True)
-            except KeyError:
-                pass
-            if self._where == "obs":
-                try:
+        with contextlib.suppress(KeyError):
+            for key in self._keys:
+                df = getattr(self._adata, self._where)
+                df.drop(key, axis="columns", inplace=True)
+                if self._where == "obs":
                     del self._adata.uns[f"{key}_colors"]
-                except KeyError:
-                    pass
diff --git a/src/moscot/_utils.py b/src/moscot/_utils.py
index 29ea56068..317437461 100644
--- a/src/moscot/_utils.py
+++ b/src/moscot/_utils.py
@@ -58,7 +58,6 @@ def parallelize(
     -------
     The result depending on ``callable``, ``extractor`` and ``as_array``.
     """
-
     if show_progress_bar:
         try:
             from tqdm.auto import tqdm
@@ -186,7 +185,6 @@ def _np_apply_along_axis(func1d, axis: int, arr: ArrayLike) -> ArrayLike:
     -------
     The reduced array.
     """
-
     assert arr.ndim == 2
     assert axis in [0, 1]
 
diff --git a/src/moscot/backends/ott/_output.py b/src/moscot/backends/ott/_output.py
index 953ec2979..7755d5878 100644
--- a/src/moscot/backends/ott/_output.py
+++ b/src/moscot/backends/ott/_output.py
@@ -80,8 +80,7 @@ def plot_costs(
 
         if save is not None:
             fig.savefig(save)
-        if return_fig:
-            return fig
+        return fig if return_fig else None
 
     @d.dedent
     def plot_errors(
@@ -123,8 +122,7 @@ def plot_errors(
 
         if save is not None:
             fig.savefig(save)
-        if return_fig:
-            return fig
+        return fig if return_fig else None
 
     def _plot_lines(
         self,
@@ -179,7 +177,7 @@ def to(self, device: Optional[Device_t] = None) -> "OTTOutput":
             try:
                 device = jax.devices(device)[idx]
             except IndexError:
-                raise IndexError(f"Unable to fetch the device with `id={idx}`.")
+                raise IndexError(f"Unable to fetch the device with `id={idx}`.") from None
 
         return OTTOutput(jax.device_put(self._output, device))
 
@@ -205,4 +203,4 @@ def rank(self) -> int:
         return len(lin_output.g) if isinstance(lin_output, OTTLRSinkhornOutput) else -1
 
     def _ones(self, n: int) -> ArrayLike:
-        return jnp.ones((n,))
+        return jnp.ones((n,))  # type: ignore[return-value]
diff --git a/src/moscot/backends/ott/_solver.py b/src/moscot/backends/ott/_solver.py
index d51a04982..d30dd0ef4 100644
--- a/src/moscot/backends/ott/_solver.py
+++ b/src/moscot/backends/ott/_solver.py
@@ -102,7 +102,7 @@ def _assert2d(arr: Optional[ArrayLike], *, allow_reshape: bool = True) -> Option
             return None
         arr: ArrayLike = jnp.asarray(arr.A if issparse(arr) else arr)  # type: ignore[attr-defined, no-redef]
         if allow_reshape and arr.ndim == 1:
-            return jnp.reshape(arr, (-1, 1))
+            return jnp.reshape(arr, (-1, 1))  # type: ignore[return-value]
         if arr.ndim != 2:
             raise ValueError(f"Expected array to have 2 dimensions, found `{arr.ndim}`.")
         return arr
diff --git a/src/moscot/datasets/_datasets.py b/src/moscot/datasets/_datasets.py
index 0125cd98d..b4c7f1afb 100644
--- a/src/moscot/datasets/_datasets.py
+++ b/src/moscot/datasets/_datasets.py
@@ -95,7 +95,9 @@ def hspc(
     path: PathLike = "~/.cache/moscot/hspc.h5ad",
     **kwargs: Any,
 ) -> AnnData:  # pragma: no cover
-    """Subsampled and processed data from the `NeurIPS Multimodal Single-Cell Integration Challenge \
+    """CD34+ hematopoietic stem and progenitor cells from 4 healthy human donors.
+
+    From the `NeurIPS Multimodal Single-Cell Integration Challenge
     <https://www.kaggle.com/competitions/open-problems-multimodal/data>`_.
 
     4000 cells were randomly selected after filtering the multiome training data of the donor `31800`.
diff --git a/src/moscot/plotting/_plotting.py b/src/moscot/plotting/_plotting.py
index e34b67f9e..bb7739a19 100644
--- a/src/moscot/plotting/_plotting.py
+++ b/src/moscot/plotting/_plotting.py
@@ -18,7 +18,7 @@
 
 @d_plotting.dedent
 def cell_transition(
-    inp: Union[AnnData, Tuple[AnnData, AnnData], CompoundProblem],
+    inp: Union[AnnData, Tuple[AnnData, AnnData], CompoundProblem],  # type: ignore[type-arg]
     uns_key: str = PlottingKeys.CELL_TRANSITION,
     row_labels: Optional[str] = None,
     col_labels: Optional[str] = None,
@@ -62,7 +62,9 @@ def cell_transition(
     try:
         _ = adata1.uns[AdataKeys.UNS][PlottingKeys.CELL_TRANSITION][key]
     except KeyError:
-        raise KeyError(f"No data found in `adata.uns[{AdataKeys.UNS!r}][{PlottingKeys.CELL_TRANSITION!r}][{key!r}]`.")
+        raise KeyError(
+            f"No data found in `adata.uns[{AdataKeys.UNS!r}][{PlottingKeys.CELL_TRANSITION!r}][{key!r}]`."
+        ) from None
 
     data = adata1.uns[AdataKeys.UNS][PlottingKeys.CELL_TRANSITION][key]
     fig = _heatmap(
@@ -87,8 +89,7 @@ def cell_transition(
         cbar_kwargs=cbar_kwargs,
         **kwargs,
     )
-    if return_fig:
-        return fig
+    return fig if return_fig else None
 
 
 @d_plotting.dedent
@@ -140,7 +141,7 @@ def sankey(
     try:
         _ = adata.uns[AdataKeys.UNS][PlottingKeys.SANKEY][key]
     except KeyError:
-        raise KeyError(f"No data found in `adata.uns[{AdataKeys.UNS!r}][{PlottingKeys.SANKEY!r}][{key!r}]`.")
+        raise KeyError(f"No data found in `adata.uns[{AdataKeys.UNS!r}][{PlottingKeys.SANKEY!r}][{key!r}]`.") from None
 
     data = adata.uns[AdataKeys.UNS][PlottingKeys.SANKEY][key]
     fig = _sankey(
@@ -160,13 +161,12 @@ def sankey(
     )
     if save:
         fig.figure.savefig(save)
-    if return_fig:
-        return fig
+    return fig if return_fig else None
 
 
 @d_plotting.dedent
 def push(
-    inp: Union[AnnData, TemporalProblem, LineageProblem, CompoundProblem],
+    inp: Union[AnnData, TemporalProblem, LineageProblem, CompoundProblem],  # type: ignore[type-arg]
     uns_key: Optional[str] = None,
     time_points: Optional[Sequence[float]] = None,
     basis: str = "umap",
@@ -245,13 +245,12 @@ def push(
         suptitle_fontsize=suptitle_fontsize,
         **kwargs,
     )
-    if return_fig:
-        return fig.figure
+    return fig.figure if return_fig else None
 
 
 @d_plotting.dedent
 def pull(
-    inp: Union[AnnData, TemporalProblem, LineageProblem, CompoundProblem],
+    inp: Union[AnnData, TemporalProblem, LineageProblem, CompoundProblem],  # type: ignore[type-arg]
     uns_key: Optional[str] = None,
     time_points: Optional[Sequence[float]] = None,
     basis: str = "umap",
@@ -330,5 +329,4 @@ def pull(
         suptitle_fontsize=suptitle_fontsize,
         **kwargs,
     )
-    if return_fig:
-        return fig.figure
+    return fig.figure if return_fig else None
diff --git a/src/moscot/plotting/_utils.py b/src/moscot/plotting/_utils.py
index 1c6a3332c..e3223058b 100644
--- a/src/moscot/plotting/_utils.py
+++ b/src/moscot/plotting/_utils.py
@@ -345,7 +345,9 @@ def _contrasting_color(r: int, g: int, b: int) -> str:
     return "#000000" if r * 0.299 + g * 0.587 + b * 0.114 > 186 else "#ffffff"
 
 
-def _input_to_adatas(inp: Union[AnnData, Tuple[AnnData, AnnData], CompoundProblem]) -> Tuple[AnnData, AnnData]:
+def _input_to_adatas(
+    inp: Union[AnnData, Tuple[AnnData, AnnData], CompoundProblem]  # type: ignore[type-arg]
+) -> Tuple[AnnData, AnnData]:
     if isinstance(inp, CompoundProblem):
         return inp.adata, inp.adata
     if isinstance(inp, AnnData):
@@ -509,6 +511,4 @@ def _create_col_colors(adata: AnnData, obs_col: str, subset: Union[str, List[str
     h, _, v = mcolors.rgb_to_hsv(mcolors.to_rgb(color))
     end_color = mcolors.hsv_to_rgb([h, 1, v])
 
-    col_cmap = mcolors.LinearSegmentedColormap.from_list("category_cmap", ["darkgrey", end_color])
-
-    return col_cmap
+    return mcolors.LinearSegmentedColormap.from_list("category_cmap", ["darkgrey", end_color])
diff --git a/src/moscot/problems/_utils.py b/src/moscot/problems/_utils.py
index 8417e1c9a..8350d5e0f 100644
--- a/src/moscot/problems/_utils.py
+++ b/src/moscot/problems/_utils.py
@@ -27,7 +27,10 @@ def require_solution(
 
 @wrapt.decorator
 def require_prepare(
-    wrapped: Callable[[Any], Any], instance: "BaseCompoundProblem", args: Tuple[Any, ...], kwargs: Mapping[str, Any]
+    wrapped: Callable[[Any], Any],
+    instance: "BaseCompoundProblem",  # type: ignore[type-arg]
+    args: Tuple[Any, ...],
+    kwargs: Mapping[str, Any],
 ) -> Any:
     """Check whether problem has been prepared."""
     if instance.problems is None:
@@ -96,12 +99,15 @@ def handle_joint_attr(
                 "y_key": joint_attr["key"],
             }
             return xy, kwargs
-        if joint_attr.get("tag", None) == "cost_matrix":  # if this is True we have custom cost matrix or moscot cost
-            if len(joint_attr) == 2 or kwargs.get("attr", None) == "obsp":  # in this case we have a custom cost matrix
-                joint_attr.setdefault("cost", "custom")
-                joint_attr.setdefault("attr", "obsp")
-                kwargs["xy_callback"] = "cost-matrix"
-                kwargs.setdefault("xy_callback_kwargs", {"key": joint_attr["key"]})
+
+        # if this is True we have custom cost matrix or moscot cost - in this case we have a custom cost matrix
+        if joint_attr.get("tag", None) == "cost_matrix" and (
+            len(joint_attr) == 2 or kwargs.get("attr", None) == "obsp"
+        ):
+            joint_attr.setdefault("cost", "custom")
+            joint_attr.setdefault("attr", "obsp")
+            kwargs["xy_callback"] = "cost-matrix"
+            kwargs.setdefault("xy_callback_kwargs", {"key": joint_attr["key"]})
         kwargs.setdefault("xy_callback_kwargs", {})
         return joint_attr, kwargs
     raise TypeError(f"Expected `joint_attr` to be either `str` or `dict`, found `{type(joint_attr)}`.")
@@ -138,3 +144,4 @@ def handle_cost(
             y = dict(y)
             y["cost"] = cost["y"]
         return xy, x, y
+    raise TypeError(type(cost))
diff --git a/src/moscot/problems/base/_utils.py b/src/moscot/problems/base/_utils.py
index 62f116174..4c4d65f6a 100644
--- a/src/moscot/problems/base/_utils.py
+++ b/src/moscot/problems/base/_utils.py
@@ -153,19 +153,6 @@ def _get_categories_from_adata(
     return adata[adata.obs[key] == key_value].obs[annotation_key]
 
 
-def _get_problem_key(
-    source: Optional[Any] = None,  # TODO(@MUCDK) using `K` induces circular import, resolve
-    target: Optional[Any] = None,  # TODO(@MUCDK) using `K` induces circular import, resolve
-) -> Tuple[Any, Any]:  # TODO(@MUCDK) using `K` induces circular import, resolve
-    if source is not None and target is not None:
-        return (source, target)
-    elif source is None and target is not None:
-        return ("src", target)  # TODO(@MUCDK) make package constant
-    elif source is not None and target is None:
-        return (source, "ref")  # TODO(@MUCDK) make package constant
-    return ("src", "ref")
-
-
 def _order_transition_matrix_helper(
     tm: pd.DataFrame,
     rows_verified: List[str],
@@ -189,43 +176,42 @@ def _order_transition_matrix(
     target_annotations_ordered: Optional[List[str]],
     forward: bool,
 ) -> pd.DataFrame:
+    # TODO(michalk8): simplify
     if target_annotations_ordered is not None or source_annotations_ordered is not None:
         if forward:
-            tm = _order_transition_matrix_helper(
+            return _order_transition_matrix_helper(
                 tm=tm,
                 rows_verified=source_annotations_verified,
                 cols_verified=target_annotations_verified,
                 row_order=source_annotations_ordered,
                 col_order=target_annotations_ordered,
-            )
-        else:
-            tm = _order_transition_matrix_helper(
-                tm=tm,
-                rows_verified=target_annotations_verified,
-                cols_verified=source_annotations_verified,
-                row_order=target_annotations_ordered,
-                col_order=source_annotations_ordered,
-            )
-        return tm.T if forward else tm
-    elif target_annotations_verified == source_annotations_verified:
+            ).T
+        return _order_transition_matrix_helper(
+            tm=tm,
+            rows_verified=target_annotations_verified,
+            cols_verified=source_annotations_verified,
+            row_order=target_annotations_ordered,
+            col_order=source_annotations_ordered,
+        )
+
+    if target_annotations_verified == source_annotations_verified:
         annotations_ordered = tm.columns.sort_values()
         if forward:
-            tm = _order_transition_matrix_helper(
+            return _order_transition_matrix_helper(
                 tm=tm,
                 rows_verified=source_annotations_verified,
                 cols_verified=target_annotations_verified,
                 row_order=annotations_ordered,
                 col_order=annotations_ordered,
-            )
-        else:
-            tm = _order_transition_matrix_helper(
-                tm=tm,
-                rows_verified=target_annotations_verified,
-                cols_verified=source_annotations_verified,
-                row_order=annotations_ordered,
-                col_order=annotations_ordered,
-            )
-        return tm.T if forward else tm
+            ).T
+        return _order_transition_matrix_helper(
+            tm=tm,
+            rows_verified=target_annotations_verified,
+            cols_verified=source_annotations_verified,
+            row_order=annotations_ordered,
+            col_order=annotations_ordered,
+        )
+
     return tm if forward else tm.T
 
 
@@ -284,7 +270,6 @@ def _correlation_test(
         - ``ci_low`` - lower bound of the ``confidence_level`` correlation confidence interval.
         - ``ci_high`` - upper bound of the ``confidence_level`` correlation confidence interval.
     """
-
     corr, pvals, ci_low, ci_high = _correlation_test_helper(
         X.T,
         Y.values,
diff --git a/src/moscot/problems/generic/_generic.py b/src/moscot/problems/generic/_generic.py
index 38d8586ac..a597cdf13 100644
--- a/src/moscot/problems/generic/_generic.py
+++ b/src/moscot/problems/generic/_generic.py
@@ -62,7 +62,7 @@ def prepare(
         --------
         %(ex_prepare)s
         """
-        self.batch_key = key
+        self.batch_key = key  # type: ignore[misc]
         xy, kwargs = handle_joint_attr(joint_attr, kwargs)
         xy, _, _ = handle_cost(xy=xy, cost=cost)
         return super().prepare(
@@ -129,7 +129,7 @@ def solve(
         --------
         %(ex_solve_linear)s
         """
-        return super().solve(
+        return super().solve(  # type: ignore[return-value]
             epsilon=epsilon,
             tau_a=tau_a,
             tau_b=tau_b,
@@ -155,7 +155,7 @@ def solve(
 
     @property
     def _base_problem_type(self) -> Type[B]:
-        return OTProblem
+        return OTProblem  # type: ignore[return-value]
 
     @property
     def _valid_policies(self) -> Tuple[str, ...]:
@@ -219,7 +219,7 @@ def prepare(
         --------
         %(ex_prepare)s
         """
-        self.batch_key = key
+        self.batch_key = key  # type: ignore[misc]
 
         GW_updated: List[Dict[str, Any]] = [{}] * 2
         for i, z in enumerate([GW_x, GW_y]):
@@ -299,7 +299,7 @@ def solve(
         --------
         %(ex_solve_quadratic)s
         """
-        return super().solve(
+        return super().solve(  # type: ignore[return-value]
             alpha=alpha,
             epsilon=epsilon,
             tau_a=tau_a,
@@ -325,7 +325,7 @@ def solve(
 
     @property
     def _base_problem_type(self) -> Type[B]:
-        return OTProblem
+        return OTProblem  # type: ignore[return-value]
 
     @property
     def _valid_policies(self) -> Tuple[str, ...]:
diff --git a/src/moscot/problems/generic/_mixins.py b/src/moscot/problems/generic/_mixins.py
index 80a46adb9..57566a404 100644
--- a/src/moscot/problems/generic/_mixins.py
+++ b/src/moscot/problems/generic/_mixins.py
@@ -143,8 +143,7 @@ def push(
             }
             self.adata.obs[key_added] = self._flatten(result, key=self.batch_key)
             Key.uns.set_plotting_vars(self.adata, PlottingKeys.PUSH, key_added, plot_vars)
-        if return_data:
-            return result
+        return result if return_data else None
 
     @d_mixins.dedent
     def pull(
@@ -197,8 +196,7 @@ def pull(
             }
             self.adata.obs[key_added] = self._flatten(result, key=self.batch_key)
             Key.uns.set_plotting_vars(self.adata, PlottingKeys.PULL, key_added, plot_vars)
-        if return_data:
-            return result
+        return result if return_data else None
 
     @property
     def batch_key(self: GenericAnalysisMixinProtocol[K, B]) -> Optional[str]:
diff --git a/src/moscot/problems/space/_mapping.py b/src/moscot/problems/space/_mapping.py
index 9c82dcd2d..bc6abb866 100644
--- a/src/moscot/problems/space/_mapping.py
+++ b/src/moscot/problems/space/_mapping.py
@@ -229,7 +229,7 @@ def filtered_vars(self) -> Optional[Sequence[str]]:
 
     @filtered_vars.setter
     def filtered_vars(self, value: Optional[Sequence[str]]) -> None:
-        self._filtered_vars = self._filter_vars(var_names=value)
+        self._filtered_vars = self._filter_vars(var_names=value)  # type: ignore[misc]
 
     @property
     def _base_problem_type(self) -> Type[B]:
diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py
index 19ecb0674..29f3bbc01 100644
--- a/src/moscot/problems/space/_mixins.py
+++ b/src/moscot/problems/space/_mixins.py
@@ -85,11 +85,7 @@ def _filter_vars(
     ) -> Optional[List[str]]:
         ...
 
-    def _cell_transition(
-        self: AnalysisMixinProtocol[K, B],
-        *args: Any,
-        **kwargs: Any,
-    ) -> pd.DataFrame:
+    def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame:
         ...
 
 
@@ -101,14 +97,13 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
         self._spatial_key: Optional[str] = None
         self._batch_key: Optional[str] = None
 
-    def _interpolate_scheme(
+    def _interpolate_scheme(  # type: ignore[misc]
         self: SpatialAlignmentMixinProtocol[K, B],
         reference: K,
         mode: Literal["warp", "affine"],
         spatial_key: Optional[str] = None,
     ) -> Tuple[Dict[K, ArrayLike], Optional[Dict[K, Optional[ArrayLike]]]]:
         """Scheme for interpolation."""
-
         # get reference
         src = self._subset_spatial(reference, spatial_key=spatial_key)
         transport_maps: Dict[K, ArrayLike] = {reference: src}
@@ -118,12 +113,9 @@ def _interpolate_scheme(
             transport_metadata = {reference: np.diag((1, 1))}  # 2d data
 
         # get policy
-        if isinstance(reference, str):
-            reference_ = [reference]
-        else:
-            reference_ = reference
+        reference_ = [reference] if isinstance(reference, str) else reference
         full_steps = self._policy._graph
-        starts = set(chain.from_iterable(full_steps)) - set(reference_)  # type: ignore[arg-type]
+        starts = set(chain.from_iterable(full_steps)) - set(reference_)  # type: ignore[call-overload]
 
         if mode == AlignmentMode.AFFINE:
             _transport = self._affine
@@ -149,7 +141,7 @@ def _interpolate_scheme(
         return transport_maps, (transport_metadata if mode == "affine" else None)
 
     @d.dedent
-    def align(
+    def align(  # type: ignore[misc]
         self: SpatialAlignmentMixinProtocol[K, B],
         reference: K,
         mode: Literal["warp", "affine"] = "warp",
@@ -193,10 +185,10 @@ def align(
             self.adata.uns[self.spatial_key]["alignment_metadata"] = aligned_metadata
         if not inplace:
             return aligned_basis
-        self.adata.obsm[f"{self.spatial_key}_{mode}"] = aligned_basis
+        self.adata.obsm[f"{self.spatial_key}_{mode}"] = aligned_basis  # noqa: RET503
 
     @d_mixins.dedent
-    def cell_transition(
+    def cell_transition(  # type: ignore[misc]
         self: SpatialAlignmentMixinProtocol[K, B],
         source: K,
         target: K,
@@ -255,7 +247,7 @@ def spatial_key(self) -> Optional[str]:
         return self._spatial_key
 
     @spatial_key.setter
-    def spatial_key(self: SpatialAlignmentMixinProtocol[K, B], key: Optional[str]) -> None:
+    def spatial_key(self: SpatialAlignmentMixinProtocol[K, B], key: Optional[str]) -> None:  # type: ignore[misc]
         if key is not None and key not in self.adata.obsm:
             raise KeyError(f"Unable to find spatial data in `adata.obsm[{key!r}]`.")
         self._spatial_key = key
@@ -267,11 +259,11 @@ def batch_key(self) -> Optional[str]:
 
     @batch_key.setter
     def batch_key(self, key: Optional[str]) -> None:
-        if key is not None and key not in self.adata.obs:
+        if key is not None and key not in self.adata.obs:  # type: ignore[attr-defined]
             raise KeyError(f"Unable to find batch data in `adata.obs[{key!r}]`.")
         self._batch_key = key
 
-    def _subset_spatial(
+    def _subset_spatial(  # type: ignore[misc]
         self: SpatialAlignmentMixinProtocol[K, B], k: K, spatial_key: Optional[str] = None
     ) -> ArrayLike:
         if spatial_key is None:
@@ -308,7 +300,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
         self._batch_key: Optional[str] = None
         self._spatial_key: Optional[str] = None
 
-    def _filter_vars(
+    def _filter_vars(  # type: ignore[misc]
         self: SpatialMappingMixinProtocol[K, B],
         var_names: Optional[Sequence[str]] = None,
     ) -> Optional[List[str]]:
@@ -330,7 +322,7 @@ def _filter_vars(
 
         raise ValueError("Some variable are missing in the single-cell or the spatial `AnnData`.")
 
-    def correlate(
+    def correlate(  # type: ignore[misc]
         self: SpatialMappingMixinProtocol[K, B],
         var_names: Optional[List[str]] = None,
         corr_method: Literal["pearson", "spearman"] = "pearson",
@@ -382,7 +374,7 @@ def correlate(
 
         return corrs
 
-    def impute(
+    def impute(  # type: ignore[misc]
         self: SpatialMappingMixinProtocol[K, B],
         var_names: Optional[Sequence[Any]] = None,
         device: Optional[Device_t] = None,
@@ -415,7 +407,7 @@ def impute(
         return adata_pred
 
     @d.dedent
-    def spatial_correspondence(
+    def spatial_correspondence(  # type: ignore[misc]
         self: SpatialMappingMixinProtocol[K, B],
         interval: Union[ArrayLike, int] = 10,
         max_dist: Optional[int] = None,
@@ -453,8 +445,7 @@ def _get_features(
 
             if key is not None:
                 return getattr(adata, att)[key]
-            else:
-                return getattr(adata, att)
+            return getattr(adata, att)
 
         if self.batch_key is not None:
             out_list = []
@@ -480,14 +471,13 @@ def _get_features(
             out = pd.concat(out_list, axis=0)
             out[self.batch_key] = pd.Categorical(out[self.batch_key])
             return out
-        else:
-            spatial = self.adata.obsm[self.spatial_key]
-            features = _get_features(self.adata, attr)
-            out = _compute_correspondence(spatial, features, interval, max_dist)
-            return out
+
+        spatial = self.adata.obsm[self.spatial_key]
+        features = _get_features(self.adata, attr)
+        return _compute_correspondence(spatial, features, interval, max_dist)
 
     @d_mixins.dedent
-    def cell_transition(
+    def cell_transition(  # type: ignore[misc]
         self: SpatialMappingMixinProtocol[K, B],
         source: K,
         target: Optional[K] = None,
@@ -546,7 +536,7 @@ def batch_key(self) -> Optional[str]:
 
     @batch_key.setter
     def batch_key(self, key: Optional[str]) -> None:
-        if key is not None and key not in self.adata.obs:
+        if key is not None and key not in self.adata.obs:  # type: ignore[attr-defined]
             raise KeyError(f"Unable to find batch data in `adata.obs[{key!r}]`.")
         self._batch_key = key
 
@@ -556,7 +546,7 @@ def spatial_key(self) -> Optional[str]:
         return self._spatial_key
 
     @spatial_key.setter
-    def spatial_key(self: SpatialAlignmentMixinProtocol[K, B], key: Optional[str]) -> None:
+    def spatial_key(self: SpatialAlignmentMixinProtocol[K, B], key: Optional[str]) -> None:  # type: ignore[misc]
         if key is not None and key not in self.adata.obsm:
             raise KeyError(f"Unable to find spatial data in `adata.obsm[{key!r}]`.")
         self._spatial_key = key
@@ -570,7 +560,6 @@ def _compute_correspondence(
 ) -> pd.DataFrame:
     if isinstance(interval, int):
         # prepare support
-        spatial.shape[0]
         hull = ConvexHull(spatial)
         area = hull.volume
         if max_dist is None:
@@ -582,6 +571,7 @@ def _compute_correspondence(
     def pdist(row_idx: ArrayLike, col_idx: float, feat: ArrayLike) -> Any:
         if len(row_idx) > 0:
             return pairwise_distances(feat[row_idx, :], feat[[col_idx], :]).mean()  # type: ignore[index]
+        return np.nan
 
     vpdist = np.vectorize(pdist, excluded=["feat"])
     features = features.A if sp.issparse(features) else features  # type: ignore[attr-defined]
diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py
index 645532bea..2e1ffb99e 100644
--- a/src/moscot/problems/time/_mixins.py
+++ b/src/moscot/problems/time/_mixins.py
@@ -184,7 +184,6 @@ def cell_transition(
         -----
         %(notes_cell_transition)s
         """
-
         if TYPE_CHECKING:
             assert isinstance(self.temporal_key, str)
         return self._cell_transition(
@@ -294,8 +293,7 @@ def sankey(
                 "captions": [str(t) for t in tuples],
             }
             Key.uns.set_plotting_vars(self.adata, PlottingKeys.SANKEY, key_added, plot_vars)
-        if return_data:
-            return cell_transitions_updated
+        return cell_transitions_updated if return_data else None
 
     @d_mixins.dedent
     def push(
@@ -353,8 +351,7 @@ def push(
             }
             self.adata.obs[key_added] = self._flatten(result, key=self.temporal_key)
             Key.uns.set_plotting_vars(self.adata, PlottingKeys.PUSH, key_added, plot_vars)
-        if return_data:
-            return result
+        return result if return_data else None
 
     @d_mixins.dedent
     def pull(
@@ -411,8 +408,7 @@ def pull(
             }
             self.adata.obs[key_added] = self._flatten(result, key=self.temporal_key)
             Key.uns.set_plotting_vars(self.adata, PlottingKeys.PULL, key_added, plot_vars)
-        if return_data:
-            return result
+        return result if return_data else None
 
     @property
     def prior_growth_rates(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFrame]:
@@ -542,14 +538,14 @@ def _get_data(
                 break
         else:
             raise ValueError(f"No data found for `{source}` time point.")
-        for src, tgt in self.problems.keys():
+        for src, tgt in self.problems:
             if src == intermediate:
                 intermediate_data = self.problems[src, tgt].xy.data_src  # type: ignore[union-attr]
                 intermediate_adata = self.problems[src, tgt].adata_src
                 break
         else:
             raise ValueError(f"No data found for `{intermediate}` time point.")
-        for src, tgt in self.problems.keys():
+        for src, tgt in self.problems:
             if tgt == target:
                 target_data = self.problems[src, tgt].xy.data_tgt  # type: ignore[union-attr]
                 break
@@ -841,12 +837,11 @@ def _interpolate_gex_randomly(
         else:
             row_probability = growth_rates ** (1 - interpolation_parameter)
         row_probability /= np.sum(row_probability)
-        result = (
+        return (
             source_data[rng.choice(len(source_data), size=number_cells, p=row_probability), :]
             * (1 - interpolation_parameter)
             + target_data[rng.choice(len(target_data), size=number_cells), :] * interpolation_parameter
         )
-        return result
 
     @staticmethod
     def _get_interp_param(
diff --git a/src/moscot/solvers/_base_solver.py b/src/moscot/solvers/_base_solver.py
index 0a007ca8a..9e2dbbff5 100644
--- a/src/moscot/solvers/_base_solver.py
+++ b/src/moscot/solvers/_base_solver.py
@@ -42,9 +42,9 @@ def solver(self, *, backend: Literal["ott"] = "ott", **kwargs: Any) -> "BaseSolv
             from moscot.backends.ott import GWSolver, SinkhornSolver  # type: ignore[attr-defined]
 
             if self == ProblemKind.LINEAR:
-                return SinkhornSolver(**kwargs)
+                return SinkhornSolver(**kwargs)  # type: ignore[return-value]
             if self == ProblemKind.QUAD:
-                return GWSolver(**kwargs)
+                return GWSolver(**kwargs)  # type: ignore[return-value]
             raise NotImplementedError(f"Unable to create solver for `{self}` problem.")
 
         raise NotImplementedError(f"Backend `{backend}` is not yet implemented.")
diff --git a/src/moscot/utils/_subset_policy.py b/src/moscot/utils/_subset_policy.py
index 5e7565a25..a42b2f3a3 100644
--- a/src/moscot/utils/_subset_policy.py
+++ b/src/moscot/utils/_subset_policy.py
@@ -2,6 +2,7 @@
 from typing import Any, Set, Dict, List, Tuple, Union, Generic, Literal, TypeVar, Hashable, Iterable, Optional, Sequence
 from operator import gt, lt
 from itertools import product
+import contextlib
 
 import pandas as pd
 import networkx as nx
@@ -229,10 +230,8 @@ def remove_node(self, node: Tuple[K, K]) -> None:
         """
         if self._graph is None:
             raise RuntimeError("Construct the policy graph first.")
-        try:
+        with contextlib.suppress(KeyError):
             self._graph.remove(node)
-        except KeyError:
-            pass
 
 
 class OrderedPolicy(SubsetPolicy[K], ABC):
diff --git a/src/moscot/utils/_tagged_array.py b/src/moscot/utils/_tagged_array.py
index 5ceb3a983..bbf6b94e9 100644
--- a/src/moscot/utils/_tagged_array.py
+++ b/src/moscot/utils/_tagged_array.py
@@ -60,9 +60,9 @@ def _extract_data(
             if key is not None:
                 data = data[key]
         except KeyError:
-            raise KeyError(f"Unable to fetch data from `{modifier}`.")
+            raise KeyError(f"Unable to fetch data from `{modifier}`.") from None
         except IndexError:
-            raise IndexError(f"Unable to fetch data from `{modifier}`.")
+            raise IndexError(f"Unable to fetch data from `{modifier}`.") from None
 
         if sp.issparse(data):
             logger.warning(f"Densifying data in `{modifier}`")
diff --git a/tests/_utils.py b/tests/_utils.py
index 5df750c79..a9121214d 100644
--- a/tests/_utils.py
+++ b/tests/_utils.py
@@ -49,10 +49,9 @@ def _ones(self, n: int) -> ArrayLike:
 
 
 def _make_adata(grid: ArrayLike, n: int, seed) -> List[AnnData]:
-    rng = np.random.default_rng(seed)
+    rng = np.random.RandomState(seed)
     X = rng.normal(size=(100, 60))
-    adatas = [AnnData(X=csr_matrix(X), obsm={"spatial": grid.copy()}, dtype=X.dtype) for _ in range(n)]
-    return adatas
+    return [AnnData(X=csr_matrix(X), obsm={"spatial": grid.copy()}, dtype=X.dtype) for _ in range(n)]
 
 
 def _adata_spatial_split(adata: AnnData) -> Tuple[AnnData, AnnData]:
@@ -67,8 +66,7 @@ def _make_grid(grid_size: int) -> ArrayLike:
     x1s = np.linspace(*xlimits, num=grid_size)
     x2s = np.linspace(*ylimits, num=grid_size)
     X1, X2 = np.meshgrid(x1s, x2s)
-    X_orig_single = np.vstack([X1.ravel(), X2.ravel()]).T
-    return X_orig_single
+    return np.vstack([X1.ravel(), X2.ravel()]).T
 
 
 class Problem(CompoundProblem[Any, OTProblem]):
diff --git a/tests/conftest.py b/tests/conftest.py
index 8e19f3692..7895ce224 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,10 +1,10 @@
 from math import cos, sin
 from typing import Tuple, Optional
-import random
 
 from scipy.sparse import csr_matrix
 import pandas as pd
 import pytest
+import matplotlib.pyplot as plt
 
 from jax.config import config
 import numpy as np
@@ -12,13 +12,13 @@
 
 from anndata import AnnData
 import scanpy as sc
-import anndata as ad
 
 from tests._utils import Geom_t, _make_grid, _make_adata
 
 ANGLES = (0, 30, 60)
 
 
+# TODO(michalk8): consider passing this via env
 config.update("jax_enable_x64", True)
 
 
@@ -30,6 +30,13 @@ def pytest_sessionstart() -> None:
     sc.set_figure_params(dpi=40, color_map="viridis")
 
 
+@pytest.fixture(autouse=True)
+def _close_figure():
+    # prevent `RuntimeWarning: More than 20 figures have been opened.`
+    yield
+    plt.close()
+
+
 @pytest.fixture()
 def x() -> Geom_t:
     rng = np.random.RandomState(0)
@@ -55,7 +62,7 @@ def y() -> Geom_t:
 
 
 @pytest.fixture()
-def xy() -> Geom_t:
+def xy() -> Tuple[Geom_t, Geom_t]:
     rng = np.random.RandomState(2)
     n = 20  # number of points in the first distribution
     n2 = 30  # number of points in the second distribution
@@ -119,7 +126,7 @@ def adata_time() -> AnnData:
     adata.obs["batch"] = rng.choice((0, 1, 2), len(adata))
     adata.obs["left_marginals"] = np.ones(len(adata))
     adata.obs["right_marginals"] = np.ones(len(adata))
-    adata.obs["celltype"] = np.random.choice(["A", "B", "C"], size=len(adata))
+    adata.obs["celltype"] = rng.choice(["A", "B", "C"], size=len(adata))
     # genes from mouse/human proliferation/apoptosis
     genes = ["ANLN", "ANP32E", "ATAD2", "Mcm4", "Smc4", "Gtse1", "ADD1", "AIFM3", "ANKH", "Ercc5", "Serpinb5", "Inhbb"]
     # genes which are transcription factors, 3 from drosophila, 2 from human, 1 from mouse
@@ -145,21 +152,23 @@ def create_marginals(n: int, m: int, *, uniform: bool = False, seed: Optional[in
 
 @pytest.fixture()
 def gt_temporal_adata() -> AnnData:
-    return _gt_temporal_adata.copy()
+    adata = _gt_temporal_adata.copy()
+    adata.obs_names_make_unique()
+    return adata
 
 
 @pytest.fixture()
 def adata_space_rotate() -> AnnData:
-    seed = random.randint(0, 422)
+    rng = np.random.RandomState(31)
     grid = _make_grid(10)
-    adatas = _make_adata(grid, n=len(ANGLES), seed=seed)
+    adatas = _make_adata(grid, n=len(ANGLES), seed=32)
     for adata, angle in zip(adatas, ANGLES):
         theta = np.deg2rad(angle)
         rot = np.array([[cos(theta), -sin(theta)], [sin(theta), cos(theta)]])
         adata.obsm["spatial"] = np.dot(adata.obsm["spatial"], rot)
 
-    adata = ad.concat(adatas, label="batch")
-    adata.obs["celltype"] = np.random.choice(["A", "B", "C"], size=len(adata))
+    adata = adatas[0].concatenate(*adatas[1:], batch_key="batch")
+    adata.obs["celltype"] = rng.choice(["A", "B", "C"], size=len(adata))
     adata.uns["spatial"] = {}
     adata.obs_names_make_unique()
     sc.pp.pca(adata)
@@ -168,11 +177,10 @@ def adata_space_rotate() -> AnnData:
 
 @pytest.fixture()
 def adata_mapping() -> AnnData:
-    seed = random.randint(0, 422)
     grid = _make_grid(10)
-    adataref, adata1, adata2 = _make_adata(grid, n=3, seed=seed)
+    adataref, adata1, adata2 = _make_adata(grid, n=3, seed=17)
     sc.pp.pca(adataref, n_comps=30)
 
-    adata = ad.concat([adataref, adata1, adata2], label="batch", join="outer")
+    adata = adataref.concatenate(adata1, adata2, batch_key="batch", join="outer")
     adata.obs_names_make_unique()
     return adata
diff --git a/tests/data/generate_gt_temporal_data.py b/tests/data/generate_gt_temporal_data.py
index 41f8e73d2..a2fcb0852 100644
--- a/tests/data/generate_gt_temporal_data.py
+++ b/tests/data/generate_gt_temporal_data.py
@@ -7,7 +7,7 @@
     raise ImportError(
         "Please install WOT from commit hash`ca5e94f05699997b01cf5ae13383f9810f0613f6`"
         + "with `pip install git+https://github.com/broadinstitute/wot.git@ca5e94f05699997b01cf5ae13383f9810f0613f6`"
-    )
+    ) from None
 
 import os
 
diff --git a/tests/data/regression_tests_spatial.py b/tests/data/regression_tests_spatial.py
index 969a41c08..544b60a31 100644
--- a/tests/data/regression_tests_spatial.py
+++ b/tests/data/regression_tests_spatial.py
@@ -34,8 +34,7 @@ def adata_mapping() -> AnnData:
     adataref, adata1, adata2 = _make_adata(grid, n=3)
     sc.pp.pca(adataref)
 
-    adata = ad.concat([adataref, adata1, adata2], label="batch", join="outer")
-    return adata
+    return ad.concat([adataref, adata1, adata2], label="batch", join="outer")
 
 
 def _make_grid(grid_size: int) -> ArrayLike:
@@ -43,8 +42,7 @@ def _make_grid(grid_size: int) -> ArrayLike:
     x1s = np.linspace(*xlimits, num=grid_size)  # type: ignore [call-overload]
     x2s = np.linspace(*ylimits, num=grid_size)  # type: ignore [call-overload]
     X1, X2 = np.meshgrid(x1s, x2s)
-    X_orig_single = np.vstack([X1.ravel(), X2.ravel()]).T
-    return X_orig_single
+    return np.vstack([X1.ravel(), X2.ravel()]).T
 
 
 def _make_adata(grid: ArrayLike, n: int) -> List[AnnData]:
diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py
index 315779488..d4be344ca 100644
--- a/tests/datasets/test_dataset.py
+++ b/tests/datasets/test_dataset.py
@@ -1,50 +1,14 @@
-from types import FunctionType
 from typing import Mapping, Optional
-from pathlib import Path
-from http.client import RemoteDisconnected
-import warnings
 
 import pytest
 import networkx as nx
 
 import numpy as np
 
-from anndata import AnnData, OldFormatWarning
-
 from moscot.datasets import simulate_data
-import moscot as mt
-
-
-class TestDatasetsImports:
-    @pytest.mark.parametrize("func", mt.datasets._datasets.__all__)
-    def test_import(self, func):
-        assert hasattr(mt.datasets, func), dir(mt.datasets)
-        fn = getattr(mt.datasets, func)
-
-        assert isinstance(fn, FunctionType)
-
-
-# TODO(michalk8): parse the code and xfail iff server issue
-class TestDatasetsDownload:
-    @pytest.mark.timeout(120)
-    def test_sim_align(self, tmp_path: Path):
-        with warnings.catch_warnings():
-            warnings.simplefilter("ignore", category=OldFormatWarning)
-            try:
-                adata = mt.datasets.sim_align(tmp_path / "foo")
-
-                assert isinstance(adata, AnnData)
-                assert adata.shape == (1200, 500)
-            except RemoteDisconnected as e:
-                pytest.xfail(str(e))
 
 
 class TestSimulateData:
-    @pytest.mark.fast()
-    def test_returns_adata(self):
-        result = simulate_data()
-        assert isinstance(result, AnnData)
-
     @pytest.mark.fast()
     @pytest.mark.parametrize("n_distributions", [2, 4])
     @pytest.mark.parametrize("key", ["batch", "day"])
diff --git a/tests/plotting/test_utils.py b/tests/plotting/test_utils.py
index 97de9a580..de98e20e0 100644
--- a/tests/plotting/test_utils.py
+++ b/tests/plotting/test_utils.py
@@ -1,5 +1,4 @@
 from typing import List, Optional
-import os
 
 import pytest
 import matplotlib as mpl
@@ -29,70 +28,47 @@ def test_input_to_adatas_adata(self, adata_time: AnnData):
         np.testing.assert_array_equal(adata1.X.A, adata_time.X.A)
         np.testing.assert_array_equal(adata2.X.A, adata_time.X.A)
 
-    @pytest.mark.parametrize("save", [None, "tests/data/test_plot.png"])
     @pytest.mark.parametrize("return_fig", [True, False])
-    def test_cell_transition(self, adata_pl_cell_transition: AnnData, return_fig: bool, save: Optional[str]):
-        if save:
-            if os.path.exists(save):
-                os.remove(save)
-        fig = msc.plotting.cell_transition(adata_pl_cell_transition, return_fig=return_fig, save=save)
+    def test_cell_transition(self, adata_pl_cell_transition: AnnData, return_fig: bool):
+        fig = msc.plotting.cell_transition(adata_pl_cell_transition, return_fig=return_fig)
         if return_fig:
             assert fig is not None
             assert isinstance(fig, mpl.figure.Figure)
         else:
             assert fig is None
-        if save:
-            assert os.path.exists(save)
 
     @pytest.mark.parametrize("time_points", [None, [0]])
     @pytest.mark.parametrize("return_fig", [True, False])
-    @pytest.mark.parametrize("save", [None, "tests/data/test_plot.png"])
-    def test_push(
-        self, adata_pl_push: AnnData, time_points: Optional[List[int]], return_fig: bool, save: Optional[str]
-    ):
-        if save:
-            if os.path.exists(save):
-                os.remove(save)
-
-        fig = msc.plotting.push(adata_pl_push, time_points=time_points, return_fig=return_fig, save=save)
+    def test_push(self, adata_pl_push: AnnData, time_points: Optional[List[int]], return_fig: bool):
+        fig = msc.plotting.push(adata_pl_push, time_points=time_points, return_fig=return_fig)
 
         if return_fig:
             assert fig is not None
             assert isinstance(fig, mpl.figure.Figure)
         else:
             assert fig is None
-        if save:
-            assert os.path.exists(save)
 
     @pytest.mark.parametrize("time_points", [None, [0]])
     @pytest.mark.parametrize("return_fig", [True, False])
-    @pytest.mark.parametrize("save", [None, "tests/data/test_plot.png"])
     def test_pull(
-        self, adata_pl_pull: AnnData, time_points: Optional[List[int]], return_fig: bool, save: Optional[str]
+        self,
+        adata_pl_pull: AnnData,
+        time_points: Optional[List[int]],
+        return_fig: bool,
     ):
-        if save:
-            if os.path.exists(save):
-                os.remove(save)
-        fig = msc.plotting.pull(adata_pl_pull, time_points=time_points, return_fig=return_fig, save=save)
+        fig = msc.plotting.pull(adata_pl_pull, time_points=time_points, return_fig=return_fig)
         if return_fig:
             assert fig is not None
             assert isinstance(fig, mpl.figure.Figure)
         else:
             assert fig is None
-        if save:
-            assert os.path.exists(save)
 
-    @pytest.mark.parametrize("save", [None, "tests/data/test_plot.png"])
     @pytest.mark.parametrize("return_fig", [True, False])
     @pytest.mark.parametrize("interpolate_color", [True, False])
-    def test_sankey(self, adata_pl_sankey: AnnData, return_fig: bool, save: Optional[str], interpolate_color: bool):
-        if save:
-            if os.path.exists(save):
-                os.remove(save)
+    def test_sankey(self, adata_pl_sankey: AnnData, return_fig: bool, interpolate_color: bool):
         fig = msc.plotting.sankey(
             adata_pl_sankey,
             return_fig=return_fig,
-            save=save,
             interpolate_color=interpolate_color,
         )
         if return_fig:
@@ -100,5 +76,3 @@ def test_sankey(self, adata_pl_sankey: AnnData, return_fig: bool, save: Optional
             assert isinstance(fig, mpl.figure.Figure)
         else:
             assert fig is None
-        if save:
-            assert os.path.exists(save)
diff --git a/tests/problems/generic/test_fgw_problem.py b/tests/problems/generic/test_fgw_problem.py
index a24a82317..b4c5cec4b 100644
--- a/tests/problems/generic/test_fgw_problem.py
+++ b/tests/problems/generic/test_fgw_problem.py
@@ -8,9 +8,9 @@
 
 from anndata import AnnData
 
-from moscot.problems.base import OTProblem  # type:ignore[attr-defined]
+from moscot.problems.base import OTProblem
 from moscot.solvers._output import BaseSolverOutput
-from moscot.problems.generic import GWProblem  # type:ignore[attr-defined]
+from moscot.problems.generic import GWProblem
 from tests.problems.conftest import (
     fgw_args_1,
     fgw_args_2,
@@ -48,7 +48,7 @@ def test_prepare(self, adata_space_rotate: AnnData):
             assert key in expected_keys
             assert isinstance(problem[key], OTProblem)
 
-    def test_solve_balanced(self, adata_space_rotate: AnnData):  # type: ignore[no-untyped-def]
+    def test_solve_balanced(self, adata_space_rotate: AnnData):
         eps = 0.5
         adata_space_rotate = adata_space_rotate[adata_space_rotate.obs["batch"].isin(("0", "1"))].copy()
         expected_keys = [("0", "1"), ("1", "2")]
diff --git a/tests/problems/generic/test_gw_problem.py b/tests/problems/generic/test_gw_problem.py
index ee5611a90..c1f0a60a9 100644
--- a/tests/problems/generic/test_gw_problem.py
+++ b/tests/problems/generic/test_gw_problem.py
@@ -8,9 +8,9 @@
 
 from anndata import AnnData
 
-from moscot.problems.base import OTProblem  # type:ignore[attr-defined]
+from moscot.problems.base import OTProblem
 from moscot.solvers._output import BaseSolverOutput
-from moscot.problems.generic import GWProblem  # type:ignore[attr-defined]
+from moscot.problems.generic import GWProblem
 from tests.problems.conftest import (
     gw_args_1,
     gw_args_2,
diff --git a/tests/problems/generic/test_mixins.py b/tests/problems/generic/test_mixins.py
index 3d63538d1..892b100b3 100644
--- a/tests/problems/generic/test_mixins.py
+++ b/tests/problems/generic/test_mixins.py
@@ -29,7 +29,7 @@ def test_sample_from_tmap_pipeline(
         problem[10, 10.5]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"])
 
         if interpolation_parameter is not None and not 0 <= interpolation_parameter <= 1:
-            with np.testing.assert_raises(ValueError):
+            with pytest.raises(ValueError, match=r"^Expected interpolation"):
                 problem._sample_from_tmap(
                     10,
                     10.5,
@@ -40,7 +40,7 @@ def test_sample_from_tmap_pipeline(
                     interpolation_parameter=interpolation_parameter,
                 )
         elif interpolation_parameter is None and account_for_unbalancedness:
-            with np.testing.assert_raises(ValueError):
+            with pytest.raises(ValueError, match=r"^When accounting for unbalancedness"):
                 problem._sample_from_tmap(
                     10,
                     10.5,
@@ -118,8 +118,8 @@ def test_cell_transition_aggregation_cell_forward(self, gt_temporal_adata: AnnDa
 
         df_res = df_res.div(df_res.sum(axis=1), axis=0)
 
-        ctr_ordered = ctr.sort_index().sort_index(1)
-        df_res_ordered = df_res.sort_index().sort_index(1)
+        ctr_ordered = ctr.sort_index().sort_index(axis=1)
+        df_res_ordered = df_res.sort_index().sort_index(axis=1)
         np.testing.assert_allclose(
             ctr_ordered.values.astype(float), df_res_ordered.values.astype(float), rtol=RTOL, atol=ATOL
         )
@@ -159,8 +159,8 @@ def test_cell_transition_aggregation_cell_backward(self, gt_temporal_adata: AnnD
 
         df_res = df_res.div(df_res.sum(axis=0), axis=1)
 
-        ctr_ordered = ctr.sort_index().sort_index(1)
-        df_res_ordered = df_res.sort_index().sort_index(1)
+        ctr_ordered = ctr.sort_index().sort_index(axis=1)
+        df_res_ordered = df_res.sort_index().sort_index(axis=1)
         np.testing.assert_allclose(
             ctr_ordered.values.astype(float), df_res_ordered.values.astype(float), rtol=RTOL, atol=ATOL
         )
diff --git a/tests/problems/generic/test_sinkhorn_problem.py b/tests/problems/generic/test_sinkhorn_problem.py
index 2af8ab263..f3ebed20f 100644
--- a/tests/problems/generic/test_sinkhorn_problem.py
+++ b/tests/problems/generic/test_sinkhorn_problem.py
@@ -44,7 +44,7 @@ def test_prepare(self, adata_time: AnnData):
             assert key in expected_keys
             assert isinstance(problem[key], OTProblem)
 
-    def test_solve_balanced(self, adata_time: AnnData):  # type: ignore[no-untyped-def]
+    def test_solve_balanced(self, adata_time: AnnData):
         eps = 0.5
         expected_keys = [(0, 1), (1, 2)]
         problem = SinkhornProblem(adata=adata_time)
@@ -126,10 +126,7 @@ def test_pass_arguments(self, adata_time: AnnData, args_to_check: Mapping[str, A
             el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val)
             assert el == args_to_check[arg]
 
-        if args_to_check["rank"] == -1:
-            args = pointcloud_args
-        else:
-            args = lr_pointcloud_args
+        args = pointcloud_args if args_to_check["rank"] == -1 else lr_pointcloud_args
         for arg, val in args.items():
             el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val)
             assert hasattr(geom, val)
diff --git a/tests/problems/space/test_alignment_problem.py b/tests/problems/space/test_alignment_problem.py
index 5ddb67539..ad5f22223 100644
--- a/tests/problems/space/test_alignment_problem.py
+++ b/tests/problems/space/test_alignment_problem.py
@@ -75,7 +75,7 @@ def test_solve_balanced(
             if initializer == "random":
                 # kwargs["kwargs_init"] = {"key": 0}
                 # kwargs["key"] = 0
-                return 0  # TODO(@MUCDK) fix after refactoring
+                return  # TODO(@MUCDK) fix after refactoring
         ap = (
             AlignmentProblem(adata=adata_space_rotate)
             .prepare(batch_key="batch")
@@ -86,9 +86,12 @@ def test_solve_balanced(
             if initializer != "random":  # TODO: is this valid?
                 assert ap[prob_key].solution.converged
 
+        # TODO(michalk8): use np.testing
         assert np.allclose(*(sol.cost for sol in ap.solutions.values()))
         assert np.all([sol.converged for sol in ap.solutions.values()])
-        assert np.all([np.all(~np.isnan(sol.transport_matrix)) for sol in ap.solutions.values()])
+        np.testing.assert_array_equal(
+            [np.all(np.isfinite(sol.transport_matrix)) for sol in ap.solutions.values()], True
+        )
 
     def test_solve_unbalanced(self, adata_space_rotate: AnnData):
         tau_a, tau_b = [0.8, 1]
diff --git a/tests/problems/space/test_mapping_problem.py b/tests/problems/space/test_mapping_problem.py
index a2d33f745..fabefe352 100644
--- a/tests/problems/space/test_mapping_problem.py
+++ b/tests/problems/space/test_mapping_problem.py
@@ -21,6 +21,7 @@
 )
 from moscot.solvers._base_solver import ProblemKind
 
+# TODO(michalk8): should be made relative to tests root
 SOLUTIONS_PATH = Path("./tests/data/mapping_solutions.pkl")  # base is moscot
 
 
@@ -97,7 +98,7 @@ def test_solve_balanced(
             if initializer == "random":
                 # kwargs["kwargs_init"] = {"key": 0}
                 # kwargs["key"] = 0
-                return 0  # TODO(@MUCDK) fix after refactoring
+                return  # TODO(@MUCDK) fix after refactoring
         mp = MappingProblem(adataref, adatasp)
         mp = mp.prepare(batch_key="batch", sc_attr=sc_attr, var_names=var_names)
         mp = mp.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs)
diff --git a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py
index c1c22ffc7..2911bc4a1 100644
--- a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py
+++ b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py
@@ -139,9 +139,9 @@ def test_apoptosis_key_pipeline(self, adata_spatio_temporal: AnnData):
     @pytest.mark.fast()
     @pytest.mark.parametrize("scaling", [0.1, 1, 4])
     def test_proliferation_key_c_pipeline(self, adata_spatio_temporal: AnnData, scaling: float):
-        keys = np.sort(np.unique(adata_spatio_temporal.obs["time"].values))
-        adata_spatio_temporal = adata_spatio_temporal[adata_spatio_temporal.obs["time"].isin([keys[0], keys[1]])]
-        delta = keys[1] - keys[0]
+        key0, key1, *_ = np.sort(np.unique(adata_spatio_temporal.obs["time"].values))
+        adata_spatio_temporal = adata_spatio_temporal[adata_spatio_temporal.obs["time"].isin([key0, key1])].copy()
+        delta = key1 - key0
         problem = SpatioTemporalProblem(adata_spatio_temporal)
         assert problem.proliferation_key is None
 
@@ -149,12 +149,10 @@ def test_proliferation_key_c_pipeline(self, adata_spatio_temporal: AnnData, scal
         assert problem.proliferation_key == "proliferation"
 
         problem = problem.prepare(time_key="time", marginal_kwargs={"scaling": scaling})
-        prolif = adata_spatio_temporal[adata_spatio_temporal.obs["time"] == keys[0]].obs["proliferation"]
-        apopt = adata_spatio_temporal[adata_spatio_temporal.obs["time"] == keys[0]].obs["apoptosis"]
+        prolif = adata_spatio_temporal[adata_spatio_temporal.obs["time"] == key0].obs["proliferation"]
+        apopt = adata_spatio_temporal[adata_spatio_temporal.obs["time"] == key0].obs["apoptosis"]
         expected_marginals = np.exp((prolif - apopt) * delta / scaling)
-        print("problem[keys[0], keys[1]]._prior_growth", problem[keys[0], keys[1]]._prior_growth)
-        print("expected_marginals", expected_marginals)
-        np.testing.assert_allclose(problem[keys[0], keys[1]]._prior_growth, expected_marginals, rtol=RTOL, atol=ATOL)
+        np.testing.assert_allclose(problem[key0, key1]._prior_growth, expected_marginals, rtol=RTOL, atol=ATOL)
 
     def test_growth_rates_pipeline(self, adata_spatio_temporal: AnnData):
         problem = SpatioTemporalProblem(adata=adata_spatio_temporal)
diff --git a/tests/problems/time/test_lineage_problem.py b/tests/problems/time/test_lineage_problem.py
index e0aa7ad8b..a1c234458 100644
--- a/tests/problems/time/test_lineage_problem.py
+++ b/tests/problems/time/test_lineage_problem.py
@@ -44,9 +44,8 @@ def test_prepare(self, adata_time_barcodes: AnnData):
             assert isinstance(problem[key], BirthDeathProblem)
 
     def test_solve_balanced(self, adata_time_barcodes: AnnData):
-        eps = 0.5
-        expected_keys = [(0, 1)]
-        adata_time_barcodes = adata_time_barcodes[adata_time_barcodes.obs["time"].isin((0, 1))]
+        eps, key = 0.5, (0, 1)
+        adata_time_barcodes = adata_time_barcodes[adata_time_barcodes.obs["time"].isin(key)].copy()
         problem = LineageProblem(adata=adata_time_barcodes)
         problem = problem.prepare(
             time_key="time",
@@ -57,7 +56,7 @@ def test_solve_balanced(self, adata_time_barcodes: AnnData):
 
         for key, subsol in problem.solutions.items():
             assert isinstance(subsol, BaseSolverOutput)
-            assert key in expected_keys
+            assert key == key
 
     def test_solve_unbalanced(self, adata_time_barcodes: AnnData):
         taus = [9e-1, 1e-2]
@@ -152,9 +151,9 @@ def test_apoptosis_key_pipeline(self, adata_time_barcodes: AnnData):
     @pytest.mark.fast()
     @pytest.mark.parametrize("scaling", [0.1, 1, 4])
     def test_proliferation_key_c_pipeline(self, adata_time_barcodes: AnnData, scaling: float):
-        keys = np.sort(np.unique(adata_time_barcodes.obs["time"].values))
-        adata_time_barcodes = adata_time_barcodes[adata_time_barcodes.obs["time"].isin([keys[0], keys[1]])]
-        delta = keys[1] - keys[0]
+        key0, key1, *_ = np.sort(np.unique(adata_time_barcodes.obs["time"].values))
+        adata_time_barcodes = adata_time_barcodes[adata_time_barcodes.obs["time"].isin([key0, key1])].copy()
+        delta = key1 - key0
         problem = LineageProblem(adata_time_barcodes)
         assert problem.proliferation_key is None
 
@@ -167,10 +166,10 @@ def test_proliferation_key_c_pipeline(self, adata_time_barcodes: AnnData, scalin
             policy="sequential",
             marginal_kwargs={"scaling": scaling},
         )
-        prolif = adata_time_barcodes[adata_time_barcodes.obs["time"] == keys[0]].obs["proliferation"]
-        apopt = adata_time_barcodes[adata_time_barcodes.obs["time"] == keys[0]].obs["apoptosis"]
+        prolif = adata_time_barcodes[adata_time_barcodes.obs["time"] == key0].obs["proliferation"]
+        apopt = adata_time_barcodes[adata_time_barcodes.obs["time"] == key0].obs["apoptosis"]
         expected_marginals = np.exp((prolif - apopt) * delta / scaling)
-        np.testing.assert_allclose(problem[keys[0], keys[1]]._prior_growth, expected_marginals, rtol=RTOL, atol=ATOL)
+        np.testing.assert_allclose(problem[key0, key1]._prior_growth, expected_marginals, rtol=RTOL, atol=ATOL)
 
     @pytest.mark.fast()
     def test_barcodes_pipeline(self, adata_time_barcodes: AnnData):
diff --git a/tests/problems/time/test_mixins.py b/tests/problems/time/test_mixins.py
index c7fd0695c..61d1fc91c 100644
--- a/tests/problems/time/test_mixins.py
+++ b/tests/problems/time/test_mixins.py
@@ -89,7 +89,6 @@ def test_cell_transition_regression(self, gt_temporal_adata: AnnData, forward: b
         key_1 = config["key_1"]
         key_2 = config["key_2"]
         key_3 = config["key_3"]
-        set(gt_temporal_adata.obs["cell_type"].cat.categories)
         problem = TemporalProblem(gt_temporal_adata)
         problem = problem.prepare(key)
         assert set(problem.problems.keys()) == {(key_1, key_2), (key_2, key_3)}
@@ -116,18 +115,17 @@ def test_cell_transition_regression(self, gt_temporal_adata: AnnData, forward: b
         assert result.shape == expected_shape
         marginal = result.sum(axis=forward == 1).values
         present_cell_type_marginal = marginal[marginal > 0]
-        np.testing.assert_almost_equal(present_cell_type_marginal, np.ones(len(present_cell_type_marginal)), decimal=5)
+        np.testing.assert_allclose(present_cell_type_marginal, 1.0)
 
         direction = "forward" if forward else "backward"
         gt = gt_temporal_adata.uns[f"cell_transition_10_105_{direction}"]
         gt = gt.sort_index()
         result = result.sort_index()
         result = result[gt.columns]
-        np.testing.assert_almost_equal(result.values, gt.values, decimal=4)
+        np.testing.assert_allclose(result.values, gt.values, rtol=1e-6, atol=1e-6)
 
     def test_compute_time_point_distances_pipeline(self, adata_time: AnnData):
-        problem = TemporalProblem(adata_time)
-        problem.prepare("time")
+        problem = TemporalProblem(adata_time).prepare("time")
         distance_source_intermediate, distance_intermediate_target = problem.compute_time_point_distances(
             source=0,
             intermediate=1,
@@ -318,8 +316,7 @@ def test_cell_transition_regression_notparam(
         self,
         adata_time_with_tmap: AnnData,
     ):  # TODO(MUCDK): please check.
-        problem = TemporalProblem(adata_time_with_tmap)
-        problem = problem.prepare("time")
+        problem = TemporalProblem(adata_time_with_tmap).prepare("time")
         problem[0, 1]._solution = MockSolverOutput(adata_time_with_tmap.uns["transport_matrix"])
 
         result = problem.cell_transition(
@@ -329,8 +326,8 @@ def test_cell_transition_regression_notparam(
             target_groups="cell_type",
             forward=True,
         )
-        res = result.sort_index().sort_index(1)
-        df_expected = adata_time_with_tmap.uns["cell_transition_gt"].sort_index().sort_index(1)
+        res = result.sort_index().sort_index(axis=1)
+        df_expected = adata_time_with_tmap.uns["cell_transition_gt"].sort_index().sort_index(axis=1)
         np.testing.assert_almost_equal(res.values, df_expected.values, decimal=8)
 
     @pytest.mark.fast()
diff --git a/tests/problems/time/test_temporal_problem.py b/tests/problems/time/test_temporal_problem.py
index eab8738a6..cb2908fcd 100644
--- a/tests/problems/time/test_temporal_problem.py
+++ b/tests/problems/time/test_temporal_problem.py
@@ -139,9 +139,9 @@ def test_apoptosis_key_pipeline(self, adata_time: AnnData):
     @pytest.mark.fast()
     @pytest.mark.parametrize("scaling", [0.1, 1, 4])
     def test_proliferation_key_c_pipeline(self, adata_time: AnnData, scaling: float):
-        keys = np.sort(np.unique(adata_time.obs["time"].values))
-        adata_time = adata_time[adata_time.obs["time"].isin([keys[0], keys[1]])]
-        delta = keys[1] - keys[0]
+        key0, key1, *_ = np.sort(np.unique(adata_time.obs["time"].values))
+        adata_time = adata_time[adata_time.obs["time"].isin([key0, key1])].copy()
+        delta = key1 - key0
         problem = TemporalProblem(adata_time)
         assert problem.proliferation_key is None
 
@@ -149,14 +149,13 @@ def test_proliferation_key_c_pipeline(self, adata_time: AnnData, scaling: float)
         assert problem.proliferation_key == "proliferation"
 
         problem = problem.prepare(time_key="time", marginal_kwargs={"scaling": scaling})
-        prolif = adata_time[adata_time.obs["time"] == keys[0]].obs["proliferation"]
-        apopt = adata_time[adata_time.obs["time"] == keys[0]].obs["apoptosis"]
+        prolif = adata_time[adata_time.obs["time"] == key0].obs["proliferation"]
+        apopt = adata_time[adata_time.obs["time"] == key0].obs["apoptosis"]
         expected_marginals = np.exp((prolif - apopt) * delta / scaling)
-        np.testing.assert_allclose(problem[keys[0], keys[1]]._prior_growth, expected_marginals, rtol=RTOL, atol=ATOL)
+        np.testing.assert_allclose(problem[key0, key1]._prior_growth, expected_marginals, rtol=RTOL, atol=ATOL)
 
     def test_cell_costs_source_pipeline(self, adata_time: AnnData):
-        problem = TemporalProblem(adata=adata_time)
-        problem = problem.prepare("time")
+        problem = TemporalProblem(adata=adata_time).prepare("time")
         problem = problem.solve(max_iterations=2)
 
         cell_costs_source = problem.cell_costs_source
@@ -272,10 +271,7 @@ def test_pass_arguments(self, adata_time: AnnData, args_to_check: Mapping[str, A
             el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val)
             assert el == args_to_check[arg]
 
-        if args_to_check["rank"] == -1:
-            args = pointcloud_args
-        else:
-            args = lr_pointcloud_args
+        args = pointcloud_args if args_to_check["rank"] == -1 else lr_pointcloud_args
         for arg, val in args.items():
             el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val)
             assert hasattr(geom, val)
diff --git a/tox.ini b/tox.ini
deleted file mode 100644
index a68c586f5..000000000
--- a/tox.ini
+++ /dev/null
@@ -1,151 +0,0 @@
-[flake8]
-per-file-ignores =
-    */__init__.py: D104, F401
-    tests/*: D
-    docs/*: D,B,A
-    src/moscot/_docs.py: D
-    src/moscot/backends/ott/_solver.py: D101, D102
-    src/moscot/utils/_subset_policy.py: D
-    src/moscot/solvers/_output.py: D101, D102, D105, D106, D107, A002, A003
-    src/moscot/problems/_compound_problem.py: D101, D102, D105, D106, D107, A002, A003
-    tests/data/*: D,B,E
-    src/moscot/costs/_costs.py: D
-    src/moscot/problems/_compound_problem.py: RST, D
-    src/moscot/solvers/_tagged_array.py: D
-    src/moscot/backends/ott/_output.py: D
-    src/moscot/_constants/_constants.py: D
-    src/moscot/_constants/_key.py: D
-    src/moscot/problems/base/_base_problem.py: D
-    src/moscot/problems/base/_problem_manager.py: D
-    src/moscot/_constants/_enum.py: D
-    src/moscot/_docs/*.py: D
-# D104: Missing docstring in public package
-# F401: <package> imported but unused
-max_line_length = 120
-filename = *.py
-# D202 No blank lines allowed after function docstring
-# D107 Missing docstring in __init__
-# B008 Do not perform function calls in argument defaults
-# W503 line break before binary operator
-# D105 Missing docstring in magic method
-# E203 whitespace before ':'
-# F405 ... may be undefined, or defined from star imports: ...
-# RST306 Unknown target name
-# D106 Missing docstring in public nested class
-ignore = D202,D107,B008,W503,D105,E203,F405,RST306,RST304,E741
-exclude =
-    .git
-    __pycache__
-    .tox
-    build
-    dist
-    setup.py
-ban-relative-imports = true
-
-[doc8]
-max-line-length = 120
-ignore-path = .tox,build,dist
-quiet = 1
-
-[gh-actions]
-python =
-    3.8: py38
-    3.9: py39
-    3.10: py310
-    3.11: py311
-
-[gh-actions:env]
-PLATFORM =
-    ubuntu-latest: linux
-    macos-latest: macos
-
-[tox]
-isolated_build = True
-envlist =
-    lint-code
-    py{38,39,310,311}-{linux,macos}
-skip_missing_interpreters = true
-
-[coverage:run]
-branch = true
-#TODO(michalk8): enable once using pytest-xdist
-parallel = false
-source = moscot
-omit = */__init__.py
-
-[coverage:paths]
-source =
-    moscot
-    */site-packages/moscot
-
-[coverage:report]
-exclude_lines =
-    \#.*pragma:\s*no.?cover
-
-    ^if __name__ == .__main__.:$
-
-    ^\s*raise AssertionError\b
-    ^\s*raise NotImplementedError\b
-    ^\s*return NotImplemented\b
-show_missing = true
-precision = 2
-skip_empty = True
-sort = Miss
-
-[pytest]
-addopts = --strict-markers
-python_files = test_*.py
-testpaths = tests/
-xfail_strict = true
-
-[testenv]
-platform =
-    linux: linux
-    macos: (osx|darwin)
-# TODO(michalk8): Cython+POT not necessary, just convenient for fixtures
-deps =
-    pytest
-    pytest-mock
-    pytest-cov
-usedevelop = true
-passenv = TOXENV,CI,CODECOV_*,GITHUB_ACTIONS,PYTEST_FLAGS
-commands =
-    python -m pytest --cov --cov-append --cov-report=term-missing --cov-config={toxinidir}/tox.ini {posargs:-vv} {env:PYTEST_FLAGS:}
-
-[testenv:lint-code]
-description = Lint the code.
-deps = pre-commit>=2.14.0
-skip_install = true
-commands =
-    pre-commit run --all-files --show-diff-on-failure
-
-[testenv:clean-docs]
-description = Remove the documentation.
-deps =
-skip_install = true
-changedir = {tox_root}/docs
-allowlist_externals = make
-commands =
-    make clean
-
-[testenv:build-docs]
-description = Build the documentation.
-deps =
-use_develop = true
-extras = docs
-allowlist_externals = make
-commands =
-    make html -C {tox_root}{/}docs {posargs}
-commands_post =
-    python -c 'import pathlib; print(f"Documentation is under:", pathlib.Path(f"{tox_root}") / "docs" / "build" / "html" / "index.html")'
-
-[testenv:build-package]
-description = Build the package.
-deps =
-    build
-    twine
-commands =
-    python -m build --sdist --wheel --outdir {tox_root}{/}dist{/} {posargs:}
-    twine check {tox_root}{/}dist{/}*
-commands_post =
-    python -c 'import pathlib; print(f"Package is under:", pathlib.Path("{tox_root}") / "dist")'