-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
75 lines (58 loc) · 2.76 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import math
import sys
import dist_misc
def train_one_epoch(model, data_loader, criterion, training_scheduler, epoch, device, log_writer=None, args=None, finetune=None):
model.train()
metric_logger = dist_misc.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', dist_misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
if args.fsdp:
print_freq = 4
else:
print_freq = 200
accum_iter = 1 if args.accum_iter < 1 else args.accum_iter
for batch_idx, (labels, strs, toks, masktoks, masks) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
if finetune == 'cls':
labels = torch.tensor(labels).long()
# else:
# labels = torch.tensor(labels)
if torch.cuda.is_available() and not args.nogpu:
toks = toks.to(device=device, non_blocking=True)
masktoks = masktoks.to(device=device, non_blocking=True)
masks = masks.to(device=device, non_blocking=True)
if finetune == 'cls':
labels = labels.to(device=device, non_blocking=True)
with torch.cuda.amp.autocast():
if finetune:
pred = model(toks)
loss = criterion(pred, labels)
else:
out = model(masktoks)
logits = out["logits"].permute(0, 2, 1) # B*C*D
loss = criterion(logits, toks, masks)
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
loss /= accum_iter
training_scheduler.loss_scale_and_backward(loss)
if (batch_idx + 1) % accum_iter == 0:
training_scheduler.step_and_lr_schedule(batch_idx / len(data_loader) + epoch)
training_scheduler.zero_grad()
torch.cuda.synchronize()
metric_logger.update(loss=loss_value)
lr = training_scheduler.optim.param_groups[0]["lr"]
metric_logger.update(lr=lr)
loss_value_reduce = dist_misc.all_reduce_mean(loss_value)
if log_writer is not None and (batch_idx + 1) % accum_iter == 0:
""" We use epoch_1000x as the x-axis in tensorboard.
This calibrates different curves when batch size changes.
"""
epoch_1000x = int((batch_idx / len(data_loader) + epoch) * 1000)
log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
log_writer.add_scalar('lr', lr, epoch_1000x)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}