Skip to content

Commit

Permalink
0.0.5
Browse files Browse the repository at this point in the history
  • Loading branch information
dohlee committed Mar 10, 2023
1 parent d1b042d commit 38682cd
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'tranception-pytorch-dohlee',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.0.4',
version = '0.0.5',
license='MIT',
description = 'Tranception - Pytorch',
author = 'Dohoon Lee',
Expand Down
6 changes: 4 additions & 2 deletions tranception_pytorch/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def main():
cnt = 0
optimizer.zero_grad()
running_loss = []
for batch in cycle(train_loader, args.total_steps):
for batch in cycle(train_loader, args.total_steps * args.gradient_accumulation_steps):
seq, masked_seq, mask = batch['seq'].cuda(), batch['masked_seq'].cuda(), batch['mask'].cuda()
# Note that seq is not one-hot encoded. It's just a sequence of integers.
out = model(masked_seq) # (batch_size, seq_len, vocab_size)
Expand All @@ -126,15 +126,17 @@ def main():
optimizer.step()
optimizer.zero_grad()
scheduler.step()
cnt += 1

if cnt % 100 == 0:
print(f'Iteration {cnt}, loss={np.mean(running_loss):.4f}')
wandb.log({
'train/loss': np.mean(running_loss),
'train/lr': get_lr(optimizer),
'train/step': cnt // args.gradient_accumulation_steps
})
running_loss = []

cnt += 1

if __name__ == '__main__':
main()

0 comments on commit 38682cd

Please sign in to comment.