-
Notifications
You must be signed in to change notification settings - Fork 16
/
train_tgif.py
68 lines (57 loc) · 2.21 KB
/
train_tgif.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import time
from dataloaders.tgif import loader
import torch
from torch import optim
from torch.nn import CrossEntropyLoss
from lstm_attention import EncoderRNN, AttnDecoderRNN, deploy, train_batch
from itertools import chain
BATCH_SIZE = 2
# TODO: Figure out memory usage issues
train_loader, val_loader, vocab = loader(batch_size=BATCH_SIZE)
def sampler(x, pad_idx=1):
def _skip_eos(row):
s = []
for i in row:
if i == pad_idx:
break
s.append(vocab.vocab.itos[i])
return s
x_np = x.data.cpu().numpy().T
return [' '.join(_skip_eos(row)) for row in x_np]
d_enc_input = 300
d_enc = 256
d_dec_input = 300
d_dec = 128
encoder = EncoderRNN(d_enc_input, d_enc, use_cnn=True)
decoder = AttnDecoderRNN(d_dec_input, d_enc*2, d_dec, vocab_size=len(vocab.vocab))
criterion = CrossEntropyLoss()
if torch.cuda.is_available():
print("Using cuda")
encoder.cuda()
decoder.cuda()
criterion.cuda()
learning_rate = 0.0001
encoder_optimizer = optim.RMSprop(chain(encoder.gru.parameters(), encoder.embed.fc.parameters()), lr=learning_rate)
decoder_optimizer = optim.RMSprop(decoder.parameters(), lr=learning_rate)
for epoch in range(1, 50):
for b, (train_x, train_y, train_y_lens) in enumerate(train_loader):
if b % 100 == 1:
for val_b, (val_x, val_y, val_y_lens) in enumerate(val_loader):
sampled_outs = sampler(deploy(encoder, decoder, val_x))
targets = sampler(val_y)
for i in range(BATCH_SIZE):
print("----")
print("Pred: {}".format(sampled_outs[i]))
print("Target: {}".format(targets[i]))
print("----", flush=True)
break
start = time.time()
loss = train_batch(encoder, decoder, [encoder_optimizer, decoder_optimizer], criterion,
train_x, train_y, train_y_lens)
dur = time.time() - start
if b % 50 == 0:
print("e{:2d}b{:3d} Loss is {}, ({:.3f} sec/batch)".format(epoch, b, loss.data[0], dur))