Skip to content

Commit

Permalink
add diloco ckpt
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 29, 2024
1 parent e239c27 commit ec499ce
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
27 changes: 23 additions & 4 deletions src/zeroband/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def __init__(
dataloader: StatefulDataLoader,
training_progress: TrainingProgress,
process_group: ProcessGroup | None,
diloco_offloaded_param_list: list[nn.Parameter] | None,
diloco_offloaded_optimizer: Optimizer | None,
):
self.model = ModelWrapper(model)
self.optimizer = OptimizerWrapper(model, optimizer)
Expand All @@ -111,6 +113,19 @@ def __init__(
"training_progress": self.training_progress,
}

assert (diloco_offloaded_param_list is None) == (
diloco_offloaded_optimizer is None
), "diloco_offloaded_model and diloco_offloaded_optimizer must be both None or both have values"

if diloco_offloaded_optimizer is not None:
self.diloco_offloaded_optimizer = diloco_offloaded_optimizer # he we don't use Wrapper because it failed
# which might make the ckpt less generic in term of loading from different number of device. FSDP ckpt seems to be a mess tho

self.diloco_offloaded_param_list = diloco_offloaded_param_list
# even if the diloco_offloaded target the cpu list model, we still use the gpu model to load and save state.
# main reason is that we actually don't a cpu model but just a list of cpu parameters.
self.states["diloco_offloaded_optimizer"] = self.diloco_offloaded_optimizer

self.process_group = process_group
self._logger = get_logger()

Expand Down Expand Up @@ -192,9 +207,13 @@ def load(self, resume_ckpt_path: str) -> None:
rank = get_world_info().local_rank # todo check after on/off ramping pr which rank is good here

## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk
if self.dataloader is not None:
with open(os.path.join(resume_ckpt_path, f"__{rank}_0.pt"), "rb") as f:
rank_state_dict = torch.load(f)
self.dataloader.load_state_dict(rank_state_dict["data_loader"])
with open(os.path.join(resume_ckpt_path, f"__{rank}_0.pt"), "rb") as f:
rank_state_dict = torch.load(f)
self.dataloader.load_state_dict(rank_state_dict["data_loader"])

# since we don't load the param list from the state dict as its the same as the model one we just copy
if self.diloco_offloaded_param_list is not None:
for param_offloaded, param_model in zip(self.diloco_offloaded_param_list, self.model.model.parameters()):
param_offloaded.data.copy_(param_model.data)

self._logger.info(f"Loaded checkpoint from {resume_ckpt_path} in {time.perf_counter() - time_start} seconds")
5 changes: 2 additions & 3 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
from contextlib import nullcontext
import time
from typing import Literal

import torch
Expand Down Expand Up @@ -172,6 +171,8 @@ def train(config: Config):
scheduler=scheduler,
dataloader=train_dataloader,
training_progress=training_progress,
diloco_offloaded_optimizer=diloco.outer_optimizer if config.diloco is not None else None,
diloco_offloaded_param_list=diloco.param_list_cpu if config.diloco is not None else None,
process_group=elastic_device_mesh.local_pg if config.diloco is not None else None,
)

Expand Down Expand Up @@ -271,8 +272,6 @@ def train(config: Config):

logger.info(log)

time.sleep(2)

if config.diloco is not None:
if config.train.log_model_hash:
with FSDP.summon_full_params(model):
Expand Down

0 comments on commit ec499ce

Please sign in to comment.