-
Notifications
You must be signed in to change notification settings - Fork 16
/
train_mt.py
85 lines (72 loc) · 2.36 KB
/
train_mt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""
German to english MT
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import time
from dataloaders.translation import loader
import torch
from torch import optim
from torch.nn import CrossEntropyLoss
from lstm_attention import EncoderRNN, AttnDecoderRNN, deploy, train_batch
de, en, train_loader, val_loader = loader(batch_size=16)
def sampler(x, pad_idx=1):
def _skip_eos(row):
s = []
for i in row:
if i == pad_idx:
break
s.append(en.vocab.itos[i])
return s
x_np = x.data.cpu().numpy().T
return [' '.join(_skip_eos(row)) for row in x_np]
d_enc_input = 300
d_enc = 256
d_dec_input = 300
d_dec = 256
encoder = EncoderRNN(
d_enc_input,
d_enc,
use_embedding=True,
vocab_size=len(de.vocab),
pad_idx=de.vocab.stoi['<pad>'],
)
decoder = AttnDecoderRNN(
d_dec_input,
d_enc*2,
d_dec,
vocab_size=len(en.vocab),
pad_idx=en.vocab.stoi['<pad>'],
bos_token=en.vocab.stoi['<bos>'],
eos_token=en.vocab.stoi['<eos>'],
)
criterion = CrossEntropyLoss()
if torch.cuda.is_available():
print("Using cuda")
encoder.cuda()
decoder.cuda()
criterion.cuda()
learning_rate = 0.01
encoder_optimizer = optim.RMSprop(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.RMSprop(decoder.parameters(), lr=learning_rate)
for epoch in range(1, 50):
for b, batch in enumerate(train_loader):
if b % 1000 == 0:
for val_b, val_batch in enumerate(val_loader):
sampled_outs_ = deploy(encoder, decoder, val_batch.src)
sampled_outs = sampler(sampled_outs_)
targets = sampler(val_batch.trg)
for i in range(min(10, val_batch.src.size(1))):
print("----")
print("Pred: {}".format(sampled_outs[i]))
print("Target: {}".format(targets[i]))
print("----", flush=True)
break
start = time.time()
loss = train_batch(encoder, decoder, [encoder_optimizer, decoder_optimizer], criterion,
batch.src, batch.trg)
dur = time.time() - start
if b % 200 == 0:
print("e{:2d}b{:3d} Loss is {}, ({:.3f} sec/batch)".format(epoch, b, loss.data[0], dur))