diff --git a/src/zeroband/models/llama/__init__.py b/src/zeroband/models/llama/__init__.py index c0103fcf..5250fe57 100644 --- a/src/zeroband/models/llama/__init__.py +++ b/src/zeroband/models/llama/__init__.py @@ -61,7 +61,7 @@ } -def get_model(name_model: str, type_model: str, vocab_size: int) -> Transformer: +def get_model(name_model: str, type_model: str, vocab_size: int) -> tuple[Transformer, ModelArgs]: """get the transformer model""" if type_model == "llama2": @@ -72,4 +72,4 @@ def get_model(name_model: str, type_model: str, vocab_size: int) -> Transformer: raise ValueError(f"Model type {type_model} not supported") config.vocab_size = vocab_size - return Transformer(config) + return Transformer(config), config diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 056d89ac..9988e767 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -18,6 +18,7 @@ MixedPrecision, ) import torch.distributed as dist +from zeroband import utils from zeroband.utils import get_sharding_strategy from zeroband.utils.monitor import WandbMonitor, DummyMonitor @@ -103,7 +104,7 @@ def train(config: Config): fake_data=config.data.fake_data, ) - model = get_model( + model, model_config = get_model( config.name_model, config.type_model, vocab_size=tokenizer.vocab_size if config.name_model != "debugmodel" else TEST_VOCAB_SIZE, @@ -111,6 +112,17 @@ def train(config: Config): model = model.to(world_info.local_rank) logger.debug("model loaded") + gpu_peak_flops = utils.get_peak_flops(torch.cuda.get_device_name(torch.device("cuda"))) + logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") + + num_params = utils.get_num_params(model, exclude_embedding=True) + logger.info(f"Number of parameters: {num_params}") + num_flop_per_token = utils.get_num_flop_per_token( + num_params, + model_config, + config.data.seq_length, + ) + model = FSDP( model, sharding_strategy=sharding_strategy, @@ -187,22 +199,26 @@ def train(config: Config): # syncing loss across all data parallel rank # todo(sami): when using diloco make sure that the loss is computed only on local world + time_taken = time.time() - beginning_step_time + tokens_per_second = config.data.seq_length * config.optim.batch_size / time_taken + + mfu = 100 * num_flop_per_token * tokens_per_second / gpu_peak_flops + metrics = { "Loss": loss_batch.item(), "step": real_step, "inner_lr": inner_lr, - "tokens_per_second": config.data.seq_length - * config.optim.batch_size - / (time.time() - beginning_step_time), + "tokens_per_second": tokens_per_second, "Perplexity": torch.exp(loss_batch).item(), "total_tokens": real_step * config.optim.batch_size * config.data.seq_length, + "mfu": mfu, } if world_info.rank == 0: metric_logger.log(metrics) logger.info( - f"step: {real_step}, loss: {loss_batch.item():.4f}, tokens_per_second: {metrics['tokens_per_second']:.2f}" + f"step: {real_step}, loss: {loss_batch.item():.4f}, tokens_per_second: {metrics['tokens_per_second']:.2f}, mfu: {mfu:.2f}" ) outer_step += 1 diff --git a/src/zeroband/utils/__init__.py b/src/zeroband/utils/__init__.py index d26823e4..a9b8fad2 100644 --- a/src/zeroband/utils/__init__.py +++ b/src/zeroband/utils/__init__.py @@ -1,7 +1,8 @@ +import torch from torch.distributed.fsdp import ShardingStrategy -__all__ = ["get_sharding_strategy"] +__all__ = ["get_sharding_strategy", "get_peak_flops", "get_num_flop_per_token", "get_num_params"] def get_sharding_strategy(sharding_strategy: str) -> ShardingStrategy: @@ -19,3 +20,49 @@ def get_sharding_strategy(sharding_strategy: str) -> ShardingStrategy: raise ValueError( f"Invalid sharding_strategy: {sharding_strategy}. Please choose 'FULL_SHARD', 'SHARD_GRAD_OP', 'NO_SHARD', 'HYBRID_SHARD', or '_HYBRID_SHARD_ZERO2'." ) + + +### code above inspired and copied from https://github.com/pytorch/torchtitan/blob/4b3f2e41a084bf79a8540068ed525539d1244edd/torchtitan/utils.py#L119 + + +# hardcoded BF16 type peak flops for NVIDIA A100 and H100 GPU +def get_peak_flops(device_name: str) -> int: + if "A100" in device_name: + # data from https://www.nvidia.com/en-us/data-center/a100/ + return 312e12 + elif "H100" in device_name: + # data from https://www.nvidia.com/en-us/data-center/h100/ + # NOTE: Specifications are one-half lower without sparsity. + if "NVL" in device_name: + return 835e12 + elif "PCIe" in device_name: + return 756e12 + else: # for H100 SXM and other variants + return 989e12 + else: # for other GPU types, assume A100 + return 312e12 + + +def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int: + l, h, q, t = ( # noqa: E741 + model_config.n_layers, + model_config.n_heads, + model_config.dim // model_config.n_heads, + seq_len, + ) + # Reasoning behind the factor of 12 for the self-attention part of the formula: + # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) + # 2. the flash attention does 1 more matmul recomputation in the backward + # but recomputation should not be counted in calculating MFU (+0) + # 3. each matmul performs 1 multiplication and 1 addition (*2) + # 4. we follow the convention and do not account for sparsity in causal attention + flop_per_token = 6 * num_params + 12 * l * h * q * t + + return flop_per_token + + +def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int: + num_params = sum(p.numel() for p in model.parameters()) + if exclude_embedding: + num_params -= model.tok_embeddings.weight.numel() + return num_params