Skip to content

Commit

Permalink
fix all reduce fake div
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Oct 17, 2024
1 parent e3e1fa2 commit b0171d3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
21 changes: 16 additions & 5 deletions src/zeroband/diloco.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,25 @@ def _init_offloaded_optimizer(self, model):
)
self._logger.debug("offload model to cpu")

def sync_pseudo_gradient(self, model: nn.Module, fake: bool = False, flag: str = "outer"):
def sync_pseudo_gradient(
self, model: nn.Module, fake: bool = False, flag: str = "outer", num_effective_peers: int | None = None
):
"""
Sync the pseudo gradient from the local process group to the global process group
"""
_start_time = time.perf_counter()
self._logger.debug("sync pseudo gradient %s", " fake" if fake else "")

world_size_pre_init = self.elastic_device_mesh.global_pg.size()
self.elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=False)
world_size_post_init = self.elastic_device_mesh.global_pg.size()

if world_size_pre_init == world_size_post_init and num_effective_peers is not None:
world_size = num_effective_peers
else:
world_size = world_size_post_init

self._logger.debug("sync pseudo gradient %s with world size %d", " fake" if fake else "", world_size)

global_pg = self.elastic_device_mesh.global_pg
for i in range(self.config.retry_all_reduce):
for param_offloaded, param in zip(self.param_list_cpu, model.parameters()):
Expand All @@ -98,7 +109,7 @@ def sync_pseudo_gradient(self, model: nn.Module, fake: bool = False, flag: str =
param_offloaded.grad.to_local().copy_(param_offloaded.data.to_local())
param_offloaded.grad.to_local().sub_(param.data.to_local().to(param_offloaded.data.device))
try:
self.offloaded_grad_flat_tensor.div_(global_pg.size())
self.offloaded_grad_flat_tensor.div_(world_size)
_collective_start_time = time.perf_counter()
self._logger.debug("Waiting on barrier")
self.elastic_device_mesh.monitored_barrier(flag)
Expand Down Expand Up @@ -198,12 +209,12 @@ def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]:
# )
return offloaded_params

def step(self, model: nn.Module, fake: bool = False, flag: str = "outer"):
def step(self, model: nn.Module, fake: bool = False, num_effective_peers: int | None = None, flag: str = "outer"):
"""
Step the optimizer
"""
time_start = time.perf_counter()
self.sync_pseudo_gradient(model, fake=fake, flag=flag)
self.sync_pseudo_gradient(model, fake=fake, flag=flag, num_effective_peers=num_effective_peers)
self._logger.info(f"all reduce pseudo gradient in: {time.perf_counter() - time_start} seconds")

if self.outer_optimizer is not None:
Expand Down
5 changes: 4 additions & 1 deletion src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,10 @@ def train(config: Config):
time_start_outer = time.perf_counter()

if config.diloco is not None:
# this is a patch for now to allow live recovery worker to not affect the all reduce at all
num_effective_peers = elastic_device_mesh.global_pg.size()
elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=True)

# at the beginning of the inner steps we allow joiner to arrive.
# We maybe reinit before the all reduce but only to allow leaving, not to join anymore

Expand Down Expand Up @@ -459,7 +462,7 @@ def train(config: Config):
ckpt_manager.cache_inner_optimizer()

time_start_inner = time.perf_counter()
diloco.step(model, flag=training_progress.outer_step)
diloco.step(model=model, flag=training_progress.outer_step, num_effective_peers=num_effective_peers)
diloco_time = time.perf_counter() - time_start_inner

if config.train.log_model_hash:
Expand Down

0 comments on commit b0171d3

Please sign in to comment.