Skip to content

Commit

Permalink
make /dev/shm/zeroband a constant and some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackmin801 committed Sep 28, 2024
1 parent 73800d9 commit e64eb2d
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ uv run pytest
## Environment variables
| Environment Variable | Description | Default Value |
|-----------------------|--------------------------------------------------|---------------|
| `GLOBAL_UNIQUE_ID` | Unique identifier worker in global store. | `""` (empty string) |
| `GLOBAL_ADDR` | IP Address of the global store | `""` (empty string) |
| `GLOBAL_PORT` | Port number of the global store. | `-1` |
| `GLOBAL_UNIQUE_ID` | Unique identifier worker in global store. | `None` |
| `GLOBAL_ADDR` | IP Address of the global store | `None` |
| `GLOBAL_PORT` | Port number of the global store. | `None` |
| `GLOBAL_WORLD_SIZE` | The size of the global process group. | `1` |
| `GLOBAL_RANK` | Rank of the process in the global process group. | `0` |
2 changes: 1 addition & 1 deletion scripts/simulate_multi_node_diloco.sh
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ export GLOBAL_WORLD_SIZE=$N
for i in $(seq 0 $(($N - 1 )))
do
> logs/log$i
TORCH_UNIQUE_ID=$i GLOBAL_RANK=$i CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank 0 --rdzv-endpoint localhost:$((10001 + $i)) --nnodes=1 $@ > logs/log$i 2>&1 &
GLOBAL_UNIQUE_ID=$i GLOBAL_RANK=$i CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank 0 --rdzv-endpoint localhost:$((10001 + $i)) --nnodes=1 $@ > logs/log$i 2>&1 &
child_pids+=($!)
done

Expand Down
11 changes: 7 additions & 4 deletions src/zeroband/diloco.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import shutil
import os
from pydantic_config import BaseConfig
import torch
from torch import nn
Expand All @@ -16,6 +16,9 @@ class DilocoConfig(BaseConfig):
inner_steps: int


SHARED_MEMORY_PATH = "/dev/shm/zeroband"


class Diloco:
"""
This class implements the diloco algorithm from https://arxiv.org/abs/2311.08105 and https://arxiv.org/abs/2407.07852.
Expand Down Expand Up @@ -99,12 +102,12 @@ def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]:
# All the processes use the same shared memory file to create a storage for each parameter. Only rank 0 will do the copy.
# A barrier is added to ensure that after the function completes, the parameters are all offloaded. Otherwise, the non 0 ranks might access uninitialized memory.
offloaded_params = []
os.makedirs(f"/dev/shm/zeroband/{self.world_info.global_unique_id}", exist_ok=True)
os.makedirs(f"{SHARED_MEMORY_PATH}/{self.world_info.global_unique_id}", exist_ok=True)

for param_name, param in model.named_parameters():
if param.requires_grad:
storage = torch.UntypedStorage.from_file(
f"/dev/shm/zeroband/{self.world_info.global_unique_id}/{param_name}",
f"{SHARED_MEMORY_PATH}/{self.world_info.global_unique_id}/{param_name}",
True,
param.data.untyped_storage().size(),
)
Expand Down Expand Up @@ -136,4 +139,4 @@ def step(self, model: nn.Module):
self._logger.debug("Post meow diloco step %s", get_module_signature(model))

def __del__(self):
shutil.rmtree(f"/dev/shm/zeroband/{self.world_info.global_unique_id}", ignore_errors=True)
shutil.rmtree(f"{SHARED_MEMORY_PATH}/{self.world_info.global_unique_id}", ignore_errors=True)

0 comments on commit e64eb2d

Please sign in to comment.