Skip to content

Commit

Permalink
fid reinit
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Oct 2, 2024
1 parent 3d7d294 commit 3684133
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 16 deletions.
38 changes: 22 additions & 16 deletions src/zeroband/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class CkptManager:
...
"""

states: dict[str, Stateful]

def __init__(
self,
model: nn.Module,
Expand All @@ -100,21 +102,12 @@ def __init__(
diloco_offloaded_param_list: list[nn.Parameter] | None,
diloco_offloaded_optimizer: Optimizer | None,
):
self.model = ModelWrapper(model)
self.optimizer = OptimizerWrapper(model, optimizer)
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.dataloader = dataloader
self.training_progress = training_progress

# states can only be stateful object, hence we need to wrap Model and Optimizer
self.states: dict[str, Stateful] = {
"model": self.model,
"optimizer": self.optimizer,
"scheduler": self.scheduler,
# "dataloader": self.dataloader, # ignoring dataloader for now as each rank has its own dataloader
"training_progress": self.training_progress,
}

assert (diloco_offloaded_param_list is None) == (
diloco_offloaded_optimizer is None
), "diloco_offloaded_model and diloco_offloaded_optimizer must be both None or both have values"
Expand All @@ -123,15 +116,27 @@ def __init__(
# which might make the ckpt less generic in term of loading from different number of device. FSDP ckpt seems to be a mess tho
self.diloco_offloaded_param_list = diloco_offloaded_param_list

if diloco_offloaded_optimizer is not None:
# even if the diloco_offloaded target the cpu list model, we still use the gpu model to load and save state.
# main reason is that we actually don't a cpu model but just a list of cpu parameters.
self.states["diloco_optimizer"] = self.diloco_offloaded_optimizer
self._init_state()

self._logger = get_logger()

self.async_save_process: list[multiprocessing.Process] = []

def _init_state(self):
# states can only be stateful object, hence we need to wrap Model and Optimizer
self.states: dict[str, Stateful] = {
"model": ModelWrapper(self.model),
"optimizer": OptimizerWrapper(self.model, self.optimizer),
"scheduler": self.scheduler,
# "dataloader": self.dataloader, # ignoring dataloader for now as each rank has its own dataloader
"training_progress": self.training_progress,
}

if self.diloco_offloaded_optimizer is not None:
# even if the diloco_offloaded target the cpu list model, we still use the gpu model to load and save state.
# main reason is that we actually don't a cpu model but just a list of cpu parameters.
self.states["diloco_optimizer"] = self.diloco_offloaded_optimizer

def save(self, ckpt_path: str, remote_ckpt_path: str | None) -> None:
"""
Each rank will save the right shard of the model and optimizer.
Expand Down Expand Up @@ -216,7 +221,6 @@ def load(self, resume_ckpt_path: str) -> None:
resume_ckpt_path = os.path.join(resume_ckpt_path, f"diloco_{world_info.diloco_rank}")

dcp.load(self.states, checkpoint_id=resume_ckpt_path)

# since we don't load the param list from the state dict as its the same as the model one we just copy
if self.diloco_offloaded_param_list is not None:
for param_offloaded, param_model in zip(self.diloco_offloaded_param_list, self.model.model.parameters()):
Expand All @@ -228,4 +232,6 @@ def load(self, resume_ckpt_path: str) -> None:

self.dataloader.load_state_dict(rank_state_dict["data_loader"])

self._init_state()

self._logger.info(f"Loaded checkpoint from {resume_ckpt_path} in {time.perf_counter() - time_start} seconds")
52 changes: 52 additions & 0 deletions src/zeroband/meow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os

import torch
import torch.distributed as dist

from zeroband.comms import ElasticDeviceMesh

from zeroband.utils.world_info import get_world_info
from zeroband.utils.logging import get_logger


def train():
elastic_device_mesh = ElasticDeviceMesh()
# dist.init_process_group(backend="gloo")
# group = dist.distributed_c10d._get_default_group()
group = elastic_device_mesh.global_pg

logger.info(f"rank: {group.rank()}")

data = torch.ones(10, 10) * world_info.local_rank
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=group)
logger.info(msg=f"data: {data.mean() / elastic_device_mesh.global_pg.size()}")

# logger.info(f"global rank: {world_info.global_rank}")

if world_info.local_rank == 1:
dest_rank = 0
logger.info(f"Sending param {data.shape} to {dest_rank}")
group.send([data], dest_rank, 0).wait()

if world_info.local_rank == 0:
src_rank = 1
logger.info(f"Receiving param {data.shape} from {src_rank}")
group.recv([data], src_rank, 0).wait()

# logger.info(f"data: {data.mean()}")
logger.info("finish")


if __name__ == "__main__":
# Allow eager fallback during production so that that the training runs dont die
# However, in development, we want to know that we broke torch compile
torch._dynamo.config.suppress_errors = "ZERO_BAND_DEV" not in os.environ
torch.set_float32_matmul_precision("high")
torch.manual_seed(42) # this ensure same weight init across diloco workers

world_info = get_world_info()
logger = get_logger()

# torch.cuda.set_device(world_info.local_rank)

train()
18 changes: 18 additions & 0 deletions src/zeroband/meow2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
from zeroband.comms import ElasticDeviceMesh

edm = ElasticDeviceMesh(backend="gloo")

if edm.world_info.global_rank == 0:
tensor = torch.randn(1000)
work = edm.global_pg.send([tensor], 1, 0)
else:
tensor = torch.randn(1000)
work = edm.global_pg.recv([tensor], 0, 0)

work.wait()


print(f"Rank {edm.world_info.global_rank}:", tensor[:10])

del edm

0 comments on commit 3684133

Please sign in to comment.