Skip to content

Commit 0be4457

Browse files
ConchylicultorThe kauldron Authors
authored and
The kauldron Authors
committed
Do not re-randomize when shuffle=False
PiperOrigin-RevId: 667939794
1 parent 53da6eb commit 0be4457

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

kauldron/data/kmix/mixture.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,15 @@ class SampleFromDatasets(base.TFDataPipeline):
3737
stop_on_empty_dataset: If True, the iteration will stop on the first empty
3838
dataset.
3939
rerandomize_each_iteration: If True, the mixture will reshuffle the datasets
40-
each time it is iterated over.
40+
each time it is iterated over. By default (`None`), only reshuffle when
41+
all the sub-datasets are also shuffled.
4142
"""
4243

4344
datasets: list[base.TFDataPipeline]
4445
_: dataclasses.KW_ONLY
4546
weights: None | list[float | int] = None
4647
stop_on_empty_dataset: bool = False
47-
rerandomize_each_iteration: bool = True
48+
rerandomize_each_iteration: bool | None = None
4849

4950
def __post_init__(self):
5051
if not all(isinstance(ds, base.TFDataPipeline) for ds in self.datasets):
@@ -60,12 +61,19 @@ def ds_for_current_process(self, rng: random.PRNGKey) -> tf.data.Dataset:
6061
for i, ds in enumerate(self.datasets)
6162
]
6263

64+
# Do not re-randomize if all datasets are in eval mode (shuffle=False)
65+
# This is so iterating twice on evaluation yield examples in deterministic
66+
# order (due to an issue with `tf.data`: b/362450807).
67+
rerandomize_each_iteration = any(
68+
not hasattr(ds, 'shuffle') or ds.shuffle for ds in self.datasets
69+
)
70+
6371
ds = tf.data.Dataset.sample_from_datasets(
6472
datasets,
6573
weights=self.weights,
6674
seed=int(rng.fold_in('sample_from_datasets').bits()),
6775
stop_on_empty_dataset=self.stop_on_empty_dataset,
68-
rerandomize_each_iteration=self.rerandomize_each_iteration,
76+
rerandomize_each_iteration=rerandomize_each_iteration,
6977
)
7078
return ds
7179

0 commit comments

Comments
 (0)