diff --git a/configs/10B/H100.toml b/configs/10B/H100.toml index 70505d4..d58e609 100644 --- a/configs/10B/H100.toml +++ b/configs/10B/H100.toml @@ -25,6 +25,8 @@ dataset_name_or_paths = "/data/datasets/fineweb-edu,/data/datasets/fineweb,/data dataset_ratio = "55:10:20:10:5" num_workers = 4 reverse_data_files = true +split_by_data_rank = false # the 10b training assume that data was already split by datarank. Keeping this for backward compatibility + [diloco] inner_steps = 100 diff --git a/configs/10B/H100_cooldown.toml b/configs/10B/H100_cooldown.toml index c01d6db..9132b1e 100644 --- a/configs/10B/H100_cooldown.toml +++ b/configs/10B/H100_cooldown.toml @@ -26,6 +26,7 @@ dataset_name_or_paths = "/data/datasets/fineweb-edu,/data/datasets/fineweb,/data dataset_ratio = "80:10:10" num_workers = 4 reverse_data_files = false +split_by_data_rank = false # the 10b training assume that data was already split by datarank. Keeping this for backward compatibility [diloco] inner_steps = 100 diff --git a/src/zeroband/data.py b/src/zeroband/data.py index c34b111..5a6fa62 100644 --- a/src/zeroband/data.py +++ b/src/zeroband/data.py @@ -36,6 +36,7 @@ class DataConfig(BaseConfig): data_rank: Optional[int] = None data_world_size: Optional[int] = None reverse_data_files: bool = False + split_by_data_rank = True class FakeTokenizedDataset(IterableDataset): @@ -393,14 +394,26 @@ def _get_probabilities(data_config: DataConfig) -> Optional[List[float]]: def load_all_datasets( - data_config: DataConfig, split: str, tokenizer: PreTrainedTokenizer, rank: int, world_size: int + data_config: DataConfig, + split: str, + tokenizer: PreTrainedTokenizer, + rank: int, + world_size: int, ) -> InterleaveDataset: """Load all datasets and interleave them""" + + if data_config.split_by_data_rank: + split_rank = data_config.data_rank * world_size + rank + split_world_size = data_config.data_world_size * world_size + else: + split_rank = rank + split_world_size = world_size + ds = _load_datasets( dataset_names=data_config.dataset_name_or_paths, split=split, - data_rank=rank, - data_world_size=world_size, + data_rank=split_rank, + data_world_size=split_world_size, probabilities=_get_probabilities(data_config), reverse_data_files=data_config.reverse_data_files, tokenizer=tokenizer,