diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 6a13e9c..5f262d7 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -68,12 +68,11 @@ def __init__( # just force compilation self.pccl_communicator = pccl_communicator - self.cpu_local_mesh = init_device_mesh("cpu", mesh_shape=(self.local_pg.size(),)) - - self._logger = get_logger() self.world_info = get_world_info() + self._logger = get_logger() self._init_offloaded_optimizer(model=model) + self.cpu_local_mesh = init_device_mesh("cpu", mesh_shape=(self.world_info.world_size,)) @torch.no_grad() def _init_offloaded_optimizer(self, model):