-
Notifications
You must be signed in to change notification settings - Fork 1
/
Attention.py
47 lines (41 loc) · 1.87 KB
/
Attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from keras.layers import Layer
from keras import backend as K
class Attention(Layer):
def __init__(self, attention_size, **kwargs):
self.attention_size = attention_size
super(Attention, self).__init__(**kwargs)
def build(self, input_shape):
# W: (EMBED_SIZE, ATTENTION_SIZE)
# b: (MAX_TIMESTEPS, 1)
# u: (ATTENTION_SIZE, 1)
self.W = self.add_weight(name="W_{:s}".format(self.name),
shape=(input_shape[-1], self.attention_size),
initializer="glorot_normal",
trainable=True)
self.b = self.add_weight(name="b_{:s}".format(self.name),
shape=(input_shape[1], 1),
initializer="zeros",
trainable=True)
self.u = self.add_weight(name="u_{:s}".format(self.name),
shape=(self.attention_size, 1),
initializer="glorot_normal",
trainable=True)
super(Attention, self).build(input_shape)
def call(self, x, mask=None):
# input: (BATCH_SIZE, MAX_TIMESTEPS, EMBED_SIZE)
# et: (BATCH_SIZE, MAX_TIMESTEPS, ATTENTION_SIZE)
et = K.tanh(K.dot(x, self.W) + self.b)
# at: (BATCH_SIZE, MAX_TIMESTEPS)
at = K.softmax(K.squeeze(K.dot(et, self.u), axis=-1))
if mask is not None:
at *= K.cast(mask, K.floatx())
# ot: (BATCH_SIZE, MAX_TIMESTEPS, EMBED_SIZE)
atx = K.expand_dims(at, axis=-1)
ot = atx * x
# output: (BATCH_SIZE, EMBED_SIZE)
output = K.sum(ot, axis=1)
return output
def compute_mask(self, input, input_mask=None):
return None
def compute_output_shape(self, input_shape):
return input_shape[0], input_shape[-1]