Skip to content

Commit

Permalink
Merge branch 'result_storage' of https://github.com/felixpetschko/scirpy
Browse files Browse the repository at this point in the history
 into result_storage
  • Loading branch information
felixpetschko committed Sep 18, 2024
2 parents bca630e + 78124f1 commit 6166305
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 20 deletions.
2 changes: 1 addition & 1 deletion src/scirpy/pl/_clonotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def _plot_clonotype_network_panel(
color_by_n_cells,
):
cell_indices = read_cell_indices(cell_indices)

colorbar_title = "mean per dot"
pie_colors = None
cat_colors = None
Expand Down
7 changes: 3 additions & 4 deletions src/scirpy/tests/test_ir_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import pytest
from mudata import MuData

from scirpy.util import read_cell_indices

from scirpy.pp import ir_dist
from scirpy.tl._ir_query import (
_reduce_json,
Expand All @@ -16,6 +14,7 @@
ir_query_annotate,
ir_query_annotate_df,
)
from scirpy.util import read_cell_indices


@pytest.mark.parametrize("metric", ["identity", "levenshtein"])
Expand All @@ -34,10 +33,10 @@ def test_ir_query(adata_cdr3, adata_cdr3_2, metric, key1, key2):

tmp_key2 = f"ir_query_TESTDB_aa_{metric}" if key2 is None else key2
tmp_ad = adata_cdr3.mod["airr"] if isinstance(adata_cdr3, MuData) else adata_cdr3

cell_indices = read_cell_indices(tmp_ad.uns[tmp_key2]["cell_indices"])
cell_indices_reference = read_cell_indices(tmp_ad.uns[tmp_key2]["cell_indices_reference"])

assert tmp_ad.uns[tmp_key2]["distances"].shape == (4, 3)
assert len(cell_indices) == 4
assert len(cell_indices_reference) == 3
Expand Down
6 changes: 2 additions & 4 deletions src/scirpy/tl/_clonotypes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import json
import random
from collections.abc import Sequence
from typing import Literal, cast
Expand All @@ -9,7 +10,6 @@
import scipy.sparse as sp
from anndata import AnnData
from scanpy import logging
import json

from scirpy.ir_dist import MetricType, _get_metric_key
from scirpy.ir_dist._clonotype_neighbors import ClonotypeNeighbors
Expand Down Expand Up @@ -604,9 +604,7 @@ def _graph_from_coordinates(adata: AnnData, clonotype_key: str, basis: str) -> t
# map the cell-id to the corresponding row/col in the clonotype distance matrix
cell_indices = read_cell_indices(clonotype_res["cell_indices"])
dist_idx, obs_names = zip(
*itertools.chain.from_iterable(
zip(itertools.repeat(i), obs_names) for i, obs_names in cell_indices.items()
),
*itertools.chain.from_iterable(zip(itertools.repeat(i), obs_names) for i, obs_names in cell_indices.items()),
strict=False,
)
dist_idx_lookup = pd.DataFrame(index=obs_names, data=dist_idx, columns=["dist_idx"])
Expand Down
3 changes: 1 addition & 2 deletions src/scirpy/tl/_ir_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
import numpy as np
import pandas as pd
from scanpy import logging
import json

from scirpy.ir_dist import MetricType, _get_metric_key
from scirpy.ir_dist._clonotype_neighbors import ClonotypeNeighbors
from scirpy.util import DataHandler, _is_na, tqdm, read_cell_indices
from scirpy.util import DataHandler, _is_na, read_cell_indices, tqdm

from ._clonotypes import _common_doc, _common_doc_parallelism, _doc_clonotype_definition, _validate_parameters

Expand Down
19 changes: 10 additions & 9 deletions src/scirpy/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import contextlib
import json
import os
import warnings
from collections.abc import Callable, Mapping, Sequence
from textwrap import dedent
from typing import Any, Optional, Union, cast, overload
from typing import Any, Literal, Optional, Union, cast, overload

import awkward as ak
import numpy as np
Expand All @@ -16,9 +17,6 @@
from scipy.sparse import issparse
from tqdm.auto import tqdm

from typing import Literal
import json

# reexport tqdm (here was previously a workaround for https://github.com/tqdm/tqdm/issues/1082)
__all__ = ["tqdm"]

Expand Down Expand Up @@ -609,17 +607,20 @@ def _get_usable_cpus(n_jobs: int = 0, use_numba: bool = False):

return usable_cpus

def read_cell_indices(cell_indices: Union[dict[str, np.ndarray[str]], str]) -> dict[str,list[str]]:

def read_cell_indices(cell_indices: dict[str, np.ndarray[str]] | str) -> dict[str, list[str]]:
"""
The datatype of the cell_indices Mapping (clonotype_id -> cell_ids) that is stored to the anndata.uns
attribute after the ´define_clonotype_clusters´ function has changed from dict[str, np.ndarray[str] to
str (json) due to performance considerations regarding the writing speed of the anndata object. But we still
want that older anndata objects with the dict[str, np.ndarray[str] datatype can be used. So we use this function
to read the cell_indices from the anndata object to support both formats.
"""
if(isinstance(cell_indices, str)): # new format
if isinstance(cell_indices, str): # new format
return json.loads(cell_indices)
elif(isinstance(cell_indices, dict)): # old format
elif isinstance(cell_indices, dict): # old format
return {k: v.tolist() for k, v in cell_indices.items()}
else: # unsupported format
raise TypeError(f"Unsupported type for cell_indices: {type(cell_indices)}. Expected str (json) or dict[str, np.ndarray[str]].")
else: # unsupported format
raise TypeError(
f"Unsupported type for cell_indices: {type(cell_indices)}. Expected str (json) or dict[str, np.ndarray[str]]."
)

0 comments on commit 6166305

Please sign in to comment.