diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 8507acb1..85069139 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -66,11 +66,11 @@ def __init__( config: DilocoConfig, model: nn.Module, fsdp_sharding_strategy: ShardingStrategy, - elastic_device_mesh: ElasticDeviceMesh, + global_pg: dist.ProcessGroup, ): self.config = config self.fsdp_sharding_strategy = fsdp_sharding_strategy - self.elastic_device_mesh = elastic_device_mesh + self.global_pg = global_pg self._logger = get_logger() self.world_info = get_world_info() @@ -93,13 +93,12 @@ def sync_pseudo_gradient(self, model: nn.Module): """ self._logger.debug("sync pseudo gradient") for param_offloaded, param in zip(self.param_list_cpu, model.parameters()): - # todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device) # gloo does not support AVG - param_offloaded.grad = param_offloaded.grad / self.elastic_device_mesh.global_pg.size() - dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.elastic_device_mesh.global_pg) - # todo maybe do async here + param_offloaded.grad = param_offloaded.grad / self.global_pg.size() + dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.global_pg) + # todo async here def sync_inner_model(self, model: nn.Module): """ diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 5a63fcf6..ed67b24d 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -128,7 +128,7 @@ def train(config: Config): if world_info.local_world_size == 1: raise ValueError("Diloco is not supported for local_world_size == 1 because of a pytorch bug") - diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh) + diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh.global_pg) # Setup optimizers inner_optimizer = torch.optim.AdamW( diff --git a/tests/test_dist/conftest.py b/tests/test_dist/conftest.py index 32f3f1c8..47a3eb99 100644 --- a/tests/test_dist/conftest.py +++ b/tests/test_dist/conftest.py @@ -45,12 +45,13 @@ def random_available_port(): @pytest.fixture() def dist_environment() -> callable: @contextmanager - def dist_environment(random_available_port, local_rank=0, world_size=1): + def dist_environment(random_available_port, local_rank=0, world_size=1, local_world_size=1): with mock.patch.dict( os.environ, { "LOCAL_RANK": str(local_rank), "WORLD_SIZE": str(world_size), + "LOCAL_WORLD_SIZE": str(local_world_size), "RANK": str(local_rank), "MASTER_ADDR": "localhost", "MASTER_PORT": str(random_available_port), diff --git a/tests/test_dist/test_all_reduce.py b/tests/test_dist/test_all_reduce.py index 3274631c..28133070 100644 --- a/tests/test_dist/test_all_reduce.py +++ b/tests/test_dist/test_all_reduce.py @@ -10,8 +10,6 @@ import torch import pytest - -import os import multiprocessing @@ -19,11 +17,8 @@ def test_all_reduce(world_size, random_available_port, dist_environment): def all_reduce(rank: int, world_size: int): with dist_environment(random_available_port, local_rank=rank, world_size=world_size): - print(f"os.environ['LOCAL_RANK'] {os.environ['WORLD_SIZE']}") data = (rank + 1) * torch.ones(10, 10).to("cuda") - print(data.mean()) dist.all_reduce(data, op=dist.ReduceOp.SUM) - print(data.mean()) assert data.mean() == sum([i + 1 for i in range(world_size)]) processes = [multiprocessing.Process(target=all_reduce, args=(rank, world_size)) for rank in range(world_size)] diff --git a/tests/test_dist/test_diloco.py b/tests/test_dist/test_diloco.py index c81ea742..43fbdcc9 100644 --- a/tests/test_dist/test_diloco.py +++ b/tests/test_dist/test_diloco.py @@ -1 +1,58 @@ -"""test Diloco. Need 4 gpus to run this tests""" +"""test Diloco.""" + +import multiprocessing +import pytest + +import torch +import torch.distributed as dist +from torch.distributed.fsdp import ShardingStrategy + +from zeroband.diloco import Diloco, DilocoConfig + + +@pytest.mark.parametrize("world_size", [2]) # [1, 2]) +def test_diloco_all_reduce(world_size, random_available_port, dist_environment): + """ + In this test we manually create a inner model and a outer model where we control the weight: + inner has weight: (rank + 1) / 2 + outer has weight: (rank + 1) + + since we know the world_size we can predict the results of the all reduce of the pseudo gradient and therefore test + if it is done correclty. + """ + + def all_reduce(rank: int, world_size: int): + with dist_environment(random_available_port, local_rank=rank, world_size=world_size): + diloco_config = DilocoConfig(inner_steps=10) + + model = torch.nn.Linear(10, 10) + + # init param to rank + 1 + for param in model.parameters(): + param.data = (rank + 1) * torch.ones_like(param.data).to("cuda") + + global_pg = dist.new_group(backend="gloo") + diloco = Diloco(diloco_config, model, ShardingStrategy.FULL_SHARD, global_pg) + + # simulate inner model updates + for param in model.parameters(): + param.data = (rank + 1) / 2 * torch.ones_like(param.data).to("cuda") + + diloco.sync_pseudo_gradient(model) + + for param in diloco.param_list_cpu: + print(f"param.grad.mean() {param.grad.mean()}") + target = ( + torch.ones_like(param.grad) + * sum([(rank + 1) - (rank + 1) / 2 for rank in range(world_size)]) + / world_size + ) + assert param.grad.mean() == target.mean() + + processes = [multiprocessing.Process(target=all_reduce, args=(rank, world_size)) for rank in range(world_size)] + for p in processes: + p.start() + for p in processes: + p.join() + if p.exitcode != 0: + pytest.fail(f"Process {p.pid} failed with exit code {p.exitcode}")