-
Notifications
You must be signed in to change notification settings - Fork 92
/
Copy pathtarget_lstm.py
76 lines (66 loc) · 2.36 KB
/
target_lstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# -*- coding: utf-8 -*-
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class TargetLSTM(nn.Module):
"""Target Lstm """
def __init__(self, num_emb, emb_dim, hidden_dim, use_cuda):
super(TargetLSTM, self).__init__()
self.num_emb = num_emb
self.emb_dim = emb_dim
self.hidden_dim = hidden_dim
self.use_cuda = use_cuda
self.emb = nn.Embedding(num_emb, emb_dim)
self.lstm = nn.LSTM(emb_dim, hidden_dim, batch_first=True)
self.lin = nn.Linear(hidden_dim, num_emb)
self.softmax = nn.LogSoftmax()
self.init_params()
def forward(self, x):
"""
Args:
x: (batch_size, seq_len), sequence of tokens generated by generator
"""
emb = self.emb(x)
h0, c0 = self.init_hidden(x.size(0))
output, (h, c) = self.lstm(emb, (h0, c0))
pred = self.softmax(self.lin(output.contiguous().view(-1, self.hidden_dim)))
return pred
def step(self, x, h, c):
"""
Args:
x: (batch_size, 1), sequence of tokens generated by generator
h: (1, batch_size, hidden_dim), lstm hidden state
c: (1, batch_size, hidden_dim), lstm cell state
"""
emb = self.emb(x)
output, (h, c) = self.lstm(emb, (h, c))
pred = F.softmax(self.lin(output.view(-1, self.hidden_dim)), dim=1)
return pred, h, c
def init_hidden(self, batch_size):
h = Variable(torch.zeros((1, batch_size, self.hidden_dim)))
c = Variable(torch.zeros((1, batch_size, self.hidden_dim)))
if self.use_cuda:
h, c = h.cuda(), c.cuda()
return h, c
def init_params(self):
for param in self.parameters():
param.data.normal_(0, 1)
def sample(self, batch_size, seq_len):
res = []
with torch.no_grad():
x = Variable(torch.zeros((batch_size, 1)).long())
if self.use_cuda:
x = x.cuda()
h, c = self.init_hidden(batch_size)
samples = []
for i in range(seq_len):
output, h, c = self.step(x, h, c)
x = output.multinomial(1)
samples.append(x)
output = torch.cat(samples, dim=1)
return output
return None