-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
165 lines (141 loc) · 6.02 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import unittest
from typing import List, Dict, Any
import random
from random import choices
import numpy as np
import torch
from torch import nn
from lr_scheduler import NoamOpt
from transformer import Transformer
from vocabulary import Vocabulary
from utils import construct_batches
def train(
transformer: nn.Module,
scheduler: Any,
criterion: Any,
batches: Dict[str, List[torch.Tensor]],
masks: Dict[str, List[torch.Tensor]],
n_epochs: int,
):
"""
Main training loop
:param transformer: the transformer model
:param scheduler: the learning rate scheduler
:param criterion: the optimization criterion (loss function)
:param batches: aligned src and tgt batches that contain tokens ids
:param masks: source key padding mask and target future mask for each batch
:param n_epochs: the number of epochs to train the model for
:return: the accuracy and loss on the latest batch
"""
transformer.train(True)
num_iters = 0
for e in range(n_epochs):
for i, (src_batch, src_mask, tgt_batch, tgt_mask) in enumerate(
zip(batches["src"], masks["src"], batches["tgt"], masks["tgt"])
):
encoder_output = transformer.encoder(src_batch, src_padding_mask=src_mask) # type: ignore
# Perform one decoder forward pass to obtain *all* next-token predictions for every index i given its
# previous *gold standard* tokens [1,..., i] (i.e. teacher forcing) in parallel/at once.
decoder_output = transformer.decoder(
tgt_batch,
encoder_output,
src_padding_mask=src_mask,
future_mask=tgt_mask,
) # type: ignore
# Align labels with predictions: the last decoder prediction is meaningless because we have no target token
# for it. The BOS token in the target is also not something we want to compute a loss for.
decoder_output = decoder_output[:, :-1, :]
tgt_batch = tgt_batch[:, 1:]
# Set pad tokens in the target to -100 so they don't incur a loss
# tgt_batch[tgt_batch == transformer.padding_idx] = -100
# Compute the average cross-entropy loss over all next-token predictions at each index i given [1, ..., i]
# for the entire batch. Note that the original paper uses label smoothing (I was too lazy).
batch_loss = criterion(
decoder_output.contiguous().permute(0, 2, 1),
tgt_batch.contiguous().long(),
)
# Rough estimate of per-token accuracy in the current training batch
batch_accuracy = (
torch.sum(decoder_output.argmax(dim=-1) == tgt_batch)
) / torch.numel(tgt_batch)
if num_iters % 100 == 0:
print(
f"epoch: {e}, num_iters: {num_iters}, batch_loss: {batch_loss}, batch_accuracy: {batch_accuracy}"
)
# Update parameters
batch_loss.backward()
scheduler.step()
scheduler.optimizer.zero_grad()
num_iters += 1
return batch_loss, batch_accuracy
class TestTransformerTraining(unittest.TestCase):
seed = 0
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def test_copy_task(self):
"""
Test training by trying to (over)fit a simple copy dataset - bringing the loss to ~zero. (GPU required)
"""
device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
if device.type == "cpu":
print("This unit test was not run because it requires a GPU")
return
# Hyperparameters
synthetic_corpus_size = 600
batch_size = 60
n_epochs = 200
n_tokens_in_batch = 10
# Construct vocabulary and create synthetic data by uniform randomly sampling tokens from it
# Note: the original paper uses byte pair encodings, we simply take each word to be a token.
corpus = ["These are the tokens that will end up in our vocabulary"]
vocab = Vocabulary(corpus)
vocab_size = len(
list(vocab.token2index.keys())
) # 14 tokens including bos, eos and pad
valid_tokens = list(vocab.token2index.keys())[3:]
corpus += [
" ".join(choices(valid_tokens, k=n_tokens_in_batch))
for _ in range(synthetic_corpus_size)
]
# Construct src-tgt aligned input batches (note: the original paper uses dynamic batching based on tokens)
corpus = [{"src": sent, "tgt": sent} for sent in corpus]
batches, masks = construct_batches(
corpus,
vocab,
batch_size=batch_size,
src_lang_key="src",
tgt_lang_key="tgt",
device=device,
)
# Initialize transformer
transformer = Transformer(
hidden_dim=512,
ff_dim=2048,
num_heads=8,
num_layers=2,
max_decoding_length=25,
vocab_size=vocab_size,
padding_idx=vocab.token2index[vocab.PAD],
bos_idx=vocab.token2index[vocab.BOS],
dropout_p=0.1,
tie_output_to_embedding=True,
).to(device)
# Initialize learning rate scheduler, optimizer and loss (note: the original paper uses label smoothing)
optimizer = torch.optim.Adam(
transformer.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9
)
scheduler = NoamOpt(
transformer.hidden_dim, factor=1, warmup=400, optimizer=optimizer,
)
criterion = nn.CrossEntropyLoss()
# Start training and verify ~zero loss and >90% accuracy on the last batch
latest_batch_loss, latest_batch_accuracy = train(
transformer, scheduler, criterion, batches, masks, n_epochs=n_epochs
)
self.assertEqual(latest_batch_loss.item() <= 0.01, True)
self.assertEqual(latest_batch_accuracy >= 0.99, True)
if __name__ == "__main__":
unittest.main()