diff --git a/src/zeroband/data.py b/src/zeroband/data.py index 662e1c1c..b4689482 100644 --- a/src/zeroband/data.py +++ b/src/zeroband/data.py @@ -32,6 +32,7 @@ class DataConfig(BaseConfig): dataset_ratio: Optional[str] = None data_rank: Optional[int] = None data_world_size: Optional[int] = None + reverse_data_files: bool = False class FakeTokenizedDataset(IterableDataset): @@ -194,6 +195,7 @@ def _load_datasets( data_world_size: Optional[int] = None, streaming: bool = True, probabilities: Optional[List[float]] = None, + reverse_data_files: bool = False, ) -> Dataset: logger.debug(dataset_names) ds_args = [] @@ -202,8 +204,11 @@ def _load_datasets( _ds_args = {"path": _ds_name} if _ds_config: _ds_args["name"] = _ds_config + _data_files = _get_datafiles(_ds_name, _ds_config, split) + if reverse_data_files: + _data_files = _data_files[::-1] + _ds_args["data_files"] = _data_files if data_rank is not None and data_world_size is not None: - _data_files = _get_datafiles(_ds_name, _ds_config, split) _ds_args["data_files"] = _data_files[data_rank::data_world_size] ds_args.append(_ds_args) @@ -244,6 +249,7 @@ def load_all_datasets(data_config: DataConfig, split: str, max_samples: Optional data_world_size=data_config.data_world_size, streaming=data_config.streaming, probabilities=_get_probabilities(data_config), + reverse_data_files=data_config.reverse_data_files, ) if max_samples is not None and data_config.streaming: if data_config.max_train_samples is not None: