From 9a37ad27c407c5762c50f5a1615b423e2d340407 Mon Sep 17 00:00:00 2001 From: samsja <55492238+samsja@users.noreply.github.com> Date: Mon, 21 Oct 2024 18:35:36 -0700 Subject: [PATCH] add skip dataset (#126) * add skip dataset * add skip dataset * add skip dataset --- configs/test.toml | 18 ++++++++ scripts/skip_data.py | 87 ++++++++++++++++++++++++++++++++++++++ scripts/worker7 | 7 +++ src/zeroband/checkpoint.py | 12 ++++-- 4 files changed, 120 insertions(+), 4 deletions(-) create mode 100644 configs/test.toml create mode 100644 scripts/skip_data.py create mode 100644 scripts/worker7 diff --git a/configs/test.toml b/configs/test.toml new file mode 100644 index 00000000..99b6c832 --- /dev/null +++ b/configs/test.toml @@ -0,0 +1,18 @@ +name_model = "debugmodel" +project = "debug_150m_zero_band" +type_model = "llama2" + +[train] +micro_bs = 4 # change this base on the gpu + +[data] +seq_length = 8192 +dataset_name_or_paths = "PrimeIntellect/fineweb-edu,PrimeIntellect/fineweb,PrimeIntellect/StackV1-popular,mlfoundations/dclm-baseline-1.0-parquet,open-web-math/open-web-math" +dataset_ratio = "55:10:20:10:5" +num_workers = 8 + +[optim] +batch_size = 128 +warmup_steps = 1000 +total_steps = 88_000 +lr = 4e-4 \ No newline at end of file diff --git a/scripts/skip_data.py b/scripts/skip_data.py new file mode 100644 index 00000000..4301ffc5 --- /dev/null +++ b/scripts/skip_data.py @@ -0,0 +1,87 @@ +""" +This script is simulating a training to exaust the datasets and recover the dataloader ckpt. + +It has the same api as the training one. The only difference is that you probably want to change the total_steps and put a data_path. + +It can load config from the config file to have the same setup as the real run. + +example. +``` +uv run torchrun --nproc_per_node=4 scripts/skip_data.py @configs/150M/3090.toml --optim.total_steps 100 --ckpt.data_path out_data +``` + +""" + +import os +import torch +from pydantic_config import parse_argv + + +from transformers import AutoTokenizer +from zeroband.checkpoint import CkptManager + +from zeroband.train import Config + +from zeroband.data import get_dataloader + +from zeroband.utils.world_info import get_world_info +from zeroband.utils.logging import get_logger + + +def skip_data(config: Config): + # batch_size is the total batch size for all GPUs + assert config.optim.batch_size % world_info.local_world_size == 0 + batch_size = config.optim.batch_size // world_info.local_world_size + + assert batch_size % config.train.micro_bs == 0 + gradient_accumulation_steps = batch_size // config.train.micro_bs + + if config.type_model == "llama2": + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True) + elif config.type_model == "llama3": + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", use_fast=True) + else: + raise ValueError(f"Model type {config.type_model} not supported") + + logger.debug("tokenizer loaded") + + train_dataloader = get_dataloader( + tokenizer=tokenizer, + world_size=world_info.world_size, + rank=world_info.rank, + batch_size=config.train.micro_bs, + data_config=config.data, + ) + + train_dataloader_iterator = iter(train_dataloader) + + logger.info("starting skipping data up to step: %d", config.optim.total_steps) + + total_steps = 0 + + while True: + num_inner_steps = config.diloco.inner_steps if config.diloco is not None else 1 + + for _inner_step in range(num_inner_steps): + for _ in range(gradient_accumulation_steps): + next(train_dataloader_iterator) + + total_steps += num_inner_steps + logger.info("total steps: %d", total_steps) + if total_steps >= config.optim.total_steps: + break + + CkptManager.save_data_v2(os.path.join(config.ckpt.data_path, "data"), train_dataloader, world_info.local_rank) + + logger.info("skipped data up to step: %d", config.optim.total_steps) + + +if __name__ == "__main__": + torch.manual_seed(42) + + world_info = get_world_info() + logger = get_logger() + + config = Config(**parse_argv()) + + skip_data(config) diff --git a/scripts/worker7 b/scripts/worker7 new file mode 100644 index 00000000..683564a1 --- /dev/null +++ b/scripts/worker7 @@ -0,0 +1,7 @@ +GLOBAL_RANK=3 +GLOBAL_UNIQUE_ID=worker7 +GLOBAL_ADDR=100.91.104.70 +GLOBAL_PORT=8989 +GLOO_SOCKET_IFNAME=tailscale0 +GLOBAL_WORLD_SIZE=10 +ZERO_BAND_GLOBAL_PG_TIMEOUT_SECONDS=3600 \ No newline at end of file diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py index a0cecdf0..46bfd441 100644 --- a/src/zeroband/checkpoint.py +++ b/src/zeroband/checkpoint.py @@ -361,10 +361,7 @@ def _save(self, ckpt_path: str): if self.config.data_version == "v2": data_path = os.path.join(ckpt_path, "data") - os.makedirs(data_path, exist_ok=True) - with open(os.path.join(data_path, f"_{self.world_info.local_rank}.pt"), "wb") as f: - state = {"data_loader": self.dataloader.state_dict()} - torch.save(state, f) + self.save_data_v2(data_path, self.dataloader, self.world_info.local_rank) non_error_barrier() @@ -381,6 +378,13 @@ def _save(self, ckpt_path: str): gc.collect() + @staticmethod + def save_data_v2(data_path: str, dataloader, local_rank: int): + os.makedirs(data_path, exist_ok=True) + with open(os.path.join(data_path, f"_{local_rank}.pt"), "wb") as f: + state = {"data_loader": dataloader.state_dict()} + torch.save(state, f) + def _async_save_remote(self, ckpt_path: str, remote_ckpt_path: str, blocking: bool = True) -> None: """asyncronously rsync a ckpt folder to a remote location. Using fsspec to handle remote cloud storage without to install specific libraries (e.g. s3fs).