Skip to content

Commit

Permalink
add mfu
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 21, 2024
1 parent 5d0deb9 commit e5a1938
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/zeroband/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
}


def get_model(name_model: str, type_model: str, vocab_size: int) -> Transformer:
def get_model(name_model: str, type_model: str, vocab_size: int) -> tuple[Transformer, ModelArgs]:
"""get the transformer model"""

if type_model == "llama2":
Expand All @@ -72,4 +72,4 @@ def get_model(name_model: str, type_model: str, vocab_size: int) -> Transformer:
raise ValueError(f"Model type {type_model} not supported")

config.vocab_size = vocab_size
return Transformer(config)
return Transformer(config), config
26 changes: 21 additions & 5 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
MixedPrecision,
)
import torch.distributed as dist
from zeroband import utils

from zeroband.utils import get_sharding_strategy
from zeroband.utils.monitor import WandbMonitor, DummyMonitor
Expand Down Expand Up @@ -103,14 +104,25 @@ def train(config: Config):
fake_data=config.data.fake_data,
)

model = get_model(
model, model_config = get_model(
config.name_model,
config.type_model,
vocab_size=tokenizer.vocab_size if config.name_model != "debugmodel" else TEST_VOCAB_SIZE,
)
model = model.to(world_info.local_rank)
logger.debug("model loaded")

gpu_peak_flops = utils.get_peak_flops(torch.cuda.get_device_name(torch.device("cuda")))
logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}")

num_params = utils.get_num_params(model, exclude_embedding=True)
logger.info(f"Number of parameters: {num_params}")
num_flop_per_token = utils.get_num_flop_per_token(
num_params,
model_config,
config.data.seq_length,
)

model = FSDP(
model,
sharding_strategy=sharding_strategy,
Expand Down Expand Up @@ -187,22 +199,26 @@ def train(config: Config):
# syncing loss across all data parallel rank
# todo(sami): when using diloco make sure that the loss is computed only on local world

time_taken = time.time() - beginning_step_time
tokens_per_second = config.data.seq_length * config.optim.batch_size / time_taken

mfu = 100 * num_flop_per_token * tokens_per_second / gpu_peak_flops

metrics = {
"Loss": loss_batch.item(),
"step": real_step,
"inner_lr": inner_lr,
"tokens_per_second": config.data.seq_length
* config.optim.batch_size
/ (time.time() - beginning_step_time),
"tokens_per_second": tokens_per_second,
"Perplexity": torch.exp(loss_batch).item(),
"total_tokens": real_step * config.optim.batch_size * config.data.seq_length,
"mfu": mfu,
}

if world_info.rank == 0:
metric_logger.log(metrics)

logger.info(
f"step: {real_step}, loss: {loss_batch.item():.4f}, tokens_per_second: {metrics['tokens_per_second']:.2f}"
f"step: {real_step}, loss: {loss_batch.item():.4f}, tokens_per_second: {metrics['tokens_per_second']:.2f}, mfu: {mfu:.2f}"
)

outer_step += 1
Expand Down
49 changes: 48 additions & 1 deletion src/zeroband/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
from torch.distributed.fsdp import ShardingStrategy


__all__ = ["get_sharding_strategy"]
__all__ = ["get_sharding_strategy", "get_peak_flops", "get_num_flop_per_token", "get_num_params"]


def get_sharding_strategy(sharding_strategy: str) -> ShardingStrategy:
Expand All @@ -19,3 +20,49 @@ def get_sharding_strategy(sharding_strategy: str) -> ShardingStrategy:
raise ValueError(
f"Invalid sharding_strategy: {sharding_strategy}. Please choose 'FULL_SHARD', 'SHARD_GRAD_OP', 'NO_SHARD', 'HYBRID_SHARD', or '_HYBRID_SHARD_ZERO2'."
)


### code above inspired and copied from https://github.com/pytorch/torchtitan/blob/4b3f2e41a084bf79a8540068ed525539d1244edd/torchtitan/utils.py#L119


# hardcoded BF16 type peak flops for NVIDIA A100 and H100 GPU
def get_peak_flops(device_name: str) -> int:
if "A100" in device_name:
# data from https://www.nvidia.com/en-us/data-center/a100/
return 312e12
elif "H100" in device_name:
# data from https://www.nvidia.com/en-us/data-center/h100/
# NOTE: Specifications are one-half lower without sparsity.
if "NVL" in device_name:
return 835e12
elif "PCIe" in device_name:
return 756e12
else: # for H100 SXM and other variants
return 989e12
else: # for other GPU types, assume A100
return 312e12


def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int:
l, h, q, t = ( # noqa: E741
model_config.n_layers,
model_config.n_heads,
model_config.dim // model_config.n_heads,
seq_len,
)
# Reasoning behind the factor of 12 for the self-attention part of the formula:
# 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
# 2. the flash attention does 1 more matmul recomputation in the backward
# but recomputation should not be counted in calculating MFU (+0)
# 3. each matmul performs 1 multiplication and 1 addition (*2)
# 4. we follow the convention and do not account for sparsity in causal attention
flop_per_token = 6 * num_params + 12 * l * h * q * t

return flop_per_token


def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int:
num_params = sum(p.numel() for p in model.parameters())
if exclude_embedding:
num_params -= model.tok_embeddings.weight.numel()
return num_params

0 comments on commit e5a1938

Please sign in to comment.