From 2b07243873921273edecfd4746ffdc4e1a4b2387 Mon Sep 17 00:00:00 2001 From: samsja <55492238+samsja@users.noreply.github.com> Date: Thu, 3 Oct 2024 13:10:34 -0700 Subject: [PATCH] fix ckpt (#34) * uniformize state dict option * load does not overide states anymore * fid reinit * fix dilocoo * revert meow --- src/zeroband/checkpoint.py | 49 ++++++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py index 2455a380..692e050d 100644 --- a/src/zeroband/checkpoint.py +++ b/src/zeroband/checkpoint.py @@ -48,7 +48,7 @@ def __init__(self, model: nn.Module) -> None: self.model = model def state_dict(self) -> dict[str, Any]: - return get_model_state_dict(self.model) + return get_model_state_dict(self.model, options=StateDictOptions(strict=False)) def load_state_dict(self, state_dict: dict[str, Any]) -> None: set_model_state_dict(model=self.model, model_state_dict=state_dict, options=StateDictOptions(strict=False)) @@ -70,7 +70,10 @@ def state_dict(self) -> dict[str, Any]: def load_state_dict(self, state_dict: dict[str, Any]) -> None: set_optimizer_state_dict( - model=self.model, optimizers=self.optim, optim_state_dict=state_dict, options=StateDictOptions(strict=False) + model=self.model, + optimizers=self.optim, + optim_state_dict=state_dict, + options=StateDictOptions(flatten_optimizer_state_dict=True), ) @@ -87,6 +90,8 @@ class CkptManager: ... """ + states: dict[str, Stateful] + def __init__( self, model: nn.Module, @@ -97,21 +102,12 @@ def __init__( diloco_offloaded_param_list: list[nn.Parameter] | None, diloco_offloaded_optimizer: Optimizer | None, ): - self.model = ModelWrapper(model) - self.optimizer = OptimizerWrapper(model, optimizer) + self.model = model + self.optimizer = optimizer self.scheduler = scheduler self.dataloader = dataloader self.training_progress = training_progress - # states can only be stateful object, hence we need to wrap Model and Optimizer - self.states: dict[str, Stateful] = { - "model": self.model, - "optimizer": self.optimizer, - "scheduler": self.scheduler, - # "dataloader": self.dataloader, # ignoring dataloader for now as each rank has its own dataloader - "training_progress": self.training_progress, - } - assert (diloco_offloaded_param_list is None) == ( diloco_offloaded_optimizer is None ), "diloco_offloaded_model and diloco_offloaded_optimizer must be both None or both have values" @@ -120,15 +116,27 @@ def __init__( # which might make the ckpt less generic in term of loading from different number of device. FSDP ckpt seems to be a mess tho self.diloco_offloaded_param_list = diloco_offloaded_param_list - if diloco_offloaded_optimizer is not None: - # even if the diloco_offloaded target the cpu list model, we still use the gpu model to load and save state. - # main reason is that we actually don't a cpu model but just a list of cpu parameters. - self.states["diloco_optimizer"] = self.diloco_offloaded_optimizer + self._init_state() self._logger = get_logger() self.async_save_process: list[multiprocessing.Process] = [] + def _init_state(self): + # states can only be stateful object, hence we need to wrap Model and Optimizer + self.states: dict[str, Stateful] = { + "model": ModelWrapper(self.model), + "optimizer": OptimizerWrapper(self.model, self.optimizer), + "scheduler": self.scheduler, + # "dataloader": self.dataloader, # ignoring dataloader for now as each rank has its own dataloader + "training_progress": self.training_progress, + } + + if self.diloco_offloaded_optimizer is not None: + # even if the diloco_offloaded target the cpu list model, we still use the gpu model to load and save state. + # main reason is that we actually don't a cpu model but just a list of cpu parameters. + self.states["diloco_optimizer"] = self.diloco_offloaded_optimizer + def save(self, ckpt_path: str, remote_ckpt_path: str | None) -> None: """ Each rank will save the right shard of the model and optimizer. @@ -212,11 +220,10 @@ def load(self, resume_ckpt_path: str) -> None: if self.diloco_offloaded_param_list is not None: resume_ckpt_path = os.path.join(resume_ckpt_path, f"diloco_{world_info.diloco_rank}") - self.states = dcp.load(self.states, checkpoint_id=resume_ckpt_path) - + dcp.load(self.states, checkpoint_id=resume_ckpt_path) # since we don't load the param list from the state dict as its the same as the model one we just copy if self.diloco_offloaded_param_list is not None: - for param_offloaded, param_model in zip(self.diloco_offloaded_param_list, self.model.model.parameters()): + for param_offloaded, param_model in zip(self.diloco_offloaded_param_list, self.model.parameters()): param_offloaded.data.copy_(param_model.data) ## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk @@ -225,4 +232,6 @@ def load(self, resume_ckpt_path: str) -> None: self.dataloader.load_state_dict(rank_state_dict["data_loader"]) + self._init_state() + self._logger.info(f"Loaded checkpoint from {resume_ckpt_path} in {time.perf_counter() - time_start} seconds")