From f0245f1f806e516a2b87b2e71889d5d8d41de940 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 20 Nov 2024 20:35:45 +0000 Subject: [PATCH] fix cpu local mehs --- src/zeroband/diloco.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 769f049..6a13e9c 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -9,6 +9,7 @@ from zeroband.utils.logging import get_logger from torch.distributed._tensor.api import DTensor from functools import lru_cache +from torch.distributed.device_mesh import init_device_mesh class DilocoConfig(BaseConfig): @@ -67,6 +68,7 @@ def __init__( # just force compilation self.pccl_communicator = pccl_communicator + self.cpu_local_mesh = init_device_mesh("cpu", mesh_shape=(self.local_pg.size(),)) self._logger = get_logger() self.world_info = get_world_info()