diff --git a/configs/150M/3090.toml b/configs/150M/3090.toml index 866d6054..e792dd00 100644 --- a/configs/150M/3090.toml +++ b/configs/150M/3090.toml @@ -3,7 +3,7 @@ project = "debug_150m_zero_band" [train] micro_bs = 16 # change this base on the gpu -sharding_strategy = "NO_SHARD" +sharding_strategy = "SHARD_GRAD_OP" [optim] batch_size = 512 diff --git a/configs/150M/A40.toml b/configs/150M/A40.toml index e7799417..867679c1 100644 --- a/configs/150M/A40.toml +++ b/configs/150M/A40.toml @@ -3,7 +3,7 @@ project = "debug_150m_zero_band" [train] micro_bs = 32 # change this base on the gpu -sharding_strategy = "NO_SHARD" +sharding_strategy = "SHARD_GRAD_OP" [optim] batch_size = 512 diff --git a/configs/150M/H100.toml b/configs/150M/H100.toml index 49a65475..3b5d7dfa 100644 --- a/configs/150M/H100.toml +++ b/configs/150M/H100.toml @@ -3,7 +3,7 @@ project = "debug_150m_zero_band" [train] micro_bs = 64 # change this base on the gpu -sharding_strategy = "NO_SHARD" +sharding_strategy = "SHARD_GRAD_OP" [optim] batch_size = 512 diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 114284f0..b132c707 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -66,11 +66,11 @@ def __init__( config: DilocoConfig, model: nn.Module, fsdp_sharding_strategy: ShardingStrategy, - elastic_device_mesh: ElasticDeviceMesh, + global_pg: dist.ProcessGroup, ): self.config = config self.fsdp_sharding_strategy = fsdp_sharding_strategy - self.elastic_device_mesh = elastic_device_mesh + self.global_pg = global_pg self._logger = get_logger() self.world_info = get_world_info() @@ -93,22 +93,21 @@ def sync_pseudo_gradient(self, model: nn.Module): """ self._logger.debug("sync pseudo gradient") for param_offloaded, param in zip(self.param_list_cpu, model.parameters()): - # todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device) # gloo does not support AVG - param_offloaded.grad = param_offloaded.grad / self.elastic_device_mesh.global_pg.size() - dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.elastic_device_mesh.global_pg) - # todo maybe do async here + param_offloaded.grad = param_offloaded.grad / self.global_pg.size() + dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.global_pg) + # todo async here def sync_inner_model(self, model: nn.Module): """ - Sync the inner model from the global process group to the local process group + Sync the inner model from the CPU outer model to GPU """ self._logger.debug("sync inner model") for param_offloaded, param in zip(self.param_list_cpu, model.parameters()): - param.data = param_offloaded.data.to("cuda") # todo: use copy_ here + param.data.copy_(param_offloaded.data) # todo: use copy_ here def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]: """ diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 16087c71..b37c0fe4 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -20,7 +20,7 @@ from zeroband import utils from zeroband.diloco import Diloco, DilocoConfig, ElasticDeviceMesh -from zeroband.utils import PerfCounter, get_sharding_strategy +from zeroband.utils import PerfCounter, get_model_hash, get_sharding_strategy from zeroband.utils.monitor import WandbMonitor, DummyMonitor from zeroband.data import TEST_VOCAB_SIZE, get_dataloader from zeroband.models.llama import get_model @@ -50,6 +50,8 @@ class TrainConfig(BaseConfig): torch_compile: bool = True sharding_strategy: str = "SHARD_GRAD_OP" + log_model_hash: bool = False + class Config(BaseConfig): # main config @@ -90,12 +92,16 @@ def train(config: Config): num_workers=config.data.num_workers, fake_data=config.data.fake, ) - model, model_config = get_model( config.name_model, config.type_model, vocab_size=tokenizer.vocab_size if config.name_model != "debugmodel" else TEST_VOCAB_SIZE, ) + + if config.train.log_model_hash: + # Compute SHA256 hash + logger.info(f"Model hash: {get_model_hash(model)}") + model = model.to(world_info.local_rank) logger.debug("model loaded") @@ -124,12 +130,6 @@ def train(config: Config): model = torch.compile(model) logger.debug("model compiled and fsdped") - if config.diloco is not None: - if world_info.local_world_size == 1: - raise ValueError("Diloco is not supported for local_world_size == 1 because of a pytorch bug") - - diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh) - # Setup optimizers inner_optimizer = torch.optim.AdamW( model.parameters(), @@ -138,6 +138,12 @@ def train(config: Config): betas=(config.optim.adam_betas1, config.optim.adam_betas2), ) + if config.diloco is not None: + if world_info.local_world_size == 1: + raise ValueError("Diloco is not supported for local_world_size == 1 because of a pytorch bug") + + diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh.global_pg) + scheduler = get_cosine_schedule_with_warmup( inner_optimizer, num_warmup_steps=config.optim.warmup_steps, @@ -192,9 +198,9 @@ def train(config: Config): real_step = outer_step * num_inner_steps + inner_step + 1 # add + 1 because inner_step start at 0 inner_lr = [group["lr"] for group in inner_optimizer.param_groups][0] - dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG) - # syncing loss across all data parallel rank - # todo(sami): when using diloco make sure that the loss is computed only on local world + dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg) + # syncing loss across all data parallel rank within a nodes + perf_counter.count_tokens(config.data.seq_length * config.optim.batch_size) metrics = { @@ -244,6 +250,7 @@ def train(config: Config): # However, in development, we want to know that we broke torch compile torch._dynamo.config.suppress_errors = "ZERO_BAND_DEV" not in os.environ torch.set_float32_matmul_precision("high") + torch.manual_seed(42) # this ensure same weight init across diloco workers world_info = get_world_info() logger = get_logger() diff --git a/src/zeroband/utils/__init__.py b/src/zeroband/utils/__init__.py index 12a6ce8c..80abef69 100644 --- a/src/zeroband/utils/__init__.py +++ b/src/zeroband/utils/__init__.py @@ -1,3 +1,4 @@ +import hashlib import time import torch from torch.distributed.fsdp import ShardingStrategy @@ -90,3 +91,17 @@ def get_tokens_per_second(self) -> float | None: if len(self.tokens) < 2: return None return sum(self.tokens) / (self.times[-1] - self.times[0]) + + +def get_model_hash(model: torch.nn.Module) -> str: + """ + Get the hash of the model parameters. + """ + # Concatenate all model parameters into a single tensor + all_params = torch.cat([p.data.view(-1) for p in model.parameters()]) + + # Convert the tensor to a byte string + param_bytes = all_params.cpu().numpy().tobytes() + + # Compute SHA256 hash + return hashlib.sha256(param_bytes).hexdigest() diff --git a/tests/test_dist.py b/tests/test_dist.py deleted file mode 100644 index 4e3314e9..00000000 --- a/tests/test_dist.py +++ /dev/null @@ -1,84 +0,0 @@ -""" -torch distribted test - -this test are different from the torchrun integration tests - -They manually do the job of torchrun to start the distributed process making it easy to write unit tests -""" - -import torch.distributed as dist -import torch -import pytest -from torch.distributed import destroy_process_group, init_process_group - - -import os -from unittest import mock -import socket -from contextlib import contextmanager -import multiprocessing -import gc - - -@pytest.fixture(autouse=True) -def memory_cleanup(): - # credits to : https://github.com/pytorch/pytorch/issues/82218#issuecomment-1675254117 - try: - gc.collect() - torch.cuda.empty_cache() - yield - finally: - gc.collect() - torch.cuda.empty_cache() - - -def get_random_available_port(): - # https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - return s.getsockname()[1] - - -@pytest.fixture() -def random_available_port(): - return get_random_available_port() - - -@contextmanager -def dist_environment(random_available_port, local_rank=0, world_size=1): - with mock.patch.dict( - os.environ, - { - "LOCAL_RANK": str(local_rank), - "WORLD_SIZE": str(world_size), - "RANK": str(local_rank), - "MASTER_ADDR": "localhost", - "MASTER_PORT": str(random_available_port), - }, - ): - try: - init_process_group() - torch.cuda.set_device(local_rank) - yield - finally: - destroy_process_group() - - -@pytest.mark.parametrize("world_size", [2]) -def test_all_reduce(world_size, random_available_port): - def all_reduce(rank: int, world_size: int): - with dist_environment(random_available_port, local_rank=rank, world_size=world_size): - print(f"os.environ['LOCAL_RANK'] {os.environ['WORLD_SIZE']}") - data = (rank + 1) * torch.ones(10, 10).to("cuda") - print(data.mean()) - dist.all_reduce(data, op=dist.ReduceOp.SUM) - print(data.mean()) - assert data.mean() == sum([i + 1 for i in range(world_size)]) - - processes = [multiprocessing.Process(target=all_reduce, args=(rank, world_size)) for rank in range(world_size)] - for p in processes: - p.start() - for p in processes: - p.join() - if p.exitcode != 0: - pytest.fail(f"Process {p.pid} failed with exit code {p.exitcode}") diff --git a/tests/test_dist/conftest.py b/tests/test_dist/conftest.py new file mode 100644 index 00000000..2f01829e --- /dev/null +++ b/tests/test_dist/conftest.py @@ -0,0 +1,68 @@ +""" +torch distribted test + +this test are different from the torchrun integration tests + +They manually do the job of torchrun to start the distributed process making it easy to write unit tests +""" + +import torch +import pytest +from torch.distributed import destroy_process_group, init_process_group + + +import os +from unittest import mock +import socket +from contextlib import contextmanager +import gc + + +@pytest.fixture(autouse=True) +def memory_cleanup(): + # credits to : https://github.com/pytorch/pytorch/issues/82218#issuecomment-1675254117 + try: + gc.collect() + torch.cuda.empty_cache() + yield + finally: + gc.collect() + torch.cuda.empty_cache() + + +def get_random_available_port(): + # https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +@pytest.fixture() +def random_available_port(): + return get_random_available_port() + + +@pytest.fixture() +def dist_environment() -> callable: + @contextmanager + def dist_environment(random_available_port, local_rank=0, world_size=1, local_world_size=1): + with mock.patch.dict( + os.environ, + { + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "LOCAL_WORLD_SIZE": str(local_world_size), + "RANK": str(local_rank), + "MASTER_ADDR": "localhost", + "MASTER_PORT": str(random_available_port), + "ZERO_BAND_LOG_LEVEL": "DEBUG", + }, + ): + try: + init_process_group() + torch.cuda.set_device(local_rank) + yield + finally: + destroy_process_group() + + return dist_environment diff --git a/tests/test_dist/test_all_reduce.py b/tests/test_dist/test_all_reduce.py new file mode 100644 index 00000000..28133070 --- /dev/null +++ b/tests/test_dist/test_all_reduce.py @@ -0,0 +1,30 @@ +""" +torch distribted test + +this test are different from the torchrun integration tests + +They manually do the job of torchrun to start the distributed process making it easy to write unit tests +""" + +import torch.distributed as dist +import torch +import pytest + +import multiprocessing + + +@pytest.mark.parametrize("world_size", [2]) +def test_all_reduce(world_size, random_available_port, dist_environment): + def all_reduce(rank: int, world_size: int): + with dist_environment(random_available_port, local_rank=rank, world_size=world_size): + data = (rank + 1) * torch.ones(10, 10).to("cuda") + dist.all_reduce(data, op=dist.ReduceOp.SUM) + assert data.mean() == sum([i + 1 for i in range(world_size)]) + + processes = [multiprocessing.Process(target=all_reduce, args=(rank, world_size)) for rank in range(world_size)] + for p in processes: + p.start() + for p in processes: + p.join() + if p.exitcode != 0: + pytest.fail(f"Process {p.pid} failed with exit code {p.exitcode}") diff --git a/tests/test_dist/test_diloco.py b/tests/test_dist/test_diloco.py new file mode 100644 index 00000000..43fbdcc9 --- /dev/null +++ b/tests/test_dist/test_diloco.py @@ -0,0 +1,58 @@ +"""test Diloco.""" + +import multiprocessing +import pytest + +import torch +import torch.distributed as dist +from torch.distributed.fsdp import ShardingStrategy + +from zeroband.diloco import Diloco, DilocoConfig + + +@pytest.mark.parametrize("world_size", [2]) # [1, 2]) +def test_diloco_all_reduce(world_size, random_available_port, dist_environment): + """ + In this test we manually create a inner model and a outer model where we control the weight: + inner has weight: (rank + 1) / 2 + outer has weight: (rank + 1) + + since we know the world_size we can predict the results of the all reduce of the pseudo gradient and therefore test + if it is done correclty. + """ + + def all_reduce(rank: int, world_size: int): + with dist_environment(random_available_port, local_rank=rank, world_size=world_size): + diloco_config = DilocoConfig(inner_steps=10) + + model = torch.nn.Linear(10, 10) + + # init param to rank + 1 + for param in model.parameters(): + param.data = (rank + 1) * torch.ones_like(param.data).to("cuda") + + global_pg = dist.new_group(backend="gloo") + diloco = Diloco(diloco_config, model, ShardingStrategy.FULL_SHARD, global_pg) + + # simulate inner model updates + for param in model.parameters(): + param.data = (rank + 1) / 2 * torch.ones_like(param.data).to("cuda") + + diloco.sync_pseudo_gradient(model) + + for param in diloco.param_list_cpu: + print(f"param.grad.mean() {param.grad.mean()}") + target = ( + torch.ones_like(param.grad) + * sum([(rank + 1) - (rank + 1) / 2 for rank in range(world_size)]) + / world_size + ) + assert param.grad.mean() == target.mean() + + processes = [multiprocessing.Process(target=all_reduce, args=(rank, world_size)) for rank in range(world_size)] + for p in processes: + p.start() + for p in processes: + p.join() + if p.exitcode != 0: + pytest.fail(f"Process {p.pid} failed with exit code {p.exitcode}")