From 272e90330932efebbcd10afe90eba2b72137161e Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sat, 28 Sep 2024 16:45:19 +0000 Subject: [PATCH 01/20] fix vocab size for debugmodel and real data --- src/zeroband/train.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 0d93d57a..3cccad1f 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -96,9 +96,7 @@ def train(config: Config): model, model_config = get_model( config.name_model, config.type_model, - vocab_size=tokenizer.vocab_size - if config.name_model != "debugmodel" or not config.data.fake - else TEST_VOCAB_SIZE, + vocab_size=tokenizer.vocab_size if config.name_model != "debugmodel" or not config.data.fake else TEST_VOCAB_SIZE, ) if config.train.log_model_hash: From 5b01f13ad4598bbec21d40d52ef55c31635f152e Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Fri, 27 Sep 2024 23:11:14 +0000 Subject: [PATCH 02/20] add torchdata --- pyproject.toml | 3 ++- uv.lock | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f5b1a711..18c7b780 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,8 @@ dependencies = [ "transformers>=4.44.2", "datasets>=3.0.0", "pydantic_config @ git+https://github.com/samsja/pydantic_config.git@e529c9c", - "einops" + "einops", + "torchdata>=0.8.0" ] [project.optional-dependencies] diff --git a/uv.lock b/uv.lock index f0ec766e..415fe509 100644 --- a/uv.lock +++ b/uv.lock @@ -1525,6 +1525,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ac/30/8b6f77ea4ce84f015ee024b8dfef0dac289396254e8bfd493906d4cbb848/torch-2.4.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:72b484d5b6cec1a735bf3fa5a1c4883d01748698c5e9cfdbeb4ffab7c7987e0d", size = 62123443 }, ] +[[package]] +name = "torchdata" +version = "0.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "requests" }, + { name = "torch" }, + { name = "urllib3" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/15/8a/3251c64214ab09d1c1756677f36e78f8cf0ce9dabb3a21386e78ef50540e/torchdata-0.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:082e27b2acb1768cb6a30ddd2f8d9c68e407164ce207194bf8bfa616d621a801", size = 4904801 }, + { url = "https://files.pythonhosted.org/packages/da/90/058fe345dfac8b50d2d0fdb421ce04c78c88b06a5f220dd8d64d424ccdbe/torchdata-0.8.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:44f7875a62f3fab86e2f8e5af92c4929f8f7390aa17bd697fdd0965723bc1e98", size = 2691733 }, + { url = "https://files.pythonhosted.org/packages/2f/54/d6f64a6e210ee50b68220d3b5564ffdda8bcc8d26c02a39a8a587caffe2f/torchdata-0.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:07e985d69c5692dda9181a8ef3e14c7f08b0460226f7cd4cf1c1bb0e6975700f", size = 1341187 }, + { url = "https://files.pythonhosted.org/packages/82/aa/4da6c725b03fb51c5a10405803308afd43970e66aad45e8cca872786ba1b/torchdata-0.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1635cecf4226fec8539c5d06ba764a48c41363ea0bbea09407ab379828527a8b", size = 4904783 }, + { url = "https://files.pythonhosted.org/packages/64/e8/c691e8e73dc6cbb09ba84ffb0341a6466d3184ff422cda07ebade3b929ef/torchdata-0.8.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:2d63d3fdcc68cf912c81709c8704b9cf435ba89bceed41a365e7362eb5740394", size = 2691483 }, + { url = "https://files.pythonhosted.org/packages/2c/f6/438a82c2f8d69114ef943c0b58f69f66ea5249bd7b2e4799d44f185f7797/torchdata-0.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:8db8a7cb946e82983517cff94317f1898128cbfe4f48821d0c3509c0cdafa4c9", size = 1341021 }, + { url = "https://files.pythonhosted.org/packages/3e/f7/2d1cd02ebcca73ff151dd94b0a08d30808574d944a360470b52a89f0be4e/torchdata-0.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4e3f7efac3d8a4bd4efcb1427869c04043d0a0a019f9aa1eb381bd6c6b321e62", size = 4905186 }, + { url = "https://files.pythonhosted.org/packages/ea/94/d9ac51405d4259094dfa0a1dc3fa4ed2efe057d194873c9f1ba1881b06c9/torchdata-0.8.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:bb878e243e58526a5b3ac54583f7c029ad643a34ade798800c1878c83f1c36ee", size = 2691660 }, + { url = "https://files.pythonhosted.org/packages/d2/c4/623f7237c69606d202870bc9e44a8ed9070cc3eb1ac03f02c457083aa746/torchdata-0.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:7a43fc7e8d3ae2632859f15d5439cd97b83af559fecd8963e5f09e08f93b81e2", size = 1341201 }, +] + [[package]] name = "tqdm" version = "4.66.5" @@ -1797,6 +1818,7 @@ dependencies = [ { name = "pydantic-config" }, { name = "setuptools" }, { name = "torch" }, + { name = "torchdata" }, { name = "transformers" }, ] @@ -1820,6 +1842,7 @@ requires-dist = [ { name = "pydantic-config", git = "https://github.com/samsja/pydantic_config.git?rev=e529c9c" }, { name = "setuptools" }, { name = "torch", specifier = "==2.4.1" }, + { name = "torchdata", specifier = ">=0.8.0" }, { name = "transformers", specifier = ">=4.44.2" }, { name = "wandb", marker = "extra == 'all'" }, ] From b01eb90a4e6f15d42fb207a8b29b356fd15d6de5 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sat, 28 Sep 2024 04:00:48 +0000 Subject: [PATCH 03/20] add ckpt v0 --- src/zeroband/checkpoint.py | 270 ++++++++++++++++++++++++++++++ src/zeroband/data.py | 3 +- src/zeroband/train.py | 68 ++++++-- tests/test_torchrun/test_train.py | 13 ++ 4 files changed, 339 insertions(+), 15 deletions(-) create mode 100644 src/zeroband/checkpoint.py diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py new file mode 100644 index 00000000..3fe0d4ad --- /dev/null +++ b/src/zeroband/checkpoint.py @@ -0,0 +1,270 @@ +from dataclasses import dataclass +import time +from typing import Any +import torch +from torch import nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR +from torchdata.stateful_dataloader import StatefulDataLoader +import torch.distributed.checkpoint as dcp +from torch.distributed import ProcessGroup +from torch.distributed.checkpoint.state_dict import ( + set_optimizer_state_dict, + set_model_state_dict, + get_model_state_dict, + get_optimizer_state_dict, + StateDictOptions, +) +from torch.distributed.checkpoint.stateful import Stateful +from zeroband.utils.logging import get_logger +import warnings +import logging + +## code inspired by torchtitan https://github.com/pytorch/torchtitan/blob/main/torchtitan/checkpoint.py + + +GLOBAL_STATE_FILE = "global_state_dict.pt" + + +@dataclass +class TrainingProgress(Stateful): + total_tokens: int + outer_step: int + step: int + + def state_dict(self) -> dict[str, Any]: + return {"total_tokens": self.total_tokens, "outer_step": self.outer_step, "step": self.step} + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + self.total_tokens = state_dict["total_tokens"] + self.outer_step = state_dict["outer_step"] + self.step = state_dict["step"] + + +class ModelWrapper(Stateful): + def __init__(self, model: nn.Module) -> None: + self.model = model + + def state_dict(self) -> dict[str, Any]: + return get_model_state_dict(self.model) + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + set_model_state_dict(model=self.model, model_state_dict=state_dict, options=StateDictOptions(strict=False)) + + +class OptimizerWrapper(Stateful): + def __init__( + self, + model: nn.Module, + optim: torch.optim.Optimizer, + ) -> None: + self.model = model + self.optim = optim + + def state_dict(self) -> dict[str, Any]: + return get_optimizer_state_dict( + model=self.model, optimizers=self.optim, options=StateDictOptions(flatten_optimizer_state_dict=True) + ) + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + set_optimizer_state_dict( + model=self.model, optimizers=self.optim, optim_state_dict=state_dict, options=StateDictOptions(strict=False) + ) + + +class CkptManager: + """Its name CkptManager because I (sami) always misstyped chekcpoint.""" + + def __init__( + self, + model: nn.Module, + optimizer: Optimizer, + scheduler: LambdaLR, + dataloader: StatefulDataLoader, + training_progress: TrainingProgress, + process_group: ProcessGroup | None, + ): + self.model = ModelWrapper(model) + self.optimizer = OptimizerWrapper(model, optimizer) + self.scheduler = scheduler + self.dataloader = dataloader + self.training_progress = training_progress + + # states can only be stateful object, hence we need to wrap Model and Optimizer + self.states: dict[str, Stateful] = { + "model": self.model, + "optimizer": self.optimizer, + "scheduler": self.scheduler, + "dataloader": self.dataloader, + "training_progress": self.training_progress, + } + + self.process_group = process_group + self._logger = get_logger() + + def save(self, ckpt_path: str) -> None: + """ + Each rank will save the right shard of the model and optimizer. + + Saving is done inplace + """ + + time_start = time.perf_counter() + + catch_warning = self._logger.getEffectiveLevel() <= logging.INFO + # pytorch has an annoying warning when saving the optimizer state https://github.com/pytorch/pytorch/issues/136907 + # we can ignore it if we are not logging in DEBUG mode + + with warnings.catch_warnings(): + if catch_warning: + warnings.simplefilter("ignore") + + dcp.save(self.states, process_group=self.process_group, checkpoint_id=ckpt_path) + + self._logger.info(f"Saved checkpoint to {ckpt_path} in {time.perf_counter() - time_start} seconds") + + def load(self, resume_ckpt_path: str) -> None: + """ + loading should be done after fsdp wrap and optimizer init. + Each rank will load the right shard of the model and optimizer. + All rank will load the global states (scheduler, step, total_tokens, dataloader). + + Loading is done inplace + """ + time_start = time.perf_counter() + self.states = dcp.load(self.states, process_group=self.process_group, checkpoint_id=resume_ckpt_path) + self._logger.info(f"Loaded checkpoint from {resume_ckpt_path} in {time.perf_counter() - time_start} seconds") + + +# def save( +# checkpoint_path: str, +# model: torch.nn.Module, +# optimizer: torch.optim.Optimizer, +# scheduler: torch.optim.lr_scheduler.LambdaLR, +# outer_optimizer: torch.optim.Optimizer | None = None, +# scaler: torch.cuda.amp.GradScaler | None = None, +# loss: float | None = None, +# data_loader: StatefulDataLoader | None = None, +# save_global_state: bool = True, +# ): +# """Save the model and optimizer state to a checkpoint folderx + +# Args: +# checkpoint_path: the path to the checkpoint folder +# model: the model to save +# optimizer: the optimizer to save +# scheduler: the scheduler to save +# outer_optimizer: the outer optimizer to save +# loss: the loss to save +# data_loader: the data loader to save +# save_global_state: whether to save the global state +# """ +# rank = int(os.environ["RANK"]) + +# # 1. Save distributed states +# # fs_storage_writer = dcp.FsspecWriter(checkpoint_path, sync_files=False) +# # for some reason sync_files = True try to call stream.fileno which is not supported with gcp ffspec storage. + +# model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) +# dcp_state_dict = { +# "model": model_state_dict, +# "optimizer": optimizer_state_dict, +# } +# dcp.save(dcp_state_dict, checkpoint_id=checkpoint_path) +# if data_loader is not None: +# rank_state_dict = {} +# rank_state_dict["data_loader"] = data_loader.state_dict() +# with open(os.path.join(checkpoint_path, f"__{rank}_0.pt"), "wb") as f: +# torch.save(rank_state_dict, f) + +# if not save_global_state: +# return + +# # 2. Save global states +# global_state_dict = {"scheduler": scheduler.state_dict(), "loss": loss if loss is not None else 0} +# if outer_optimizer is not None: +# global_state_dict["outer_optimizer"] = outer_optimizer.state_dict() +# if scaler is not None: +# global_state_dict["scaler"] = scaler.state_dict() + +# with open(os.path.join(checkpoint_path, GLOBAL_STATE_FILE), "wb") as f: +# torch.save(global_state_dict, f) + +# def load_checkpoint( +# checkpoint_path: str, +# model: torch.nn.Module, +# optimizer: torch.optim.Optimizer, +# scheduler: torch.optim.lr_scheduler.LambdaLR | None = None, +# outer_optimizer: torch.optim.Optimizer | None = None, +# scaler: torch.cuda.amp.GradScaler | None = None, +# data_loader: StatefulDataLoader | None = None, +# ) -> float: +# """Load the model and optimizer state from a checkpoint folder + +# Args: +# checkpoint_path: the path to the checkpoint folder +# model: the model to load +# optimizer: the optimizer to load +# scheduler: the scheduler to load +# outer_optimizer: the outer optimizer to load +# data_loader: the data loader to load + +# Returns: +# loss: the loss from the checkpoint +# """ +# rank = int(os.environ["RANK"]) +# # 1. Load distributed states +# # fs_storage_reader = dcp.FsspecReader(checkpoint_path) + +# model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) +# dcp_state_dict = { +# "model": model_state_dict, +# "optimizer": optimizer_state_dict, +# } +# dcp.load(dcp_state_dict, checkpoint_id=checkpoint_path) +# set_state_dict( +# model, +# optimizer, +# model_state_dict=model_state_dict, +# optim_state_dict=optimizer_state_dict, +# ) +# if data_loader is not None: +# with open(os.path.join(checkpoint_path, f"__{rank}_0.pt"), "rb") as f: +# rank_state_dict = torch.load(f) +# data_loader.load_state_dict(rank_state_dict["data_loader"]) + +# # 2. Load global states +# with open(os.path.join(checkpoint_path, GLOBAL_STATE_FILE), "rb") as f: +# global_state_dict = torch.load(f) +# if scheduler is not None: +# scheduler.load_state_dict(global_state_dict["scheduler"]) +# optimizer.param_groups[0]["lr"] = scheduler.get_last_lr()[0] +# if outer_optimizer is not None: +# outer_optimizer.load_state_dict(global_state_dict["outer_optimizer"]) +# if scaler is not None: +# scaler.load_state_dict(global_state_dict["scaler"]) +# return global_state_dict["loss"] + + +# class CkptManager: +# """Its name CkptManager because I (sami) always misstyped chekcpoint. """ + +# def __init__(self, model: nn.Module, optimizer: Optimizer, scheduler: LambdaLR, dataloader: StatefulDataLoader, training_progress: TrainingProgress, process_group: ProcessGroup | None): + +# self.model = model +# self.optimizer = optimizer +# self.scheduler = scheduler +# self.dataloader = dataloader +# self.training_progress = training_progress + +# # states can only be stateful object, hence we need to wrap Model and Optimizer + +# self.process_group = process_group +# self._logger = get_logger() + +# def save(self, ckpt_path: str) -> None: +# save(checkpoint_path=ckpt_path, model=self.model, optimizer=self.optimizer, scheduler=self.scheduler, data_loader=self.dataloader) + + +# def load(self, resume_ckpt_path: str) -> None: +# load_checkpoint(checkpoint_path=resume_ckpt_path, model=self.model, optimizer=self.optimizer, scheduler=self.scheduler, data_loader=self.dataloader) diff --git a/src/zeroband/data.py b/src/zeroband/data.py index 61a1a986..1a093d1c 100644 --- a/src/zeroband/data.py +++ b/src/zeroband/data.py @@ -4,6 +4,7 @@ import torch from torch.utils.data import DataLoader from torch.utils.data import IterableDataset +from torchdata.stateful_dataloader import StatefulDataLoader from datasets import load_dataset from datasets.distributed import split_dataset_by_node @@ -79,7 +80,7 @@ def tokenize_function(data): data_collator = collate_causal_mask(max_seq_length=seq_length, pad_id=tokenizer.pad_token_id, ignore_index=-100) - return DataLoader( + return StatefulDataLoader( train_dataset, collate_fn=data_collator, batch_size=batch_size, diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 3cccad1f..ac81f705 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -26,6 +26,7 @@ from zeroband.models.llama import get_model from zeroband.utils.world_info import get_world_info from zeroband.utils.logging import get_logger +from zeroband.checkpoint import TrainingProgress class DataConfig(BaseConfig): @@ -53,6 +54,11 @@ class TrainConfig(BaseConfig): log_model_hash: bool = False +class CkptConfig(BaseConfig): + path: str + interval: int + + class Config(BaseConfig): # main config name_model: Literal["debugmodel", "150M", "271M", "1B", "7B", "13B", "26B", "70B"] = "150M" @@ -67,6 +73,9 @@ class Config(BaseConfig): optim: OptimConfig = OptimConfig() train: TrainConfig + ckpt: CkptConfig | None = None + resume: str | None = None + def train(config: Config): sharding_strategy = get_sharding_strategy(config.train.sharding_strategy) @@ -78,6 +87,11 @@ def train(config: Config): assert batch_size % config.train.micro_bs == 0 gradient_accumulation_steps = batch_size // config.train.micro_bs + if config.ckpt is not None and config.ckpt.interval is not None and config.diloco is not None: + assert ( + config.ckpt.interval % config.diloco.inner_steps == 0 + ), "ckpt interval must be a multiple of diloco inner steps as we only save at the end of an outer step" + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True) tokenizer.pad_token = "" # todo(sami): remove padding tokens once we have context stuffing @@ -126,10 +140,7 @@ def train(config: Config): use_orig_params=True, process_group=elastic_device_mesh.local_pg if config.diloco is not None else None, ) - - if config.train.torch_compile: - model = torch.compile(model) - logger.debug("model compiled and fsdped") + logger.debug("model fsdped") # Setup optimizers inner_optimizer = torch.optim.AdamW( @@ -148,6 +159,26 @@ def train(config: Config): num_training_steps=config.optim.total_steps, ) + training_progress = TrainingProgress(total_tokens=0, outer_step=0, step=0) + + ckpt_manager = CkptManager( + model=model, + optimizer=inner_optimizer, + scheduler=scheduler, + dataloader=train_dataloader, + training_progress=training_progress, + process_group=elastic_device_mesh.local_pg if config.diloco is not None else None, + ) + + if config.train.torch_compile: + # we need to compile AFTER creating the CKPT manager, DON'T ASK ME WHY + model = torch.compile(model) + logger.debug("model compiled") + + if config.resume is not None: + # all is inplace + ckpt_manager.load(resume_ckpt_path=config.resume) + model.train() if world_info.rank == 0: @@ -156,7 +187,6 @@ def train(config: Config): train_dataloader_iterator = iter(train_dataloader) - outer_step = 0 num_inner_steps = config.diloco.inner_steps if config.diloco is not None else 1 perf_counter = PerfCounter(window_size=10) @@ -164,9 +194,9 @@ def train(config: Config): while True: if num_inner_steps > 1: # if we don't use diloco we don't print the outer step logs - logger.info(f"outer_step step: {outer_step}") + logger.info(f"outer_step step: {training_progress.outer_step}") - for inner_step in range(num_inner_steps): + for _inner_step in range(num_inner_steps): loss_batch = 0 for grad_acc_step in range(gradient_accumulation_steps): @@ -193,22 +223,24 @@ def train(config: Config): inner_optimizer.zero_grad() # logging - real_step = outer_step * num_inner_steps + inner_step + 1 # add + 1 because inner_step start at 0 + training_progress.step += 1 inner_lr = [group["lr"] for group in inner_optimizer.param_groups][0] dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg) # syncing loss across all data parallel rank within a nodes - perf_counter.count_tokens(config.data.seq_length * config.optim.batch_size) + new_tokens = config.data.seq_length * config.optim.batch_size + perf_counter.count_tokens(new_tokens) + training_progress.total_tokens += new_tokens metrics = { "Loss": loss_batch.item(), - "step": real_step, + "step": training_progress.step, "inner_lr": inner_lr, "Perplexity": torch.exp(loss_batch).item(), - "total_tokens": real_step * config.optim.batch_size * config.data.seq_length, + "total_tokens": training_progress.total_tokens, } - log = f"step: {real_step}, loss: {loss_batch.item():.4f}" + log = f"step: {training_progress.step}, loss: {loss_batch.item():.4f}" tokens_per_second = perf_counter.get_tokens_per_second() @@ -237,9 +269,17 @@ def train(config: Config): with FSDP.summon_full_params(model): logger.debug("Post diloco model: %s", get_module_signature(model)) - outer_step += 1 + training_progress.outer_step += 1 + + if ( + config.ckpt is not None + and training_progress.step > 0 + and training_progress.step % config.ckpt.interval == 0 + ): + # we only allow to checkpoint after a outer step. For non diloco training outer step = 1 anyway + ckpt_manager.save(config.ckpt.path) - if real_step >= config.optim.total_steps: + if training_progress.step >= config.optim.total_steps: # we only allow to break outisde of the inner loop. # This avoid ending the training in the middle of a the inner loop # Since ckpt strategy and all reduce is done at the outer loop level. diff --git a/tests/test_torchrun/test_train.py b/tests/test_torchrun/test_train.py index 79bbc61c..d116b2ca 100644 --- a/tests/test_torchrun/test_train.py +++ b/tests/test_torchrun/test_train.py @@ -1,5 +1,7 @@ import copy import os +from pathlib import Path +import shutil import subprocess import pytest import socket @@ -71,3 +73,14 @@ def test_multi_gpu_diloco_non_full_shard(strategy): # we don't test 1,1 and 2,1 because 1 solo gpu failed with fsdp num_gpus = [2, 2] _test_multi_gpu(num_gpus, "debug/diloco.toml", extra_args=["--train.sharding_strategy", strategy]) + + +## test ckpt + + +def test_ckpt(tmp_path: Path): + ckpt_path = "outputs" # for some reason tmp_path is not working + os.makedirs(ckpt_path, exist_ok=True) + _test_multi_gpu([1, 1], "debug/normal.toml", extra_args=["--ckpt.path", str(ckpt_path), "--ckpt.interval", "10"]) + _test_multi_gpu([1, 1], "debug/normal.toml", extra_args=["--resume", str(ckpt_path)]) + shutil.rmtree(ckpt_path) From 98e9565cd524983093530314d0b117ef1105cc41 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sat, 28 Sep 2024 18:35:03 +0000 Subject: [PATCH 04/20] fix real dataset loading error --- src/zeroband/checkpoint.py | 152 +++++-------------------------------- 1 file changed, 18 insertions(+), 134 deletions(-) diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py index 3fe0d4ad..65757aaf 100644 --- a/src/zeroband/checkpoint.py +++ b/src/zeroband/checkpoint.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +import os import time from typing import Any import torch @@ -20,6 +21,8 @@ import warnings import logging +from zeroband.utils.world_info import get_world_info + ## code inspired by torchtitan https://github.com/pytorch/torchtitan/blob/main/torchtitan/checkpoint.py @@ -95,7 +98,7 @@ def __init__( "model": self.model, "optimizer": self.optimizer, "scheduler": self.scheduler, - "dataloader": self.dataloader, + # "dataloader": self.dataloader, # ignoring dataloader for now as each rank has its own dataloader "training_progress": self.training_progress, } @@ -115,12 +118,18 @@ def save(self, ckpt_path: str) -> None: # pytorch has an annoying warning when saving the optimizer state https://github.com/pytorch/pytorch/issues/136907 # we can ignore it if we are not logging in DEBUG mode + rank = get_world_info().local_rank + with warnings.catch_warnings(): if catch_warning: warnings.simplefilter("ignore") dcp.save(self.states, process_group=self.process_group, checkpoint_id=ckpt_path) + ## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk + with open(os.path.join(ckpt_path, f"__{rank}_0.pt"), "wb") as f: + torch.save({"data_loader": self.dataloader.state_dict()}, f) + self._logger.info(f"Saved checkpoint to {ckpt_path} in {time.perf_counter() - time_start} seconds") def load(self, resume_ckpt_path: str) -> None: @@ -133,138 +142,13 @@ def load(self, resume_ckpt_path: str) -> None: """ time_start = time.perf_counter() self.states = dcp.load(self.states, process_group=self.process_group, checkpoint_id=resume_ckpt_path) - self._logger.info(f"Loaded checkpoint from {resume_ckpt_path} in {time.perf_counter() - time_start} seconds") + rank = get_world_info().local_rank # todo check after on/off ramping pr which rank is good here + + ## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk + if self.dataloader is not None: + with open(os.path.join(resume_ckpt_path, f"__{rank}_0.pt"), "rb") as f: + rank_state_dict = torch.load(f) + self.dataloader.load_state_dict(rank_state_dict["data_loader"]) -# def save( -# checkpoint_path: str, -# model: torch.nn.Module, -# optimizer: torch.optim.Optimizer, -# scheduler: torch.optim.lr_scheduler.LambdaLR, -# outer_optimizer: torch.optim.Optimizer | None = None, -# scaler: torch.cuda.amp.GradScaler | None = None, -# loss: float | None = None, -# data_loader: StatefulDataLoader | None = None, -# save_global_state: bool = True, -# ): -# """Save the model and optimizer state to a checkpoint folderx - -# Args: -# checkpoint_path: the path to the checkpoint folder -# model: the model to save -# optimizer: the optimizer to save -# scheduler: the scheduler to save -# outer_optimizer: the outer optimizer to save -# loss: the loss to save -# data_loader: the data loader to save -# save_global_state: whether to save the global state -# """ -# rank = int(os.environ["RANK"]) - -# # 1. Save distributed states -# # fs_storage_writer = dcp.FsspecWriter(checkpoint_path, sync_files=False) -# # for some reason sync_files = True try to call stream.fileno which is not supported with gcp ffspec storage. - -# model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) -# dcp_state_dict = { -# "model": model_state_dict, -# "optimizer": optimizer_state_dict, -# } -# dcp.save(dcp_state_dict, checkpoint_id=checkpoint_path) -# if data_loader is not None: -# rank_state_dict = {} -# rank_state_dict["data_loader"] = data_loader.state_dict() -# with open(os.path.join(checkpoint_path, f"__{rank}_0.pt"), "wb") as f: -# torch.save(rank_state_dict, f) - -# if not save_global_state: -# return - -# # 2. Save global states -# global_state_dict = {"scheduler": scheduler.state_dict(), "loss": loss if loss is not None else 0} -# if outer_optimizer is not None: -# global_state_dict["outer_optimizer"] = outer_optimizer.state_dict() -# if scaler is not None: -# global_state_dict["scaler"] = scaler.state_dict() - -# with open(os.path.join(checkpoint_path, GLOBAL_STATE_FILE), "wb") as f: -# torch.save(global_state_dict, f) - -# def load_checkpoint( -# checkpoint_path: str, -# model: torch.nn.Module, -# optimizer: torch.optim.Optimizer, -# scheduler: torch.optim.lr_scheduler.LambdaLR | None = None, -# outer_optimizer: torch.optim.Optimizer | None = None, -# scaler: torch.cuda.amp.GradScaler | None = None, -# data_loader: StatefulDataLoader | None = None, -# ) -> float: -# """Load the model and optimizer state from a checkpoint folder - -# Args: -# checkpoint_path: the path to the checkpoint folder -# model: the model to load -# optimizer: the optimizer to load -# scheduler: the scheduler to load -# outer_optimizer: the outer optimizer to load -# data_loader: the data loader to load - -# Returns: -# loss: the loss from the checkpoint -# """ -# rank = int(os.environ["RANK"]) -# # 1. Load distributed states -# # fs_storage_reader = dcp.FsspecReader(checkpoint_path) - -# model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) -# dcp_state_dict = { -# "model": model_state_dict, -# "optimizer": optimizer_state_dict, -# } -# dcp.load(dcp_state_dict, checkpoint_id=checkpoint_path) -# set_state_dict( -# model, -# optimizer, -# model_state_dict=model_state_dict, -# optim_state_dict=optimizer_state_dict, -# ) -# if data_loader is not None: -# with open(os.path.join(checkpoint_path, f"__{rank}_0.pt"), "rb") as f: -# rank_state_dict = torch.load(f) -# data_loader.load_state_dict(rank_state_dict["data_loader"]) - -# # 2. Load global states -# with open(os.path.join(checkpoint_path, GLOBAL_STATE_FILE), "rb") as f: -# global_state_dict = torch.load(f) -# if scheduler is not None: -# scheduler.load_state_dict(global_state_dict["scheduler"]) -# optimizer.param_groups[0]["lr"] = scheduler.get_last_lr()[0] -# if outer_optimizer is not None: -# outer_optimizer.load_state_dict(global_state_dict["outer_optimizer"]) -# if scaler is not None: -# scaler.load_state_dict(global_state_dict["scaler"]) -# return global_state_dict["loss"] - - -# class CkptManager: -# """Its name CkptManager because I (sami) always misstyped chekcpoint. """ - -# def __init__(self, model: nn.Module, optimizer: Optimizer, scheduler: LambdaLR, dataloader: StatefulDataLoader, training_progress: TrainingProgress, process_group: ProcessGroup | None): - -# self.model = model -# self.optimizer = optimizer -# self.scheduler = scheduler -# self.dataloader = dataloader -# self.training_progress = training_progress - -# # states can only be stateful object, hence we need to wrap Model and Optimizer - -# self.process_group = process_group -# self._logger = get_logger() - -# def save(self, ckpt_path: str) -> None: -# save(checkpoint_path=ckpt_path, model=self.model, optimizer=self.optimizer, scheduler=self.scheduler, data_loader=self.dataloader) - - -# def load(self, resume_ckpt_path: str) -> None: -# load_checkpoint(checkpoint_path=resume_ckpt_path, model=self.model, optimizer=self.optimizer, scheduler=self.scheduler, data_loader=self.dataloader) + self._logger.info(f"Loaded checkpoint from {resume_ckpt_path} in {time.perf_counter() - time_start} seconds") From 986aa20c487355d3e41c4f6ff177b360cf192b43 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sat, 28 Sep 2024 20:57:07 +0000 Subject: [PATCH 05/20] ckpt save in the right step folder --- src/zeroband/checkpoint.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py index 65757aaf..c06bebdb 100644 --- a/src/zeroband/checkpoint.py +++ b/src/zeroband/checkpoint.py @@ -76,7 +76,17 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: class CkptManager: - """Its name CkptManager because I (sami) always misstyped chekcpoint.""" + """Its name CkptManager because I (sami) always misstyped chekcpoint. + + Checkpoint are saved in a folder with the following structure: + ckpt_path/ + step_0/ + _0_0.pt + _1_0.pt + ... + step_1/ + ... + """ def __init__( self, @@ -114,6 +124,7 @@ def save(self, ckpt_path: str) -> None: time_start = time.perf_counter() + ckpt_path = self._get_ckpt_folder_name(ckpt_path, self.training_progress.step) catch_warning = self._logger.getEffectiveLevel() <= logging.INFO # pytorch has an annoying warning when saving the optimizer state https://github.com/pytorch/pytorch/issues/136907 # we can ignore it if we are not logging in DEBUG mode @@ -138,9 +149,12 @@ def load(self, resume_ckpt_path: str) -> None: Each rank will load the right shard of the model and optimizer. All rank will load the global states (scheduler, step, total_tokens, dataloader). + `resume_ckpt_path` should point to a specific step and not to the base ckpt folder. Example: `ckpt_path/step_100` + Loading is done inplace """ time_start = time.perf_counter() + self.states = dcp.load(self.states, process_group=self.process_group, checkpoint_id=resume_ckpt_path) rank = get_world_info().local_rank # todo check after on/off ramping pr which rank is good here @@ -152,3 +166,10 @@ def load(self, resume_ckpt_path: str) -> None: self.dataloader.load_state_dict(rank_state_dict["data_loader"]) self._logger.info(f"Loaded checkpoint from {resume_ckpt_path} in {time.perf_counter() - time_start} seconds") + + @staticmethod + def _get_ckpt_folder_name(ckpt_path: str, step: int) -> str: + """ + The ckpt folder can contains multiple ckpt with different step name. This function return the sub directory name for the ckpt with the given step. + """ + return os.path.join(ckpt_path, f"step_{step}") From 7ccce1b8e321509e8bdbd80926dfa64445974344 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sat, 28 Sep 2024 20:59:40 +0000 Subject: [PATCH 06/20] change total tokens diloco --- src/zeroband/train.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index ac81f705..df7bf775 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -231,7 +231,13 @@ def train(config: Config): new_tokens = config.data.seq_length * config.optim.batch_size perf_counter.count_tokens(new_tokens) - training_progress.total_tokens += new_tokens + + if config.diloco is not None: + training_progress.total_tokens += new_tokens + else: + # we count the total tokens with respect to all diloco workers + # might need to tweak this as some worker might fail to join the all reduce later + training_progress.total_tokens += new_tokens * elastic_device_mesh.global_pg.size() metrics = { "Loss": loss_batch.item(), From 09d59f0fdec418a21a7b6e429d1b37918456bf7f Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sat, 28 Sep 2024 21:06:00 +0000 Subject: [PATCH 07/20] add tests --- tests/test_torchrun/test_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_torchrun/test_train.py b/tests/test_torchrun/test_train.py index d116b2ca..e3b99893 100644 --- a/tests/test_torchrun/test_train.py +++ b/tests/test_torchrun/test_train.py @@ -81,6 +81,6 @@ def test_multi_gpu_diloco_non_full_shard(strategy): def test_ckpt(tmp_path: Path): ckpt_path = "outputs" # for some reason tmp_path is not working os.makedirs(ckpt_path, exist_ok=True) - _test_multi_gpu([1, 1], "debug/normal.toml", extra_args=["--ckpt.path", str(ckpt_path), "--ckpt.interval", "10"]) - _test_multi_gpu([1, 1], "debug/normal.toml", extra_args=["--resume", str(ckpt_path)]) + _test_multi_gpu([1, 1], "debug/normal.toml", extra_args=["--ckpt.path", f"{ckpt_path}/", "--ckpt.interval", "10"]) + _test_multi_gpu([1, 1], "debug/normal.toml", extra_args=["--resume", f"{ckpt_path}/step_10"]) shutil.rmtree(ckpt_path) From 10b621546f1a482a16ea412d2f8db77220f01eac Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sat, 28 Sep 2024 21:46:13 +0000 Subject: [PATCH 08/20] refactor easier step --- src/zeroband/checkpoint.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py index c06bebdb..95a6aae6 100644 --- a/src/zeroband/checkpoint.py +++ b/src/zeroband/checkpoint.py @@ -124,7 +124,7 @@ def save(self, ckpt_path: str) -> None: time_start = time.perf_counter() - ckpt_path = self._get_ckpt_folder_name(ckpt_path, self.training_progress.step) + ckpt_path = os.path.join(ckpt_path, f"step_{self.training_progress.step}") catch_warning = self._logger.getEffectiveLevel() <= logging.INFO # pytorch has an annoying warning when saving the optimizer state https://github.com/pytorch/pytorch/issues/136907 # we can ignore it if we are not logging in DEBUG mode @@ -166,10 +166,3 @@ def load(self, resume_ckpt_path: str) -> None: self.dataloader.load_state_dict(rank_state_dict["data_loader"]) self._logger.info(f"Loaded checkpoint from {resume_ckpt_path} in {time.perf_counter() - time_start} seconds") - - @staticmethod - def _get_ckpt_folder_name(ckpt_path: str, step: int) -> str: - """ - The ckpt folder can contains multiple ckpt with different step name. This function return the sub directory name for the ckpt with the given step. - """ - return os.path.join(ckpt_path, f"step_{step}") From d7119aebc6d8f050c09bff9710b225ec11d9efe9 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sun, 29 Sep 2024 03:09:20 +0000 Subject: [PATCH 09/20] use fsspec rsync for bbacking up the ckpt to remote --- pyproject.toml | 3 +- src/zeroband/checkpoint.py | 7 +- src/zeroband/train.py | 4 +- uv.lock | 231 +++++++++++++++++++++++++++++++++++++ 4 files changed, 242 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 18c7b780..c963ea3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,8 @@ dependencies = [ "datasets>=3.0.0", "pydantic_config @ git+https://github.com/samsja/pydantic_config.git@e529c9c", "einops", - "torchdata>=0.8.0" + "torchdata>=0.8.0", + "fsspec[gcs]>=2024.3.1", ] [project.optional-dependencies] diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py index 95a6aae6..db4adf94 100644 --- a/src/zeroband/checkpoint.py +++ b/src/zeroband/checkpoint.py @@ -2,6 +2,7 @@ import os import time from typing import Any +from fsspec.generic import rsync as rsync_fsspec import torch from torch import nn from torch.optim import Optimizer @@ -115,7 +116,7 @@ def __init__( self.process_group = process_group self._logger = get_logger() - def save(self, ckpt_path: str) -> None: + def save(self, ckpt_path: str, remote_ckpt_path: str | None = None) -> None: """ Each rank will save the right shard of the model and optimizer. @@ -143,6 +144,10 @@ def save(self, ckpt_path: str) -> None: self._logger.info(f"Saved checkpoint to {ckpt_path} in {time.perf_counter() - time_start} seconds") + if remote_ckpt_path is not None: + remote_ckpt_path = os.path.join(remote_ckpt_path, f"step_{self.training_progress.step}") + rsync_fsspec(ckpt_path, remote_ckpt_path) + def load(self, resume_ckpt_path: str) -> None: """ loading should be done after fsdp wrap and optimizer init. diff --git a/src/zeroband/train.py b/src/zeroband/train.py index df7bf775..05ded3ed 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -58,6 +58,8 @@ class CkptConfig(BaseConfig): path: str interval: int + remote_path: str | None = None # could be a s3 path + class Config(BaseConfig): # main config @@ -283,7 +285,7 @@ def train(config: Config): and training_progress.step % config.ckpt.interval == 0 ): # we only allow to checkpoint after a outer step. For non diloco training outer step = 1 anyway - ckpt_manager.save(config.ckpt.path) + ckpt_manager.save(config.ckpt.path, config.ckpt.remote_path) if training_progress.step >= config.optim.total_steps: # we only allow to break outisde of the inner loop. diff --git a/uv.lock b/uv.lock index 415fe509..84348844 100644 --- a/uv.lock +++ b/uv.lock @@ -140,6 +140,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/21/5b6702a7f963e95456c0de2d495f67bf5fd62840ac655dc451586d23d39a/attrs-24.2.0-py3-none-any.whl", hash = "sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2", size = 63001 }, ] +[[package]] +name = "cachetools" +version = "5.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c3/38/a0f315319737ecf45b4319a8cd1f3a908e29d9277b46942263292115eee7/cachetools-5.5.0.tar.gz", hash = "sha256:2cc24fb4cbe39633fb7badd9db9ca6295d766d9c2995f245725a46715d050f2a", size = 27661 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/07/14f8ad37f2d12a5ce41206c21820d8cb6561b728e51fad4530dff0552a67/cachetools-5.5.0-py3-none-any.whl", hash = "sha256:02134e8439cdc2ffb62023ce1debca2944c3f289d66bb17ead3ab3dede74b292", size = 9524 }, +] + [[package]] name = "certifi" version = "2024.8.30" @@ -258,6 +267,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a5/52/45dab187f03d48c765b94db0464f5c10431756e47ae4cc6a8029a7d57a36/datasets-3.0.0-py3-none-any.whl", hash = "sha256:c23fefb6c953dcb1cd5f6deb6c502729c733ef98791e0c3f2d80c7ca2d9a01dd", size = 474265 }, ] +[[package]] +name = "decorator" +version = "5.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/66/0c/8d907af351aa16b42caae42f9d6aa37b900c67308052d10fdce809f8d952/decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330", size = 35016 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/50/83c593b07763e1161326b3b8c6686f0f4b0f24d5526546bee538c89837d6/decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186", size = 9073 }, +] + [[package]] name = "dill" version = "0.3.8" @@ -379,10 +397,31 @@ wheels = [ ] [package.optional-dependencies] +gcs = [ + { name = "gcsfs" }, +] http = [ { name = "aiohttp" }, ] +[[package]] +name = "gcsfs" +version = "2024.6.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, + { name = "decorator" }, + { name = "fsspec" }, + { name = "google-auth" }, + { name = "google-auth-oauthlib" }, + { name = "google-cloud-storage" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7f/b1/c5ae16ad1d499f0cf10e3306f717eadae30dba64ec29236077b8fe661e7c/gcsfs-2024.6.1.tar.gz", hash = "sha256:e8858c7a893b2265e9bfce2fe270a024a2e348c74c23528801db388fc0224ed7", size = 79259 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bf/65/f467159d42a2ce4191f1f6ff319e75b5a14ab2cc080062dbf5821a80244c/gcsfs-2024.6.1-py2.py3-none-any.whl", hash = "sha256:13fd18095425e54e248870594fd155812723966b1bda3b102b3a5c44ec436a03", size = 34866 }, +] + [[package]] name = "gitdb" version = "4.0.11" @@ -407,6 +446,129 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e9/bd/cc3a402a6439c15c3d4294333e13042b915bbeab54edc457c723931fed3f/GitPython-3.1.43-py3-none-any.whl", hash = "sha256:eec7ec56b92aad751f9912a73404bc02ba212a23adb2c7098ee668417051a1ff", size = 207337 }, ] +[[package]] +name = "google-api-core" +version = "2.20.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-auth" }, + { name = "googleapis-common-protos" }, + { name = "proto-plus" }, + { name = "protobuf" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c8/5c/31c1742a53b79c8a0c4757b5fae2e8ab9c519cbd7b98c587d4294e1d2d16/google_api_core-2.20.0.tar.gz", hash = "sha256:f74dff1889ba291a4b76c5079df0711810e2d9da81abfdc99957bc961c1eb28f", size = 152583 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/dc/6143f67acf5f30717c9e1b1c48fc04c0f59b869be046e6639d3f171640ae/google_api_core-2.20.0-py3-none-any.whl", hash = "sha256:ef0591ef03c30bb83f79b3d0575c3f31219001fc9c5cf37024d08310aeffed8a", size = 142162 }, +] + +[[package]] +name = "google-auth" +version = "2.35.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cachetools" }, + { name = "pyasn1-modules" }, + { name = "rsa" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a1/37/c854a8b1b1020cf042db3d67577c6f84cd1e8ff6515e4f5498ae9e444ea5/google_auth-2.35.0.tar.gz", hash = "sha256:f4c64ed4e01e8e8b646ef34c018f8bf3338df0c8e37d8b3bba40e7f574a3278a", size = 267223 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/27/1f/3a72917afcb0d5cd842cbccb81bf7a8a7b45b4c66d8dc4556ccb3b016bfc/google_auth-2.35.0-py2.py3-none-any.whl", hash = "sha256:25df55f327ef021de8be50bad0dfd4a916ad0de96da86cd05661c9297723ad3f", size = 208968 }, +] + +[[package]] +name = "google-auth-oauthlib" +version = "1.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-auth" }, + { name = "requests-oauthlib" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cc/0f/1772edb8d75ecf6280f1c7f51cbcebe274e8b17878b382f63738fd96cee5/google_auth_oauthlib-1.2.1.tar.gz", hash = "sha256:afd0cad092a2eaa53cd8e8298557d6de1034c6cb4a740500b5357b648af97263", size = 24970 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1a/8e/22a28dfbd218033e4eeaf3a0533b2b54852b6530da0c0fe934f0cc494b29/google_auth_oauthlib-1.2.1-py2.py3-none-any.whl", hash = "sha256:2d58a27262d55aa1b87678c3ba7142a080098cbc2024f903c62355deb235d91f", size = 24930 }, +] + +[[package]] +name = "google-cloud-core" +version = "2.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core" }, + { name = "google-auth" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b8/1f/9d1e0ba6919668608570418a9a51e47070ac15aeff64261fb092d8be94c0/google-cloud-core-2.4.1.tar.gz", hash = "sha256:9b7749272a812bde58fff28868d0c5e2f585b82f37e09a1f6ed2d4d10f134073", size = 35587 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5e/0f/2e2061e3fbcb9d535d5da3f58cc8de4947df1786fe6a1355960feb05a681/google_cloud_core-2.4.1-py2.py3-none-any.whl", hash = "sha256:a9e6a4422b9ac5c29f79a0ede9485473338e2ce78d91f2370c01e730eab22e61", size = 29233 }, +] + +[[package]] +name = "google-cloud-storage" +version = "2.18.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core" }, + { name = "google-auth" }, + { name = "google-cloud-core" }, + { name = "google-crc32c" }, + { name = "google-resumable-media" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d6/b7/1554cdeb55d9626a4b8720746cba8119af35527b12e1780164f9ba0f659a/google_cloud_storage-2.18.2.tar.gz", hash = "sha256:aaf7acd70cdad9f274d29332673fcab98708d0e1f4dceb5a5356aaef06af4d99", size = 5532864 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fc/da/95db7bd4f0bd1644378ac1702c565c0210b004754d925a74f526a710c087/google_cloud_storage-2.18.2-py2.py3-none-any.whl", hash = "sha256:97a4d45c368b7d401ed48c4fdfe86e1e1cb96401c9e199e419d289e2c0370166", size = 130466 }, +] + +[[package]] +name = "google-crc32c" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/67/72/c3298da1a3773102359c5a78f20dae8925f5ea876e37354415f68594a6fb/google_crc32c-1.6.0.tar.gz", hash = "sha256:6eceb6ad197656a1ff49ebfbbfa870678c75be4344feb35ac1edf694309413dc", size = 14472 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1a/be/d7846cb50e17bf72a70ea2d8159478ac5de0f1170b10cac279f50079e78d/google_crc32c-1.6.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:5bcc90b34df28a4b38653c36bb5ada35671ad105c99cfe915fb5bed7ad6924aa", size = 30267 }, + { url = "https://files.pythonhosted.org/packages/84/3b/29cadae166132e4991087a49dc88906a1d3d5ec22b80f63bc4bc7b6e0431/google_crc32c-1.6.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:d9e9913f7bd69e093b81da4535ce27af842e7bf371cde42d1ae9e9bd382dc0e9", size = 30113 }, + { url = "https://files.pythonhosted.org/packages/18/a9/49a7b2c4b7cc69d15778a820734f9beb647b1b4cf1a629ca43e3d3a54c70/google_crc32c-1.6.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a184243544811e4a50d345838a883733461e67578959ac59964e43cca2c791e7", size = 37702 }, + { url = "https://files.pythonhosted.org/packages/4b/aa/52538cceddefc7c2d66c6bd59dfe67a50f65a4952f441f91049e4188eb57/google_crc32c-1.6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:236c87a46cdf06384f614e9092b82c05f81bd34b80248021f729396a78e55d7e", size = 32847 }, + { url = "https://files.pythonhosted.org/packages/b1/2c/1928413d3faae74ae0d7bdba648cf36ed6b03328c562b47046af016b7249/google_crc32c-1.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ebab974b1687509e5c973b5c4b8b146683e101e102e17a86bd196ecaa4d099fc", size = 37844 }, + { url = "https://files.pythonhosted.org/packages/d6/f4/f62fa405e442b37c5676973b759dd6e56cd8d58a5c78662912456526f716/google_crc32c-1.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:50cf2a96da226dcbff8671233ecf37bf6e95de98b2a2ebadbfdf455e6d05df42", size = 33444 }, + { url = "https://files.pythonhosted.org/packages/7d/14/ab47972ac79b6e7b03c8be3a7ef44b530a60e69555668dbbf08fc5692a98/google_crc32c-1.6.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:f7a1fc29803712f80879b0806cb83ab24ce62fc8daf0569f2204a0cfd7f68ed4", size = 30267 }, + { url = "https://files.pythonhosted.org/packages/54/7d/738cb0d25ee55629e7d07da686decf03864a366e5e863091a97b7bd2b8aa/google_crc32c-1.6.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:40b05ab32a5067525670880eb5d169529089a26fe35dce8891127aeddc1950e8", size = 30112 }, + { url = "https://files.pythonhosted.org/packages/3e/6d/33ca50cbdeec09c31bb5dac277c90994edee975662a4c890bda7ffac90ef/google_crc32c-1.6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9e4b426c3702f3cd23b933436487eb34e01e00327fac20c9aebb68ccf34117d", size = 32861 }, + { url = "https://files.pythonhosted.org/packages/67/1e/4870896fc81ec77b1b5ebae7fdd680d5a4d40e19a4b6d724032f996ca77a/google_crc32c-1.6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51c4f54dd8c6dfeb58d1df5e4f7f97df8abf17a36626a217f169893d1d7f3e9f", size = 32490 }, + { url = "https://files.pythonhosted.org/packages/00/9c/f5f5af3ddaa7a639d915f8f58b09bbb8d1db90ecd0459b62cd430eb9a4b6/google_crc32c-1.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:bb8b3c75bd157010459b15222c3fd30577042a7060e29d42dabce449c087f2b3", size = 33446 }, + { url = "https://files.pythonhosted.org/packages/cf/41/65a91657d6a8123c6c12f9aac72127b6ac76dda9e2ba1834026a842eb77c/google_crc32c-1.6.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:ed767bf4ba90104c1216b68111613f0d5926fb3780660ea1198fc469af410e9d", size = 30268 }, + { url = "https://files.pythonhosted.org/packages/59/d0/ee743a267c7d5c4bb8bd865f7d4c039505f1c8a4b439df047fdc17be9769/google_crc32c-1.6.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:62f6d4a29fea082ac4a3c9be5e415218255cf11684ac6ef5488eea0c9132689b", size = 30113 }, + { url = "https://files.pythonhosted.org/packages/25/53/e5e449c368dd26ade5fb2bb209e046d4309ed0623be65b13f0ce026cb520/google_crc32c-1.6.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c87d98c7c4a69066fd31701c4e10d178a648c2cac3452e62c6b24dc51f9fcc00", size = 32995 }, + { url = "https://files.pythonhosted.org/packages/52/12/9bf6042d5b0ac8c25afed562fb78e51b0641474097e4139e858b45de40a5/google_crc32c-1.6.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bd5e7d2445d1a958c266bfa5d04c39932dc54093fa391736dbfdb0f1929c1fb3", size = 32614 }, + { url = "https://files.pythonhosted.org/packages/76/29/fc20f5ec36eac1eea0d0b2de4118c774c5f59c513f2a8630d4db6991f3e0/google_crc32c-1.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:7aec8e88a3583515f9e0957fe4f5f6d8d4997e36d0f61624e70469771584c760", size = 33445 }, + { url = "https://files.pythonhosted.org/packages/e7/ff/ed48d136b65ddc61f5aef6261c58cd817c8cd60640b16680e5419fb17018/google_crc32c-1.6.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48abd62ca76a2cbe034542ed1b6aee851b6f28aaca4e6551b5599b6f3ef175cc", size = 28057 }, + { url = "https://files.pythonhosted.org/packages/14/fb/54deefe679b7d1c1cc81d83396fcf28ad1a66d213bddeb275a8d28665918/google_crc32c-1.6.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18e311c64008f1f1379158158bb3f0c8d72635b9eb4f9545f8cf990c5668e59d", size = 27866 }, +] + +[[package]] +name = "google-resumable-media" +version = "2.7.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-crc32c" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/58/5a/0efdc02665dca14e0837b62c8a1a93132c264bd02054a15abb2218afe0ae/google_resumable_media-2.7.2.tar.gz", hash = "sha256:5280aed4629f2b60b847b0d42f9857fd4935c11af266744df33d8074cae92fe0", size = 2163099 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/82/35/b8d3baf8c46695858cb9d8835a53baa1eeb9906ddaf2f728a5f5b640fd1e/google_resumable_media-2.7.2-py2.py3-none-any.whl", hash = "sha256:3ce7551e9fe6d99e9a126101d2536612bb73486721951e9562fee0f90c6ababa", size = 81251 }, +] + +[[package]] +name = "googleapis-common-protos" +version = "1.65.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/53/3b/1599ceafa875ffb951480c8c74f4b77646a6b80e80970698f2aa93c216ce/googleapis_common_protos-1.65.0.tar.gz", hash = "sha256:334a29d07cddc3aa01dee4988f9afd9b2916ee2ff49d6b757155dc0d197852c0", size = 113657 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/08/49bfe7cf737952cc1a9c43e80cc258ed45dad7f183c5b8276fc94cb3862d/googleapis_common_protos-1.65.0-py2.py3-none-any.whl", hash = "sha256:2972e6c496f435b92590fd54045060867f3fe9be2c82ab148fc8885035479a63", size = 220890 }, +] + [[package]] name = "huggingface-hub" version = "0.24.6" @@ -818,6 +980,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b8/d7/bd7cb2d95ac6ac6e8d05bfa96cdce69619f1ef2808e072919044c2d47a8c/nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82", size = 66307 }, ] +[[package]] +name = "oauthlib" +version = "3.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6d/fa/fbf4001037904031639e6bfbfc02badfc7e12f137a8afa254df6c4c8a670/oauthlib-3.2.2.tar.gz", hash = "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918", size = 177352 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/80/cab10959dc1faead58dc8384a781dfbf93cb4d33d50988f7a69f1b7c9bbe/oauthlib-3.2.2-py3-none-any.whl", hash = "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca", size = 151688 }, +] + [[package]] name = "packaging" version = "24.1" @@ -896,6 +1067,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/07/92/caae8c86e94681b42c246f0bca35c059a2f0529e5b92619f6aba4cf7e7b6/pre_commit-3.8.0-py2.py3-none-any.whl", hash = "sha256:9a90a53bf82fdd8778d58085faf8d83df56e40dfe18f45b19446e26bf1b3a63f", size = 204643 }, ] +[[package]] +name = "proto-plus" +version = "1.24.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3e/fc/e9a65cd52c1330d8d23af6013651a0bc50b6d76bcbdf91fae7cd19c68f29/proto-plus-1.24.0.tar.gz", hash = "sha256:30b72a5ecafe4406b0d339db35b56c4059064e69227b8c3bda7462397f966445", size = 55942 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/6f/db31f0711c0402aa477257205ce7d29e86a75cb52cd19f7afb585f75cda0/proto_plus-1.24.0-py3-none-any.whl", hash = "sha256:402576830425e5f6ce4c2a6702400ac79897dab0b4343821aa5188b0fab81a12", size = 50080 }, +] + [[package]] name = "protobuf" version = "5.28.2" @@ -959,6 +1142,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ae/49/baafe2a964f663413be3bd1cf5c45ed98c5e42e804e2328e18f4570027c1/pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7", size = 25099235 }, ] +[[package]] +name = "pyasn1" +version = "0.6.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/e9/01f1a64245b89f039897cb0130016d79f77d52669aae6ee7b159a6c4c018/pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034", size = 145322 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/f1/d6a797abb14f6283c0ddff96bbdd46937f64122b8c925cab503dd37f8214/pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629", size = 83135 }, +] + +[[package]] +name = "pyasn1-modules" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1d/67/6afbf0d507f73c32d21084a79946bfcfca5fbc62a72057e9c23797a737c9/pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c", size = 310028 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/89/bc88a6711935ba795a679ea6ebee07e128050d6382eaa35a0a47c8032bdc/pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd", size = 181537 }, +] + [[package]] name = "pydantic" version = "2.9.1" @@ -1224,6 +1428,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928 }, ] +[[package]] +name = "requests-oauthlib" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "oauthlib" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/f2/05f29bc3913aea15eb670be136045bf5c5bbf4b99ecb839da9b422bb2c85/requests-oauthlib-2.0.0.tar.gz", hash = "sha256:b3dffaebd884d8cd778494369603a9e7b58d29111bf6b41bdc2dcd87203af4e9", size = 55650 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/5d/63d4ae3b9daea098d5d6f5da83984853c1bbacd5dc826764b249fe119d24/requests_oauthlib-2.0.0-py2.py3-none-any.whl", hash = "sha256:7dd8a5c40426b779b0868c404bdef9768deccf22749cde15852df527e6269b36", size = 24179 }, +] + [[package]] name = "rich" version = "13.8.1" @@ -1237,6 +1454,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b0/11/dadb85e2bd6b1f1ae56669c3e1f0410797f9605d752d68fb47b77f525b31/rich-13.8.1-py3-none-any.whl", hash = "sha256:1760a3c0848469b97b558fc61c85233e3dafb69c7a071b4d60c38099d3cd4c06", size = 241608 }, ] +[[package]] +name = "rsa" +version = "4.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/aa/65/7d973b89c4d2351d7fb232c2e452547ddfa243e93131e7cfa766da627b52/rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21", size = 29711 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/97/fa78e3d2f65c02c8e1268b9aba606569fe97f6c8f7c2d74394553347c145/rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7", size = 34315 }, +] + [[package]] name = "ruff" version = "0.6.4" @@ -1814,6 +2043,7 @@ source = { editable = "." } dependencies = [ { name = "datasets" }, { name = "einops" }, + { name = "fsspec", extra = ["gcs"] }, { name = "numpy" }, { name = "pydantic-config" }, { name = "setuptools" }, @@ -1838,6 +2068,7 @@ dev = [ requires-dist = [ { name = "datasets", specifier = ">=3.0.0" }, { name = "einops" }, + { name = "fsspec", extras = ["gcs"], specifier = ">=2024.3.1" }, { name = "numpy" }, { name = "pydantic-config", git = "https://github.com/samsja/pydantic_config.git?rev=e529c9c" }, { name = "setuptools" }, From 7622bbe7669f8bc1ee811af658e06eb096fa47ee Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sun, 29 Sep 2024 03:31:51 +0000 Subject: [PATCH 10/20] add async saving to remote --- src/zeroband/checkpoint.py | 36 +++++++++++++++++++++++++++++++++--- src/zeroband/train.py | 10 +++++++++- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py index db4adf94..9d486560 100644 --- a/src/zeroband/checkpoint.py +++ b/src/zeroband/checkpoint.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +import multiprocessing import os import time from typing import Any @@ -116,7 +117,9 @@ def __init__( self.process_group = process_group self._logger = get_logger() - def save(self, ckpt_path: str, remote_ckpt_path: str | None = None) -> None: + self.async_save_process: list[multiprocessing.Process] = [] + + def save(self, ckpt_path: str, remote_ckpt_path: str | None) -> None: """ Each rank will save the right shard of the model and optimizer. @@ -145,8 +148,35 @@ def save(self, ckpt_path: str, remote_ckpt_path: str | None = None) -> None: self._logger.info(f"Saved checkpoint to {ckpt_path} in {time.perf_counter() - time_start} seconds") if remote_ckpt_path is not None: - remote_ckpt_path = os.path.join(remote_ckpt_path, f"step_{self.training_progress.step}") - rsync_fsspec(ckpt_path, remote_ckpt_path) + self._async_save_remote(ckpt_path, remote_ckpt_path) + + def _async_save_remote(self, ckpt_path: str, remote_ckpt_path: str): + """asyncronously rsync a ckpt folder to a remote location. Using fsspec to handle remote cloud storage without to install + specific libraries (e.g. s3fs) + """ + + def rsync(): + time_start = time.perf_counter() + self._logger.info(f"start pushing {ckpt_path} to {remote_ckpt_path} asynchronously") + rsync_fsspec(ckpt_path, destination=remote_ckpt_path) + self._logger.info( + f"finish pushing {ckpt_path} to {remote_ckpt_path} in {time.perf_counter() - time_start} seconds" + ) + + processes = multiprocessing.Process(target=rsync, daemon=True) + processes.start() + + self.async_save_process.append(processes) + + def wait_async_save_process(self): + """ + wait for all async save process to finish + """ + for process in self.async_save_process: + process.join() + + def _del__(self): + self.wait_async_save_process() def load(self, resume_ckpt_path: str) -> None: """ diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 05ded3ed..240ff35e 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -1,5 +1,6 @@ import os from contextlib import nullcontext +import time from typing import Literal import torch @@ -193,6 +194,7 @@ def train(config: Config): perf_counter = PerfCounter(window_size=10) logger.info("starting training") + d = 0 while True: if num_inner_steps > 1: # if we don't use diloco we don't print the outer step logs @@ -268,6 +270,8 @@ def train(config: Config): logger.info(log) + time.sleep(2) + if config.diloco is not None: if config.train.log_model_hash: with FSDP.summon_full_params(model): @@ -285,7 +289,9 @@ def train(config: Config): and training_progress.step % config.ckpt.interval == 0 ): # we only allow to checkpoint after a outer step. For non diloco training outer step = 1 anyway - ckpt_manager.save(config.ckpt.path, config.ckpt.remote_path) + if d == 0: + ckpt_manager.save(config.ckpt.path, config.ckpt.remote_path) + d += 1 if training_progress.step >= config.optim.total_steps: # we only allow to break outisde of the inner loop. @@ -296,6 +302,8 @@ def train(config: Config): if world_info.rank == 0: metric_logger.finish() + ckpt_manager.wait_async_save_process() + if __name__ == "__main__": # Allow eager fallback during production so that that the training runs dont die From aab5a5b4d7171a007151dadc2118beb88f764458 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sun, 29 Sep 2024 03:49:43 +0000 Subject: [PATCH 11/20] remove unused file --- src/zeroband/checkpoint.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py index 9d486560..f3527ad8 100644 --- a/src/zeroband/checkpoint.py +++ b/src/zeroband/checkpoint.py @@ -28,9 +28,6 @@ ## code inspired by torchtitan https://github.com/pytorch/torchtitan/blob/main/torchtitan/checkpoint.py -GLOBAL_STATE_FILE = "global_state_dict.pt" - - @dataclass class TrainingProgress(Stateful): total_tokens: int From 6b63fe39966883af21fa2973778802707f010ab0 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sun, 29 Sep 2024 03:51:41 +0000 Subject: [PATCH 12/20] fix rebase --- src/zeroband/train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 240ff35e..5cb2b85b 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -27,7 +27,7 @@ from zeroband.models.llama import get_model from zeroband.utils.world_info import get_world_info from zeroband.utils.logging import get_logger -from zeroband.checkpoint import TrainingProgress +from zeroband.checkpoint import CkptManager, TrainingProgress class DataConfig(BaseConfig): @@ -113,7 +113,9 @@ def train(config: Config): model, model_config = get_model( config.name_model, config.type_model, - vocab_size=tokenizer.vocab_size if config.name_model != "debugmodel" or not config.data.fake else TEST_VOCAB_SIZE, + vocab_size=tokenizer.vocab_size + if config.name_model != "debugmodel" or not config.data.fake + else TEST_VOCAB_SIZE, ) if config.train.log_model_hash: From 00eaec7082f3b5a3aa552b0a71027b931cfe56ba Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sun, 29 Sep 2024 04:25:47 +0000 Subject: [PATCH 13/20] fix ckpt --- src/zeroband/train.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 5cb2b85b..41eaf92b 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -196,7 +196,6 @@ def train(config: Config): perf_counter = PerfCounter(window_size=10) logger.info("starting training") - d = 0 while True: if num_inner_steps > 1: # if we don't use diloco we don't print the outer step logs @@ -291,9 +290,7 @@ def train(config: Config): and training_progress.step % config.ckpt.interval == 0 ): # we only allow to checkpoint after a outer step. For non diloco training outer step = 1 anyway - if d == 0: - ckpt_manager.save(config.ckpt.path, config.ckpt.remote_path) - d += 1 + ckpt_manager.save(config.ckpt.path, config.ckpt.remote_path) if training_progress.step >= config.optim.total_steps: # we only allow to break outisde of the inner loop. From fcbd1028d20b2b9546ceabe15a757b7678e6444e Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sun, 29 Sep 2024 21:45:47 +0000 Subject: [PATCH 14/20] add diloco ckpt --- src/zeroband/checkpoint.py | 27 +++++++++++++++++++++++---- src/zeroband/train.py | 5 ++--- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py index f3527ad8..59158cbd 100644 --- a/src/zeroband/checkpoint.py +++ b/src/zeroband/checkpoint.py @@ -95,6 +95,8 @@ def __init__( dataloader: StatefulDataLoader, training_progress: TrainingProgress, process_group: ProcessGroup | None, + diloco_offloaded_param_list: list[nn.Parameter] | None, + diloco_offloaded_optimizer: Optimizer | None, ): self.model = ModelWrapper(model) self.optimizer = OptimizerWrapper(model, optimizer) @@ -111,6 +113,19 @@ def __init__( "training_progress": self.training_progress, } + assert (diloco_offloaded_param_list is None) == ( + diloco_offloaded_optimizer is None + ), "diloco_offloaded_model and diloco_offloaded_optimizer must be both None or both have values" + + if diloco_offloaded_optimizer is not None: + self.diloco_offloaded_optimizer = diloco_offloaded_optimizer # he we don't use Wrapper because it failed + # which might make the ckpt less generic in term of loading from different number of device. FSDP ckpt seems to be a mess tho + + self.diloco_offloaded_param_list = diloco_offloaded_param_list + # even if the diloco_offloaded target the cpu list model, we still use the gpu model to load and save state. + # main reason is that we actually don't a cpu model but just a list of cpu parameters. + self.states["diloco_offloaded_optimizer"] = self.diloco_offloaded_optimizer + self.process_group = process_group self._logger = get_logger() @@ -192,9 +207,13 @@ def load(self, resume_ckpt_path: str) -> None: rank = get_world_info().local_rank # todo check after on/off ramping pr which rank is good here ## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk - if self.dataloader is not None: - with open(os.path.join(resume_ckpt_path, f"__{rank}_0.pt"), "rb") as f: - rank_state_dict = torch.load(f) - self.dataloader.load_state_dict(rank_state_dict["data_loader"]) + with open(os.path.join(resume_ckpt_path, f"__{rank}_0.pt"), "rb") as f: + rank_state_dict = torch.load(f) + self.dataloader.load_state_dict(rank_state_dict["data_loader"]) + + # since we don't load the param list from the state dict as its the same as the model one we just copy + if self.diloco_offloaded_param_list is not None: + for param_offloaded, param_model in zip(self.diloco_offloaded_param_list, self.model.model.parameters()): + param_offloaded.data.copy_(param_model.data) self._logger.info(f"Loaded checkpoint from {resume_ckpt_path} in {time.perf_counter() - time_start} seconds") diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 41eaf92b..a59e6a91 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -1,6 +1,5 @@ import os from contextlib import nullcontext -import time from typing import Literal import torch @@ -172,6 +171,8 @@ def train(config: Config): scheduler=scheduler, dataloader=train_dataloader, training_progress=training_progress, + diloco_offloaded_optimizer=diloco.outer_optimizer if config.diloco is not None else None, + diloco_offloaded_param_list=diloco.param_list_cpu if config.diloco is not None else None, process_group=elastic_device_mesh.local_pg if config.diloco is not None else None, ) @@ -271,8 +272,6 @@ def train(config: Config): logger.info(log) - time.sleep(2) - if config.diloco is not None: if config.train.log_model_hash: with FSDP.summon_full_params(model): From 3c052ddb5d1a74dd81b88322532f1a6dd341bf8b Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sun, 29 Sep 2024 22:25:41 +0000 Subject: [PATCH 15/20] save into dioco scpeific folder --- src/zeroband/checkpoint.py | 29 ++++++++++++++++++++++++----- src/zeroband/utils/world_info.py | 4 ++++ 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py index 59158cbd..56eb8bdb 100644 --- a/src/zeroband/checkpoint.py +++ b/src/zeroband/checkpoint.py @@ -124,7 +124,9 @@ def __init__( self.diloco_offloaded_param_list = diloco_offloaded_param_list # even if the diloco_offloaded target the cpu list model, we still use the gpu model to load and save state. # main reason is that we actually don't a cpu model but just a list of cpu parameters. - self.states["diloco_offloaded_optimizer"] = self.diloco_offloaded_optimizer + self.diloco_states = {"optimizer": self.diloco_offloaded_optimizer} + else: + self.diloco_states = {} self.process_group = process_group self._logger = get_logger() @@ -145,7 +147,7 @@ def save(self, ckpt_path: str, remote_ckpt_path: str | None) -> None: # pytorch has an annoying warning when saving the optimizer state https://github.com/pytorch/pytorch/issues/136907 # we can ignore it if we are not logging in DEBUG mode - rank = get_world_info().local_rank + world_info = get_world_info() with warnings.catch_warnings(): if catch_warning: @@ -153,8 +155,14 @@ def save(self, ckpt_path: str, remote_ckpt_path: str | None) -> None: dcp.save(self.states, process_group=self.process_group, checkpoint_id=ckpt_path) + if self.diloco_states: + diloco_ckpt_path = os.path.join(ckpt_path, f"diloco_{world_info.diloco_rank}") + dcp.save(self.diloco_states, process_group=self.process_group, checkpoint_id=diloco_ckpt_path) ## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk - with open(os.path.join(ckpt_path, f"__{rank}_0.pt"), "wb") as f: + + # dataloader is different for each diloco worker. If diloco is enable we use diloco_ckpt_path + dataloader_path = ckpt_path if self.diloco_states is None else diloco_ckpt_path + with open(os.path.join(dataloader_path, f"__{world_info.local_rank}_0.pt"), "wb") as f: torch.save({"data_loader": self.dataloader.state_dict()}, f) self._logger.info(f"Saved checkpoint to {ckpt_path} in {time.perf_counter() - time_start} seconds") @@ -204,10 +212,21 @@ def load(self, resume_ckpt_path: str) -> None: self.states = dcp.load(self.states, process_group=self.process_group, checkpoint_id=resume_ckpt_path) - rank = get_world_info().local_rank # todo check after on/off ramping pr which rank is good here + world_info = get_world_info() + + self._logger.debug(msg=f"diloco_states {self.diloco_states}") + if self.diloco_states: + resume_ckpt_path_diloco = os.path.join(resume_ckpt_path, f"diloco_{world_info.diloco_rank}") + dcp.load(self.diloco_states, process_group=self.process_group, checkpoint_id=resume_ckpt_path_diloco) + + self._logger.debug(msg=f"postdiloco_states {self.diloco_states}") ## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk - with open(os.path.join(resume_ckpt_path, f"__{rank}_0.pt"), "rb") as f: + + # dataloader is different for each diloco worker. If diloco is enable we use diloco_ckpt_path + dataloader_path = resume_ckpt_path if self.diloco_states is None else resume_ckpt_path_diloco + self._logger.debug(f"loading dataloader from {dataloader_path}") + with open(os.path.join(dataloader_path, f"__{world_info.local_rank}_0.pt"), "rb") as f: rank_state_dict = torch.load(f) self.dataloader.load_state_dict(rank_state_dict["data_loader"]) diff --git a/src/zeroband/utils/world_info.py b/src/zeroband/utils/world_info.py index 9b73f328..f7c6548b 100644 --- a/src/zeroband/utils/world_info.py +++ b/src/zeroband/utils/world_info.py @@ -27,6 +27,10 @@ def __init__(self): def __repr__(self): return f"WorldInfo(world_size={self.world_size}, rank={self.rank}, local_rank={self.local_rank}, local_world_size={self.local_world_size}, nnodes={self.nnodes}, global_unique_id={self.global_unique_id}, global_addr={self.global_addr}, global_port={self.global_port}, global_world_size={self.global_world_size}, global_rank={self.global_rank})" + @property + def diloco_rank(self): + return self.rank // self.local_world_size + def get_world_info() -> WorldInfo: """ From 47c5396a9ebcd46294e4d6a59bbeec57ba770015 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 30 Sep 2024 00:09:26 +0000 Subject: [PATCH 16/20] firemove process group --- src/zeroband/checkpoint.py | 7 ++----- src/zeroband/train.py | 1 - 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py index 56eb8bdb..98159af4 100644 --- a/src/zeroband/checkpoint.py +++ b/src/zeroband/checkpoint.py @@ -10,7 +10,6 @@ from torch.optim.lr_scheduler import LambdaLR from torchdata.stateful_dataloader import StatefulDataLoader import torch.distributed.checkpoint as dcp -from torch.distributed import ProcessGroup from torch.distributed.checkpoint.state_dict import ( set_optimizer_state_dict, set_model_state_dict, @@ -94,7 +93,6 @@ def __init__( scheduler: LambdaLR, dataloader: StatefulDataLoader, training_progress: TrainingProgress, - process_group: ProcessGroup | None, diloco_offloaded_param_list: list[nn.Parameter] | None, diloco_offloaded_optimizer: Optimizer | None, ): @@ -128,7 +126,6 @@ def __init__( else: self.diloco_states = {} - self.process_group = process_group self._logger = get_logger() self.async_save_process: list[multiprocessing.Process] = [] @@ -153,11 +150,11 @@ def save(self, ckpt_path: str, remote_ckpt_path: str | None) -> None: if catch_warning: warnings.simplefilter("ignore") - dcp.save(self.states, process_group=self.process_group, checkpoint_id=ckpt_path) + dcp.save(self.states, checkpoint_id=ckpt_path) if self.diloco_states: diloco_ckpt_path = os.path.join(ckpt_path, f"diloco_{world_info.diloco_rank}") - dcp.save(self.diloco_states, process_group=self.process_group, checkpoint_id=diloco_ckpt_path) + dcp.save(self.diloco_states, checkpoint_id=diloco_ckpt_path) ## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk # dataloader is different for each diloco worker. If diloco is enable we use diloco_ckpt_path diff --git a/src/zeroband/train.py b/src/zeroband/train.py index a59e6a91..5618c812 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -173,7 +173,6 @@ def train(config: Config): training_progress=training_progress, diloco_offloaded_optimizer=diloco.outer_optimizer if config.diloco is not None else None, diloco_offloaded_param_list=diloco.param_list_cpu if config.diloco is not None else None, - process_group=elastic_device_mesh.local_pg if config.diloco is not None else None, ) if config.train.torch_compile: From e08064139d50107b6f26d978f22364935b26b774 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 30 Sep 2024 00:32:29 +0000 Subject: [PATCH 17/20] firemove process group --- src/zeroband/checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py index 98159af4..f980a595 100644 --- a/src/zeroband/checkpoint.py +++ b/src/zeroband/checkpoint.py @@ -207,14 +207,14 @@ def load(self, resume_ckpt_path: str) -> None: """ time_start = time.perf_counter() - self.states = dcp.load(self.states, process_group=self.process_group, checkpoint_id=resume_ckpt_path) + self.states = dcp.load(self.states, checkpoint_id=resume_ckpt_path) world_info = get_world_info() self._logger.debug(msg=f"diloco_states {self.diloco_states}") if self.diloco_states: resume_ckpt_path_diloco = os.path.join(resume_ckpt_path, f"diloco_{world_info.diloco_rank}") - dcp.load(self.diloco_states, process_group=self.process_group, checkpoint_id=resume_ckpt_path_diloco) + dcp.load(self.diloco_states, checkpoint_id=resume_ckpt_path_diloco) self._logger.debug(msg=f"postdiloco_states {self.diloco_states}") From 1c098b2fb778d260d326877cb4697f9e0d025b6d Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 30 Sep 2024 02:17:03 +0000 Subject: [PATCH 18/20] add diloco rank --- src/zeroband/utils/world_info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zeroband/utils/world_info.py b/src/zeroband/utils/world_info.py index f7c6548b..fcca5da2 100644 --- a/src/zeroband/utils/world_info.py +++ b/src/zeroband/utils/world_info.py @@ -29,7 +29,7 @@ def __repr__(self): @property def diloco_rank(self): - return self.rank // self.local_world_size + return self.global_rank def get_world_info() -> WorldInfo: From 2c54c2c928658ba562256e161a22f51479bcaedb Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 30 Sep 2024 02:45:40 +0000 Subject: [PATCH 19/20] fix ckpt issue --- src/zeroband/checkpoint.py | 60 ++++++++++++++++---------------------- src/zeroband/train.py | 1 + 2 files changed, 26 insertions(+), 35 deletions(-) diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py index f980a595..9bc0f003 100644 --- a/src/zeroband/checkpoint.py +++ b/src/zeroband/checkpoint.py @@ -115,16 +115,14 @@ def __init__( diloco_offloaded_optimizer is None ), "diloco_offloaded_model and diloco_offloaded_optimizer must be both None or both have values" - if diloco_offloaded_optimizer is not None: - self.diloco_offloaded_optimizer = diloco_offloaded_optimizer # he we don't use Wrapper because it failed - # which might make the ckpt less generic in term of loading from different number of device. FSDP ckpt seems to be a mess tho + self.diloco_offloaded_optimizer = diloco_offloaded_optimizer # he we don't use Wrapper because it failed + # which might make the ckpt less generic in term of loading from different number of device. FSDP ckpt seems to be a mess tho + self.diloco_offloaded_param_list = diloco_offloaded_param_list - self.diloco_offloaded_param_list = diloco_offloaded_param_list + if diloco_offloaded_optimizer is not None: # even if the diloco_offloaded target the cpu list model, we still use the gpu model to load and save state. # main reason is that we actually don't a cpu model but just a list of cpu parameters. - self.diloco_states = {"optimizer": self.diloco_offloaded_optimizer} - else: - self.diloco_states = {} + self.states["diloco_optimizer"] = self.diloco_offloaded_optimizer self._logger = get_logger() @@ -138,28 +136,28 @@ def save(self, ckpt_path: str, remote_ckpt_path: str | None) -> None: """ time_start = time.perf_counter() + world_info = get_world_info() ckpt_path = os.path.join(ckpt_path, f"step_{self.training_progress.step}") - catch_warning = self._logger.getEffectiveLevel() <= logging.INFO - # pytorch has an annoying warning when saving the optimizer state https://github.com/pytorch/pytorch/issues/136907 - # we can ignore it if we are not logging in DEBUG mode + if self.diloco_offloaded_optimizer: + # here we save model and offloaded optimizer on each diloco rank even tho they are the same + # this is done for two reasons: + # * if the nodes don't share a filesystem nor a remote path, they still save all of the data + # * its easier to implement and avoid race condition on the shared data. + ckpt_path = os.path.join(ckpt_path, f"diloco_{world_info.diloco_rank}") - world_info = get_world_info() + catch_warning = self._logger.getEffectiveLevel() <= logging.INFO with warnings.catch_warnings(): + # pytorch has an annoying warning when saving the optimizer state https://github.com/pytorch/pytorch/issues/136907 + # we can ignore it if we are not logging in DEBUG mode if catch_warning: warnings.simplefilter("ignore") dcp.save(self.states, checkpoint_id=ckpt_path) - if self.diloco_states: - diloco_ckpt_path = os.path.join(ckpt_path, f"diloco_{world_info.diloco_rank}") - dcp.save(self.diloco_states, checkpoint_id=diloco_ckpt_path) ## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk - - # dataloader is different for each diloco worker. If diloco is enable we use diloco_ckpt_path - dataloader_path = ckpt_path if self.diloco_states is None else diloco_ckpt_path - with open(os.path.join(dataloader_path, f"__{world_info.local_rank}_0.pt"), "wb") as f: + with open(os.path.join(ckpt_path, f"__{world_info.local_rank}_0.pt"), "wb") as f: torch.save({"data_loader": self.dataloader.state_dict()}, f) self._logger.info(f"Saved checkpoint to {ckpt_path} in {time.perf_counter() - time_start} seconds") @@ -207,29 +205,21 @@ def load(self, resume_ckpt_path: str) -> None: """ time_start = time.perf_counter() - self.states = dcp.load(self.states, checkpoint_id=resume_ckpt_path) - world_info = get_world_info() + if self.diloco_offloaded_param_list is not None: + resume_ckpt_path = os.path.join(resume_ckpt_path, f"diloco_{world_info.diloco_rank}") - self._logger.debug(msg=f"diloco_states {self.diloco_states}") - if self.diloco_states: - resume_ckpt_path_diloco = os.path.join(resume_ckpt_path, f"diloco_{world_info.diloco_rank}") - dcp.load(self.diloco_states, checkpoint_id=resume_ckpt_path_diloco) - - self._logger.debug(msg=f"postdiloco_states {self.diloco_states}") - - ## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk - - # dataloader is different for each diloco worker. If diloco is enable we use diloco_ckpt_path - dataloader_path = resume_ckpt_path if self.diloco_states is None else resume_ckpt_path_diloco - self._logger.debug(f"loading dataloader from {dataloader_path}") - with open(os.path.join(dataloader_path, f"__{world_info.local_rank}_0.pt"), "rb") as f: - rank_state_dict = torch.load(f) - self.dataloader.load_state_dict(rank_state_dict["data_loader"]) + self.states = dcp.load(self.states, checkpoint_id=resume_ckpt_path) # since we don't load the param list from the state dict as its the same as the model one we just copy if self.diloco_offloaded_param_list is not None: for param_offloaded, param_model in zip(self.diloco_offloaded_param_list, self.model.model.parameters()): param_offloaded.data.copy_(param_model.data) + ## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk + with open(os.path.join(resume_ckpt_path, f"__{world_info.local_rank}_0.pt"), "rb") as f: + rank_state_dict = torch.load(f) + + self.dataloader.load_state_dict(rank_state_dict["data_loader"]) + self._logger.info(f"Loaded checkpoint from {resume_ckpt_path} in {time.perf_counter() - time_start} seconds") diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 5618c812..9879a17e 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -300,6 +300,7 @@ def train(config: Config): metric_logger.finish() ckpt_manager.wait_async_save_process() + logger.info("Training finished, exiting ...") if __name__ == "__main__": From b9e6eef706c7e82391c746801908fd2e4dccc18c Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 30 Sep 2024 04:33:44 +0000 Subject: [PATCH 20/20] remove ckpt tests --- tests/test_torchrun/test_train.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/test_torchrun/test_train.py b/tests/test_torchrun/test_train.py index e3b99893..79bbc61c 100644 --- a/tests/test_torchrun/test_train.py +++ b/tests/test_torchrun/test_train.py @@ -1,7 +1,5 @@ import copy import os -from pathlib import Path -import shutil import subprocess import pytest import socket @@ -73,14 +71,3 @@ def test_multi_gpu_diloco_non_full_shard(strategy): # we don't test 1,1 and 2,1 because 1 solo gpu failed with fsdp num_gpus = [2, 2] _test_multi_gpu(num_gpus, "debug/diloco.toml", extra_args=["--train.sharding_strategy", strategy]) - - -## test ckpt - - -def test_ckpt(tmp_path: Path): - ckpt_path = "outputs" # for some reason tmp_path is not working - os.makedirs(ckpt_path, exist_ok=True) - _test_multi_gpu([1, 1], "debug/normal.toml", extra_args=["--ckpt.path", f"{ckpt_path}/", "--ckpt.interval", "10"]) - _test_multi_gpu([1, 1], "debug/normal.toml", extra_args=["--resume", f"{ckpt_path}/step_10"]) - shutil.rmtree(ckpt_path)