Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix diloco 2 #13

Merged
merged 9 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/150M/3090.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ project = "debug_150m_zero_band"

[train]
micro_bs = 16 # change this base on the gpu
sharding_strategy = "NO_SHARD"
sharding_strategy = "SHARD_GRAD_OP"

[optim]
batch_size = 512
Expand Down
2 changes: 1 addition & 1 deletion configs/150M/A40.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ project = "debug_150m_zero_band"

[train]
micro_bs = 32 # change this base on the gpu
sharding_strategy = "NO_SHARD"
sharding_strategy = "SHARD_GRAD_OP"

[optim]
batch_size = 512
Expand Down
2 changes: 1 addition & 1 deletion configs/150M/H100.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ project = "debug_150m_zero_band"

[train]
micro_bs = 64 # change this base on the gpu
sharding_strategy = "NO_SHARD"
sharding_strategy = "SHARD_GRAD_OP"

[optim]
batch_size = 512
Expand Down
15 changes: 7 additions & 8 deletions src/zeroband/diloco.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ def __init__(
config: DilocoConfig,
model: nn.Module,
fsdp_sharding_strategy: ShardingStrategy,
elastic_device_mesh: ElasticDeviceMesh,
global_pg: dist.ProcessGroup,
Jackmin801 marked this conversation as resolved.
Show resolved Hide resolved
):
self.config = config
self.fsdp_sharding_strategy = fsdp_sharding_strategy
self.elastic_device_mesh = elastic_device_mesh
self.global_pg = global_pg

self._logger = get_logger()
self.world_info = get_world_info()
Expand All @@ -93,22 +93,21 @@ def sync_pseudo_gradient(self, model: nn.Module):
"""
self._logger.debug("sync pseudo gradient")
for param_offloaded, param in zip(self.param_list_cpu, model.parameters()):
# todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices
param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device)

# gloo does not support AVG
param_offloaded.grad = param_offloaded.grad / self.elastic_device_mesh.global_pg.size()
dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.elastic_device_mesh.global_pg)
# todo maybe do async here
param_offloaded.grad = param_offloaded.grad / self.global_pg.size()
dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.global_pg)
# todo async here

def sync_inner_model(self, model: nn.Module):
"""
Sync the inner model from the global process group to the local process group
Sync the inner model from the CPU outer model to GPU
"""

self._logger.debug("sync inner model")
for param_offloaded, param in zip(self.param_list_cpu, model.parameters()):
param.data = param_offloaded.data.to("cuda") # todo: use copy_ here
param.data.copy_(param_offloaded.data.to(param.device)) # todo: use copy_ here
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the .to(device) isnt needed here. You should be able to directly copy the data

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting so it will handle the move to device by itself ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my understanding is that a.copy_(b) means copy the content of b to a inplace so it works across devices and doesnt allocate a new intermediate storage.

meanwhile param.copy_(param_offloaded.data.to(param.device)) will first evaluate the .to which creates a new intermediate storage then copies the content. might be wrong tho, not sure how smart compilers/interpreters are

Copy link
Collaborator Author

@samsja samsja Sep 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]:
"""
Expand Down
29 changes: 18 additions & 11 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from zeroband import utils
from zeroband.diloco import Diloco, DilocoConfig, ElasticDeviceMesh

from zeroband.utils import PerfCounter, get_sharding_strategy
from zeroband.utils import PerfCounter, get_model_hash, get_sharding_strategy
from zeroband.utils.monitor import WandbMonitor, DummyMonitor
from zeroband.data import TEST_VOCAB_SIZE, get_dataloader
from zeroband.models.llama import get_model
Expand Down Expand Up @@ -50,6 +50,8 @@ class TrainConfig(BaseConfig):
torch_compile: bool = True
sharding_strategy: str = "SHARD_GRAD_OP"

log_model_hash: bool = False


class Config(BaseConfig):
# main config
Expand Down Expand Up @@ -90,12 +92,16 @@ def train(config: Config):
num_workers=config.data.num_workers,
fake_data=config.data.fake,
)

model, model_config = get_model(
config.name_model,
config.type_model,
vocab_size=tokenizer.vocab_size if config.name_model != "debugmodel" else TEST_VOCAB_SIZE,
)

if config.train.log_model_hash:
# Compute SHA256 hash
logger.info(f"Model hash: {get_model_hash(model)}")

model = model.to(world_info.local_rank)
logger.debug("model loaded")

Expand Down Expand Up @@ -124,12 +130,6 @@ def train(config: Config):
model = torch.compile(model)
logger.debug("model compiled and fsdped")

if config.diloco is not None:
if world_info.local_world_size == 1:
raise ValueError("Diloco is not supported for local_world_size == 1 because of a pytorch bug")

diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh)

