From a185a6059db6769e3a87c3ea23f8a7e7fbc9ac00 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Tue, 22 Oct 2024 08:06:26 +0800 Subject: [PATCH] hack: reverse data file option --- src/zeroband/data.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/zeroband/data.py b/src/zeroband/data.py index 662e1c1c..b4689482 100644 --- a/src/zeroband/data.py +++ b/src/zeroband/data.py @@ -32,6 +32,7 @@ class DataConfig(BaseConfig): dataset_ratio: Optional[str] = None data_rank: Optional[int] = None data_world_size: Optional[int] = None + reverse_data_files: bool = False class FakeTokenizedDataset(IterableDataset): @@ -194,6 +195,7 @@ def _load_datasets( data_world_size: Optional[int] = None, streaming: bool = True, probabilities: Optional[List[float]] = None, + reverse_data_files: bool = False, ) -> Dataset: logger.debug(dataset_names) ds_args = [] @@ -202,8 +204,11 @@ def _load_datasets( _ds_args = {"path": _ds_name} if _ds_config: _ds_args["name"] = _ds_config + _data_files = _get_datafiles(_ds_name, _ds_config, split) + if reverse_data_files: + _data_files = _data_files[::-1] + _ds_args["data_files"] = _data_files if data_rank is not None and data_world_size is not None: - _data_files = _get_datafiles(_ds_name, _ds_config, split) _ds_args["data_files"] = _data_files[data_rank::data_world_size] ds_args.append(_ds_args) @@ -244,6 +249,7 @@ def load_all_datasets(data_config: DataConfig, split: str, max_samples: Optional data_world_size=data_config.data_world_size, streaming=data_config.streaming, probabilities=_get_probabilities(data_config), + reverse_data_files=data_config.reverse_data_files, ) if max_samples is not None and data_config.streaming: if data_config.max_train_samples is not None: