Skip to content

Commit

Permalink
add tokens / sec
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 21, 2024
1 parent 5267437 commit 9a5efa7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
5 changes: 4 additions & 1 deletion configs/debug.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,7 @@ micro_bs = 8
[optim]
batch_size = 16
warmup_steps = 10
total_steps = 5000
total_steps = 5000

[data]
fake_data = true
18 changes: 16 additions & 2 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from contextlib import nullcontext
import time
from typing import Literal

import torch
Expand All @@ -16,6 +17,8 @@
FullyShardedDataParallel as FSDP,
MixedPrecision,
)
import torch.distributed as dist

from zeroband.utils import get_sharding_strategy
from zeroband.utils.monitor import WandbMonitor, DummyMonitor
from zeroband.data import TEST_VOCAB_SIZE, get_dataloader
Expand Down Expand Up @@ -152,6 +155,7 @@ def train(config: Config):

for inner_step in range(num_inner_steps):
loss_batch = 0
beginning_step_time = time.time()

for grad_acc_step in range(gradient_accumulation_steps):
is_accumulating = grad_acc_step < gradient_accumulation_steps - 1
Expand Down Expand Up @@ -179,16 +183,26 @@ def train(config: Config):
real_step = outer_step * num_inner_steps + inner_step + 1 # add + 1 because inner_step start at 0
inner_lr = [group["lr"] for group in inner_optimizer.param_groups][0]

dist.all_reduce(loss_batch, op=dist.ReduceOp.AVG)
# syncing loss across all data parallel rank
# todo(sami): when using diloco make sure that the loss is computed only on local world

metrics = {
"Loss": loss_batch.item(), # todo(sami): do local all reduce for the loss
"Loss": loss_batch.item(),
"step": real_step,
"inner_lr": inner_lr,
"tokens_per_second": config.data.seq_length
* config.optim.batch_size
/ (time.time() - beginning_step_time),
"Perplexity": torch.exp(loss_batch).item(),
}

if world_info.rank == 0:
metric_logger.log(metrics)

logger.info(f"step: {real_step}, loss: {loss_batch.item()}, inner_lr: {inner_lr}")
logger.info(
f"step: {real_step}, loss: {loss_batch.item():.4f}, tokens_per_second: {metrics['tokens_per_second']:.2f}"
)

outer_step += 1

Expand Down

0 comments on commit 9a5efa7

Please sign in to comment.