Skip to content

Commit

Permalink
add skip dataset (#126)
Browse files Browse the repository at this point in the history
* add skip dataset

* add skip dataset

* add skip dataset
  • Loading branch information
samsja authored Oct 22, 2024
1 parent a185a60 commit 9a37ad2
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 4 deletions.
18 changes: 18 additions & 0 deletions configs/test.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name_model = "debugmodel"
project = "debug_150m_zero_band"
type_model = "llama2"

[train]
micro_bs = 4 # change this base on the gpu

[data]
seq_length = 8192
dataset_name_or_paths = "PrimeIntellect/fineweb-edu,PrimeIntellect/fineweb,PrimeIntellect/StackV1-popular,mlfoundations/dclm-baseline-1.0-parquet,open-web-math/open-web-math"
dataset_ratio = "55:10:20:10:5"
num_workers = 8

[optim]
batch_size = 128
warmup_steps = 1000
total_steps = 88_000
lr = 4e-4
87 changes: 87 additions & 0 deletions scripts/skip_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""
This script is simulating a training to exaust the datasets and recover the dataloader ckpt.
It has the same api as the training one. The only difference is that you probably want to change the total_steps and put a data_path.
It can load config from the config file to have the same setup as the real run.
example.
```
uv run torchrun --nproc_per_node=4 scripts/skip_data.py @configs/150M/3090.toml --optim.total_steps 100 --ckpt.data_path out_data
```
"""

import os
import torch
from pydantic_config import parse_argv


from transformers import AutoTokenizer
from zeroband.checkpoint import CkptManager

from zeroband.train import Config

from zeroband.data import get_dataloader

from zeroband.utils.world_info import get_world_info
from zeroband.utils.logging import get_logger


def skip_data(config: Config):
# batch_size is the total batch size for all GPUs
assert config.optim.batch_size % world_info.local_world_size == 0
batch_size = config.optim.batch_size // world_info.local_world_size

assert batch_size % config.train.micro_bs == 0
gradient_accumulation_steps = batch_size // config.train.micro_bs

if config.type_model == "llama2":
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True)
elif config.type_model == "llama3":
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", use_fast=True)
else:
raise ValueError(f"Model type {config.type_model} not supported")

logger.debug("tokenizer loaded")

train_dataloader = get_dataloader(
tokenizer=tokenizer,
world_size=world_info.world_size,
rank=world_info.rank,
batch_size=config.train.micro_bs,
data_config=config.data,
)

train_dataloader_iterator = iter(train_dataloader)

logger.info("starting skipping data up to step: %d", config.optim.total_steps)

total_steps = 0

while True:
num_inner_steps = config.diloco.inner_steps if config.diloco is not None else 1

for _inner_step in range(num_inner_steps):
for _ in range(gradient_accumulation_steps):
next(train_dataloader_iterator)

total_steps += num_inner_steps
logger.info("total steps: %d", total_steps)
if total_steps >= config.optim.total_steps:
break

CkptManager.save_data_v2(os.path.join(config.ckpt.data_path, "data"), train_dataloader, world_info.local_rank)

logger.info("skipped data up to step: %d", config.optim.total_steps)


if __name__ == "__main__":
torch.manual_seed(42)

world_info = get_world_info()
logger = get_logger()

config = Config(**parse_argv())

skip_data(config)
7 changes: 7 additions & 0 deletions scripts/worker7
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
GLOBAL_RANK=3
GLOBAL_UNIQUE_ID=worker7
GLOBAL_ADDR=100.91.104.70
GLOBAL_PORT=8989
GLOO_SOCKET_IFNAME=tailscale0
GLOBAL_WORLD_SIZE=10
ZERO_BAND_GLOBAL_PG_TIMEOUT_SECONDS=3600
12 changes: 8 additions & 4 deletions src/zeroband/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,10 +361,7 @@ def _save(self, ckpt_path: str):

if self.config.data_version == "v2":
data_path = os.path.join(ckpt_path, "data")
os.makedirs(data_path, exist_ok=True)
with open(os.path.join(data_path, f"_{self.world_info.local_rank}.pt"), "wb") as f:
state = {"data_loader": self.dataloader.state_dict()}
torch.save(state, f)
self.save_data_v2(data_path, self.dataloader, self.world_info.local_rank)

non_error_barrier()

Expand All @@ -381,6 +378,13 @@ def _save(self, ckpt_path: str):

gc.collect()

@staticmethod
def save_data_v2(data_path: str, dataloader, local_rank: int):
os.makedirs(data_path, exist_ok=True)
with open(os.path.join(data_path, f"_{local_rank}.pt"), "wb") as f:
state = {"data_loader": dataloader.state_dict()}
torch.save(state, f)

def _async_save_remote(self, ckpt_path: str, remote_ckpt_path: str, blocking: bool = True) -> None:
"""asyncronously rsync a ckpt folder to a remote location. Using fsspec to handle remote cloud storage without to install
specific libraries (e.g. s3fs).
Expand Down

0 comments on commit 9a37ad2

Please sign in to comment.