Skip to content

Commit

Permalink
[feat] option to return csr tensors in datapipe
Browse files Browse the repository at this point in the history
  • Loading branch information
martinkim0 committed Mar 21, 2024
1 parent 90cc9eb commit f1554f1
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datetime import timedelta
from math import ceil
from time import time
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple
from typing import Any, Dict, Iterator, List, Literal, Optional, Sequence, Tuple

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -226,7 +226,7 @@ def __init__(
batch_size: int,
encoders: Dict[str, LabelEncoder],
stats: Stats,
return_sparse_X: bool,
X_format: Literal["dense", "csr", "coo"],
use_eager_fetch: bool,
shuffle_rng: Optional[Generator] = None,
) -> None:
Expand All @@ -238,7 +238,7 @@ def __init__(
self.soma_chunk = None
self.var_joinids = var_joinids
self.batch_size = batch_size
self.return_sparse_X = return_sparse_X
self.X_format = X_format
self.encoders = encoders
self.stats = stats
self.max_process_mem_usage_bytes = 0
Expand Down Expand Up @@ -272,8 +272,15 @@ def __next__(self) -> ObsAndXDatum:
# `to_numpy()` avoids copying the numpy array data
obs_tensor = torch.from_numpy(obs_encoded.to_numpy())

if not self.return_sparse_X:
if self.X_format == "dense":
X_tensor = torch.from_numpy(X.todense())
elif self.X_format == "csr":
X_tensor = torch.sparse_csr_tensor(
crow_indices=torch.as_tensor(X.indptr),
col_indices=torch.as_tensor(X.indices),
values=torch.as_tensor(X.data),
size=X.shape,
)
else:
coo = X.tocoo()

Expand Down Expand Up @@ -350,7 +357,7 @@ class ExperimentDataPipe(pipes.IterDataPipe[Dataset[ObsAndXDatum]]): # type: ig
[2416, 0, 4],
[2417, 0, 3]], dtype=torch.int64))
The ``return_sparse_X`` parameter controls whether the ``X`` data is returned as a dense or sparse
The ``X_format`` parameter controls whether the ``X`` data is returned as a dense or sparse
:class:`torch.Tensor`. If the model supports use of sparse :class:`torch.Tensor`\ s, this will reduce memory usage.
The ``obs_column_names`` parameter determines the data columns that are returned in the ``obs`` Tensor. The first
Expand Down Expand Up @@ -390,7 +397,7 @@ def __init__(
batch_size: int = 1,
shuffle: bool = False,
seed: Optional[int] = None,
return_sparse_X: bool = False,
X_format: Literal["dense", "csr", "coo"] = "dense",
soma_chunk_size: Optional[int] = None,
use_eager_fetch: bool = True,
) -> None:
Expand Down Expand Up @@ -433,11 +440,11 @@ def __init__(
The random seed used for shuffling. Defaults to ``None`` (no seed). This *must* be specified when using
:class:`torch.nn.parallel.DistributedDataParallel` to ensure data partitions are disjoint across worker
processes.
return_sparse_X:
Controls whether the ``X`` data is returned as a dense or sparse :class:`torch.Tensor`. As ``X`` data is
very sparse, setting this to ``True`` will reduce memory usage, if the model supports use of sparse
:class:`torch.Tensor`\ s. Defaults to ``False``, since sparse :class:`torch.Tensor`\ s are still
experimental in PyTorch.
X_format:
Controls whether the ``X`` data is returned as a dense or sparse :class:`torch.Tensor`. Must be one of
``"dense"``, ``"csr"``, or ``"coo"``. As ``X`` data is very sparse, setting this to ``"coo"`` or
``"csr"`` will reduce memory usage, if the model supports use of sparse :class:`torch.Tensor`\ s.
Defaults to ``"dense"``, since sparse :class:`torch.Tensor`\ s are still experimental in PyTorch.
soma_chunk_size:
The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts two aspects of
this class's behavior: 1) The maximum memory utilization, with larger values providing
Expand All @@ -463,7 +470,7 @@ def __init__(
self.var_query = var_query
self.obs_column_names = obs_column_names
self.batch_size = batch_size
self.return_sparse_X = return_sparse_X
self.X_format = X_format
self.soma_chunk_size = soma_chunk_size
self.use_eager_fetch = use_eager_fetch
self._stats = Stats()
Expand Down Expand Up @@ -545,7 +552,7 @@ def __iter__(self) -> Iterator[ObsAndXDatum]:
pytorch_logger.debug(f"Using {self.soma_chunk_size=}")

if (
self.return_sparse_X
self.X_format != "dense"
and torch.utils.data.get_worker_info()
and torch.utils.data.get_worker_info().num_workers > 0
):
Expand Down Expand Up @@ -583,7 +590,7 @@ def __iter__(self) -> Iterator[ObsAndXDatum]:
batch_size=self.batch_size,
encoders=self.obs_encoders,
stats=self._stats,
return_sparse_X=self.return_sparse_X,
X_format=self.X_format,
use_eager_fetch=self.use_eager_fetch,
shuffle_rng=self._shuffle_rng,
)
Expand Down
16 changes: 11 additions & 5 deletions api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pathlib
import sys
from typing import Callable, List, Optional, Sequence, Union
from typing import Callable, List, Literal, Optional, Sequence, Union
from unittest.mock import patch

import numpy as np
Expand Down Expand Up @@ -278,13 +278,16 @@ def test_batching__empty_query_result(soma_experiment: Experiment, use_eager_fet
"obs_range,var_range,X_value_gen,use_eager_fetch",
[(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)],
)
def test_sparse_output__non_batched(soma_experiment: Experiment, use_eager_fetch: bool) -> None:
@pytest.mark.parametrize("X_format", ("coo", "csr"))
def test_sparse_output__non_batched(
soma_experiment: Experiment, use_eager_fetch: bool, X_format: Literal["dense", "csr", "coo"]
) -> None:
exp_data_pipe = ExperimentDataPipe(
soma_experiment,
measurement_name="RNA",
X_name="raw",
obs_column_names=["label"],
return_sparse_X=True,
X_format=X_format,
use_eager_fetch=use_eager_fetch,
)
batch_iter = iter(exp_data_pipe)
Expand All @@ -300,14 +303,17 @@ def test_sparse_output__non_batched(soma_experiment: Experiment, use_eager_fetch
"obs_range,var_range,X_value_gen,use_eager_fetch",
[(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)],
)
def test_sparse_output__batched(soma_experiment: Experiment, use_eager_fetch: bool) -> None:
@pytest.mark.parametrize("X_format", ("coo", "csr"))
def test_sparse_output__batched(
soma_experiment: Experiment, use_eager_fetch: bool, X_format: Literal["dense", "csr", "coo"]
) -> None:
exp_data_pipe = ExperimentDataPipe(
soma_experiment,
measurement_name="RNA",
X_name="raw",
obs_column_names=["label"],
batch_size=3,
return_sparse_X=True,
X_format=X_format,
use_eager_fetch=use_eager_fetch,
)
batch_iter = iter(exp_data_pipe)
Expand Down

0 comments on commit f1554f1

Please sign in to comment.