From 0be4457565191322d7eeb44fbcdfb7f1bdac24b5 Mon Sep 17 00:00:00 2001 From: Etienne Pot Date: Tue, 27 Aug 2024 05:03:39 -0700 Subject: [PATCH] Do not re-randomize when shuffle=False PiperOrigin-RevId: 667939794 --- kauldron/data/kmix/mixture.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/kauldron/data/kmix/mixture.py b/kauldron/data/kmix/mixture.py index 32b6baa8..381c442a 100644 --- a/kauldron/data/kmix/mixture.py +++ b/kauldron/data/kmix/mixture.py @@ -37,14 +37,15 @@ class SampleFromDatasets(base.TFDataPipeline): stop_on_empty_dataset: If True, the iteration will stop on the first empty dataset. rerandomize_each_iteration: If True, the mixture will reshuffle the datasets - each time it is iterated over. + each time it is iterated over. By default (`None`), only reshuffle when + all the sub-datasets are also shuffled. """ datasets: list[base.TFDataPipeline] _: dataclasses.KW_ONLY weights: None | list[float | int] = None stop_on_empty_dataset: bool = False - rerandomize_each_iteration: bool = True + rerandomize_each_iteration: bool | None = None def __post_init__(self): if not all(isinstance(ds, base.TFDataPipeline) for ds in self.datasets): @@ -60,12 +61,19 @@ def ds_for_current_process(self, rng: random.PRNGKey) -> tf.data.Dataset: for i, ds in enumerate(self.datasets) ] + # Do not re-randomize if all datasets are in eval mode (shuffle=False) + # This is so iterating twice on evaluation yield examples in deterministic + # order (due to an issue with `tf.data`: b/362450807). + rerandomize_each_iteration = any( + not hasattr(ds, 'shuffle') or ds.shuffle for ds in self.datasets + ) + ds = tf.data.Dataset.sample_from_datasets( datasets, weights=self.weights, seed=int(rng.fold_in('sample_from_datasets').bits()), stop_on_empty_dataset=self.stop_on_empty_dataset, - rerandomize_each_iteration=self.rerandomize_each_iteration, + rerandomize_each_iteration=rerandomize_each_iteration, ) return ds