-
Notifications
You must be signed in to change notification settings - Fork 1
/
bidaf_lstm.py
220 lines (176 loc) · 8.58 KB
/
bidaf_lstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import torch
import torch.nn as nn
import constants
class LSTMContextLayer(nn.Module):
"""Extracts contextual information from inputs"""
def __init__(self, embedding_dim, hidden_size, num_layers):
super().__init__()
self.pass_context = nn.LSTM(embedding_dim, hidden_size, num_layers, batch_first=True, bidirectional=True)
self.ans_context = nn.LSTM(embedding_dim, hidden_size, num_layers, batch_first=True, bidirectional=True)
def forward(self, passage, answer):
"""
passage - N x P x E
answer - N X A X E
"""
out_pass, _ = self.pass_context(passage) # N x P x 2*H
out_ans, _ = self.ans_context(answer) # N x A x 2*H
return out_pass, out_ans
def __call__(self, passage, answer):
return self.forward(passage, answer)
class AttentionFlowLayer(nn.Module):
"""
Computes C2Q and Q2C as per
https://arxiv.org/pdf/1611.01603.pdf
https://towardsdatascience.com/the-definitive-guide-to-bidaf-part-2-word-embedding-character-embedding-and-contextual-c151fc4f05bb
https://towardsdatascience.com/the-definitive-guide-to-bidaf-part-3-attention-92352bbdcb07
"""
def __init__(self, embedding_dim, hidden_size):
super().__init__()
self.embedding_dim = embedding_dim
self.hidden_size = hidden_size
self.weights = nn.Linear(6*self.hidden_size, 1, bias=False)
def forward(self, H, U):
"""
H: N x P x 2d -- passage representation
U: N x A x 2d -- answer representation
G (output): N x P x 8d
Implementation based off of
https://github.com/jojonki/BiDAF/blob/master/layers/bidaf.py
"""
context = H.unsqueeze(2) # N x P x 1 x 2d
ans = U.unsqueeze(1) # N x 1 x A x 2d
cast_shape = (H.shape[0], H.shape[1], U.shape[1], 2*self.hidden_size) # N x P x A x 2d
context = context.expand(cast_shape)
ans = ans.expand(cast_shape)
# Similarity Matrix Passage and Answer
prod = torch.mul(context, ans)
vec = torch.cat([context, ans, prod], axis=3) # N x P x A x 6d
sim = self.weights(vec).squeeze() # N x P x A
# C2Q - which query (ie. answer) words matter most to each context (ie. passage) word
a = nn.Softmax(dim=2)(sim) # N x P x A
U_tilda = torch.bmm(a, U) # N x P x 2d
# Q2C
b = torch.max(sim, dim=2)[0].unsqueeze(1) # N x 1 x P
H_tilda = torch.bmm(b, H).tile((1, H.shape[1], 1)) # N x P x 2d
# Merge to form G
G = torch.cat([H, U_tilda, torch.mul(H, U_tilda), torch.mul(H, H_tilda)], dim=2) # N x P x 8d
return G
def __call__(self, H, U):
return self.forward(H, U)
class AttentionFlowLSTMEncoder(nn.Module):
def __init__(self, embedding_dim, hidden_size, num_layers, vocab):
super().__init__()
self.embedding_dim = embedding_dim
self.hidden_size = hidden_size
self.num_layers = num_layers
self.vocab = vocab
self.context_layer = LSTMContextLayer(self.embedding_dim, self.hidden_size, self.num_layers)
self.attention_flow_layer = AttentionFlowLayer(self.embedding_dim, self.hidden_size)
self.encoder = nn.LSTM(8*self.hidden_size, self.hidden_size, batch_first=True, bidirectional=True)
self.fc = nn.Linear(2*self.hidden_size, self.embedding_dim)
def forward(self, passage, answer):
"""
passage: embedded passage N x P x E
answer: embedded answer N x A x E
encoding: encoded representation of passage-answer N x 1 x E
"""
H, U = self.context_layer(passage, answer) # N x P x 2H, N x A x 2H
G = self.attention_flow_layer(H, U) # N x P x 8H
encoding, _ = self.encoder(G) # N x P x 2H
encoding = torch.mean(encoding, dim=1, keepdim=True) # N x 1 x 2H
encoding = self.fc(encoding) # N x 1 x E
return encoding
def __call__(self, passage, answer):
return self.forward(passage, answer)
class LSTMDecoder(nn.Module):
"""Decoder to produce sequential output as question"""
def __init__(self, embedding_dim, hidden_size, num_layers, vocab, question_length):
"""
embedding_dim: dimension of word embedding
hidden_size: hidden state dimension for decoder
"""
super().__init__()
self.embedding_dim = embedding_dim
self.hidden_size = hidden_size
self.num_layers = num_layers
self.vocab = vocab
self.question_length = question_length
self.decoder = nn.LSTM(self.embedding_dim, self.hidden_size, self.num_layers, batch_first=True)
self.fc = nn.Linear(self.hidden_size, len(self.vocab))
self.softmax = nn.Softmax(dim=2)
def forward(self, encoded_inputs, embedded_targets=None, embedder=None, temperature=1.0):
"""
Generates sequential output
encoded_inputs: N x 1 x E
embedded_targets: N x Q x E
embedded_targets: supply targets for teacher forcing, otherwise, a single output is produced
"""
if embedded_targets is not None:
assert embedded_targets.shape[1] == self.question_length
decoder_inputs = torch.cat([encoded_inputs, embedded_targets[:, :-1, :]], dim=1)
sequence, _ = self.decoder(decoder_inputs) # N x Q x hidden_size
sequence = self.fc(sequence) # N x Q x vocab_size
return sequence
else:
with torch.no_grad():
assert embedder is not None
sequence = None
hidden_state = None
inputs = encoded_inputs
for i in range(self.question_length):
out, hidden_state = self.decoder(inputs, hidden_state) # N x 1 x H
out = self.fc(out) # N x 1 x vocab_size
probs = self.softmax(out.div(temperature))
word = torch.multinomial(probs.squeeze(), 1) # N x 1
if i == 0:
sequence = word
else:
sequence = torch.cat([sequence, word], dim=1)
inputs = embedder(word.long()) # N x 1 x E
return sequence # N x Q
def __call__(self, encoded_inputs, embedded_targets, embedder=None, temperature=1.0):
return self.forward(encoded_inputs, embedded_targets, embedder, temperature)
class BiDAF_LSTMNet(nn.Module):
def __init__(self, embedding_dim, hidden_size, num_layers, vocab, question_length, temperature=1.0):
super().__init__()
self.embedding_dim = embedding_dim
self.hidden_size = hidden_size
self.num_layers = num_layers
self.vocab = vocab
self.question_length = question_length
self.temperature = temperature
self.embedder = nn.Embedding(num_embeddings = len(self.vocab), embedding_dim=embedding_dim, padding_idx=0)
self.encoder = AttentionFlowLSTMEncoder(self.embedding_dim, self.hidden_size, self.num_layers, self.vocab)
self.decoder = LSTMDecoder(self.embedding_dim, self.hidden_size, self.num_layers, self.vocab, self.question_length)
self.softmax = nn.Softmax(dim=2)
def embed(self, words):
"""words is N x L"""
return self.embedder(words) # N X L X embedding
def forward(self, passage, answer, question):
"""
passage: N x P
answer: N x A
question: N x Q
"""
P = passage.shape[1]
A = answer.shape[1]
Q = question.shape[1]
assert P == constants.MAX_PASSAGE_LEN + 2
assert A == constants.MAX_ANSWER_LEN + 2
assert Q == constants.MAX_QUESTION_LEN + 2
pass_ans_qembed = self.embed(torch.cat([passage, answer, question], dim=1))
passage, answer, q_embed = pass_ans_qembed[:, :P, :], pass_ans_qembed[:, P:P+A, :], pass_ans_qembed[:, P+A:, :]
encoded = self.encoder(passage, answer)
sequence = self.decoder(encoded, q_embed) # teacher forced
return sequence # N x Q x vocab_size
def predict(self, passage, answer):
"""
Generates question word-by-word
Should only be used in evaluation mode
"""
with torch.no_grad():
pass_ans = self.embed(torch.cat([passage, answer], dim=1))
passage, answer = pass_ans[:, :passage.shape[1], :], pass_ans[:, passage.shape[1]:, :]
encoded = self.encoder(passage, answer) # N x 1 x E
sequence = self.decoder(encoded, embedded_targets=None, embedder=self.embedder, temperature=self.temperature) # N x Q
return sequence