Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add init code #2

Merged
merged 19 commits into from
Sep 21, 2024
Prev Previous commit
Next Next commit
apply ruff format
samsja committed Sep 20, 2024
commit 93be8f6a7995d06ee7f265c8710516db4007e320
10 changes: 5 additions & 5 deletions src/zeroband/data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from functools import partial
from typing import Any, Generator

@@ -32,7 +31,7 @@ def collate_causal_mask(max_seq_length: int = -1, pad_id: int = 0, ignore_index:
def _collate_fn_causal_mask(
samples: list[dict[str, torch.LongTensor]], max_seq_length: int = -1, pad_id: int = 0, ignore_index: int = -100
) -> dict[str, torch.LongTensor]:
"""collate function for causal mask. Fill with padding tokens if sequence is shorter than max_seq_length.
"""collate function for causal mask. Fill with padding tokens if sequence is shorter than max_seq_length.
input_ids and labels are both of size max_seq_length.
"""

@@ -57,11 +56,13 @@ def _collate_fn_causal_mask(
return {"input_ids": torch.stack(batched["input_ids"], dim=0), "labels": torch.stack(batched["labels"], dim=0)}


def get_dataloader(pad_token_id: int, world_size: int, rank: int, seq_length: int, batch_size: int, num_workers: int) -> DataLoader:
def get_dataloader(
pad_token_id: int, world_size: int, rank: int, seq_length: int, batch_size: int, num_workers: int
) -> DataLoader:
"""
Get a pytorch dataloader to train on
"""
#todo add real dataset and world splitting
# todo add real dataset and world splitting
train_dataset = FakeTokenizedDataset(seq_length, TEST_VOCAB_SIZE)
data_collator = collate_causal_mask(max_seq_length=seq_length, pad_id=pad_token_id, ignore_index=-100)

@@ -71,4 +72,3 @@ def get_dataloader(pad_token_id: int, world_size: int, rank: int, seq_length: in
batch_size=batch_size,
num_workers=num_workers,
)

6 changes: 3 additions & 3 deletions src/zeroband/models/llama/__init__.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@

llama2_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=2, n_heads=8),
"150M": ModelArgs(dim=1024, n_layers=12, n_heads=16), # todo(sami): double check this
"150M": ModelArgs(dim=1024, n_layers=12, n_heads=16), # todo(sami): double check this
"271M": ModelArgs(dim=1024, n_layers=16, n_heads=8),
"1B": ModelArgs(dim=2048, n_layers=18, n_heads=16),
"7B": ModelArgs(dim=4096, n_layers=32, n_heads=32),
@@ -60,6 +60,7 @@
),
}


def get_model(name_model: str, type_model: str, vocab_size: int) -> Transformer:
"""get the transformer model"""

@@ -69,7 +70,6 @@ def get_model(name_model: str, type_model: str, vocab_size: int) -> Transformer:
config = llama3_configs[name_model]
else:
raise ValueError(f"Model type {type_model} not supported")

config.vocab_size = vocab_size
return Transformer(config)

34 changes: 9 additions & 25 deletions src/zeroband/models/llama/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# this code is copy pasted from the torchtitan repo https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py
# this code is copy pasted from the torchtitan repo https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py
# the commit at time of copy paste was commit f2a1551

# Copyright (c) Meta Platforms, Inc. and affiliates.
@@ -152,22 +152,14 @@ class Attention(nn.Module):
def __init__(self, model_args: ModelArgs):
super().__init__()
self.n_heads = model_args.n_heads
self.n_kv_heads = (
model_args.n_heads
if model_args.n_kv_heads is None
else model_args.n_kv_heads
)
self.n_kv_heads = model_args.n_heads if model_args.n_kv_heads is None else model_args.n_kv_heads
self.n_rep = self.n_heads // self.n_kv_heads
self.head_dim = model_args.dim // model_args.n_heads

self.wq = nn.Linear(
model_args.dim, model_args.n_heads * self.head_dim, bias=False
)
self.wq = nn.Linear(model_args.dim, model_args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(
model_args.n_heads * self.head_dim, model_args.dim, bias=False
)
self.wo = nn.Linear(model_args.n_heads * self.head_dim, model_args.dim, bias=False)

def init_weights(self, init_std: float):
for linear in (self.wq, self.wk, self.wv):
@@ -212,9 +204,7 @@ def forward(

# we use casual mask for training
output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
output = output.transpose(
1, 2
).contiguous() # (bs, seqlen, n_local_heads, head_dim)
output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim)
output = output.view(bs, seqlen, -1)
return self.wo(output)

@@ -297,12 +287,8 @@ def __init__(self, layer_id: int, model_args: ModelArgs):
self.layer_id = layer_id
self.num_layers = model_args.n_layers

self.attention_norm = build_norm(
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
)
self.ffn_norm = build_norm(
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
)
self.attention_norm = build_norm(model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps)
self.ffn_norm = build_norm(model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps)

if model_args.depth_init:
self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5
@@ -376,9 +362,7 @@ def __init__(self, model_args: ModelArgs):
for layer_id in range(model_args.n_layers):
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)

self.norm = build_norm(
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
)
self.norm = build_norm(model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps)

self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
self.init_weights()
@@ -457,4 +441,4 @@ def from_model_args(cls, model_args: ModelArgs) -> "Transformer":
Transformer: Transformer model.

"""
return cls(model_args)
return cls(model_args)
4 changes: 2 additions & 2 deletions src/zeroband/models/norms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# this code is copy pasted from the torchtitan repo https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py
# this code is copy pasted from the torchtitan repo https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py
# the commit at time of copy paste was commit f2a1551

# Copyright (c) Meta Platforms, Inc. and affiliates.
@@ -330,4 +330,4 @@ def fused_rms_norm_fn(
x,
weight,
eps,
)
)
49 changes: 33 additions & 16 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
@@ -31,6 +31,7 @@ def ddp_setup():
init_process_group()
torch.cuda.set_device(world_info.local_rank)


class DilocoConfig(BaseConfig):
outer_lr: float = 0.7
inner_steps: int = 10
@@ -42,6 +43,7 @@ class DataConfig(BaseConfig):
fake_data: bool = False
num_workers: int = 4


class OptimConfig(BaseConfig):
lr: float = 4e-4
weight_decay: float = 0.1
@@ -52,28 +54,26 @@ class OptimConfig(BaseConfig):
total_steps: int = 88_000
batch_size: int = 512


class TrainConfig(BaseConfig):
micro_bs: int
torch_compile: bool = True
sharding_strategy: str = "SHARD_GRAD_OP"


class Config(BaseConfig):

# main config
name_model: Literal["debugmodel", "150M", "271M", "1B", "7B", "13B", "26B", "70B"] = "150M"
type_model: Literal["llama2","llama3"] = "llama2"
type_model: Literal["llama2", "llama3"] = "llama2"

project: str = "zeroband"
metric_logger_type: Literal["wandb", "dummy"] = "wandb"


# sub config
diloco: DilocoConfig | None = None
data: DataConfig = DataConfig()
optim: OptimConfig = OptimConfig()
train: TrainConfig



def train(config: Config):
@@ -90,9 +90,20 @@ def train(config: Config):
tokenizer.pad_token = "</s>" # todo(sami): remove padding tokens once we have context stuffing

logger.debug("tokenizer loaded")
train_dataloader = get_dataloader(tokenizer.pad_token_id, world_info.world_size, world_info.rank, config.data.seq_length, config.train.micro_bs, config.data.num_workers)
train_dataloader = get_dataloader(
tokenizer.pad_token_id,
world_info.world_size,
world_info.rank,
config.data.seq_length,
config.train.micro_bs,
config.data.num_workers,
)

model = get_model(config.name_model, config.type_model, vocab_size=tokenizer.vocab_size if config.name_model != "debugmodel" else TEST_VOCAB_SIZE)
model = get_model(
config.name_model,
config.type_model,
vocab_size=tokenizer.vocab_size if config.name_model != "debugmodel" else TEST_VOCAB_SIZE,
)
model = model.to(world_info.local_rank)
logger.debug("model loaded")

@@ -108,13 +119,18 @@ def train(config: Config):
logger.debug("model compiled and fsdped")

# Setup optimizers
inner_optimizer = torch.optim.AdamW(model.parameters(), lr=config.optim.lr, weight_decay=config.optim.weight_decay, betas=(config.optim.adam_betas1, config.optim.adam_betas2))
inner_optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.optim.lr,
weight_decay=config.optim.weight_decay,
betas=(config.optim.adam_betas1, config.optim.adam_betas2),
)

scheduler = get_cosine_schedule_with_warmup(
inner_optimizer,
num_warmup_steps=config.optim.warmup_steps,
num_training_steps=config.optim.total_steps,
)
)

model.train()

@@ -129,7 +145,6 @@ def train(config: Config):

logger.info("starting training")
while True:

if num_inner_steps > 1:
# if we don't use diloco we don't print the outer step logs
logger.info(f"outer_step step: {outer_step}")
@@ -144,11 +159,13 @@ def train(config: Config):
labels = batch["labels"].to("cuda")

with model.no_sync() if is_accumulating else nullcontext():
logits = model(tokens = input_ids).contiguous()
logits = model(tokens=input_ids).contiguous()
flatten_logits = rearrange(logits, "b seq vocab -> (b seq) vocab")
flatten_labels = rearrange(labels, "b seq -> (b seq)")

loss = F.cross_entropy(flatten_logits, flatten_labels, ignore_index=-100) / gradient_accumulation_steps
loss = (
F.cross_entropy(flatten_logits, flatten_labels, ignore_index=-100) / gradient_accumulation_steps
)
loss.backward()
loss_batch += loss.detach()

@@ -158,12 +175,12 @@ def train(config: Config):
inner_optimizer.zero_grad()

# logging
real_step = outer_step * num_inner_steps + inner_step + 1 # add + 1 because inner_step start at 0
real_step = outer_step * num_inner_steps + inner_step + 1 # add + 1 because inner_step start at 0
inner_lr = [group["lr"] for group in inner_optimizer.param_groups][0]

metrics = {
"Loss": loss_batch.item(), # todo(sami): do local all reduce for the loss
"step": real_step,
"Loss": loss_batch.item(), # todo(sami): do local all reduce for the loss
"step": real_step,
"inner_lr": inner_lr,
}

@@ -189,10 +206,10 @@ def train(config: Config):
# However, in development, we want to know that we broke torch compile
torch._dynamo.config.suppress_errors = "ZERO_BAND_DEV" not in os.environ
torch.set_float32_matmul_precision("high")

world_info = get_world_info()
logger = get_logger()

ddp_setup()

config = Config(**parse_argv())
5 changes: 3 additions & 2 deletions src/zeroband/utils/logging.py
Original file line number Diff line number Diff line change
@@ -5,17 +5,19 @@

logger = None


class CustomFormatter(logging.Formatter):
def __init__(self, local_rank: int):
super().__init__()
self.local_rank = local_rank

def format(self, record):
log_format = "{asctime} [{levelname}] [Rank {local_rank}] {message}"
formatter = logging.Formatter(log_format, style='{', datefmt="%H:%M:%S")
formatter = logging.Formatter(log_format, style="{", datefmt="%H:%M:%S")
record.local_rank = self.local_rank # Add this line to set the local rank in the record
return formatter.format(record)


def get_logger():
global logger # Add this line to modify the global logger variable
if logger is not None:
@@ -36,4 +38,3 @@ def get_logger():
logger.propagate = False # Prevent the log messages from being propagated to the root logger

return logger

8 changes: 6 additions & 2 deletions src/zeroband/utils/monitor.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
from typing import Any, Protocol
import importlib


class Monitor(Protocol):
def __init__(self, project, config): ...

@@ -14,18 +15,21 @@ class WandbMonitor:
def __init__(self, project, config, resume: bool):
if importlib.util.find_spec("wandb") is None:
raise ImportError("wandb is not installed. Please install it to use WandbMonitor.")

import wandb

wandb.init(
project=project, config=config, resume="auto" if resume else None
) # make wandb reuse the same run id if possible

def log(self, metrics: dict[str, Any]):
import wandb

wandb.log(metrics)

def finish(self):
import wandb

wandb.finish()


@@ -42,4 +46,4 @@ def log(self, metrics: dict[str, Any]):

def finish(self):
with open(self.project, "wb") as f:
pickle.dump(self.data, f)
pickle.dump(self.data, f)
4 changes: 3 additions & 1 deletion src/zeroband/utils/world_info.py
Original file line number Diff line number Diff line change
@@ -2,8 +2,10 @@

world_info = None


class WorldInfo:
"""This class parse env var about torch world into class variables."""

world_size: int
rank: int
local_rank: int
@@ -15,6 +17,7 @@ def __init__(self):
self.local_rank = int(os.environ["LOCAL_RANK"])
self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])


def get_world_info() -> WorldInfo:
"""
Return a WorldInfo singleton.
@@ -23,4 +26,3 @@ def get_world_info() -> WorldInfo:
if world_info is None:
world_info = WorldInfo()
return world_info

2 changes: 1 addition & 1 deletion tests/test_configs.py
Original file line number Diff line number Diff line change
@@ -10,10 +10,10 @@

config_file_names = [file for file in os.listdir("configs") if file.endswith(".toml")]


@pytest.mark.parametrize("config_file_name", config_file_names)
def test_load_config(config_file_name):
with open(f"configs/{config_file_name}", "rb") as f:
content = tomli.load(f)
config = Config(**content)
assert config is not None

5 changes: 3 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -5,17 +5,18 @@

VOCAB_SIZE = 1024


@pytest.fixture
def llama_config():
config = llama2_configs["debugmodel"]
config = llama2_configs["debugmodel"]
config.vocab_size = VOCAB_SIZE
return config


def test_llama(llama_config):
seq_len = 512
bs = 8
model = Transformer(llama_config)
input_ = torch.randint(0, llama_config.vocab_size, (bs, seq_len))
output = model(input_)
assert output.shape == (bs, seq_len, llama_config.vocab_size)

9 changes: 3 additions & 6 deletions tests/test_torchrun/test_train.py
Original file line number Diff line number Diff line change
@@ -15,17 +15,14 @@ def random_available_port():
return get_random_available_port()



@pytest.fixture()
def config_path() -> str:
# need to be executed in the root dir
return "configs/debug.toml"

return "configs/debug.toml"


@pytest.mark.parametrize("num_gpu", [1, 2])
def test_multi_gpu_ckpt(config_path, random_available_port, num_gpu):

cmd = [
"torchrun",
f"--nproc_per_node={num_gpu}",
@@ -34,10 +31,10 @@ def test_multi_gpu_ckpt(config_path, random_available_port, num_gpu):
"src/zeroband/train.py",
f"@{config_path}",
"--optim.total_steps",
"10"
"10",
]

result = subprocess.run(cmd)

if result.returncode != 0:
pytest.fail(f"Process {result} failed {result.stderr}")
pytest.fail(f"Process {result} failed {result.stderr}")