Skip to content

Commit

Permalink
add jack todo
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 25, 2024
1 parent a90d07e commit f3ef344
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/zeroband/diloco.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ def sync_pseudo_gradient(self, model: nn.Module):

# gloo does not support AVG
param_offloaded.grad = param_offloaded.grad / self.elastic_device_mesh.global_pg.size()
dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.elastic_device_mesh.global_pg)
dist.all_reduce(
param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.elastic_device_mesh.global_pg, async_op=True
)
# todo async here

def sync_inner_model(self, model: nn.Module):
"""
Expand All @@ -113,7 +116,7 @@ def sync_inner_model(self, model: nn.Module):
# here each rank has a shard of the model in memory so all rank do the sync
self._logger.debug("sync inner model")
for param_offloaded, param in zip(self.cpu_model, model.parameters()):
param.data = param_offloaded.data.to("cuda")
param.data = param_offloaded.data.to("cuda") # todo: use copy_ here

elif self.fsdp_sharding_strategy in [ShardingStrategy.SHARD_GRAD_OP, ShardingStrategy.NO_SHARD]:
self._logger.debug("sync inner model")
Expand Down

0 comments on commit f3ef344

Please sign in to comment.