diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 9486a97e..3a11f809 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -81,14 +81,25 @@ def _init_offloaded_optimizer(self, model): ) self._logger.debug("offload model to cpu") - def sync_pseudo_gradient(self, model: nn.Module, fake: bool = False, flag: str = "outer"): + def sync_pseudo_gradient( + self, model: nn.Module, fake: bool = False, flag: str = "outer", num_effective_peers: int | None = None + ): """ Sync the pseudo gradient from the local process group to the global process group """ _start_time = time.perf_counter() - self._logger.debug("sync pseudo gradient %s", " fake" if fake else "") + world_size_pre_init = self.elastic_device_mesh.global_pg.size() self.elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=False) + world_size_post_init = self.elastic_device_mesh.global_pg.size() + + if world_size_pre_init == world_size_post_init and num_effective_peers is not None: + world_size = num_effective_peers + else: + world_size = world_size_post_init + + self._logger.debug("sync pseudo gradient %s with world size %d", " fake" if fake else "", world_size) + global_pg = self.elastic_device_mesh.global_pg for i in range(self.config.retry_all_reduce): for param_offloaded, param in zip(self.param_list_cpu, model.parameters()): @@ -98,7 +109,7 @@ def sync_pseudo_gradient(self, model: nn.Module, fake: bool = False, flag: str = param_offloaded.grad.to_local().copy_(param_offloaded.data.to_local()) param_offloaded.grad.to_local().sub_(param.data.to_local().to(param_offloaded.data.device)) try: - self.offloaded_grad_flat_tensor.div_(global_pg.size()) + self.offloaded_grad_flat_tensor.div_(world_size) _collective_start_time = time.perf_counter() self._logger.debug("Waiting on barrier") self.elastic_device_mesh.monitored_barrier(flag) @@ -198,12 +209,12 @@ def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]: # ) return offloaded_params - def step(self, model: nn.Module, fake: bool = False, flag: str = "outer"): + def step(self, model: nn.Module, fake: bool = False, num_effective_peers: int | None = None, flag: str = "outer"): """ Step the optimizer """ time_start = time.perf_counter() - self.sync_pseudo_gradient(model, fake=fake, flag=flag) + self.sync_pseudo_gradient(model, fake=fake, flag=flag, num_effective_peers=num_effective_peers) self._logger.info(f"all reduce pseudo gradient in: {time.perf_counter() - time_start} seconds") if self.outer_optimizer is not None: diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 11014f50..4089de7c 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -316,7 +316,10 @@ def train(config: Config): time_start_outer = time.perf_counter() if config.diloco is not None: + # this is a patch for now to allow live recovery worker to not affect the all reduce at all + num_effective_peers = elastic_device_mesh.global_pg.size() elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=True) + # at the beginning of the inner steps we allow joiner to arrive. # We maybe reinit before the all reduce but only to allow leaving, not to join anymore @@ -459,7 +462,7 @@ def train(config: Config): ckpt_manager.cache_inner_optimizer() time_start_inner = time.perf_counter() - diloco.step(model, flag=training_progress.outer_step) + diloco.step(model=model, flag=training_progress.outer_step, num_effective_peers=num_effective_peers) diloco_time = time.perf_counter() - time_start_inner if config.train.log_model_hash: