diff --git a/src/scirpy/pl/_clonotypes.py b/src/scirpy/pl/_clonotypes.py index e6cb28ab0..5dc2a9b48 100644 --- a/src/scirpy/pl/_clonotypes.py +++ b/src/scirpy/pl/_clonotypes.py @@ -414,7 +414,7 @@ def _plot_clonotype_network_panel( color_by_n_cells, ): cell_indices = read_cell_indices(cell_indices) - + colorbar_title = "mean per dot" pie_colors = None cat_colors = None diff --git a/src/scirpy/tests/test_ir_query.py b/src/scirpy/tests/test_ir_query.py index dfc10150e..7c33e69cf 100644 --- a/src/scirpy/tests/test_ir_query.py +++ b/src/scirpy/tests/test_ir_query.py @@ -5,8 +5,6 @@ import pytest from mudata import MuData -from scirpy.util import read_cell_indices - from scirpy.pp import ir_dist from scirpy.tl._ir_query import ( _reduce_json, @@ -16,6 +14,7 @@ ir_query_annotate, ir_query_annotate_df, ) +from scirpy.util import read_cell_indices @pytest.mark.parametrize("metric", ["identity", "levenshtein"]) @@ -34,10 +33,10 @@ def test_ir_query(adata_cdr3, adata_cdr3_2, metric, key1, key2): tmp_key2 = f"ir_query_TESTDB_aa_{metric}" if key2 is None else key2 tmp_ad = adata_cdr3.mod["airr"] if isinstance(adata_cdr3, MuData) else adata_cdr3 - + cell_indices = read_cell_indices(tmp_ad.uns[tmp_key2]["cell_indices"]) cell_indices_reference = read_cell_indices(tmp_ad.uns[tmp_key2]["cell_indices_reference"]) - + assert tmp_ad.uns[tmp_key2]["distances"].shape == (4, 3) assert len(cell_indices) == 4 assert len(cell_indices_reference) == 3 diff --git a/src/scirpy/tl/_clonotypes.py b/src/scirpy/tl/_clonotypes.py index fdb1491a5..79a4bead5 100644 --- a/src/scirpy/tl/_clonotypes.py +++ b/src/scirpy/tl/_clonotypes.py @@ -1,4 +1,5 @@ import itertools +import json import random from collections.abc import Sequence from typing import Literal, cast @@ -9,7 +10,6 @@ import scipy.sparse as sp from anndata import AnnData from scanpy import logging -import json from scirpy.ir_dist import MetricType, _get_metric_key from scirpy.ir_dist._clonotype_neighbors import ClonotypeNeighbors @@ -604,9 +604,7 @@ def _graph_from_coordinates(adata: AnnData, clonotype_key: str, basis: str) -> t # map the cell-id to the corresponding row/col in the clonotype distance matrix cell_indices = read_cell_indices(clonotype_res["cell_indices"]) dist_idx, obs_names = zip( - *itertools.chain.from_iterable( - zip(itertools.repeat(i), obs_names) for i, obs_names in cell_indices.items() - ), + *itertools.chain.from_iterable(zip(itertools.repeat(i), obs_names) for i, obs_names in cell_indices.items()), strict=False, ) dist_idx_lookup = pd.DataFrame(index=obs_names, data=dist_idx, columns=["dist_idx"]) diff --git a/src/scirpy/tl/_ir_query.py b/src/scirpy/tl/_ir_query.py index fbfa7cec8..89bf5c975 100644 --- a/src/scirpy/tl/_ir_query.py +++ b/src/scirpy/tl/_ir_query.py @@ -7,11 +7,10 @@ import numpy as np import pandas as pd from scanpy import logging -import json from scirpy.ir_dist import MetricType, _get_metric_key from scirpy.ir_dist._clonotype_neighbors import ClonotypeNeighbors -from scirpy.util import DataHandler, _is_na, tqdm, read_cell_indices +from scirpy.util import DataHandler, _is_na, read_cell_indices, tqdm from ._clonotypes import _common_doc, _common_doc_parallelism, _doc_clonotype_definition, _validate_parameters diff --git a/src/scirpy/util/__init__.py b/src/scirpy/util/__init__.py index 789614579..d9924c60f 100644 --- a/src/scirpy/util/__init__.py +++ b/src/scirpy/util/__init__.py @@ -1,9 +1,10 @@ import contextlib +import json import os import warnings from collections.abc import Callable, Mapping, Sequence from textwrap import dedent -from typing import Any, Optional, Union, cast, overload +from typing import Any, Literal, Optional, Union, cast, overload import awkward as ak import numpy as np @@ -16,9 +17,6 @@ from scipy.sparse import issparse from tqdm.auto import tqdm -from typing import Literal -import json - # reexport tqdm (here was previously a workaround for https://github.com/tqdm/tqdm/issues/1082) __all__ = ["tqdm"] @@ -609,7 +607,8 @@ def _get_usable_cpus(n_jobs: int = 0, use_numba: bool = False): return usable_cpus -def read_cell_indices(cell_indices: Union[dict[str, np.ndarray[str]], str]) -> dict[str,list[str]]: + +def read_cell_indices(cell_indices: dict[str, np.ndarray[str]] | str) -> dict[str, list[str]]: """ The datatype of the cell_indices Mapping (clonotype_id -> cell_ids) that is stored to the anndata.uns attribute after the ´define_clonotype_clusters´ function has changed from dict[str, np.ndarray[str] to @@ -617,9 +616,11 @@ def read_cell_indices(cell_indices: Union[dict[str, np.ndarray[str]], str]) -> d want that older anndata objects with the dict[str, np.ndarray[str] datatype can be used. So we use this function to read the cell_indices from the anndata object to support both formats. """ - if(isinstance(cell_indices, str)): # new format + if isinstance(cell_indices, str): # new format return json.loads(cell_indices) - elif(isinstance(cell_indices, dict)): # old format + elif isinstance(cell_indices, dict): # old format return {k: v.tolist() for k, v in cell_indices.items()} - else: # unsupported format - raise TypeError(f"Unsupported type for cell_indices: {type(cell_indices)}. Expected str (json) or dict[str, np.ndarray[str]].") + else: # unsupported format + raise TypeError( + f"Unsupported type for cell_indices: {type(cell_indices)}. Expected str (json) or dict[str, np.ndarray[str]]." + )