-
Notifications
You must be signed in to change notification settings - Fork 149
/
cell.py
107 lines (87 loc) · 3.67 KB
/
cell.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import tensorflow.compat.v1 as tf
class ConvLSTMCell(tf.nn.rnn_cell.RNNCell):
"""A LSTM cell with convolutions instead of multiplications.
Reference:
Xingjian, S. H. I., et al. "Convolutional LSTM network: A machine learning approach for precipitation nowcasting." Advances in Neural Information Processing Systems. 2015.
"""
def __init__(self, shape, filters, kernel, forget_bias=1.0, activation=tf.tanh, data_format='channels_last', reuse=None):
super(ConvLSTMCell, self).__init__(_reuse=reuse)
self._kernel = kernel
self._filters = filters
self._forget_bias = forget_bias
self._activation = activation
if data_format == 'channels_last':
self._size = tf.TensorShape(shape + [self._filters])
self._feature_axis = self._size.ndims
self._data_format = None
elif data_format == 'channels_first':
self._size = tf.TensorShape([self._filters] + shape)
self._feature_axis = 0
self._data_format = 'NC'
else:
raise ValueError('Unknown data_format')
@property
def state_size(self):
return tf.nn.rnn_cell.LSTMStateTuple(self._size, self._size)
@property
def output_size(self):
return self._size
def call(self, x, state):
c, h = state
x = tf.concat([x, h], axis=self._feature_axis)
n = x.shape[-1].value
m = 4 * self._filters if self._filters > 1 else 4
W = tf.get_variable('kernel', self._kernel + [n, m])
y = tf.nn.convolution(x, W, 'SAME', data_format=self._data_format)
y += tf.get_variable('bias', [m], initializer=tf.zeros_initializer())
j, i, f, o = tf.split(y, 4, axis=self._feature_axis)
f = tf.sigmoid(f + self._forget_bias)
i = tf.sigmoid(i)
c = c * f + i * self._activation(j)
o = tf.sigmoid(o)
h = o * self._activation(c)
state = tf.nn.rnn_cell.LSTMStateTuple(c, h)
return h, state
class ConvGRUCell(tf.nn.rnn_cell.RNNCell):
"""A GRU cell with convolutions instead of multiplications."""
def __init__(self, shape, filters, kernel, activation=tf.tanh, data_format='channels_last', reuse=None):
super(ConvGRUCell, self).__init__(_reuse=reuse)
self._filters = filters
self._kernel = kernel
self._activation = activation
if data_format == 'channels_last':
self._size = tf.TensorShape(shape + [self._filters])
self._feature_axis = self._size.ndims
self._data_format = None
elif data_format == 'channels_first':
self._size = tf.TensorShape([self._filters] + shape)
self._feature_axis = 0
self._data_format = 'NC'
else:
raise ValueError('Unknown data_format')
@property
def state_size(self):
return self._size
@property
def output_size(self):
return self._size
def call(self, x, h):
channels = x.shape[self._feature_axis].value
with tf.variable_scope('gates'):
inputs = tf.concat([x, h], axis=self._feature_axis)
n = channels + self._filters
m = 2 * self._filters if self._filters > 1 else 2
W = tf.get_variable('kernel', self._kernel + [n, m])
y = tf.nn.convolution(inputs, W, 'SAME', data_format=self._data_format)
y += tf.get_variable('bias', [m], initializer=tf.ones_initializer())
r, u = tf.split(y, 2, axis=self._feature_axis)
r, u = tf.sigmoid(r), tf.sigmoid(u)
with tf.variable_scope('candidate'):
inputs = tf.concat([x, r * h], axis=self._feature_axis)
n = channels + self._filters
m = self._filters
W = tf.get_variable('kernel', self._kernel + [n, m])
y = tf.nn.convolution(inputs, W, 'SAME', data_format=self._data_format)
y += tf.get_variable('bias', [m], initializer=tf.zeros_initializer())
h = u * h + (1 - u) * self._activation(y)
return h, h