From e4b65cefbabe6740677590306cabeb4517e24ec5 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Thu, 21 Nov 2024 01:01:55 +0000 Subject: [PATCH] fix local pg --- src/zeroband/diloco.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 6a13e9c..5f262d7 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -68,12 +68,11 @@ def __init__( # just force compilation self.pccl_communicator = pccl_communicator - self.cpu_local_mesh = init_device_mesh("cpu", mesh_shape=(self.local_pg.size(),)) - - self._logger = get_logger() self.world_info = get_world_info() + self._logger = get_logger() self._init_offloaded_optimizer(model=model) + self.cpu_local_mesh = init_device_mesh("cpu", mesh_shape=(self.world_info.world_size,)) @torch.no_grad() def _init_offloaded_optimizer(self, model):