Skip to content

Commit

Permalink
Remove __len__() for IterableDataset classes (#50)
Browse files Browse the repository at this point in the history
* Removed len and simplified

* Update warning

* Fix warning

* Remove unintentional flycheck files
  • Loading branch information
wfondrie authored Apr 29, 2024
1 parent 15d52f4 commit b8be2e2
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 22 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [v0.4.4]

### Changed
- Partially revert length changes to `SpectrumDataset` and `AnnotatedSpectrumDataset`. We removed `__len__` from both due to problems with PyTorch Lightning compatibility.
- Simplify dataset code by removing redundancy with `lance.pytorch.LanceDatset`.
- Improved warning message for skipped spectra.

## [v0.4.3]

### Changed
Expand Down
7 changes: 4 additions & 3 deletions depthcharge/data/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import logging
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable
from os import PathLike
Expand Down Expand Up @@ -223,10 +224,10 @@ def iter_batches(self, batch_size: int | None) -> pa.RecordBatch:
yield self._yield_batch()

if n_skipped:
LOGGER.warning(
"Skipped %d spectra with invalid information", n_skipped
warnings.warn(
f"Skipped {n_skipped} spectra with invalid information."
f"Last error was: \n {str(last_exc)}"
)
LOGGER.debug("Last error: %s", str(last_exc))

def _update_batch(self, entry: dict) -> None:
"""Update the batch.
Expand Down
28 changes: 14 additions & 14 deletions depthcharge/data/spectrum_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import copy
import logging
import math
import uuid
from collections.abc import Generator, Iterable
from os import PathLike
Expand Down Expand Up @@ -77,6 +76,8 @@ class SpectrumDataset(LanceDataset):
----------
peak_files : list of str
path : Path
n_spectra : int
dataset : lance.LanceDataset
"""

Expand Down Expand Up @@ -118,11 +119,11 @@ def __init__(
elif not self._path.exists():
raise ValueError("No spectra were provided")

self._dataset = lance.dataset(str(self._path))
dataset = lance.dataset(str(self._path))
if "to_tensor_fn" not in kwargs:
kwargs["to_tensor_fn"] = self._to_tensor

super().__init__(self._dataset, batch_size, **kwargs)
super().__init__(dataset, batch_size, **kwargs)

def add_spectra(
self,
Expand All @@ -144,7 +145,7 @@ def add_spectra(
"""
spectra = utils.listify(spectra)
batch = next(_get_records(spectra, **self._init_kwargs))
self._dataset = lance.write_dataset(
self.dataset = lance.write_dataset(
_get_records(spectra, **self._parse_kwargs),
self._path,
mode="append",
Expand All @@ -170,26 +171,23 @@ def __getitem__(self, idx: int) -> dict[str, Any]:
PyTorch tensors if the nested data type is compatible.
"""
return self._to_tensor(self._dataset.take(utils.listify(idx)))

def __len__(self) -> int:
"""The number of batches in the lance dataset."""
num = self._dataset.count_rows()
if self.samples:
num = min(self.samples, num)

return math.ceil(num / self.batch_size)
return self._to_tensor(self.dataset.take(utils.listify(idx)))

def __del__(self) -> None:
"""Cleanup the temporary directory."""
if self._tmpdir is not None:
self._tmpdir.cleanup()

@property
def n_spectra(self) -> int:
"""The number of spectra in the Lance dataset."""
return self.dataset.count_rows()

@property
def peak_files(self) -> list[str]:
"""The files currently in the lance dataset."""
return (
self._dataset.to_table(columns=["peak_file"])
self.dataset.to_table(columns=["peak_file"])
.column(0)
.unique()
.to_pylist()
Expand Down Expand Up @@ -320,6 +318,8 @@ class AnnotatedSpectrumDataset(SpectrumDataset):
----------
peak_files : list of str
path : Path
n_spectra : int
dataset : lance.LanceDataset
tokenizer : PeptideTokenizer
The tokenizer for the annotations.
annotations : str
Expand Down
8 changes: 4 additions & 4 deletions tests/unit_tests/test_data/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def tokenizer():
def test_addition(mgf_small, tmp_path):
"""Testing adding a file."""
dataset = SpectrumDataset(mgf_small, path=tmp_path / "test", batch_size=1)
assert len(dataset) == 2
assert dataset.n_spectra == 2

dataset = dataset.add_spectra(mgf_small)
assert len(dataset) == 4
assert dataset.n_spectra == 4


def test_indexing(tokenizer, mgf_small, tmp_path):
Expand Down Expand Up @@ -197,7 +197,7 @@ def test_pickle(tokenizer, tmp_path, mgf_small):
with pkl_file.open("rb") as pkl:
loaded = pickle.load(pkl)

assert len(dataset) == len(loaded)
assert dataset.n_spectra == loaded.n_spectra

dataset = AnnotatedSpectrumDataset(
[mgf_small],
Expand All @@ -214,4 +214,4 @@ def test_pickle(tokenizer, tmp_path, mgf_small):
with pkl_file.open("rb") as pkl:
loaded = pickle.load(pkl)

assert len(dataset) == len(loaded)
assert dataset.n_spectra == loaded.n_spectra
2 changes: 1 addition & 1 deletion tests/unit_tests/test_data/test_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_spectrum_loader_samples(mgf_small, tmp_path, samples, batches):
)

loaded = list(DataLoader(dset))
assert len(dset) == batches
assert dset.n_spectra == 2
assert len(loaded) == batches


Expand Down

0 comments on commit b8be2e2

Please sign in to comment.