@@ -37,14 +37,15 @@ class SampleFromDatasets(base.TFDataPipeline):
37
37
stop_on_empty_dataset: If True, the iteration will stop on the first empty
38
38
dataset.
39
39
rerandomize_each_iteration: If True, the mixture will reshuffle the datasets
40
- each time it is iterated over.
40
+ each time it is iterated over. By default (`None`), only reshuffle when
41
+ all the sub-datasets are also shuffled.
41
42
"""
42
43
43
44
datasets : list [base .TFDataPipeline ]
44
45
_ : dataclasses .KW_ONLY
45
46
weights : None | list [float | int ] = None
46
47
stop_on_empty_dataset : bool = False
47
- rerandomize_each_iteration : bool = True
48
+ rerandomize_each_iteration : bool | None = None
48
49
49
50
def __post_init__ (self ):
50
51
if not all (isinstance (ds , base .TFDataPipeline ) for ds in self .datasets ):
@@ -60,12 +61,19 @@ def ds_for_current_process(self, rng: random.PRNGKey) -> tf.data.Dataset:
60
61
for i , ds in enumerate (self .datasets )
61
62
]
62
63
64
+ # Do not re-randomize if all datasets are in eval mode (shuffle=False)
65
+ # This is so iterating twice on evaluation yield examples in deterministic
66
+ # order (due to an issue with `tf.data`: b/362450807).
67
+ rerandomize_each_iteration = any (
68
+ not hasattr (ds , 'shuffle' ) or ds .shuffle for ds in self .datasets
69
+ )
70
+
63
71
ds = tf .data .Dataset .sample_from_datasets (
64
72
datasets ,
65
73
weights = self .weights ,
66
74
seed = int (rng .fold_in ('sample_from_datasets' ).bits ()),
67
75
stop_on_empty_dataset = self .stop_on_empty_dataset ,
68
- rerandomize_each_iteration = self . rerandomize_each_iteration ,
76
+ rerandomize_each_iteration = rerandomize_each_iteration ,
69
77
)
70
78
return ds
71
79
0 commit comments