From 723bcfbafae14c137ebad644d084ca4a78f99989 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sat, 2 Aug 2025 15:29:18 +0200 Subject: [PATCH 1/8] Add support to sample more negatives --- cebra/data/base.py | 16 +++++++++-- cebra/data/multi_session.py | 10 ++++--- cebra/data/multiobjective.py | 12 ++++++--- cebra/data/single_session.py | 52 ++++++++++++++++++++++++++++-------- 4 files changed, 70 insertions(+), 20 deletions(-) diff --git a/cebra/data/base.py b/cebra/data/base.py index 51199cec..92f82a6f 100644 --- a/cebra/data/base.py +++ b/cebra/data/base.py @@ -239,6 +239,12 @@ class Loader(abc.ABC, cebra.io.HasDevice): batch_size: int = dataclasses.field(default=None, doc="""The total batch size.""") + num_negatives: int = dataclasses.field( + default=None, + doc="""The number of negative samples to draw for each reference. + If not specified, the batch size is used.""" + ) + def __post_init__(self): if self.num_steps is None or self.num_steps <= 0: raise ValueError( @@ -255,11 +261,12 @@ def __len__(self): def __iter__(self) -> Batch: for _ in range(len(self)): - index = self.get_indices(num_samples=self.batch_size) + index = self.get_indices(num_samples=self.batch_size, + num_negatives=self.num_negatives) yield self.dataset.load_batch(index) @abc.abstractmethod - def get_indices(self, num_samples: int): + def get_indices(self, num_samples: int, num_negatives: int = None): """Sample and return the specified number of indices. The elements of the returned `BatchIndex` will be used to index the @@ -271,5 +278,10 @@ def get_indices(self, num_samples: int): Returns: batch indices for the reference, positive and negative sample. + + + Note: + From version 0.7.0 onwards, `num_negatives` parameter was added to allow + specifying a different number of negative samples compared to the batch size. """ raise NotImplementedError() diff --git a/cebra/data/multi_session.py b/cebra/data/multi_session.py index f33ad6ec..9be1e5c4 100644 --- a/cebra/data/multi_session.py +++ b/cebra/data/multi_session.py @@ -155,10 +155,14 @@ def __post_init__(self): super().__post_init__() self.sampler = cebra.distributions.MultisessionSampler( self.dataset, self.time_offset) + if self.num_negatives is None: + self.num_negatives = self.batch_size - def get_indices(self, num_samples: int) -> List[BatchIndex]: + # NOTE(stes): In the longer run, we need to unify the API here; the num_samples argument + # is not used in the multi-session case, which is different to the single session samples. + def get_indices(self, num_samples) -> List[BatchIndex]: ref_idx = self.sampler.sample_prior(self.batch_size) - neg_idx = self.sampler.sample_prior(self.batch_size) + neg_idx = self.sampler.sample_prior(self.num_negatives) pos_idx, idx, idx_rev = self.sampler.sample_conditional(ref_idx) ref_idx = torch.from_numpy(ref_idx) @@ -251,7 +255,7 @@ def get_indices(self, num_samples: int) -> BatchIndex: Batch indices for the reference, positive and negative samples. """ ref_idx = self.sampler.sample_prior(self.batch_size) - neg_idx = self.sampler.sample_prior(self.batch_size) + neg_idx = self.sampler.sample_prior(self.num_negatives) pos_idx = self.sampler.sample_conditional(ref_idx) diff --git a/cebra/data/multiobjective.py b/cebra/data/multiobjective.py index f700d1c4..e7018f08 100644 --- a/cebra/data/multiobjective.py +++ b/cebra/data/multiobjective.py @@ -71,7 +71,7 @@ def __post_init__(self): def add_config(self, config): self.labels.append(config['label']) - def get_indices(self, num_samples: int): + def get_indices(self, num_samples: int, num_negatives: int = None): if self.sampling_mode_supervised == "ref_shared": reference_idx = self.prior.sample_prior(num_samples) else: @@ -142,11 +142,14 @@ def add_config(self, config): self.distributions.append(distribution) - def get_indices(self, num_samples: int): + def get_indices(self, num_samples: int, num_negatives: int = None): """Sample and return the specified number of indices.""" + if num_negatives is None: + num_negatives = num_samples + if self.sampling_mode_contrastive == "refneg_shared": - ref_and_neg = self.prior.sample_prior(num_samples * 2) + ref_and_neg = self.prior.sample_prior(num_samples + num_negatives) reference_idx = ref_and_neg[:num_samples] negative_idx = ref_and_neg[num_samples:] @@ -169,5 +172,6 @@ def get_indices(self, num_samples: int): def __iter__(self): for _ in range(len(self)): - index = self.get_indices(num_samples=self.batch_size) + index = self.get_indices(num_samples=self.batch_size, + num_negatives=self.num_negatives) yield self.dataset.load_batch_contrastive(index) diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index 7e4ad2fd..0e59183b 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -138,7 +138,9 @@ def _init_distribution(self): f"Invalid choice of prior distribution. Got '{self.prior}', but " f"only accept 'uniform' or 'empirical' as potential values.") - def get_indices(self, num_samples: int) -> BatchIndex: + def get_indices(self, + num_samples: int, + num_negatives: int = None) -> BatchIndex: """Samples indices for reference, positive and negative examples. The reference samples will be sampled from the empirical or uniform prior @@ -154,11 +156,16 @@ def get_indices(self, num_samples: int) -> BatchIndex: Args: num_samples: The number of samples (batch size) of the returned :py:class:`cebra.data.datatypes.BatchIndex`. + num_negatives: The number of negative samples. If None, defaults to num_samples. Returns: Indices for reference, positive and negatives samples. """ - reference_idx = self.distribution.sample_prior(num_samples * 2) + if num_negatives is None: + num_negatives = num_samples + + reference_idx = self.distribution.sample_prior(num_samples + + num_negatives) negative_idx = reference_idx[num_samples:] reference_idx = reference_idx[:num_samples] reference = self.index[reference_idx] @@ -246,7 +253,9 @@ def _init_distribution(self): else: raise ValueError(self.conditional) - def get_indices(self, num_samples: int) -> BatchIndex: + def get_indices(self, + num_samples: int, + num_negatives: int = None) -> BatchIndex: """Samples indices for reference, positive and negative examples. The reference and negative samples will be sampled uniformly from @@ -262,7 +271,11 @@ def get_indices(self, num_samples: int) -> BatchIndex: Returns: Indices for reference, positive and negatives samples. """ - reference_idx = self.distribution.sample_prior(num_samples * 2) + if num_negatives is None: + num_negatives = num_samples + + reference_idx = self.distribution.sample_prior(num_samples + + num_negatives) negative_idx = reference_idx[num_samples:] reference_idx = reference_idx[:num_samples] positive_idx = self.distribution.sample_conditional(reference_idx) @@ -305,7 +318,9 @@ def __post_init__(self): continuous=self.cindex, time_delta=self.time_offset) - def get_indices(self, num_samples: int) -> BatchIndex: + def get_indices(self, + num_samples: int, + num_negatives: int = None) -> BatchIndex: """Samples indices for reference, positive and negative examples. The reference and negative samples will be sampled uniformly from @@ -319,6 +334,7 @@ def get_indices(self, num_samples: int) -> BatchIndex: Args: num_samples: The number of samples (batch size) of the returned :py:class:`cebra.data.datatypes.BatchIndex`. + num_negatives: The number of negative samples. If None, defaults to num_samples. Returns: Indices for reference, positive and negatives samples. @@ -328,10 +344,16 @@ def get_indices(self, num_samples: int) -> BatchIndex: class. - Sample the negatives with matching discrete variable """ - reference_idx = self.distribution.sample_prior(num_samples) + if num_negatives is None: + num_negatives = num_samples + + reference_idx = self.distribution.sample_prior(num_samples + + num_negatives) + negative_idx = reference_idx[num_samples:] + reference_idx = reference_idx[:num_samples] return BatchIndex( reference=reference_idx, - negative=self.distribution.sample_prior(num_samples), + negative=negative_idx, positive=self.distribution.sample_conditional(reference_idx), ) @@ -421,11 +443,13 @@ def _init_time_distribution(self): else: raise ValueError - def get_indices(self, num_samples: int) -> BatchIndex: + def get_indices(self, + num_samples: int, + num_negatives: int = None) -> BatchIndex: """Samples indices for reference, positive and negative examples. The reference and negative samples will be sampled uniformly from - all available time steps, and a total of ``2*num_samples`` will be + all available time steps, and a total of ``num_samples + num_negatives`` will be returned for both. For the positive samples, ``num_samples`` are sampled according to the @@ -436,6 +460,7 @@ def get_indices(self, num_samples: int) -> BatchIndex: Args: num_samples: The number of samples (batch size) of the returned :py:class:`cebra.data.datatypes.BatchIndex`. + num_negatives: The number of negative samples. If None, defaults to num_samples. Returns: Indices for reference, positive and negatives samples. @@ -444,7 +469,11 @@ def get_indices(self, num_samples: int) -> BatchIndex: Add the ``empirical`` vs. ``discrete`` sampling modes to this class. """ - reference_idx = self.time_distribution.sample_prior(num_samples * 2) + if num_negatives is None: + num_negatives = num_samples + + reference_idx = self.time_distribution.sample_prior(num_samples + + num_negatives) negative_idx = reference_idx[num_samples:] reference_idx = reference_idx[:num_samples] behavior_positive_idx = self.behavior_distribution.sample_conditional( @@ -470,7 +499,7 @@ def __post_init__(self): def offset(self): return self.dataset.offset - def get_indices(self, num_samples=None) -> BatchIndex: + def get_indices(self, num_samples=None, num_negatives=None) -> BatchIndex: """Samples indices for reference, positive and negative examples. The reference indices are all available (valid, according to the @@ -491,6 +520,7 @@ def get_indices(self, num_samples=None) -> BatchIndex: class. """ assert num_samples is None + assert num_negatives is None reference_idx = torch.arange( self.offset.left, From dbabb6ef3214e78852467ddaa123d7f045270008 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sat, 2 Aug 2025 15:35:23 +0200 Subject: [PATCH 2/8] fix missing arg --- cebra/data/multi_session.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cebra/data/multi_session.py b/cebra/data/multi_session.py index 9be1e5c4..aa0b245b 100644 --- a/cebra/data/multi_session.py +++ b/cebra/data/multi_session.py @@ -160,7 +160,9 @@ def __post_init__(self): # NOTE(stes): In the longer run, we need to unify the API here; the num_samples argument # is not used in the multi-session case, which is different to the single session samples. - def get_indices(self, num_samples) -> List[BatchIndex]: + def get_indices(self, + num_samples: int, + num_negatives: int = None) -> List[BatchIndex]: ref_idx = self.sampler.sample_prior(self.batch_size) neg_idx = self.sampler.sample_prior(self.num_negatives) pos_idx, idx, idx_rev = self.sampler.sample_conditional(ref_idx) From 540b006eeacd53636a2e18e0c3fc1e9bd8df45a7 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sat, 2 Aug 2025 15:39:53 +0200 Subject: [PATCH 3/8] Fix multi-session samplers --- cebra/data/multi_session.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/cebra/data/multi_session.py b/cebra/data/multi_session.py index aa0b245b..e893fa77 100644 --- a/cebra/data/multi_session.py +++ b/cebra/data/multi_session.py @@ -198,8 +198,11 @@ class DiscreteMultiSessionDataLoader(MultiSessionLoader): # Overwrite sampler with the discrete implementation # Generalize MultisessionSampler to avoid doing this? def __post_init__(self): + # NOTE(stes): __post_init__ from superclass is intentionally not called. self.sampler = cebra.distributions.DiscreteMultisessionSampler( self.dataset) + if self.num_negatives is None: + self.num_negatives = self.batch_size @property def index(self): @@ -235,7 +238,9 @@ def __post_init__(self): self.sampler = cebra.distributions.UnifiedSampler( self.dataset, self.time_offset) - def get_indices(self, num_samples: int) -> BatchIndex: + def get_indices(self, + num_samples: int, + num_negatives: int = None) -> BatchIndex: """Sample and return the specified number of indices. The elements of the returned ``BatchIndex`` will be used to index the From 07212f2e8f52d1986862f9fe249a5e1b23809d6f Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sat, 2 Aug 2025 16:15:03 +0200 Subject: [PATCH 4/8] Improve sampling API --- cebra/data/base.py | 15 +++- cebra/data/multi_session.py | 15 ++-- cebra/data/multiobjective.py | 26 +++--- cebra/data/single_session.py | 77 +++++++---------- tests/test_loader.py | 158 ++++++++++++++++++++++------------- 5 files changed, 165 insertions(+), 126 deletions(-) diff --git a/cebra/data/base.py b/cebra/data/base.py index 92f82a6f..4ab73f7c 100644 --- a/cebra/data/base.py +++ b/cebra/data/base.py @@ -22,6 +22,7 @@ """Base classes for datasets and loaders.""" import abc +from typing import Iterator import literate_dataclasses as dataclasses import torch @@ -254,19 +255,25 @@ def __post_init__(self): raise ValueError( f"Batch size has to be None, or a non-negative value. Got {self.batch_size}." ) + if self.num_negatives is not None and self.num_negatives <= 0: + raise ValueError( + f"Number of negatives has to be None, or a non-negative value. Got {self.num_negatives}." + ) + + if self.num_negatives is None: + self.num_negatives = self.batch_size def __len__(self): """The number of batches returned when calling as an iterator.""" return self.num_steps - def __iter__(self) -> Batch: + def __iter__(self) -> Iterator[Batch]: for _ in range(len(self)): - index = self.get_indices(num_samples=self.batch_size, - num_negatives=self.num_negatives) + index = self.get_indices() yield self.dataset.load_batch(index) @abc.abstractmethod - def get_indices(self, num_samples: int, num_negatives: int = None): + def get_indices(self): """Sample and return the specified number of indices. The elements of the returned `BatchIndex` will be used to index the diff --git a/cebra/data/multi_session.py b/cebra/data/multi_session.py index e893fa77..62aca06d 100644 --- a/cebra/data/multi_session.py +++ b/cebra/data/multi_session.py @@ -160,9 +160,7 @@ def __post_init__(self): # NOTE(stes): In the longer run, we need to unify the API here; the num_samples argument # is not used in the multi-session case, which is different to the single session samples. - def get_indices(self, - num_samples: int, - num_negatives: int = None) -> List[BatchIndex]: + def get_indices(self) -> List[BatchIndex]: ref_idx = self.sampler.sample_prior(self.batch_size) neg_idx = self.sampler.sample_prior(self.num_negatives) pos_idx, idx, idx_rev = self.sampler.sample_conditional(ref_idx) @@ -238,9 +236,14 @@ def __post_init__(self): self.sampler = cebra.distributions.UnifiedSampler( self.dataset, self.time_offset) - def get_indices(self, - num_samples: int, - num_negatives: int = None) -> BatchIndex: + if self.batch_size < 2: + raise ValueError("UnifiedLoader does not support batch_size < 2.") + + if self.num_negatives < 2: + raise ValueError( + "UnifiedLoader does not support num_negatives < 2.") + + def get_indices(self) -> BatchIndex: """Sample and return the specified number of indices. The elements of the returned ``BatchIndex`` will be used to index the diff --git a/cebra/data/multiobjective.py b/cebra/data/multiobjective.py index e7018f08..4ccfb635 100644 --- a/cebra/data/multiobjective.py +++ b/cebra/data/multiobjective.py @@ -20,10 +20,13 @@ # limitations under the License. # +from typing import Iterator + import literate_dataclasses as dataclasses import cebra.data as cebra_data import cebra.distributions +from cebra.data.datatypes import Batch from cebra.data.datatypes import BatchIndex from cebra.distributions.continuous import Prior @@ -71,9 +74,9 @@ def __post_init__(self): def add_config(self, config): self.labels.append(config['label']) - def get_indices(self, num_samples: int, num_negatives: int = None): + def get_indices(self) -> BatchIndex: if self.sampling_mode_supervised == "ref_shared": - reference_idx = self.prior.sample_prior(num_samples) + reference_idx = self.prior.sample_prior(self.batch_size) else: raise ValueError( f"Sampling mode {self.sampling_mode_supervised} is not implemented." @@ -87,9 +90,9 @@ def get_indices(self, num_samples: int, num_negatives: int = None): return batch_index - def __iter__(self): + def __iter__(self) -> Iterator[Batch]: for _ in range(len(self)): - index = self.get_indices(num_samples=self.batch_size) + index = self.get_indices() yield self.dataset.load_batch_supervised(index, self.labels) @@ -142,16 +145,14 @@ def add_config(self, config): self.distributions.append(distribution) - def get_indices(self, num_samples: int, num_negatives: int = None): + def get_indices(self) -> BatchIndex: """Sample and return the specified number of indices.""" - if num_negatives is None: - num_negatives = num_samples - if self.sampling_mode_contrastive == "refneg_shared": - ref_and_neg = self.prior.sample_prior(num_samples + num_negatives) - reference_idx = ref_and_neg[:num_samples] - negative_idx = ref_and_neg[num_samples:] + ref_and_neg = self.prior.sample_prior(self.batch_size + + self.num_negatives) + reference_idx = ref_and_neg[:self.batch_size] + negative_idx = ref_and_neg[self.batch_size:] positives_idx = [] for distribution in self.distributions: @@ -172,6 +173,5 @@ def get_indices(self, num_samples: int, num_negatives: int = None): def __iter__(self): for _ in range(len(self)): - index = self.get_indices(num_samples=self.batch_size, - num_negatives=self.num_negatives) + index = self.get_indices() yield self.dataset.load_batch_contrastive(index) diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index 0e59183b..2daef64f 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -27,6 +27,7 @@ import abc import warnings +from typing import Iterator import literate_dataclasses as dataclasses import torch @@ -138,9 +139,7 @@ def _init_distribution(self): f"Invalid choice of prior distribution. Got '{self.prior}', but " f"only accept 'uniform' or 'empirical' as potential values.") - def get_indices(self, - num_samples: int, - num_negatives: int = None) -> BatchIndex: + def get_indices(self) -> BatchIndex: """Samples indices for reference, positive and negative examples. The reference samples will be sampled from the empirical or uniform prior @@ -161,13 +160,10 @@ def get_indices(self, Returns: Indices for reference, positive and negatives samples. """ - if num_negatives is None: - num_negatives = num_samples - - reference_idx = self.distribution.sample_prior(num_samples + - num_negatives) - negative_idx = reference_idx[num_samples:] - reference_idx = reference_idx[:num_samples] + reference_idx = self.distribution.sample_prior(self.batch_size + + self.num_negatives) + negative_idx = reference_idx[self.batch_size:] + reference_idx = reference_idx[:self.batch_size] reference = self.index[reference_idx] positive_idx = self.distribution.sample_conditional(reference) return BatchIndex(reference=reference_idx, @@ -253,9 +249,7 @@ def _init_distribution(self): else: raise ValueError(self.conditional) - def get_indices(self, - num_samples: int, - num_negatives: int = None) -> BatchIndex: + def get_indices(self) -> BatchIndex: """Samples indices for reference, positive and negative examples. The reference and negative samples will be sampled uniformly from @@ -271,13 +265,10 @@ def get_indices(self, Returns: Indices for reference, positive and negatives samples. """ - if num_negatives is None: - num_negatives = num_samples - - reference_idx = self.distribution.sample_prior(num_samples + - num_negatives) - negative_idx = reference_idx[num_samples:] - reference_idx = reference_idx[:num_samples] + reference_idx = self.distribution.sample_prior(self.batch_size + + self.num_negatives) + negative_idx = reference_idx[self.batch_size:] + reference_idx = reference_idx[:self.batch_size] positive_idx = self.distribution.sample_conditional(reference_idx) return BatchIndex(reference=reference_idx, positive=positive_idx, @@ -318,9 +309,7 @@ def __post_init__(self): continuous=self.cindex, time_delta=self.time_offset) - def get_indices(self, - num_samples: int, - num_negatives: int = None) -> BatchIndex: + def get_indices(self) -> BatchIndex: """Samples indices for reference, positive and negative examples. The reference and negative samples will be sampled uniformly from @@ -344,13 +333,10 @@ def get_indices(self, class. - Sample the negatives with matching discrete variable """ - if num_negatives is None: - num_negatives = num_samples - - reference_idx = self.distribution.sample_prior(num_samples + - num_negatives) - negative_idx = reference_idx[num_samples:] - reference_idx = reference_idx[:num_samples] + reference_idx = self.distribution.sample_prior(self.batch_size + + self.num_negatives) + negative_idx = reference_idx[self.batch_size:] + reference_idx = reference_idx[:self.batch_size] return BatchIndex( reference=reference_idx, negative=negative_idx, @@ -443,9 +429,7 @@ def _init_time_distribution(self): else: raise ValueError - def get_indices(self, - num_samples: int, - num_negatives: int = None) -> BatchIndex: + def get_indices(self) -> BatchIndex: """Samples indices for reference, positive and negative examples. The reference and negative samples will be sampled uniformly from @@ -469,13 +453,10 @@ def get_indices(self, Add the ``empirical`` vs. ``discrete`` sampling modes to this class. """ - if num_negatives is None: - num_negatives = num_samples - - reference_idx = self.time_distribution.sample_prior(num_samples + - num_negatives) - negative_idx = reference_idx[num_samples:] - reference_idx = reference_idx[:num_samples] + reference_idx = self.time_distribution.sample_prior(self.batch_size + + self.num_negatives) + negative_idx = reference_idx[self.batch_size:] + reference_idx = reference_idx[:self.batch_size] behavior_positive_idx = self.behavior_distribution.sample_conditional( reference_idx) time_positive_idx = self.time_distribution.sample_conditional( @@ -493,13 +474,18 @@ class FullDataLoader(ContinuousDataLoader): def __post_init__(self): super().__post_init__() - self.batch_size = None + + if self.batch_size is not None: + raise ValueError("Batch size cannot be set for FullDataLoader.") + if self.num_negatives is not None: + raise ValueError( + "Number of negatives cannot be set for FullDataLoader.") @property def offset(self): return self.dataset.offset - def get_indices(self, num_samples=None, num_negatives=None) -> BatchIndex: + def get_indices(self) -> BatchIndex: """Samples indices for reference, positive and negative examples. The reference indices are all available (valid, according to the @@ -519,8 +505,6 @@ def get_indices(self, num_samples=None, num_negatives=None) -> BatchIndex: Add the ``empirical`` vs. ``discrete`` sampling modes to this class. """ - assert num_samples is None - assert num_negatives is None reference_idx = torch.arange( self.offset.left, @@ -534,7 +518,6 @@ def get_indices(self, num_samples=None, num_negatives=None) -> BatchIndex: positive=positive_idx, negative=negative_idx) - def __iter__(self): + def __iter__(self) -> Iterator[BatchIndex]: for _ in range(len(self)): - index = self.get_indices(num_samples=self.batch_size) - yield index + yield self.get_indices() diff --git a/tests/test_loader.py b/tests/test_loader.py index cb6be9a7..8eaa9f4f 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -29,6 +29,43 @@ BATCH_SIZE = 32 NUMS_NEURAL = [3, 4, 5] +SINGLE_SESSION_LOADERS = [ + ("demo-discrete", cebra.data.DiscreteDataLoader), + ("demo-continuous", cebra.data.ContinuousDataLoader), + ("demo-mixed", cebra.data.MixedDataLoader), +] +MULTI_SESSION_LOADERS = [ + ("demo-continuous-multisession", + cebra.data.ContinuousMultiSessionDataLoader), + ("demo-discrete-multisession", cebra.data.DiscreteMultiSessionDataLoader), +] +LOADERS = SINGLE_SESSION_LOADERS + MULTI_SESSION_LOADERS + [ + ("demo-continuous-unified", cebra.data.UnifiedLoader), +] + + +def _setup_functional_loader_test(data_name, loader_initfunc, device, + batch_size, num_negatives): + data = cebra.datasets.init(data_name) + data.to(device) + if num_negatives == "do not pass": + loader = loader_initfunc(data, num_steps=10, batch_size=batch_size) + else: + loader = loader_initfunc(data, + num_steps=10, + batch_size=batch_size, + num_negatives=num_negatives) + + if num_negatives is None or num_negatives == "do not pass": + assert loader.num_negatives == batch_size + expected_num_negatives = batch_size + else: + assert loader.num_negatives == num_negatives + expected_num_negatives = num_negatives + + _assert_dataset_on_correct_device(loader, device) + + return loader, expected_num_negatives class LoadSpeed: @@ -135,16 +172,7 @@ def _to_str(val): @_util.parametrize_device -@pytest.mark.parametrize( - "data_name, loader_initfunc", - [ - ("demo-discrete", cebra.data.DiscreteDataLoader), - ("demo-continuous", cebra.data.ContinuousDataLoader), - ("demo-mixed", cebra.data.MixedDataLoader), - ("demo-continuous-multisession", cebra.data.MultiSessionLoader), - ("demo-continuous-unified", cebra.data.UnifiedLoader), - ], -) +@pytest.mark.parametrize("data_name, loader_initfunc", LOADERS) def test_device(data_name, loader_initfunc, device): if not torch.cuda.is_available(): pytest.skip("Test only possible with CUDA.") @@ -158,8 +186,7 @@ def test_device(data_name, loader_initfunc, device): assert loader.dataset == dataset _assert_device(loader.device, device) _assert_device(loader.dataset.device, device) - - _assert_device(loader.get_indices(10).reference.device, device) + _assert_device(loader.get_indices().reference.device, device) @_util.parametrize_device @@ -206,44 +233,34 @@ def _check_attributes(obj, is_list=False): @_util.parametrize_device -@pytest.mark.parametrize( - "data_name, loader_initfunc", - [ - ("demo-discrete", cebra.data.DiscreteDataLoader), - ("demo-continuous", cebra.data.ContinuousDataLoader), - ("demo-mixed", cebra.data.MixedDataLoader), - ], -) -def test_singlesession_loader(data_name, loader_initfunc, device): - data = cebra.datasets.init(data_name) - data.to(device) - loader = loader_initfunc(data, num_steps=10, batch_size=BATCH_SIZE) - _assert_dataset_on_correct_device(loader, device) +@pytest.mark.parametrize("data_name, loader_initfunc", SINGLE_SESSION_LOADERS) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("num_negatives", [None, 1, 32, "do not pass"]) +def test_singlesession_loader(data_name, loader_initfunc, device, batch_size, + num_negatives): - index = loader.get_indices(100) + loader, expected_num_negatives = _setup_functional_loader_test( + data_name, loader_initfunc, device, batch_size, num_negatives) + + index = loader.get_indices() _check_attributes(index) for batch in loader: _check_attributes(batch) - assert len(batch.positive) == BATCH_SIZE + assert len(batch.positive) == batch_size + assert len(batch.reference) == batch_size + assert len(batch.negative) == expected_num_negatives @_util.parametrize_device -@pytest.mark.parametrize( - "data_name, loader_initfunc", - [ - ("demo-continuous-multisession", - cebra.data.ContinuousMultiSessionDataLoader), - ("demo-discrete-multisession", - cebra.data.DiscreteMultiSessionDataLoader), - ], -) -def test_multisession_loader(data_name, loader_initfunc, device): - data = cebra.datasets.init(data_name) - data.to(device) - loader = loader_initfunc(data, num_steps=10, batch_size=BATCH_SIZE) +@pytest.mark.parametrize("data_name, loader_initfunc", MULTI_SESSION_LOADERS) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("num_negatives", [None, 1, 32, 33, "do not pass"]) +def test_multisession_loader(data_name, loader_initfunc, device, batch_size, + num_negatives): - _assert_dataset_on_correct_device(loader, device) + loader, expected_num_negatives = _setup_functional_loader_test( + data_name, loader_initfunc, device, batch_size, num_negatives) # Check the sampler assert hasattr(loader, "sampler") @@ -260,7 +277,7 @@ def test_multisession_loader(data_name, loader_initfunc, device): batch = next(iter(loader)) for i, n_neurons in enumerate(NUMS_NEURAL): - assert batch[i].reference.shape == (BATCH_SIZE, n_neurons, 10) + assert batch[i].reference.shape == (batch_size, n_neurons, 10) def _mix(array, idx): shape = array.shape @@ -276,18 +293,18 @@ def _process(batch, feature_dim=1): dim=0).repeat(1, 1, feature_dim) dummy_prediction = _process(batch, feature_dim=6) - assert dummy_prediction.shape == (3, BATCH_SIZE, 6) + assert dummy_prediction.shape == (3, batch_size, 6) _mix(dummy_prediction, batch[0].index) - index = loader.get_indices(100) - #print(index[0]) - #print(type(index)) + index = loader.get_indices() _check_attributes(index, is_list=False) for batch in loader: _check_attributes(batch, is_list=True) for session_batch in batch: - assert len(session_batch.positive) == BATCH_SIZE + assert len(session_batch.positive) == batch_size + assert len(session_batch.reference) == batch_size + assert len(session_batch.negative) == expected_num_negatives @_util.parametrize_device @@ -297,12 +314,14 @@ def _process(batch, feature_dim=1): ("demo-continuous-unified", cebra.data.UnifiedLoader), ], ) -def test_unified_loader(data_name, loader_initfunc, device): - data = cebra.datasets.init(data_name) - data.to(device) - loader = loader_initfunc(data, num_steps=10, batch_size=BATCH_SIZE) +# TODO(stes): unified sampler breaks for batch_size = 1; tested further below +@pytest.mark.parametrize("batch_size", [2, 32, 100]) +@pytest.mark.parametrize("num_negatives", [None, 2, 32, 33, "do not pass"]) +def test_unified_loader_sampler(data_name, loader_initfunc, device, batch_size, + num_negatives): - _assert_dataset_on_correct_device(loader, device) + loader, expected_num_negatives = _setup_functional_loader_test( + data_name, loader_initfunc, device, batch_size, num_negatives) # Check the sampler num_samples = 100 @@ -334,11 +353,38 @@ def test_unified_loader(data_name, loader_initfunc, device): pos_idx = loader.sampler.sample_conditional(all_ref_idx) assert pos_idx.shape == (len(NUMS_NEURAL), num_samples) + +@_util.parametrize_device +@pytest.mark.parametrize( + "data_name, loader_initfunc", + [ + ("demo-continuous-unified", cebra.data.UnifiedLoader), + ], +) +# TODO(stes): unified sampler breaks for batch_size = 1 +@pytest.mark.parametrize("batch_size", [1, 32, 100]) +@pytest.mark.parametrize("num_negatives", [None, 1, 32, 33, "do not pass"]) +def test_unified_loader(data_name, loader_initfunc, device, batch_size, + num_negatives): + + if batch_size == 1 or num_negatives == 1: + with pytest.raises(ValueError, + match=r"UnifiedLoader does not support .* < 2"): + _setup_functional_loader_test(data_name, loader_initfunc, device, + batch_size, num_negatives) + pytest.skip( + "UnifiedLoader does not support batch_size < 2 or num_negatives < 2." + ) + + loader, expected_num_negatives = _setup_functional_loader_test( + data_name, loader_initfunc, device, batch_size, num_negatives) + # Check the batch batch = next(iter(loader)) - assert batch.reference.shape == (BATCH_SIZE, sum(NUMS_NEURAL), 10) - assert batch.positive.shape == (BATCH_SIZE, sum(NUMS_NEURAL), 10) - assert batch.negative.shape == (BATCH_SIZE, sum(NUMS_NEURAL), 10) + assert batch.reference.shape == (batch_size, sum(NUMS_NEURAL), 10) + assert batch.positive.shape == (batch_size, sum(NUMS_NEURAL), 10) + assert batch.negative.shape == (expected_num_negatives, sum(NUMS_NEURAL), + 10) - index = loader.get_indices(100) + index = loader.get_indices() _check_attributes(index, is_list=False) From 6c2d55919e55ed94579d7013ac9909e4aa7abce3 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sat, 2 Aug 2025 16:25:51 +0200 Subject: [PATCH 5/8] add sklearn implementation --- cebra/data/multi_session.py | 4 ++-- cebra/integrations/sklearn/cebra.py | 5 +++++ tests/test_sklearn.py | 17 +++++++++++++++++ 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/cebra/data/multi_session.py b/cebra/data/multi_session.py index 62aca06d..c6561ee2 100644 --- a/cebra/data/multi_session.py +++ b/cebra/data/multi_session.py @@ -236,10 +236,10 @@ def __post_init__(self): self.sampler = cebra.distributions.UnifiedSampler( self.dataset, self.time_offset) - if self.batch_size < 2: + if self.batch_size is not None and self.batch_size < 2: raise ValueError("UnifiedLoader does not support batch_size < 2.") - if self.num_negatives < 2: + if self.num_negatives is not None and self.num_negatives < 2: raise ValueError( "UnifiedLoader does not support num_negatives < 2.") diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 98e56747..b2b5460d 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -501,6 +501,9 @@ class CEBRA(TransformerMixin, BaseEstimator): A Tuple of masking types and their corresponding required masking values. The keys are the names of the Mask instances and formatting should be ``((key, value), (key, value))``. |Default:| ``None``. + num_negatives (int): + The number of negative samples to use for training. If ``None``, the number of negative samples + will be set to the batch size. |Default:| ``None``. Example: @@ -576,6 +579,7 @@ def __init__( ), masking_kwargs: Tuple[Tuple[str, Union[float, List[float], Tuple[float, ...]]], ...] = None, + num_negatives: int = None, ): self.__dict__.update(locals()) @@ -728,6 +732,7 @@ def _prepare_loader(self, dataset: cebra.data.Dataset, max_iterations: int, dataset=dataset, batch_size=self.batch_size, num_steps=max_iterations, + num_negatives=self.num_negatives, ), extra_kwargs=dict( time_offsets=self.time_offsets, diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index c3d2095c..dfa09dad 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -1544,3 +1544,20 @@ def test_last_incomplete_batch_smaller_than_offset(): model.fit(train.neural, train.continuous) _ = model.transform(train.neural, batch_size=300) + + +@pytest.mark.parametrize("batch_size,num_negatives", [ + (None, None), + (100, None), + (100, 100), +]) +def test_num_negatives(batch_size, num_negatives): + train = cebra.data.TensorDataset(neural=np.random.rand(20111, 100), + continuous=np.random.rand(20111, 2)) + + model = cebra.CEBRA(max_iterations=2, + batch_size=batch_size, + num_negatives=num_negatives, + device="cpu") + model.fit(train.neural, train.continuous) + _ = model.transform(train.neural) From 0dba5fcb615b0e447782ffa73c6ef052d5257c90 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sat, 2 Aug 2025 16:34:29 +0200 Subject: [PATCH 6/8] Update deprecation note --- cebra/data/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cebra/data/base.py b/cebra/data/base.py index 4ab73f7c..68f38d66 100644 --- a/cebra/data/base.py +++ b/cebra/data/base.py @@ -288,7 +288,8 @@ def get_indices(self): Note: - From version 0.7.0 onwards, `num_negatives` parameter was added to allow - specifying a different number of negative samples compared to the batch size. + From version 0.7.0 onwards, specifying the ``num_samples`` and + ``num_negatives`` directly was deprecated. Please set these + variables via the attributes ``batch_size`` and ``num_negatives``. """ raise NotImplementedError() From e259e45fe05e1763a282a5385f79216669cea769 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sat, 2 Aug 2025 16:38:29 +0200 Subject: [PATCH 7/8] update deprecation note --- cebra/data/base.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/cebra/data/base.py b/cebra/data/base.py index 68f38d66..518f7b4f 100644 --- a/cebra/data/base.py +++ b/cebra/data/base.py @@ -242,8 +242,8 @@ class Loader(abc.ABC, cebra.io.HasDevice): num_negatives: int = dataclasses.field( default=None, - doc="""The number of negative samples to draw for each reference. - If not specified, the batch size is used.""" + doc=("The number of negative samples to draw for each reference. " + "If not specified, the batch size is used."), ) def __post_init__(self): @@ -273,23 +273,23 @@ def __iter__(self) -> Iterator[Batch]: yield self.dataset.load_batch(index) @abc.abstractmethod - def get_indices(self): + def get_indices(self, num_samples: int = None): """Sample and return the specified number of indices. The elements of the returned `BatchIndex` will be used to index the `dataset` of this data loader. Args: - num_samples: The size of each of the reference, positive and - negative samples. + num_samples: Deprecated. Use ``batch_size`` on the instance level + instead. Returns: batch indices for the reference, positive and negative sample. - Note: - From version 0.7.0 onwards, specifying the ``num_samples`` and - ``num_negatives`` directly was deprecated. Please set these - variables via the attributes ``batch_size`` and ``num_negatives``. + From version 0.7.0 onwards, specifying the ``num_samples`` + directly is deprecated and will be removed in version 0.8.0. + Please set ``batch_size`` and ``num_negatives`` on the instance + level instead. """ raise NotImplementedError() From f2af3b61b1810fe0b11da6e28ed048fd08ad7366 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sat, 2 Aug 2025 17:09:45 +0200 Subject: [PATCH 8/8] Update GoF computation --- cebra/integrations/sklearn/cebra.py | 7 +++++++ cebra/integrations/sklearn/metrics.py | 14 +++++--------- tests/test_sklearn_metrics.py | 21 ++++++++++++++++----- 3 files changed, 28 insertions(+), 14 deletions(-) diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index b2b5460d..25ee6e05 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -596,6 +596,13 @@ def num_sessions(self) -> Optional[int]: """ return self.num_sessions_ + @property + def num_negatives_(self) -> int: + """The number of negative examples.""" + if self.num_negatives is None: + return self.batch_size + return self.num_negatives + @property def state_dict_(self) -> dict: return self.solver_.state_dict() diff --git a/cebra/integrations/sklearn/metrics.py b/cebra/integrations/sklearn/metrics.py index d8fd791d..d072b0ae 100644 --- a/cebra/integrations/sklearn/metrics.py +++ b/cebra/integrations/sklearn/metrics.py @@ -100,12 +100,12 @@ def infonce_loss( solver.to(cebra_model.device_) avg_loss = solver.validation(loader=loader, session_id=session_id) if correct_by_batchsize: - if cebra_model.batch_size is None: + if cebra_model.num_negatives_ is None: raise ValueError( "Batch size is None, please provide a model with a batch size to correct the InfoNCE." ) else: - avg_loss = avg_loss - np.log(cebra_model.batch_size) + avg_loss = avg_loss - np.log(cebra_model.num_negatives_) return avg_loss @@ -211,7 +211,7 @@ def infonce_to_goodness_of_fit( Args: infonce: The InfoNCE loss, either a single value or an iterable of values. model: The trained CEBRA model. - batch_size: The batch size used to train the model. + batch_size: The batch size (or number of negatives, if different from the batch size) used to train the model. num_sessions: The number of sessions used to train the model. Returns: @@ -228,19 +228,15 @@ def infonce_to_goodness_of_fit( ) if not hasattr(model, "state_dict_"): raise RuntimeError("Fit the CEBRA model first.") - if model.batch_size is None: + if model.num_negatives_ is None: raise ValueError( "Computing the goodness of fit is not yet supported for " "models trained on the full dataset (batchsize = None). ") - batch_size = model.batch_size + batch_size = model.num_negatives_ num_sessions = model.num_sessions_ if num_sessions is None: num_sessions = 1 - if model.batch_size is None: - raise ValueError( - "Computing the goodness of fit is not yet supported for " - "models trained on the full dataset (batchsize = None). ") else: if batch_size is None or num_sessions is None: raise ValueError( diff --git a/tests/test_sklearn_metrics.py b/tests/test_sklearn_metrics.py index 10c62453..bb71b1f1 100644 --- a/tests/test_sklearn_metrics.py +++ b/tests/test_sklearn_metrics.py @@ -482,14 +482,22 @@ def _fit_and_get_history(X, y): @pytest.mark.parametrize("seed", [42, 24, 10]) -def test_infonce_to_goodness_of_fit(seed): +@pytest.mark.parametrize("batch_size", [100, 200]) +@pytest.mark.parametrize("num_negatives", [None, 100, 200]) +def test_infonce_to_goodness_of_fit(seed, batch_size, num_negatives): """Test the conversion from InfoNCE loss to goodness of fit metric.""" + nats_to_bits = np.log2(np.e) + # Test with model cebra_model = cebra_sklearn_cebra.CEBRA( model_architecture="offset10-model", max_iterations=5, - batch_size=128, + batch_size=batch_size, + num_negatives=num_negatives, ) + if num_negatives is None: + num_negatives = batch_size + generator = torch.Generator().manual_seed(seed) X = torch.rand(1000, 50, dtype=torch.float32, generator=generator) cebra_model.fit(X) @@ -498,6 +506,7 @@ def test_infonce_to_goodness_of_fit(seed): gof = cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, model=cebra_model) assert isinstance(gof, float) + assert np.isclose(gof, (np.log(num_negatives) - 1.0) * nats_to_bits) # Test array of values infonce_values = np.array([1.0, 2.0, 3.0]) @@ -505,12 +514,14 @@ def test_infonce_to_goodness_of_fit(seed): infonce_values, model=cebra_model) assert isinstance(gof_array, np.ndarray) assert gof_array.shape == infonce_values.shape + assert np.allclose(gof_array, + (np.log(num_negatives) - infonce_values) * nats_to_bits) # Test with explicit batch_size and num_sessions - gof = cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, - batch_size=128, - num_sessions=1) + gof = cebra_sklearn_metrics.infonce_to_goodness_of_fit( + 1.0, batch_size=batch_size, num_sessions=1) assert isinstance(gof, float) + assert np.isclose(gof, (np.log(batch_size) - 1.0) * nats_to_bits) # Test error cases with pytest.raises(ValueError, match="batch_size.*should not be provided"):