Skip to content

Commit

Permalink
Type-fixes / errata (#18)
Browse files Browse the repository at this point in the history
* `pyproject.toml`: rm `[tool.setuptools_scm]`

Copypasta from Census: https://github.com/chanzuckerberg/cellxgene-census/blob/v1.16.2/api/python/cellxgene_census/pyproject.toml#L76-L77

* error msg typo fix, worker_info simplification

`mypy` (outside `pre-commit`) didn't like the `get_worker_info().num_workers`

* `test_pytorch.py`: `PipeClassType` / `IterableWrapperType` fixes
  • Loading branch information
ryan-williams authored Oct 31, 2024
1 parent d972a84 commit 2e9e5b1
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 46 deletions.
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@ version = {attr = "tiledbsoma_ml.__version__"}
[tool.setuptools.package-data]
"tiledbsoma_ml" = ["py.typed"]

[tool.setuptools_scm]
root = "../../.."

[tool.mypy]
show_error_codes = true
ignore_missing_imports = true
Expand Down
18 changes: 8 additions & 10 deletions src/tiledbsoma_ml/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,15 +328,13 @@ def __iter__(self) -> Iterator[XObsDatum]:
experimental
"""

if (
self.return_sparse_X
and torch.utils.data.get_worker_info()
and torch.utils.data.get_worker_info().num_workers > 0
):
raise NotImplementedError(
"torch does not work with sparse tensors in multi-processing mode "
"(see https://github.com/pytorch/pytorch/issues/20248)"
)
if self.return_sparse_X:
worker_info = torch.utils.data.get_worker_info()
if worker_info and worker_info.num_workers > 0:
raise NotImplementedError(
"torch does not work with sparse tensors in multi-processing mode "
"(see https://github.com/pytorch/pytorch/issues/20248)"
)

world_size, rank = _get_distributed_world_rank()
n_workers, worker_id = _get_worker_world_rank()
Expand Down Expand Up @@ -426,7 +424,7 @@ def set_epoch(self, epoch: int) -> None:

def __getitem__(self, index: int) -> XObsDatum:
raise NotImplementedError(
"``ExperimentAxisQueryIterable can only be iterated - does not support mapping"
"`ExperimentAxisQueryIterable` can only be iterated - does not support mapping"
)

def _io_batch_iter(
Expand Down
58 changes: 25 additions & 33 deletions tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
from functools import partial
from pathlib import Path
from typing import Any, Callable, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union
from unittest.mock import patch

import numpy as np
Expand All @@ -35,15 +35,21 @@

# These control which classes are tested (for most, but not all tests).
# Centralized to allow easy add/delete of specific test parameters.
PipeClassType = Union[
ExperimentAxisQueryIterable,
IterableWrapperType = Union[
Type[ExperimentAxisQueryIterDataPipe],
Type[ExperimentAxisQueryIterableDataset],
]
IterableWrappers = (
ExperimentAxisQueryIterDataPipe,
ExperimentAxisQueryIterableDataset,
)
PipeClassType = Union[
Type[ExperimentAxisQueryIterable],
IterableWrapperType,
]
PipeClasses = (
ExperimentAxisQueryIterable,
ExperimentAxisQueryIterDataPipe,
ExperimentAxisQueryIterableDataset,
*IterableWrappers,
)
XValueGen = Callable[[range, range], spmatrix]

Expand Down Expand Up @@ -450,24 +456,22 @@ def test_batching__partial_soma_batches_are_concatenated(
@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen", [(6, 3, pytorch_x_value_gen)]
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
@pytest.mark.parametrize("PipeClass", IterableWrappers)
def test_multiprocessing__returns_full_result(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
PipeClass: IterableWrapperType,
soma_experiment: Experiment,
) -> None:
"""Tests the ExperimentAxisQueryIterDataPipe provides all data, as collected from multiple processes that are managed by a
PyTorch DataLoader with multiple workers configured."""
"""Tests that ``ExperimentAxisQueryIterDataPipe`` / ``ExperimentAxisQueryIterableDataset``
provide all data, as collected from multiple processes that are managed by a PyTorch DataLoader
with multiple workers configured."""
with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = PipeClass(
query,
X_name="raw",
obs_column_names=["soma_joinid", "label"],
io_batch_size=3, # two chunks, one per worker
)
# Note we're testing the ExperimentAxisQueryIterDataPipe via a DataLoader, since this is what sets up the multiprocessing
# Wrap with a DataLoader, which sets up the multiprocessing
dl = experiment_dataloader(dp, num_workers=2)

full_result = list(iter(dl))
Expand Down Expand Up @@ -593,12 +597,9 @@ def test_distributed_and_multiprocessing__returns_data_partition_for_rank(
"obs_range,var_range,X_value_gen,use_eager_fetch",
[(3, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)],
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
@pytest.mark.parametrize("PipeClass", IterableWrappers)
def test_experiment_dataloader__non_batched(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
PipeClass: IterableWrapperType,
soma_experiment: Experiment,
use_eager_fetch: bool,
) -> None:
Expand All @@ -624,12 +625,9 @@ def test_experiment_dataloader__non_batched(
"obs_range,var_range,X_value_gen,use_eager_fetch",
[(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)],
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
@pytest.mark.parametrize("PipeClass", IterableWrappers)
def test_experiment_dataloader__batched(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
PipeClass: IterableWrapperType,
soma_experiment: Experiment,
use_eager_fetch: bool,
) -> None:
Expand All @@ -656,12 +654,9 @@ def test_experiment_dataloader__batched(
for use_eager_fetch in (True, False)
],
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
@pytest.mark.parametrize("PipeClass", IterableWrappers)
def test_experiment_dataloader__batched_length(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
PipeClass: IterableWrapperType,
soma_experiment: Experiment,
use_eager_fetch: bool,
) -> None:
Expand All @@ -682,12 +677,9 @@ def test_experiment_dataloader__batched_length(
"obs_range,var_range,X_value_gen,batch_size",
[(10, 3, pytorch_x_value_gen, batch_size) for batch_size in (1, 3, 10)],
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
@pytest.mark.parametrize("PipeClass", IterableWrappers)
def test_experiment_dataloader__collate_fn(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
PipeClass: IterableWrapperType,
soma_experiment: Experiment,
batch_size: int,
) -> None:
Expand Down

0 comments on commit 2e9e5b1

Please sign in to comment.