diff --git a/README.md b/README.md index 424fc304..42116a54 100644 --- a/README.md +++ b/README.md @@ -78,8 +78,8 @@ uv run pytest ## Environment variables | Environment Variable | Description | Default Value | |-----------------------|--------------------------------------------------|---------------| -| `GLOBAL_UNIQUE_ID` | Unique identifier worker in global store. | `""` (empty string) | -| `GLOBAL_ADDR` | IP Address of the global store | `""` (empty string) | -| `GLOBAL_PORT` | Port number of the global store. | `-1` | +| `GLOBAL_UNIQUE_ID` | Unique identifier worker in global store. | `None` | +| `GLOBAL_ADDR` | IP Address of the global store | `None` | +| `GLOBAL_PORT` | Port number of the global store. | `None` | | `GLOBAL_WORLD_SIZE` | The size of the global process group. | `1` | | `GLOBAL_RANK` | Rank of the process in the global process group. | `0` | diff --git a/scripts/simulate_multi_node_diloco.sh b/scripts/simulate_multi_node_diloco.sh index 858c1805..cbbd8737 100755 --- a/scripts/simulate_multi_node_diloco.sh +++ b/scripts/simulate_multi_node_diloco.sh @@ -59,7 +59,7 @@ export GLOBAL_WORLD_SIZE=$N for i in $(seq 0 $(($N - 1 ))) do > logs/log$i - TORCH_UNIQUE_ID=$i GLOBAL_RANK=$i CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank 0 --rdzv-endpoint localhost:$((10001 + $i)) --nnodes=1 $@ > logs/log$i 2>&1 & + GLOBAL_UNIQUE_ID=$i GLOBAL_RANK=$i CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank 0 --rdzv-endpoint localhost:$((10001 + $i)) --nnodes=1 $@ > logs/log$i 2>&1 & child_pids+=($!) done diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 9d9c12fc..73fc23dd 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -1,5 +1,5 @@ -import os import shutil +import os from pydantic_config import BaseConfig import torch from torch import nn @@ -16,6 +16,9 @@ class DilocoConfig(BaseConfig): inner_steps: int +SHARED_MEMORY_PATH = "/dev/shm/zeroband" + + class Diloco: """ This class implements the diloco algorithm from https://arxiv.org/abs/2311.08105 and https://arxiv.org/abs/2407.07852. @@ -99,12 +102,12 @@ def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]: # All the processes use the same shared memory file to create a storage for each parameter. Only rank 0 will do the copy. # A barrier is added to ensure that after the function completes, the parameters are all offloaded. Otherwise, the non 0 ranks might access uninitialized memory. offloaded_params = [] - os.makedirs(f"/dev/shm/zeroband/{self.world_info.global_unique_id}", exist_ok=True) + os.makedirs(f"{SHARED_MEMORY_PATH}/{self.world_info.global_unique_id}", exist_ok=True) for param_name, param in model.named_parameters(): if param.requires_grad: storage = torch.UntypedStorage.from_file( - f"/dev/shm/zeroband/{self.world_info.global_unique_id}/{param_name}", + f"{SHARED_MEMORY_PATH}/{self.world_info.global_unique_id}/{param_name}", True, param.data.untyped_storage().size(), ) @@ -136,4 +139,4 @@ def step(self, model: nn.Module): self._logger.debug("Post meow diloco step %s", get_module_signature(model)) def __del__(self): - shutil.rmtree(f"/dev/shm/zeroband/{self.world_info.global_unique_id}", ignore_errors=True) + shutil.rmtree(f"{SHARED_MEMORY_PATH}/{self.world_info.global_unique_id}", ignore_errors=True)