diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 46a61cbb..567356f9 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -55,6 +55,9 @@ def __init__( if self.fsdp_sharding_strategy not in [ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP]: raise ValueError("Diloco only support FULL_SHARD and SHARD_GRAD_OP") + if self.world_info.global_world_size < 1: + raise ValueError("Diloco requires a global world size of at least 1") + self._init_offloaded_optimizer(model=model) def _init_offloaded_optimizer(self, model): diff --git a/src/zeroband/utils/world_info.py b/src/zeroband/utils/world_info.py index 9b73f328..94dc5cb1 100644 --- a/src/zeroband/utils/world_info.py +++ b/src/zeroband/utils/world_info.py @@ -21,7 +21,7 @@ def __init__(self): self.global_unique_id = os.environ.get("GLOBAL_UNIQUE_ID", None) self.global_addr = os.environ.get("GLOBAL_ADDR", None) self.global_port = int(os.environ.get("GLOBAL_PORT")) if "GLOBAL_PORT" in os.environ else None - self.global_world_size = int(os.environ.get("GLOBAL_WORLD_SIZE", 1)) + self.global_world_size = int(os.environ.get("GLOBAL_WORLD_SIZE", 0)) self.global_rank = int(os.environ.get("GLOBAL_RANK", 0)) def __repr__(self):