Skip to content

Commit

Permalink
Better ElasticDeviceMesh (#9)
Browse files Browse the repository at this point in the history
* refactor: move pg concerns into edm

* working but only rank 0 syncs

* use fake pg instead of None

* testing utils

* syncing correctly but ugly

* make cpu offload use mmaped file

* fix: allow none diloco to work with fake pg

* simulate multi node diloco script

* docs: update docs

* remove prints

* ruff lint

* move global info to world info and fix unique id

* fixes from merge

* move unique id to world info

* update command in readme

* remove broadcasts at init

* move summon full params to diloco class

* fix data split

* move testing to utils

* document offloading logic

* add envs to readme

* repre for worldinfo

* revert to global pg

* set unique id in tests

* fix: nccl cannot all reduce same device

* use get module signature instead of model hash

* change default global unique id to none

* revert data changes

* make /dev/shm/zeroband a constant and some fixes

* revert shm offload

* fix: non zero rank need to reduce too

* remove testing
  • Loading branch information
Jackmin801 authored Sep 28, 2024
1 parent 55e0b71 commit 879828a
Show file tree
Hide file tree
Showing 10 changed files with 372 additions and 56 deletions.
21 changes: 14 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
ZeroBand is a production ready codebase for decentralized training of LLM


## developlment
## Development

install uv

Expand Down Expand Up @@ -40,28 +40,28 @@ run your code using
uv run ...
```

## quick check
## Quick check

To check that everything is working you can do

```bash
ZERO_BAND_LOG_LEVEL=DEBUG torchrun --nproc_per_node=2 src/zeroband/train.py @configs/debug/normal.toml
ZERO_BAND_LOG_LEVEL=DEBUG torchrun --nproc_per_node=2 src/zeroband/train.py @configs/debug/normal.toml
```

## run diloco
## Run diloco

To run diloco locally you can use the helper script `scripts/simulatsimulate_multi_nodee_mutl.sh`

:note: you need 4 gpus to run the following command

```bash
ZERO_BAND_LOG_LEVEL=DEBUG ./scripts/simulate_multi_node.sh 2 2 src/zeroband/train.py @configs/debug/diloco.toml
ZERO_BAND_LOG_LEVEL=DEBUG ./scripts/simulate_multi_node_diloco.sh 2 2 src/zeroband/train.py @configs/debug/diloco.toml
```

if you have only two gpus

```bash
ZERO_BAND_LOG_LEVEL=DEBUG ./scripts/simulate_multi_node.sh 2 1 src/zeroband/train.py @configs/debug/diloco.toml
ZERO_BAND_LOG_LEVEL=DEBUG ./scripts/simulate_multi_node_diloco.sh 2 1 src/zeroband/train.py @configs/debug/diloco.toml
```

One gpu is not supported at the moment because of a fsdp bug in our implementation.
Expand All @@ -71,8 +71,15 @@ One gpu is not supported at the moment because of a fsdp bug in our implementati
You need a machine with a least two gpus to run the full test suite.

Some test must be run from the root directory.

```bash
uv run pytest
```

## Environment variables
| Environment Variable | Description | Default Value |
|-----------------------|--------------------------------------------------|---------------|
| `GLOBAL_UNIQUE_ID` | Unique identifier worker in global store. | `None` |
| `GLOBAL_ADDR` | IP Address of the global store | `None` |
| `GLOBAL_PORT` | Port number of the global store. | `None` |
| `GLOBAL_WORLD_SIZE` | The size of the global process group. | `1` |
| `GLOBAL_RANK` | Rank of the process in the global process group. | `0` |
69 changes: 69 additions & 0 deletions scripts/simulate_multi_node_diloco.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/bin/bash

#
# simulate multi nodes on one gpu. start N torchrun on X gpu locally.
# example how to run ./scripts/simulate_multi_node.sh 2 1 src/zeroband/train.py @configs/debug/normal.toml

# Function to get CUDA devices based on the number of GPUs and index
function get_cuda_devices() {
local num_gpu=$1
local index=$2
local start_gpu=$((num_gpu * index))
local end_gpu=$((start_gpu + num_gpu - 1))

if [ "$num_gpu" -eq 1 ]; then
echo $start_gpu
else
echo $(seq -s ',' $start_gpu $end_gpu)
fi
}

# Array to store PIDs of child processes
child_pids=()

# Function to kill all child processes
cleanup() {
echo "Cleaning up child processes..."
local killed=0
for pid in "${child_pids[@]}"; do
if kill -TERM "$pid" 2>/dev/null; then
((killed++))
fi
done
wait
echo "All child processes terminated. Killed $killed processes."
exit
}

# Check if at least three arguments were passed
if [ "$#" -lt 3 ]; then
echo "Usage: $0 <N> <initial_peer> <num_gpu> [additional_python_args]"
exit 1
fi


N=$1 # Set N from the first argument
NUM_GPU=$2
shift 2 # Remove the first three arguments so $@ contains only additional Python arguments

# Register the cleanup function to be called on SIGINT (Ctrl+C)
trap cleanup SIGINT


mkdir -p logs

export GLOBAL_ADDR=localhost
export GLOBAL_PORT=10000
export GLOBAL_WORLD_SIZE=$N

for i in $(seq 0 $(($N - 1 )))
do
> logs/log$i
GLOBAL_UNIQUE_ID=$i GLOBAL_RANK=$i CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank 0 --rdzv-endpoint localhost:$((10001 + $i)) --nnodes=1 $@ > logs/log$i 2>&1 &
child_pids+=($!)
done

tail -f logs/log0 &
child_pids+=($!)

wait
230 changes: 230 additions & 0 deletions src/zeroband/comms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
from torch.distributed.device_mesh import init_device_mesh
from zeroband.utils.world_info import get_world_info
from zeroband.utils.logging import get_logger
import torch.distributed as dist
from datetime import timedelta
import time
from typing import List, Tuple, Optional
from torch.testing._internal.distributed.fake_pg import FakeProcessGroup


TCPSTORE_TIMEOUT = timedelta(seconds=10)
MAX_JOINERS = 100 # Maximum number of nodes that can join in a single reinit
MAX_LEAVERS = 100 # Maximum number of nodes that can leave in a single reinit


def _wait_for_status(store: dist.Store, status: Optional[str] = None) -> str:
while True:
try:
ret = store.get("status").decode("utf-8")
if status is None or ret == status:
return ret
time.sleep(0.1)
except dist.DistStoreError as e:
if status is not None:
raise e
time.sleep(0.1)


def _queue_join(store: dist.Store, unique_id: str):
for i in range(MAX_JOINERS):
joiner_id = store.get(f"joiner_{i}").decode("utf-8")
if joiner_id == "null":
store.set(f"joiner_{i}", unique_id)
store.set(f"joiner_{i + 1}", "null")
break
else:
raise RuntimeError("Too many joiners")


def _queue_leave(store: dist.Store, unique_id: str):
for i in range(MAX_LEAVERS):
leaver_id = store.get(f"leaver_{i}").decode("utf-8")
if leaver_id == "null":
store.set(f"leaver_{i}", unique_id)
store.set(f"leaver_{i + 1}", "null")
break
else:
raise RuntimeError("Too many leavers")


def _get_joiners_and_leavers(store: dist.Store) -> Tuple[List[str], List[str]]:
joiners = []
leavers = []
for i in range(MAX_JOINERS):
joiner_id = store.get(f"joiner_{i}").decode("utf-8")
if joiner_id == "null":
break
joiners.append(joiner_id)
for i in range(MAX_LEAVERS):
leaver_id = store.get(f"leaver_{i}").decode("utf-8")
if leaver_id == "null":
break
leavers.append(leaver_id)
print(f"Joiners: {joiners}, Leavers: {leavers}")
return joiners, leavers


def _clear_joiners_and_leavers(store: dist.Store):
store.set("joiner_0", "null")
store.set("leaver_0", "null")


class ElasticDeviceMesh:
"""A class to manage the process groups for elastic training without restarts.
The way it works is rank 0 coordinates the joining and leaving of nodes.
Rank 0 manages the status to coordinate the creation and recreation of the process groups.
When a node wants to join, rank 0 will setup the store so that all nodes know the new world size and their respective ranks.
Store keys used:
- status: "init", "running", "reinit"
- world_size: The current world size
- mesh_count: The version of the mesh
- rank_{uuid}: The rank of the node with the given uuid
- rank_map_{rank}: The new rank of the node with the given rank. Used to remap ranks when nodes leave.
- joiner_{i}: The uuid of the ith joiner. Its a KV implmentation of a queue.
- leaver_{i}: The uuid of the ith leaver. Its a KV implmentation of a queue.
"""

local_pg: dist.ProcessGroup
global_pg: dist.ProcessGroup

def __init__(self):
self._logger = get_logger()
self.world_info = get_world_info()

# Initialize global process group
self.global_pg = FakeProcessGroup(self.world_info.rank, 1)
if self.world_info.global_world_size > 1:
self.global_pg = self._init_global_pg()

# Initialize local process group
dist.init_process_group(backend="cpu:gloo,cuda:nccl")
self._device_mesh = init_device_mesh(
"cuda",
(self.world_info.nnodes, self.world_info.local_world_size),
mesh_dim_names=("internode", "intranode"),
)
self.local_pg = self._device_mesh.get_group("intranode")

if self.world_info.rank == 0:
self._logger.debug(f"global pg world : {self.global_pg.size()}, local pg: {self.local_pg.size()}")
else:
self._logger.debug(f"local pg world : {self.local_pg.size()}")

def __del__(self):
dist.destroy_process_group()

def _init_global_pg(self) -> dist.Store:
store = dist.TCPStore(
host_name=self.world_info.global_addr,
port=self.world_info.global_port + self.world_info.rank,
timeout=TCPSTORE_TIMEOUT,
is_master=(self.world_info.global_rank == 0),
)

# Initialize store
if self.world_info.global_rank == 0:
store.set("mesh_count", "0")
store.set("joiner_0", "null")
store.set("leaver_0", "null")
store.set("status", "init")
status = "init"
else:
status = _wait_for_status(store)

if status == "init":
# First time initialization
self.mesh_count = 0
self.prefix_store = dist.PrefixStore("mesh_0", store)
pg = dist.ProcessGroupGloo(
self.prefix_store, self.world_info.global_rank, self.world_info.global_world_size, TCPSTORE_TIMEOUT
)
if self.world_info.global_rank == 0:
store.set("status", "running")
store.set(f"rank_{self.world_info.global_unique_id}", str(self.world_info.global_rank))
elif status == "running":
# Node wants to join
_queue_join(store, self.world_info.global_unique_id)
_wait_for_status(store, "reinit")
# Get assigned rank
self.world_info.global_rank = int(store.get(f"rank_{self.world_info.global_unique_id}").decode("utf-8"))
# Get updated world_size
self.world_info.global_world_size = int(store.get("world_size").decode("utf-8"))
self.mesh_count = int(store.get("mesh_count").decode("utf-8"))
self.prefix_store = dist.PrefixStore(f"mesh_{self.mesh_count}", store)
pg = dist.ProcessGroupGloo(
self.prefix_store, self.world_info.global_rank, self.world_info.global_world_size, TCPSTORE_TIMEOUT
)
else:
# TODO: Could be in "reinit" status
raise RuntimeError(f"Unknown status {status}")

# Setting instance variables
self.global_store = store
self.leaving = False
return pg

def _resolve_world(self):
"""Set the new world size and ranks for all nodes."""
# Find joiners and leavers
joiners, leavers = _get_joiners_and_leavers(self.global_store)
# If no joiners or leavers, no resolution needed
if len(joiners) == 0 and len(leavers) == 0:
return

# Remap live ranks to smaller world_size caused by leavers
leaving_ranks = {int(self.global_store.get(f"rank_{leaver_id}").decode("utf-8")) for leaver_id in leavers}
live_ranks = [i for i in range(0, self.world_size, self.local_world_size) if i not in leaving_ranks]
for i, rank in enumerate(live_ranks):
self.global_store.set(f"rank_map_{rank}", str(i * self.local_world_size))
new_world_size = len(live_ranks) * self.local_world_size

# Give joiners new ranks
for joiner_id in joiners:
self.global_store.set(f"rank_{joiner_id}", str(new_world_size))
new_world_size += self.local_world_size

# Update world_size
self.global_store.set("world_size", str(new_world_size))
self.global_store.set("mesh_count", str(self.mesh_count + 1))
# Set status to "reinit"
self.global_store.set("status", "reinit")

def maybe_reinit_device_mesh(self):
"""Reinitialize the device mesh if there are joiners or leavers."""
if self.rank == 0:
self._resolve_world()
dist.barrier()
status = self.global_store.get("status").decode("utf-8")
if status == "running":
return

print("Reinitializing device mesh")
dist.destroy_process_group()
print("Destroyed process group")
if self.leaving:
print("Leaving")
return

# Check if we got remapped
prev_uuid_rank = int(self.global_store.get(f"rank_{self.world_info.global_unique_id}").decode("utf-8"))
new_uuid_rank = int(self.global_store.get(f"rank_map_{prev_uuid_rank}").decode("utf-8"))
self.rank = new_uuid_rank + self.local_rank

self.world_size = int(self.global_store.get("world_size").decode("utf-8"))
self.mesh_count = int(self.global_store.get("mesh_count").decode("utf-8"))
self.prefix_store = dist.PrefixStore(f"mesh_{self.mesh_count}", self.global_store)
dist.init_process_group(
backend="cpu:gloo,cuda:nccl", store=self.prefix_store, rank=self.rank, world_size=self.world_size
)

if self.rank == 0:
_clear_joiners_and_leavers(self.global_store)
self.global_store.set("status", "running")
# Update rank if needed (otherwise, the next remap will do the lookup incorrectly)
if self.local_rank == 0 and new_uuid_rank != prev_uuid_rank:
self.global_store.set(f"rank_{self.world_info.global_unique_id}", str(new_uuid_rank))
# Reinitialize sub process groups
self.world_rank = self.rank // self.local_world_size
Loading

0 comments on commit 879828a

Please sign in to comment.