Skip to content

Commit

Permalink
first working poc
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Oct 5, 2024
1 parent 565109b commit 8fa1f0e
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
3 changes: 3 additions & 0 deletions src/zeroband/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import gc
import multiprocessing
import os
import shutil
import time
from typing import Any
from fsspec.generic import rsync as rsync_fsspec
Expand Down Expand Up @@ -279,6 +280,8 @@ def load(self, resume_ckpt_path: str, diloco_rank: int | None = None) -> None:

def download_and_load_ckpt_from_peers(self):
path = f"/tmp/zeroband/node_{self.world_info.global_rank}"
if os.path.exists(path):
shutil.rmtree(path)
os.makedirs(path, exist_ok=True)
dest_rank = self.world_info.global_rank - 1

Expand Down
2 changes: 1 addition & 1 deletion src/zeroband/diloco.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def sync_pseudo_gradient(self, model: nn.Module, fake: bool = False):
Sync the pseudo gradient from the local process group to the global process group
"""
self._logger.debug("sync pseudo gradient")
global_pg = self.elastic_device_mesh.get_global_pg(maybe_reinit=True)
global_pg = self.elastic_device_mesh.get_global_pg(maybe_reinit=False)
for param_offloaded, param in zip(self.param_list_cpu, model.parameters()):
if param.shape[0] == 0:
continue
Expand Down
6 changes: 4 additions & 2 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,8 @@ def train(config: Config):
ckpt_manager.load(resume_ckpt_path=config.ckpt.resume)

if elastic_device_mesh.live_recovery.need_live_recovery():
# time.sleep(4)
# diloco.fake_step(model)
ckpt_manager.download_and_load_ckpt_from_peers()
diloco.fake_step(model)

if world_info.rank == 0:
logger_cls = WandbMonitor if config.metric_logger_type == "wandb" else DummyMonitor
Expand All @@ -247,6 +246,9 @@ def train(config: Config):
logger.info(f"outer_step step: {training_progress.outer_step}")

time_start_outer = time.perf_counter()

elastic_device_mesh.maybe_reinit_global_pg() # we call meybe reinit at the begining of each outer step

for _inner_step in range(num_inner_steps):
loss_batch = 0

Expand Down

0 comments on commit 8fa1f0e

Please sign in to comment.