diff --git a/configs/debug.toml b/configs/debug.toml index 2a9bea2e..e7d6e30d 100644 --- a/configs/debug.toml +++ b/configs/debug.toml @@ -7,4 +7,7 @@ micro_bs = 8 [optim] batch_size = 16 warmup_steps = 10 -total_steps = 5000 \ No newline at end of file +total_steps = 5000 + +[data] +fake_data = true \ No newline at end of file diff --git a/src/zeroband/train.py b/src/zeroband/train.py index b625cb7e..9efc8367 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -1,5 +1,6 @@ import os from contextlib import nullcontext +import time from typing import Literal import torch @@ -16,6 +17,8 @@ FullyShardedDataParallel as FSDP, MixedPrecision, ) +import torch.distributed as dist + from zeroband.utils import get_sharding_strategy from zeroband.utils.monitor import WandbMonitor, DummyMonitor from zeroband.data import TEST_VOCAB_SIZE, get_dataloader @@ -152,6 +155,7 @@ def train(config: Config): for inner_step in range(num_inner_steps): loss_batch = 0 + beginning_step_time = time.time() for grad_acc_step in range(gradient_accumulation_steps): is_accumulating = grad_acc_step < gradient_accumulation_steps - 1 @@ -179,16 +183,26 @@ def train(config: Config): real_step = outer_step * num_inner_steps + inner_step + 1 # add + 1 because inner_step start at 0 inner_lr = [group["lr"] for group in inner_optimizer.param_groups][0] + dist.all_reduce(loss_batch, op=dist.ReduceOp.AVG) + # syncing loss across all data parallel rank + # todo(sami): when using diloco make sure that the loss is computed only on local world + metrics = { - "Loss": loss_batch.item(), # todo(sami): do local all reduce for the loss + "Loss": loss_batch.item(), "step": real_step, "inner_lr": inner_lr, + "tokens_per_second": config.data.seq_length + * config.optim.batch_size + / (time.time() - beginning_step_time), + "Perplexity": torch.exp(loss_batch).item(), } if world_info.rank == 0: metric_logger.log(metrics) - logger.info(f"step: {real_step}, loss: {loss_batch.item()}, inner_lr: {inner_lr}") + logger.info( + f"step: {real_step}, loss: {loss_batch.item():.4f}, tokens_per_second: {metrics['tokens_per_second']:.2f}" + ) outer_step += 1