diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 9efc8367..056d89ac 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -195,6 +195,7 @@ def train(config: Config): * config.optim.batch_size / (time.time() - beginning_step_time), "Perplexity": torch.exp(loss_batch).item(), + "total_tokens": real_step * config.optim.batch_size * config.data.seq_length, } if world_info.rank == 0: