-
Notifications
You must be signed in to change notification settings - Fork 2
/
attention.py
139 lines (100 loc) · 5.39 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
import torch.nn as nn
import torch.nn.functional as F
class AdditiveAttention(nn.Module):
"""Attention mechanism as a MLP
as used by Bahdanau et. al 2015"""
def __init__(self, encoder_hidd_dim, decoder_hidd_dim):
super(AdditiveAttention, self).__init__()
self.atten = nn.Linear((encoder_hidd_dim * 2) + decoder_hidd_dim, decoder_hidd_dim)
self.v = nn.Linear(decoder_hidd_dim, 1, bias=False)
def forward(self, keys, query, mask):
"""keys: encoder hidden states.
query: decoder hidden state at time t
mask: the mask vector of zeros and ones
"""
#keys shape: [batch_size, src_seq_length, encoder_hidd_dim * 2]
#query shape: [num_layers * num_dirs, batch_size, decoder_hidd_dim]
batch_size, src_seq_length, encoder_hidd_dim = keys.shape
# applying attention to the hidden state at the last layer of the decoder
query = query[-1, :, :]
# query shape: [batch_size, decoder_hidd_dim]
#changing the shape of query to [batch_size, src_seq_length, decoder_hidd_dim]
#we will repeat the query src_seq_length times at dim 1
query = query.unsqueeze(1).repeat(1, src_seq_length, 1)
# Step 1: Compute the attention scores through a MLP
# concatenating the keys and the query
atten_input = torch.cat((keys, query), dim=2)
# atten_input shape: [batch_size, src_seq_length, (encoder_hidd_dim * 2) + decoder_hidd_dim]
atten_scores = self.atten(atten_input)
# atten_scores shape: [batch_size, src_seq_length, decoder_hidd_dim]
atten_scores = torch.tanh(atten_scores)
# mapping atten_scores from decoder_hidd_dim to 1
atten_scores = self.v(atten_scores)
# atten_scores shape: [batch_size, src_seq_length, 1]
atten_scores = atten_scores.squeeze(dim=2)
# atten_scores shape: [batch_size, src_seq_length]
# masking the atten_scores
atten_scores = atten_scores.masked_fill(mask==0, -float('inf'))
# Step 2: normalizing atten_scores through a softmax to get probs
atten_scores = F.softmax(atten_scores, dim=1)
# Step 3: computing the new context vector
context_vector = torch.matmul(keys.permute(0, 2, 1), atten_scores.unsqueeze(2)).squeeze(dim=2)
# context_vector shape: [batch_size, encoder_hidd_dim * 2]
return context_vector, atten_scores
class GeneralAttention(nn.Module):
"""General Attention mechanism
as described by Luong et. al 2015"""
def __init__(self, encoder_hidd_dim, decoder_hidd_dim):
super(GeneralAttention, self).__init__()
self.linear_map = nn.Linear(encoder_hidd_dim * 2, decoder_hidd_dim, bias=False)
def forward(self, keys, query, mask):
"""keys: encoder hidden states.
query: decoder hidden state at time t
mask: the mask vector of zeros and ones
"""
#keys shape: [batch_size, src_seq_length, encoder_hidd_dim * 2]
#query shape: [num_layers * num_dirs, batch_size, decoder_hidd_dim]
batch_size, src_seq_length, encoder_hidd_dim = keys.shape
# applying attention to the last hidden state of the decoder
query = query[-1, :, :]
# query shape: [batch_size, decoder_hidd_dim]
# mapping the keys from encoder_hidd_dim * 2 to decoder_hidd_dim
mapped_key_vectors = self.linear_map(keys)
# keys shape: [batch_size, src_seq_length, decoder_hidd_dim]
# performing the dot product
atten_scores = torch.matmul(query.unsqueeze(1), mapped_key_vectors.permute(0, 2, 1)).squeeze(1)
# atten_scores shape: [batch_size, src_seq_len]
# masking the atten_scores
atten_scores = atten_scores.masked_fill(mask==0, -float('inf'))
# Step 2: normalizing atten_scores through a softmax to get probs
atten_scores = F.softmax(atten_scores, dim=1)
# Step 3: computing the new context vector
context_vector = torch.matmul(keys.permute(0, 2, 1), atten_scores.unsqueeze(2)).squeeze(dim=2)
# context_vector shape: [batch_size, encoder_hidd_dim * 2]
return context_vector, atten_scores
def DotProductAttention(keys, query, mask):
"""
Args:
- query: decoder hidden state
- keys: encoder outputs (hidden states from the last layer)
Returns:
- context_vector: [batch_size, encoder_hidd_dim * 2]
- attention_scores: [batch_size, src_seq_length]
NOTE: This attention works only when encoder_hidd_dim * 2 == decoder_hidd_dim
"""
#keys shape: [batch_size, src_seq_length, encoder_hidd_dim * 2]
#query shape: [num_layers * num_dirs, batch_size, encoder_hidd_dim * 2]
# applying attention on the last layer of the decoder
query = query[-1, :, :]
#query shape: [batch_size, encoder_hidd_dim * 2]
attention_scores = torch.matmul(keys, query.unsqueeze(-1)).squeeze(-1)
# attention_scores shape: [batch_size, src_seq_length]
# masking the attention_scores
attention_scores = attention_scores.masked_fill(mask==0, -float('inf'))
# normalizing the attention_scores through a softmax
attention_scores = F.softmax(attention_scores, dim=1)
# computing the context vector
context_vector = torch.matmul(keys.permute(0, 2, 1), attention_scores.unsqueeze(-1)).squeeze(dim=2)
#context_vector shape: [batch_size, encoder_hidd_dim * 2]
return context_vector, attention_scores