From 879828a23dff27503020ba70d34e2c2811f6bd59 Mon Sep 17 00:00:00 2001 From: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Date: Sat, 28 Sep 2024 15:36:16 -0700 Subject: [PATCH] Better ElasticDeviceMesh (#9) * refactor: move pg concerns into edm * working but only rank 0 syncs * use fake pg instead of None * testing utils * syncing correctly but ugly * make cpu offload use mmaped file * fix: allow none diloco to work with fake pg * simulate multi node diloco script * docs: update docs * remove prints * ruff lint * move global info to world info and fix unique id * fixes from merge * move unique id to world info * update command in readme * remove broadcasts at init * move summon full params to diloco class * fix data split * move testing to utils * document offloading logic * add envs to readme * repre for worldinfo * revert to global pg * set unique id in tests * fix: nccl cannot all reduce same device * use get module signature instead of model hash * change default global unique id to none * revert data changes * make /dev/shm/zeroband a constant and some fixes * revert shm offload * fix: non zero rank need to reduce too * remove testing --- README.md | 21 ++- scripts/simulate_multi_node_diloco.sh | 69 ++++++++ src/zeroband/comms.py | 230 ++++++++++++++++++++++++++ src/zeroband/diloco.py | 25 +-- src/zeroband/train.py | 27 +-- src/zeroband/utils/__init__.py | 32 +++- src/zeroband/utils/world_info.py | 9 + tests/test_dist/conftest.py | 9 +- tests/test_dist/test_all_reduce.py | 4 +- tests/test_dist/test_diloco.py | 2 +- 10 files changed, 372 insertions(+), 56 deletions(-) create mode 100755 scripts/simulate_multi_node_diloco.sh create mode 100644 src/zeroband/comms.py diff --git a/README.md b/README.md index 25861da1..42116a54 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ ZeroBand is a production ready codebase for decentralized training of LLM -## developlment +## Development install uv @@ -40,28 +40,28 @@ run your code using uv run ... ``` -## quick check +## Quick check To check that everything is working you can do ```bash -ZERO_BAND_LOG_LEVEL=DEBUG torchrun --nproc_per_node=2 src/zeroband/train.py @configs/debug/normal.toml +ZERO_BAND_LOG_LEVEL=DEBUG torchrun --nproc_per_node=2 src/zeroband/train.py @configs/debug/normal.toml ``` -## run diloco +## Run diloco To run diloco locally you can use the helper script `scripts/simulatsimulate_multi_nodee_mutl.sh` :note: you need 4 gpus to run the following command ```bash -ZERO_BAND_LOG_LEVEL=DEBUG ./scripts/simulate_multi_node.sh 2 2 src/zeroband/train.py @configs/debug/diloco.toml +ZERO_BAND_LOG_LEVEL=DEBUG ./scripts/simulate_multi_node_diloco.sh 2 2 src/zeroband/train.py @configs/debug/diloco.toml ``` if you have only two gpus ```bash -ZERO_BAND_LOG_LEVEL=DEBUG ./scripts/simulate_multi_node.sh 2 1 src/zeroband/train.py @configs/debug/diloco.toml +ZERO_BAND_LOG_LEVEL=DEBUG ./scripts/simulate_multi_node_diloco.sh 2 1 src/zeroband/train.py @configs/debug/diloco.toml ``` One gpu is not supported at the moment because of a fsdp bug in our implementation. @@ -71,8 +71,15 @@ One gpu is not supported at the moment because of a fsdp bug in our implementati You need a machine with a least two gpus to run the full test suite. Some test must be run from the root directory. - ```bash uv run pytest ``` +## Environment variables +| Environment Variable | Description | Default Value | +|-----------------------|--------------------------------------------------|---------------| +| `GLOBAL_UNIQUE_ID` | Unique identifier worker in global store. | `None` | +| `GLOBAL_ADDR` | IP Address of the global store | `None` | +| `GLOBAL_PORT` | Port number of the global store. | `None` | +| `GLOBAL_WORLD_SIZE` | The size of the global process group. | `1` | +| `GLOBAL_RANK` | Rank of the process in the global process group. | `0` | diff --git a/scripts/simulate_multi_node_diloco.sh b/scripts/simulate_multi_node_diloco.sh new file mode 100755 index 00000000..cbbd8737 --- /dev/null +++ b/scripts/simulate_multi_node_diloco.sh @@ -0,0 +1,69 @@ +#!/bin/bash + +# +# simulate multi nodes on one gpu. start N torchrun on X gpu locally. +# example how to run ./scripts/simulate_multi_node.sh 2 1 src/zeroband/train.py @configs/debug/normal.toml + +# Function to get CUDA devices based on the number of GPUs and index +function get_cuda_devices() { + local num_gpu=$1 + local index=$2 + local start_gpu=$((num_gpu * index)) + local end_gpu=$((start_gpu + num_gpu - 1)) + + if [ "$num_gpu" -eq 1 ]; then + echo $start_gpu + else + echo $(seq -s ',' $start_gpu $end_gpu) + fi +} + +# Array to store PIDs of child processes +child_pids=() + +# Function to kill all child processes +cleanup() { + echo "Cleaning up child processes..." + local killed=0 + for pid in "${child_pids[@]}"; do + if kill -TERM "$pid" 2>/dev/null; then + ((killed++)) + fi + done + wait + echo "All child processes terminated. Killed $killed processes." + exit +} + +# Check if at least three arguments were passed +if [ "$#" -lt 3 ]; then + echo "Usage: $0 [additional_python_args]" + exit 1 +fi + + +N=$1 # Set N from the first argument +NUM_GPU=$2 +shift 2 # Remove the first three arguments so $@ contains only additional Python arguments + +# Register the cleanup function to be called on SIGINT (Ctrl+C) +trap cleanup SIGINT + + +mkdir -p logs + +export GLOBAL_ADDR=localhost +export GLOBAL_PORT=10000 +export GLOBAL_WORLD_SIZE=$N + +for i in $(seq 0 $(($N - 1 ))) +do + > logs/log$i + GLOBAL_UNIQUE_ID=$i GLOBAL_RANK=$i CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank 0 --rdzv-endpoint localhost:$((10001 + $i)) --nnodes=1 $@ > logs/log$i 2>&1 & + child_pids+=($!) +done + +tail -f logs/log0 & +child_pids+=($!) + +wait diff --git a/src/zeroband/comms.py b/src/zeroband/comms.py new file mode 100644 index 00000000..6a7fdded --- /dev/null +++ b/src/zeroband/comms.py @@ -0,0 +1,230 @@ +from torch.distributed.device_mesh import init_device_mesh +from zeroband.utils.world_info import get_world_info +from zeroband.utils.logging import get_logger +import torch.distributed as dist +from datetime import timedelta +import time +from typing import List, Tuple, Optional +from torch.testing._internal.distributed.fake_pg import FakeProcessGroup + + +TCPSTORE_TIMEOUT = timedelta(seconds=10) +MAX_JOINERS = 100 # Maximum number of nodes that can join in a single reinit +MAX_LEAVERS = 100 # Maximum number of nodes that can leave in a single reinit + + +def _wait_for_status(store: dist.Store, status: Optional[str] = None) -> str: + while True: + try: + ret = store.get("status").decode("utf-8") + if status is None or ret == status: + return ret + time.sleep(0.1) + except dist.DistStoreError as e: + if status is not None: + raise e + time.sleep(0.1) + + +def _queue_join(store: dist.Store, unique_id: str): + for i in range(MAX_JOINERS): + joiner_id = store.get(f"joiner_{i}").decode("utf-8") + if joiner_id == "null": + store.set(f"joiner_{i}", unique_id) + store.set(f"joiner_{i + 1}", "null") + break + else: + raise RuntimeError("Too many joiners") + + +def _queue_leave(store: dist.Store, unique_id: str): + for i in range(MAX_LEAVERS): + leaver_id = store.get(f"leaver_{i}").decode("utf-8") + if leaver_id == "null": + store.set(f"leaver_{i}", unique_id) + store.set(f"leaver_{i + 1}", "null") + break + else: + raise RuntimeError("Too many leavers") + + +def _get_joiners_and_leavers(store: dist.Store) -> Tuple[List[str], List[str]]: + joiners = [] + leavers = [] + for i in range(MAX_JOINERS): + joiner_id = store.get(f"joiner_{i}").decode("utf-8") + if joiner_id == "null": + break + joiners.append(joiner_id) + for i in range(MAX_LEAVERS): + leaver_id = store.get(f"leaver_{i}").decode("utf-8") + if leaver_id == "null": + break + leavers.append(leaver_id) + print(f"Joiners: {joiners}, Leavers: {leavers}") + return joiners, leavers + + +def _clear_joiners_and_leavers(store: dist.Store): + store.set("joiner_0", "null") + store.set("leaver_0", "null") + + +class ElasticDeviceMesh: + """A class to manage the process groups for elastic training without restarts. + + The way it works is rank 0 coordinates the joining and leaving of nodes. + Rank 0 manages the status to coordinate the creation and recreation of the process groups. + When a node wants to join, rank 0 will setup the store so that all nodes know the new world size and their respective ranks. + + Store keys used: + - status: "init", "running", "reinit" + - world_size: The current world size + - mesh_count: The version of the mesh + - rank_{uuid}: The rank of the node with the given uuid + - rank_map_{rank}: The new rank of the node with the given rank. Used to remap ranks when nodes leave. + - joiner_{i}: The uuid of the ith joiner. Its a KV implmentation of a queue. + - leaver_{i}: The uuid of the ith leaver. Its a KV implmentation of a queue. + """ + + local_pg: dist.ProcessGroup + global_pg: dist.ProcessGroup + + def __init__(self): + self._logger = get_logger() + self.world_info = get_world_info() + + # Initialize global process group + self.global_pg = FakeProcessGroup(self.world_info.rank, 1) + if self.world_info.global_world_size > 1: + self.global_pg = self._init_global_pg() + + # Initialize local process group + dist.init_process_group(backend="cpu:gloo,cuda:nccl") + self._device_mesh = init_device_mesh( + "cuda", + (self.world_info.nnodes, self.world_info.local_world_size), + mesh_dim_names=("internode", "intranode"), + ) + self.local_pg = self._device_mesh.get_group("intranode") + + if self.world_info.rank == 0: + self._logger.debug(f"global pg world : {self.global_pg.size()}, local pg: {self.local_pg.size()}") + else: + self._logger.debug(f"local pg world : {self.local_pg.size()}") + + def __del__(self): + dist.destroy_process_group() + + def _init_global_pg(self) -> dist.Store: + store = dist.TCPStore( + host_name=self.world_info.global_addr, + port=self.world_info.global_port + self.world_info.rank, + timeout=TCPSTORE_TIMEOUT, + is_master=(self.world_info.global_rank == 0), + ) + + # Initialize store + if self.world_info.global_rank == 0: + store.set("mesh_count", "0") + store.set("joiner_0", "null") + store.set("leaver_0", "null") + store.set("status", "init") + status = "init" + else: + status = _wait_for_status(store) + + if status == "init": + # First time initialization + self.mesh_count = 0 + self.prefix_store = dist.PrefixStore("mesh_0", store) + pg = dist.ProcessGroupGloo( + self.prefix_store, self.world_info.global_rank, self.world_info.global_world_size, TCPSTORE_TIMEOUT + ) + if self.world_info.global_rank == 0: + store.set("status", "running") + store.set(f"rank_{self.world_info.global_unique_id}", str(self.world_info.global_rank)) + elif status == "running": + # Node wants to join + _queue_join(store, self.world_info.global_unique_id) + _wait_for_status(store, "reinit") + # Get assigned rank + self.world_info.global_rank = int(store.get(f"rank_{self.world_info.global_unique_id}").decode("utf-8")) + # Get updated world_size + self.world_info.global_world_size = int(store.get("world_size").decode("utf-8")) + self.mesh_count = int(store.get("mesh_count").decode("utf-8")) + self.prefix_store = dist.PrefixStore(f"mesh_{self.mesh_count}", store) + pg = dist.ProcessGroupGloo( + self.prefix_store, self.world_info.global_rank, self.world_info.global_world_size, TCPSTORE_TIMEOUT + ) + else: + # TODO: Could be in "reinit" status + raise RuntimeError(f"Unknown status {status}") + + # Setting instance variables + self.global_store = store + self.leaving = False + return pg + + def _resolve_world(self): + """Set the new world size and ranks for all nodes.""" + # Find joiners and leavers + joiners, leavers = _get_joiners_and_leavers(self.global_store) + # If no joiners or leavers, no resolution needed + if len(joiners) == 0 and len(leavers) == 0: + return + + # Remap live ranks to smaller world_size caused by leavers + leaving_ranks = {int(self.global_store.get(f"rank_{leaver_id}").decode("utf-8")) for leaver_id in leavers} + live_ranks = [i for i in range(0, self.world_size, self.local_world_size) if i not in leaving_ranks] + for i, rank in enumerate(live_ranks): + self.global_store.set(f"rank_map_{rank}", str(i * self.local_world_size)) + new_world_size = len(live_ranks) * self.local_world_size + + # Give joiners new ranks + for joiner_id in joiners: + self.global_store.set(f"rank_{joiner_id}", str(new_world_size)) + new_world_size += self.local_world_size + + # Update world_size + self.global_store.set("world_size", str(new_world_size)) + self.global_store.set("mesh_count", str(self.mesh_count + 1)) + # Set status to "reinit" + self.global_store.set("status", "reinit") + + def maybe_reinit_device_mesh(self): + """Reinitialize the device mesh if there are joiners or leavers.""" + if self.rank == 0: + self._resolve_world() + dist.barrier() + status = self.global_store.get("status").decode("utf-8") + if status == "running": + return + + print("Reinitializing device mesh") + dist.destroy_process_group() + print("Destroyed process group") + if self.leaving: + print("Leaving") + return + + # Check if we got remapped + prev_uuid_rank = int(self.global_store.get(f"rank_{self.world_info.global_unique_id}").decode("utf-8")) + new_uuid_rank = int(self.global_store.get(f"rank_map_{prev_uuid_rank}").decode("utf-8")) + self.rank = new_uuid_rank + self.local_rank + + self.world_size = int(self.global_store.get("world_size").decode("utf-8")) + self.mesh_count = int(self.global_store.get("mesh_count").decode("utf-8")) + self.prefix_store = dist.PrefixStore(f"mesh_{self.mesh_count}", self.global_store) + dist.init_process_group( + backend="cpu:gloo,cuda:nccl", store=self.prefix_store, rank=self.rank, world_size=self.world_size + ) + + if self.rank == 0: + _clear_joiners_and_leavers(self.global_store) + self.global_store.set("status", "running") + # Update rank if needed (otherwise, the next remap will do the lookup incorrectly) + if self.local_rank == 0 and new_uuid_rank != prev_uuid_rank: + self.global_store.set(f"rank_{self.world_info.global_unique_id}", str(new_uuid_rank)) + # Reinitialize sub process groups + self.world_rank = self.rank // self.local_world_size diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index b132c707..46a61cbb 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -1,6 +1,5 @@ from pydantic_config import BaseConfig import torch -from torch.distributed.device_mesh import init_device_mesh from torch import nn from zeroband.utils.world_info import get_world_info from zeroband.utils.logging import get_logger @@ -13,28 +12,6 @@ class DilocoConfig(BaseConfig): inner_steps: int -class ElasticDeviceMesh: - """Init two process group through device mesh, one local on gpu and one global on cpu""" - - def __init__(self): - self._logger = get_logger() - - self.world_info = get_world_info() - - # right now device mesh does not support two backend so we just create two identicaly mesh expect the backend - self.device_mesh = init_device_mesh( - "cuda", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("global", "local") - ) - self.device_mesh_cpu = init_device_mesh( - "gloo", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("global", "local") - ) - - self.global_pg = self.device_mesh_cpu.get_group("global") - self.local_pg = self.device_mesh.get_group("local") - - self._logger.debug(f"global pg world : {self.global_pg.size()}, local pg: {self.local_pg.size()}") - - class Diloco: """ This class implements the diloco algorithm from https://arxiv.org/abs/2311.08105 and https://arxiv.org/abs/2407.07852. @@ -93,6 +70,8 @@ def sync_pseudo_gradient(self, model: nn.Module): """ self._logger.debug("sync pseudo gradient") for param_offloaded, param in zip(self.param_list_cpu, model.parameters()): + if param.shape[0] == 0: + continue param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device) # gloo does not support AVG diff --git a/src/zeroband/train.py b/src/zeroband/train.py index b37c0fe4..4ff5134d 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -4,7 +4,6 @@ import torch from pydantic_config import parse_argv, BaseConfig -from torch.distributed import destroy_process_group, init_process_group from einops import rearrange from torch.nn import functional as F @@ -18,9 +17,10 @@ ) import torch.distributed as dist from zeroband import utils -from zeroband.diloco import Diloco, DilocoConfig, ElasticDeviceMesh +from zeroband.diloco import Diloco, DilocoConfig +from zeroband.comms import ElasticDeviceMesh -from zeroband.utils import PerfCounter, get_model_hash, get_sharding_strategy +from zeroband.utils import PerfCounter, get_module_signature, get_sharding_strategy from zeroband.utils.monitor import WandbMonitor, DummyMonitor from zeroband.data import TEST_VOCAB_SIZE, get_dataloader from zeroband.models.llama import get_model @@ -85,8 +85,8 @@ def train(config: Config): train_dataloader = get_dataloader( tokenizer=tokenizer, - world_size=world_info.world_size, - rank=world_info.rank, + world_size=world_info.world_size * world_info.global_world_size, + rank=world_info.rank + world_info.global_rank * world_info.global_world_size, seq_length=config.data.seq_length, batch_size=config.train.micro_bs, num_workers=config.data.num_workers, @@ -95,12 +95,14 @@ def train(config: Config): model, model_config = get_model( config.name_model, config.type_model, - vocab_size=tokenizer.vocab_size if config.name_model != "debugmodel" else TEST_VOCAB_SIZE, + vocab_size=tokenizer.vocab_size + if config.name_model != "debugmodel" or not config.data.fake + else TEST_VOCAB_SIZE, ) if config.train.log_model_hash: # Compute SHA256 hash - logger.info(f"Model hash: {get_model_hash(model)}") + logger.info(f"Model hash: {get_module_signature(model)}") model = model.to(world_info.local_rank) logger.debug("model loaded") @@ -139,9 +141,6 @@ def train(config: Config): ) if config.diloco is not None: - if world_info.local_world_size == 1: - raise ValueError("Diloco is not supported for local_world_size == 1 because of a pytorch bug") - diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh.global_pg) scheduler = get_cosine_schedule_with_warmup( @@ -231,7 +230,13 @@ def train(config: Config): logger.info(log) if config.diloco is not None: + if config.train.log_model_hash: + with FSDP.summon_full_params(model): + logger.debug("Pre diloco model: %s", get_module_signature(model)) diloco.step(model) + if config.train.log_model_hash: + with FSDP.summon_full_params(model): + logger.debug("Post diloco model: %s", get_module_signature(model)) outer_step += 1 @@ -255,11 +260,9 @@ def train(config: Config): world_info = get_world_info() logger = get_logger() - init_process_group() torch.cuda.set_device(world_info.local_rank) config = Config(**parse_argv()) logger.debug(f"config: {config.model_dump()}") train(config) - destroy_process_group() diff --git a/src/zeroband/utils/__init__.py b/src/zeroband/utils/__init__.py index 80abef69..14486b23 100644 --- a/src/zeroband/utils/__init__.py +++ b/src/zeroband/utils/__init__.py @@ -93,15 +93,31 @@ def get_tokens_per_second(self) -> float | None: return sum(self.tokens) / (self.times[-1] - self.times[0]) -def get_model_hash(model: torch.nn.Module) -> str: +TENSOR_SIG_SAMPLE_SIZE = 100 + + +def get_tensor_signature(a: torch.Tensor | torch.nn.Parameter) -> str: """ - Get the hash of the model parameters. + Get the tensor signature """ - # Concatenate all model parameters into a single tensor - all_params = torch.cat([p.data.view(-1) for p in model.parameters()]) + while isinstance(a, torch.nn.Parameter): + a = a.data + if a.numel() < TENSOR_SIG_SAMPLE_SIZE: + b = a.as_strided(size=(a.numel(),), stride=(1,)) + else: + step_size = a.numel() // TENSOR_SIG_SAMPLE_SIZE + b = a.as_strided(size=(TENSOR_SIG_SAMPLE_SIZE,), stride=(step_size,)) + element_str = "".join([f"{x:.3e}" for x in b]) + element_hash = hashlib.md5(element_str.encode("utf-8")).hexdigest() + return f"{a.dtype}{a.shape}{a.stride()}<{element_hash}>" - # Convert the tensor to a byte string - param_bytes = all_params.cpu().numpy().tobytes() - # Compute SHA256 hash - return hashlib.sha256(param_bytes).hexdigest() +def get_module_signature(module: torch.nn.Module, compress: bool = True) -> str: + """ + Get the module signature + """ + state_dict_sig = {name: get_tensor_signature(param) for name, param in module.named_parameters()} + if compress: + return hashlib.md5(str(state_dict_sig).encode("utf-8")).hexdigest() + else: + return "\n".join(f"{name}: {sig}" for name, sig in state_dict_sig.items()) diff --git a/src/zeroband/utils/world_info.py b/src/zeroband/utils/world_info.py index efe30a1a..9b73f328 100644 --- a/src/zeroband/utils/world_info.py +++ b/src/zeroband/utils/world_info.py @@ -18,6 +18,15 @@ def __init__(self): self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) self.nnodes = self.world_size // self.local_world_size + self.global_unique_id = os.environ.get("GLOBAL_UNIQUE_ID", None) + self.global_addr = os.environ.get("GLOBAL_ADDR", None) + self.global_port = int(os.environ.get("GLOBAL_PORT")) if "GLOBAL_PORT" in os.environ else None + self.global_world_size = int(os.environ.get("GLOBAL_WORLD_SIZE", 1)) + self.global_rank = int(os.environ.get("GLOBAL_RANK", 0)) + + def __repr__(self): + return f"WorldInfo(world_size={self.world_size}, rank={self.rank}, local_rank={self.local_rank}, local_world_size={self.local_world_size}, nnodes={self.nnodes}, global_unique_id={self.global_unique_id}, global_addr={self.global_addr}, global_port={self.global_port}, global_world_size={self.global_world_size}, global_rank={self.global_rank})" + def get_world_info() -> WorldInfo: """ diff --git a/tests/test_dist/conftest.py b/tests/test_dist/conftest.py index 2f01829e..dd9df4d6 100644 --- a/tests/test_dist/conftest.py +++ b/tests/test_dist/conftest.py @@ -45,14 +45,17 @@ def random_available_port(): @pytest.fixture() def dist_environment() -> callable: @contextmanager - def dist_environment(random_available_port, local_rank=0, world_size=1, local_world_size=1): + def dist_environment( + random_available_port, rank=0, local_rank=0, world_size=1, local_world_size=1, global_unique_id="" + ): with mock.patch.dict( os.environ, { - "LOCAL_RANK": str(local_rank), + "GLOBAL_UNIQUE_ID": global_unique_id, + "RANK": str(rank), "WORLD_SIZE": str(world_size), + "LOCAL_RANK": str(local_rank), "LOCAL_WORLD_SIZE": str(local_world_size), - "RANK": str(local_rank), "MASTER_ADDR": "localhost", "MASTER_PORT": str(random_available_port), "ZERO_BAND_LOG_LEVEL": "DEBUG", diff --git a/tests/test_dist/test_all_reduce.py b/tests/test_dist/test_all_reduce.py index 28133070..2ee020f4 100644 --- a/tests/test_dist/test_all_reduce.py +++ b/tests/test_dist/test_all_reduce.py @@ -16,8 +16,8 @@ @pytest.mark.parametrize("world_size", [2]) def test_all_reduce(world_size, random_available_port, dist_environment): def all_reduce(rank: int, world_size: int): - with dist_environment(random_available_port, local_rank=rank, world_size=world_size): - data = (rank + 1) * torch.ones(10, 10).to("cuda") + with dist_environment(random_available_port, rank=rank, world_size=world_size): + data = (rank + 1) * torch.ones(10, 10).to(f"cuda:{rank}") dist.all_reduce(data, op=dist.ReduceOp.SUM) assert data.mean() == sum([i + 1 for i in range(world_size)]) diff --git a/tests/test_dist/test_diloco.py b/tests/test_dist/test_diloco.py index 43fbdcc9..c9a90e94 100644 --- a/tests/test_dist/test_diloco.py +++ b/tests/test_dist/test_diloco.py @@ -22,7 +22,7 @@ def test_diloco_all_reduce(world_size, random_available_port, dist_environment): """ def all_reduce(rank: int, world_size: int): - with dist_environment(random_available_port, local_rank=rank, world_size=world_size): + with dist_environment(random_available_port, rank=rank, world_size=world_size, global_unique_id=str(rank)): diloco_config = DilocoConfig(inner_steps=10) model = torch.nn.Linear(10, 10)