Jackmin801 marked this conversation as resolved.
Show resolved Hide resolved
# Setup optimizers
inner_optimizer = torch.optim.AdamW(
model.parameters(),
Expand All @@ -138,6 +138,12 @@ def train(config: Config):
betas=(config.optim.adam_betas1, config.optim.adam_betas2),
)

if config.diloco is not None:
if world_info.local_world_size == 1:
raise ValueError("Diloco is not supported for local_world_size == 1 because of a pytorch bug")

diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh.global_pg)

scheduler = get_cosine_schedule_with_warmup(
inner_optimizer,
num_warmup_steps=config.optim.warmup_steps,
Expand Down Expand Up @@ -192,9 +198,9 @@ def train(config: Config):
real_step = outer_step * num_inner_steps + inner_step + 1 # add + 1 because inner_step start at 0
inner_lr = [group["lr"] for group in inner_optimizer.param_groups][0]

dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG)
# syncing loss across all data parallel rank
# todo(sami): when using diloco make sure that the loss is computed only on local world
dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg)
# syncing loss across all data parallel rank within a nodes

perf_counter.count_tokens(config.data.seq_length * config.optim.batch_size)

metrics = {
Expand Down Expand Up @@ -244,6 +250,7 @@ def train(config: Config):
# However, in development, we want to know that we broke torch compile
torch._dynamo.config.suppress_errors = "ZERO_BAND_DEV" not in os.environ
torch.set_float32_matmul_precision("high")
torch.manual_seed(42) # this ensure same weight init across diloco workers

world_info = get_world_info()
logger = get_logger()
Expand Down
15 changes: 15 additions & 0 deletions src/zeroband/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import time
import torch
from torch.distributed.fsdp import ShardingStrategy
Expand Down Expand Up @@ -90,3 +91,17 @@ def get_tokens_per_second(self) -> float | None:
if len(self.tokens) < 2:
return None
return sum(self.tokens) / (self.times[-1] - self.times[0])


def get_model_hash(model: torch.nn.Module) -> str:
"""
Get the hash of the model parameters.
"""
# Concatenate all model parameters into a single tensor
all_params = torch.cat([p.data.view(-1) for p in model.parameters()])

# Convert the tensor to a byte string
param_bytes = all_params.cpu().numpy().tobytes()

# Compute SHA256 hash
return hashlib.sha256(param_bytes).hexdigest()
84 changes: 0 additions & 84 deletions tests/test_dist.py

This file was deleted.

68 changes: 68 additions & 0 deletions tests/test_dist/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""
torch distribted test

this test are different from the torchrun integration tests

They manually do the job of torchrun to start the distributed process making it easy to write unit tests
"""

import torch
import pytest
from torch.distributed import destroy_process_group, init_process_group


import os
from unittest import mock
import socket
from contextlib import contextmanager
import gc


@pytest.fixture(autouse=True)
def memory_cleanup():
# credits to : https://github.com/pytorch/pytorch/issues/82218#issuecomment-1675254117
try:
gc.collect()
torch.cuda.empty_cache()
yield
finally:
gc.collect()
torch.cuda.empty_cache()


def get_random_available_port():
# https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]


@pytest.fixture()
def random_available_port():
return get_random_available_port()


@pytest.fixture()
def dist_environment() -> callable:
@contextmanager
def dist_environment(random_available_port, local_rank=0, world_size=1, local_world_size=1):
with mock.patch.dict(
os.environ,
{
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"LOCAL_WORLD_SIZE": str(local_world_size),
"RANK": str(local_rank),
"MASTER_ADDR": "localhost",
"MASTER_PORT": str(random_available_port),
"ZERO_BAND_LOG_LEVEL": "DEBUG",
},
):
try:
init_process_group()
torch.cuda.set_device(local_rank)
yield
finally:
destroy_process_group()

return dist_environment
30 changes: 30 additions & 0 deletions tests/test_dist/test_all_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
torch distribted test

this test are different from the torchrun integration tests

They manually do the job of torchrun to start the distributed process making it easy to write unit tests
"""

import torch.distributed as dist
import torch
import pytest

import multiprocessing


@pytest.mark.parametrize("world_size", [2])
def test_all_reduce(world_size, random_available_port, dist_environment):
def all_reduce(rank: int, world_size: int):
with dist_environment(random_available_port, local_rank=rank, world_size=world_size):
data = (rank + 1) * torch.ones(10, 10).to("cuda")
dist.all_reduce(data, op=dist.ReduceOp.SUM)
assert data.mean() == sum([i + 1 for i in range(world_size)])

processes = [multiprocessing.Process(target=all_reduce, args=(rank, world_size)) for rank in range(world_size)]
for p in processes:
p.start()
for p in processes:
p.join()
if p.exitcode != 0:
pytest.fail(f"Process {p.pid} failed with exit code {p.exitcode}")
Loading