Skip to content

Commit

Permalink
fix ckpt issue
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 30, 2024
1 parent 1c098b2 commit 2c54c2c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 35 deletions.
60 changes: 25 additions & 35 deletions src/zeroband/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,14 @@ def __init__(
diloco_offloaded_optimizer is None
), "diloco_offloaded_model and diloco_offloaded_optimizer must be both None or both have values"

if diloco_offloaded_optimizer is not None:
self.diloco_offloaded_optimizer = diloco_offloaded_optimizer # he we don't use Wrapper because it failed
# which might make the ckpt less generic in term of loading from different number of device. FSDP ckpt seems to be a mess tho
self.diloco_offloaded_optimizer = diloco_offloaded_optimizer # he we don't use Wrapper because it failed
# which might make the ckpt less generic in term of loading from different number of device. FSDP ckpt seems to be a mess tho
self.diloco_offloaded_param_list = diloco_offloaded_param_list

self.diloco_offloaded_param_list = diloco_offloaded_param_list
if diloco_offloaded_optimizer is not None:
# even if the diloco_offloaded target the cpu list model, we still use the gpu model to load and save state.
# main reason is that we actually don't a cpu model but just a list of cpu parameters.
self.diloco_states = {"optimizer": self.diloco_offloaded_optimizer}
else:
self.diloco_states = {}
self.states["diloco_optimizer"] = self.diloco_offloaded_optimizer

self._logger = get_logger()

Expand All @@ -138,28 +136,28 @@ def save(self, ckpt_path: str, remote_ckpt_path: str | None) -> None:
"""

time_start = time.perf_counter()
world_info = get_world_info()

ckpt_path = os.path.join(ckpt_path, f"step_{self.training_progress.step}")
catch_warning = self._logger.getEffectiveLevel() <= logging.INFO
# pytorch has an annoying warning when saving the optimizer state https://github.com/pytorch/pytorch/issues/136907
# we can ignore it if we are not logging in DEBUG mode
if self.diloco_offloaded_optimizer:
# here we save model and offloaded optimizer on each diloco rank even tho they are the same
# this is done for two reasons:
# * if the nodes don't share a filesystem nor a remote path, they still save all of the data
# * its easier to implement and avoid race condition on the shared data.
ckpt_path = os.path.join(ckpt_path, f"diloco_{world_info.diloco_rank}")

world_info = get_world_info()
catch_warning = self._logger.getEffectiveLevel() <= logging.INFO

with warnings.catch_warnings():
# pytorch has an annoying warning when saving the optimizer state https://github.com/pytorch/pytorch/issues/136907
# we can ignore it if we are not logging in DEBUG mode
if catch_warning:
warnings.simplefilter("ignore")

dcp.save(self.states, checkpoint_id=ckpt_path)

if self.diloco_states:
diloco_ckpt_path = os.path.join(ckpt_path, f"diloco_{world_info.diloco_rank}")
dcp.save(self.diloco_states, checkpoint_id=diloco_ckpt_path)
## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk

# dataloader is different for each diloco worker. If diloco is enable we use diloco_ckpt_path
dataloader_path = ckpt_path if self.diloco_states is None else diloco_ckpt_path
with open(os.path.join(dataloader_path, f"__{world_info.local_rank}_0.pt"), "wb") as f:
with open(os.path.join(ckpt_path, f"__{world_info.local_rank}_0.pt"), "wb") as f:
torch.save({"data_loader": self.dataloader.state_dict()}, f)

self._logger.info(f"Saved checkpoint to {ckpt_path} in {time.perf_counter() - time_start} seconds")
Expand Down Expand Up @@ -207,29 +205,21 @@ def load(self, resume_ckpt_path: str) -> None:
"""
time_start = time.perf_counter()

self.states = dcp.load(self.states, checkpoint_id=resume_ckpt_path)

world_info = get_world_info()
if self.diloco_offloaded_param_list is not None:
resume_ckpt_path = os.path.join(resume_ckpt_path, f"diloco_{world_info.diloco_rank}")

self._logger.debug(msg=f"diloco_states {self.diloco_states}")
if self.diloco_states:
resume_ckpt_path_diloco = os.path.join(resume_ckpt_path, f"diloco_{world_info.diloco_rank}")
dcp.load(self.diloco_states, checkpoint_id=resume_ckpt_path_diloco)

self._logger.debug(msg=f"postdiloco_states {self.diloco_states}")

## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk

# dataloader is different for each diloco worker. If diloco is enable we use diloco_ckpt_path
dataloader_path = resume_ckpt_path if self.diloco_states is None else resume_ckpt_path_diloco
self._logger.debug(f"loading dataloader from {dataloader_path}")
with open(os.path.join(dataloader_path, f"__{world_info.local_rank}_0.pt"), "rb") as f:
rank_state_dict = torch.load(f)
self.dataloader.load_state_dict(rank_state_dict["data_loader"])
self.states = dcp.load(self.states, checkpoint_id=resume_ckpt_path)

# since we don't load the param list from the state dict as its the same as the model one we just copy
if self.diloco_offloaded_param_list is not None:
for param_offloaded, param_model in zip(self.diloco_offloaded_param_list, self.model.model.parameters()):
param_offloaded.data.copy_(param_model.data)

## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk
with open(os.path.join(resume_ckpt_path, f"__{world_info.local_rank}_0.pt"), "rb") as f:
rank_state_dict = torch.load(f)

self.dataloader.load_state_dict(rank_state_dict["data_loader"])

self._logger.info(f"Loaded checkpoint from {resume_ckpt_path} in {time.perf_counter() - time_start} seconds")
1 change: 1 addition & 0 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def train(config: Config):
metric_logger.finish()

ckpt_manager.wait_async_save_process()
logger.info("Training finished, exiting ...")


if __name__ == "__main__":
Expand Down

0 comments on commit 2c54c2c

Please sign in to comment.