diff --git a/pyproject.toml b/pyproject.toml index bfa5f53..6a60417 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,9 +56,6 @@ version = {attr = "tiledbsoma_ml.__version__"} [tool.setuptools.package-data] "tiledbsoma_ml" = ["py.typed"] -[tool.setuptools_scm] -root = "../../.." - [tool.mypy] show_error_codes = true ignore_missing_imports = true diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index c81798d..a5561fb 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -328,15 +328,13 @@ def __iter__(self) -> Iterator[XObsDatum]: experimental """ - if ( - self.return_sparse_X - and torch.utils.data.get_worker_info() - and torch.utils.data.get_worker_info().num_workers > 0 - ): - raise NotImplementedError( - "torch does not work with sparse tensors in multi-processing mode " - "(see https://github.com/pytorch/pytorch/issues/20248)" - ) + if self.return_sparse_X: + worker_info = torch.utils.data.get_worker_info() + if worker_info and worker_info.num_workers > 0: + raise NotImplementedError( + "torch does not work with sparse tensors in multi-processing mode " + "(see https://github.com/pytorch/pytorch/issues/20248)" + ) world_size, rank = _get_distributed_world_rank() n_workers, worker_id = _get_worker_world_rank() @@ -426,7 +424,7 @@ def set_epoch(self, epoch: int) -> None: def __getitem__(self, index: int) -> XObsDatum: raise NotImplementedError( - "``ExperimentAxisQueryIterable can only be iterated - does not support mapping" + "`ExperimentAxisQueryIterable` can only be iterated - does not support mapping" ) def _io_batch_iter( diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 89b56d9..05bf6ca 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -8,7 +8,7 @@ import sys from functools import partial from pathlib import Path -from typing import Any, Callable, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union from unittest.mock import patch import numpy as np @@ -35,15 +35,21 @@ # These control which classes are tested (for most, but not all tests). # Centralized to allow easy add/delete of specific test parameters. -PipeClassType = Union[ - ExperimentAxisQueryIterable, +IterableWrapperType = Union[ + Type[ExperimentAxisQueryIterDataPipe], + Type[ExperimentAxisQueryIterableDataset], +] +IterableWrappers = ( ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset, +) +PipeClassType = Union[ + Type[ExperimentAxisQueryIterable], + IterableWrapperType, ] PipeClasses = ( ExperimentAxisQueryIterable, - ExperimentAxisQueryIterDataPipe, - ExperimentAxisQueryIterableDataset, + *IterableWrappers, ) XValueGen = Callable[[range, range], spmatrix] @@ -450,16 +456,14 @@ def test_batching__partial_soma_batches_are_concatenated( @pytest.mark.parametrize( "obs_range,var_range,X_value_gen", [(6, 3, pytorch_x_value_gen)] ) -@pytest.mark.parametrize( - "PipeClass", - (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset), -) +@pytest.mark.parametrize("PipeClass", IterableWrappers) def test_multiprocessing__returns_full_result( - PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: IterableWrapperType, soma_experiment: Experiment, ) -> None: - """Tests the ExperimentAxisQueryIterDataPipe provides all data, as collected from multiple processes that are managed by a - PyTorch DataLoader with multiple workers configured.""" + """Tests that ``ExperimentAxisQueryIterDataPipe`` / ``ExperimentAxisQueryIterableDataset`` + provide all data, as collected from multiple processes that are managed by a PyTorch DataLoader + with multiple workers configured.""" with soma_experiment.axis_query(measurement_name="RNA") as query: dp = PipeClass( query, @@ -467,7 +471,7 @@ def test_multiprocessing__returns_full_result( obs_column_names=["soma_joinid", "label"], io_batch_size=3, # two chunks, one per worker ) - # Note we're testing the ExperimentAxisQueryIterDataPipe via a DataLoader, since this is what sets up the multiprocessing + # Wrap with a DataLoader, which sets up the multiprocessing dl = experiment_dataloader(dp, num_workers=2) full_result = list(iter(dl)) @@ -593,12 +597,9 @@ def test_distributed_and_multiprocessing__returns_data_partition_for_rank( "obs_range,var_range,X_value_gen,use_eager_fetch", [(3, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) -@pytest.mark.parametrize( - "PipeClass", - (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset), -) +@pytest.mark.parametrize("PipeClass", IterableWrappers) def test_experiment_dataloader__non_batched( - PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: IterableWrapperType, soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -624,12 +625,9 @@ def test_experiment_dataloader__non_batched( "obs_range,var_range,X_value_gen,use_eager_fetch", [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) -@pytest.mark.parametrize( - "PipeClass", - (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset), -) +@pytest.mark.parametrize("PipeClass", IterableWrappers) def test_experiment_dataloader__batched( - PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: IterableWrapperType, soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -656,12 +654,9 @@ def test_experiment_dataloader__batched( for use_eager_fetch in (True, False) ], ) -@pytest.mark.parametrize( - "PipeClass", - (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset), -) +@pytest.mark.parametrize("PipeClass", IterableWrappers) def test_experiment_dataloader__batched_length( - PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: IterableWrapperType, soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -682,12 +677,9 @@ def test_experiment_dataloader__batched_length( "obs_range,var_range,X_value_gen,batch_size", [(10, 3, pytorch_x_value_gen, batch_size) for batch_size in (1, 3, 10)], ) -@pytest.mark.parametrize( - "PipeClass", - (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset), -) +@pytest.mark.parametrize("PipeClass", IterableWrappers) def test_experiment_dataloader__collate_fn( - PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: IterableWrapperType, soma_experiment: Experiment, batch_size: int, ) -> None: