From f1554f1d9f81be10078133b320c9f13212b083ba Mon Sep 17 00:00:00 2001 From: Martin Kim Date: Thu, 21 Mar 2024 14:22:52 -0700 Subject: [PATCH] [feat] option to return csr tensors in datapipe --- .../experimental/ml/pytorch.py | 35 +++++++++++-------- .../tests/experimental/ml/test_pytorch.py | 16 ++++++--- 2 files changed, 32 insertions(+), 19 deletions(-) diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py index 748bef804..75fd6ed1f 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py @@ -5,7 +5,7 @@ from datetime import timedelta from math import ceil from time import time -from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple +from typing import Any, Dict, Iterator, List, Literal, Optional, Sequence, Tuple import numpy as np import numpy.typing as npt @@ -226,7 +226,7 @@ def __init__( batch_size: int, encoders: Dict[str, LabelEncoder], stats: Stats, - return_sparse_X: bool, + X_format: Literal["dense", "csr", "coo"], use_eager_fetch: bool, shuffle_rng: Optional[Generator] = None, ) -> None: @@ -238,7 +238,7 @@ def __init__( self.soma_chunk = None self.var_joinids = var_joinids self.batch_size = batch_size - self.return_sparse_X = return_sparse_X + self.X_format = X_format self.encoders = encoders self.stats = stats self.max_process_mem_usage_bytes = 0 @@ -272,8 +272,15 @@ def __next__(self) -> ObsAndXDatum: # `to_numpy()` avoids copying the numpy array data obs_tensor = torch.from_numpy(obs_encoded.to_numpy()) - if not self.return_sparse_X: + if self.X_format == "dense": X_tensor = torch.from_numpy(X.todense()) + elif self.X_format == "csr": + X_tensor = torch.sparse_csr_tensor( + crow_indices=torch.as_tensor(X.indptr), + col_indices=torch.as_tensor(X.indices), + values=torch.as_tensor(X.data), + size=X.shape, + ) else: coo = X.tocoo() @@ -350,7 +357,7 @@ class ExperimentDataPipe(pipes.IterDataPipe[Dataset[ObsAndXDatum]]): # type: ig [2416, 0, 4], [2417, 0, 3]], dtype=torch.int64)) - The ``return_sparse_X`` parameter controls whether the ``X`` data is returned as a dense or sparse + The ``X_format`` parameter controls whether the ``X`` data is returned as a dense or sparse :class:`torch.Tensor`. If the model supports use of sparse :class:`torch.Tensor`\ s, this will reduce memory usage. The ``obs_column_names`` parameter determines the data columns that are returned in the ``obs`` Tensor. The first @@ -390,7 +397,7 @@ def __init__( batch_size: int = 1, shuffle: bool = False, seed: Optional[int] = None, - return_sparse_X: bool = False, + X_format: Literal["dense", "csr", "coo"] = "dense", soma_chunk_size: Optional[int] = None, use_eager_fetch: bool = True, ) -> None: @@ -433,11 +440,11 @@ def __init__( The random seed used for shuffling. Defaults to ``None`` (no seed). This *must* be specified when using :class:`torch.nn.parallel.DistributedDataParallel` to ensure data partitions are disjoint across worker processes. - return_sparse_X: - Controls whether the ``X`` data is returned as a dense or sparse :class:`torch.Tensor`. As ``X`` data is - very sparse, setting this to ``True`` will reduce memory usage, if the model supports use of sparse - :class:`torch.Tensor`\ s. Defaults to ``False``, since sparse :class:`torch.Tensor`\ s are still - experimental in PyTorch. + X_format: + Controls whether the ``X`` data is returned as a dense or sparse :class:`torch.Tensor`. Must be one of + ``"dense"``, ``"csr"``, or ``"coo"``. As ``X`` data is very sparse, setting this to ``"coo"`` or + ``"csr"`` will reduce memory usage, if the model supports use of sparse :class:`torch.Tensor`\ s. + Defaults to ``"dense"``, since sparse :class:`torch.Tensor`\ s are still experimental in PyTorch. soma_chunk_size: The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts two aspects of this class's behavior: 1) The maximum memory utilization, with larger values providing @@ -463,7 +470,7 @@ def __init__( self.var_query = var_query self.obs_column_names = obs_column_names self.batch_size = batch_size - self.return_sparse_X = return_sparse_X + self.X_format = X_format self.soma_chunk_size = soma_chunk_size self.use_eager_fetch = use_eager_fetch self._stats = Stats() @@ -545,7 +552,7 @@ def __iter__(self) -> Iterator[ObsAndXDatum]: pytorch_logger.debug(f"Using {self.soma_chunk_size=}") if ( - self.return_sparse_X + self.X_format != "dense" and torch.utils.data.get_worker_info() and torch.utils.data.get_worker_info().num_workers > 0 ): @@ -583,7 +590,7 @@ def __iter__(self) -> Iterator[ObsAndXDatum]: batch_size=self.batch_size, encoders=self.obs_encoders, stats=self._stats, - return_sparse_X=self.return_sparse_X, + X_format=self.X_format, use_eager_fetch=self.use_eager_fetch, shuffle_rng=self._shuffle_rng, ) diff --git a/api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py b/api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py index f87b282cc..7cdfe7969 100644 --- a/api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py +++ b/api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py @@ -1,6 +1,6 @@ import pathlib import sys -from typing import Callable, List, Optional, Sequence, Union +from typing import Callable, List, Literal, Optional, Sequence, Union from unittest.mock import patch import numpy as np @@ -278,13 +278,16 @@ def test_batching__empty_query_result(soma_experiment: Experiment, use_eager_fet "obs_range,var_range,X_value_gen,use_eager_fetch", [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) -def test_sparse_output__non_batched(soma_experiment: Experiment, use_eager_fetch: bool) -> None: +@pytest.mark.parametrize("X_format", ("coo", "csr")) +def test_sparse_output__non_batched( + soma_experiment: Experiment, use_eager_fetch: bool, X_format: Literal["dense", "csr", "coo"] +) -> None: exp_data_pipe = ExperimentDataPipe( soma_experiment, measurement_name="RNA", X_name="raw", obs_column_names=["label"], - return_sparse_X=True, + X_format=X_format, use_eager_fetch=use_eager_fetch, ) batch_iter = iter(exp_data_pipe) @@ -300,14 +303,17 @@ def test_sparse_output__non_batched(soma_experiment: Experiment, use_eager_fetch "obs_range,var_range,X_value_gen,use_eager_fetch", [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) -def test_sparse_output__batched(soma_experiment: Experiment, use_eager_fetch: bool) -> None: +@pytest.mark.parametrize("X_format", ("coo", "csr")) +def test_sparse_output__batched( + soma_experiment: Experiment, use_eager_fetch: bool, X_format: Literal["dense", "csr", "coo"] +) -> None: exp_data_pipe = ExperimentDataPipe( soma_experiment, measurement_name="RNA", X_name="raw", obs_column_names=["label"], batch_size=3, - return_sparse_X=True, + X_format=X_format, use_eager_fetch=use_eager_fetch, ) batch_iter = iter(exp_data_pipe)