From f0d8b98a5545c41e820c9bf26f36493c92d182f7 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 3 Dec 2024 21:49:08 +0000 Subject: [PATCH] add data rank split --- configs/10B/H100.toml | 2 ++ configs/10B/H100_cooldown.toml | 1 + src/zeroband/data.py | 19 ++++++++++++++++--- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/configs/10B/H100.toml b/configs/10B/H100.toml index 70505d4d..d58e6098 100644 --- a/configs/10B/H100.toml +++ b/configs/10B/H100.toml @@ -25,6 +25,8 @@ dataset_name_or_paths = "/data/datasets/fineweb-edu,/data/datasets/fineweb,/data dataset_ratio = "55:10:20:10:5" num_workers = 4 reverse_data_files = true +split_by_data_rank = false # the 10b training assume that data was already split by datarank. Keeping this for backward compatibility + [diloco] inner_steps = 100 diff --git a/configs/10B/H100_cooldown.toml b/configs/10B/H100_cooldown.toml index c01d6db4..9132b1e8 100644 --- a/configs/10B/H100_cooldown.toml +++ b/configs/10B/H100_cooldown.toml @@ -26,6 +26,7 @@ dataset_name_or_paths = "/data/datasets/fineweb-edu,/data/datasets/fineweb,/data dataset_ratio = "80:10:10" num_workers = 4 reverse_data_files = false +split_by_data_rank = false # the 10b training assume that data was already split by datarank. Keeping this for backward compatibility [diloco] inner_steps = 100 diff --git a/src/zeroband/data.py b/src/zeroband/data.py index c34b1118..5a6fa621 100644 --- a/src/zeroband/data.py +++ b/src/zeroband/data.py @@ -36,6 +36,7 @@ class DataConfig(BaseConfig): data_rank: Optional[int] = None data_world_size: Optional[int] = None reverse_data_files: bool = False + split_by_data_rank = True class FakeTokenizedDataset(IterableDataset): @@ -393,14 +394,26 @@ def _get_probabilities(data_config: DataConfig) -> Optional[List[float]]: def load_all_datasets( - data_config: DataConfig, split: str, tokenizer: PreTrainedTokenizer, rank: int, world_size: int + data_config: DataConfig, + split: str, + tokenizer: PreTrainedTokenizer, + rank: int, + world_size: int, ) -> InterleaveDataset: """Load all datasets and interleave them""" + + if data_config.split_by_data_rank: + split_rank = data_config.data_rank * world_size + rank + split_world_size = data_config.data_world_size * world_size + else: + split_rank = rank + split_world_size = world_size + ds = _load_datasets( dataset_names=data_config.dataset_name_or_paths, split=split, - data_rank=rank, - data_world_size=world_size, + data_rank=split_rank, + data_world_size=split_world_size, probabilities=_get_probabilities(data_config), reverse_data_files=data_config.reverse_data_files, tokenizer=tokenizer,