diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py index f3527ad8..59158cbd 100644 --- a/src/zeroband/checkpoint.py +++ b/src/zeroband/checkpoint.py @@ -95,6 +95,8 @@ def __init__( dataloader: StatefulDataLoader, training_progress: TrainingProgress, process_group: ProcessGroup | None, + diloco_offloaded_param_list: list[nn.Parameter] | None, + diloco_offloaded_optimizer: Optimizer | None, ): self.model = ModelWrapper(model) self.optimizer = OptimizerWrapper(model, optimizer) @@ -111,6 +113,19 @@ def __init__( "training_progress": self.training_progress, } + assert (diloco_offloaded_param_list is None) == ( + diloco_offloaded_optimizer is None + ), "diloco_offloaded_model and diloco_offloaded_optimizer must be both None or both have values" + + if diloco_offloaded_optimizer is not None: + self.diloco_offloaded_optimizer = diloco_offloaded_optimizer # he we don't use Wrapper because it failed + # which might make the ckpt less generic in term of loading from different number of device. FSDP ckpt seems to be a mess tho + + self.diloco_offloaded_param_list = diloco_offloaded_param_list + # even if the diloco_offloaded target the cpu list model, we still use the gpu model to load and save state. + # main reason is that we actually don't a cpu model but just a list of cpu parameters. + self.states["diloco_offloaded_optimizer"] = self.diloco_offloaded_optimizer + self.process_group = process_group self._logger = get_logger() @@ -192,9 +207,13 @@ def load(self, resume_ckpt_path: str) -> None: rank = get_world_info().local_rank # todo check after on/off ramping pr which rank is good here ## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk - if self.dataloader is not None: - with open(os.path.join(resume_ckpt_path, f"__{rank}_0.pt"), "rb") as f: - rank_state_dict = torch.load(f) - self.dataloader.load_state_dict(rank_state_dict["data_loader"]) + with open(os.path.join(resume_ckpt_path, f"__{rank}_0.pt"), "rb") as f: + rank_state_dict = torch.load(f) + self.dataloader.load_state_dict(rank_state_dict["data_loader"]) + + # since we don't load the param list from the state dict as its the same as the model one we just copy + if self.diloco_offloaded_param_list is not None: + for param_offloaded, param_model in zip(self.diloco_offloaded_param_list, self.model.model.parameters()): + param_offloaded.data.copy_(param_model.data) self._logger.info(f"Loaded checkpoint from {resume_ckpt_path} in {time.perf_counter() - time_start} seconds") diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 41eaf92b..a59e6a91 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -1,6 +1,5 @@ import os from contextlib import nullcontext -import time from typing import Literal import torch @@ -172,6 +171,8 @@ def train(config: Config): scheduler=scheduler, dataloader=train_dataloader, training_progress=training_progress, + diloco_offloaded_optimizer=diloco.outer_optimizer if config.diloco is not None else None, + diloco_offloaded_param_list=diloco.param_list_cpu if config.diloco is not None else None, process_group=elastic_device_mesh.local_pg if config.diloco is not None else None, ) @@ -271,8 +272,6 @@ def train(config: Config): logger.info(log) - time.sleep(2) - if config.diloco is not None: if config.train.log_model_hash: with FSDP.summon_full_params(model):