Skip to content

Commit

Permalink
fix ckpt (#34)
Browse files Browse the repository at this point in the history
* uniformize state dict option

* load does not overide states anymore

* fid reinit

* fix dilocoo

* revert meow
  • Loading branch information
samsja authored Oct 3, 2024
1 parent 31caeac commit 2b07243
Showing 1 changed file with 29 additions and 20 deletions.
49 changes: 29 additions & 20 deletions src/zeroband/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, model: nn.Module) -> None:
self.model = model

def state_dict(self) -> dict[str, Any]:
return get_model_state_dict(self.model)
return get_model_state_dict(self.model, options=StateDictOptions(strict=False))

def load_state_dict(self, state_dict: dict[str, Any]) -> None:
set_model_state_dict(model=self.model, model_state_dict=state_dict, options=StateDictOptions(strict=False))
Expand All @@ -70,7 +70,10 @@ def state_dict(self) -> dict[str, Any]:

def load_state_dict(self, state_dict: dict[str, Any]) -> None:
set_optimizer_state_dict(
model=self.model, optimizers=self.optim, optim_state_dict=state_dict, options=StateDictOptions(strict=False)
model=self.model,
optimizers=self.optim,
optim_state_dict=state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)


Expand All @@ -87,6 +90,8 @@ class CkptManager:
...
"""

states: dict[str, Stateful]

def __init__(
self,
model: nn.Module,
Expand All @@ -97,21 +102,12 @@ def __init__(
diloco_offloaded_param_list: list[nn.Parameter] | None,
diloco_offloaded_optimizer: Optimizer | None,
):
self.model = ModelWrapper(model)
self.optimizer = OptimizerWrapper(model, optimizer)
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.dataloader = dataloader
self.training_progress = training_progress

# states can only be stateful object, hence we need to wrap Model and Optimizer
self.states: dict[str, Stateful] = {
"model": self.model,
"optimizer": self.optimizer,
"scheduler": self.scheduler,
# "dataloader": self.dataloader, # ignoring dataloader for now as each rank has its own dataloader
"training_progress": self.training_progress,
}

assert (diloco_offloaded_param_list is None) == (
diloco_offloaded_optimizer is None
), "diloco_offloaded_model and diloco_offloaded_optimizer must be both None or both have values"
Expand All @@ -120,15 +116,27 @@ def __init__(
# which might make the ckpt less generic in term of loading from different number of device. FSDP ckpt seems to be a mess tho
self.diloco_offloaded_param_list = diloco_offloaded_param_list

if diloco_offloaded_optimizer is not None:
# even if the diloco_offloaded target the cpu list model, we still use the gpu model to load and save state.
# main reason is that we actually don't a cpu model but just a list of cpu parameters.
self.states["diloco_optimizer"] = self.diloco_offloaded_optimizer
self._init_state()

self._logger = get_logger()

self.async_save_process: list[multiprocessing.Process] = []

def _init_state(self):
# states can only be stateful object, hence we need to wrap Model and Optimizer
self.states: dict[str, Stateful] = {
"model": ModelWrapper(self.model),
"optimizer": OptimizerWrapper(self.model, self.optimizer),
"scheduler": self.scheduler,
# "dataloader": self.dataloader, # ignoring dataloader for now as each rank has its own dataloader
"training_progress": self.training_progress,
}

if self.diloco_offloaded_optimizer is not None:
# even if the diloco_offloaded target the cpu list model, we still use the gpu model to load and save state.
# main reason is that we actually don't a cpu model but just a list of cpu parameters.
self.states["diloco_optimizer"] = self.diloco_offloaded_optimizer

def save(self, ckpt_path: str, remote_ckpt_path: str | None) -> None:
"""
Each rank will save the right shard of the model and optimizer.
Expand Down Expand Up @@ -212,11 +220,10 @@ def load(self, resume_ckpt_path: str) -> None:
if self.diloco_offloaded_param_list is not None:
resume_ckpt_path = os.path.join(resume_ckpt_path, f"diloco_{world_info.diloco_rank}")

self.states = dcp.load(self.states, checkpoint_id=resume_ckpt_path)

dcp.load(self.states, checkpoint_id=resume_ckpt_path)
# since we don't load the param list from the state dict as its the same as the model one we just copy
if self.diloco_offloaded_param_list is not None:
for param_offloaded, param_model in zip(self.diloco_offloaded_param_list, self.model.model.parameters()):
for param_offloaded, param_model in zip(self.diloco_offloaded_param_list, self.model.parameters()):
param_offloaded.data.copy_(param_model.data)

## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk
Expand All @@ -225,4 +232,6 @@ def load(self, resume_ckpt_path: str) -> None:

self.dataloader.load_state_dict(rank_state_dict["data_loader"])

self._init_state()

self._logger.info(f"Loaded checkpoint from {resume_ckpt_path} in {time.perf_counter() - time_start} seconds")

0 comments on commit 2b07243

Please sign in to comment.