-
Notifications
You must be signed in to change notification settings - Fork 0
/
odeblocktensorflow.py
50 lines (37 loc) · 1.89 KB
/
odeblocktensorflow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#imports
import tensorflow as tf
import tensorflow_scientific as tfs
import tensorflow.keras.backend as K
#############################################################
class ODEBlock(tf.keras.layers.Layer):
def __init__(self, filters, kernel_size, **kwargs):
self.filters = filters
self.kernel_size = kernel_size
super(ODEBlock, self).__init__(**kwargs)
def build(self, input_shape):
self.conv2d_w1 = self.add_weight("conv2d_w1", self.kernel_size + (self.filters + 1, self.filters), initializer='glorot_uniform')
self.conv2d_w2 = self.add_weight("conv2d_w2", self.kernel_size + (self.filters + 1, self.filters), initializer='glorot_uniform')
self.conv2d_b1 = self.add_weight("conv2d_b1", (self.filters,), initializer='zero')
self.conv2d_b2 = self.add_weight("conv2d_b2", (self.filters,), initializer='zero')
super(ODEBlock, self).build(input_shape)
def call(self, x):
t = K.constant([0,1], dtype="float32")
#return tf.contrib.integrate.odeint(self.ode_func, x, t, rtol=1e-3, atol=1e-3)[1] #for tensorflow 1.x
return tfs.integrate.odeint(self.ode_func, x, t, rtol=1e-3, atol=1e-3)[1]
def compute_output_shape(self, input_shape):
return input_shape
def ode_func(self, x, t):
y = self.concat_t(x, t)
y = K.conv2d(y, self.conv2d_w1, padding="same")
y = K.bias_add(y, self.conv2d_b1)
y = K.relu(y)
y = self.concat_t(y, t)
y = K.conv2d(y, self.conv2d_w2, padding="same")
y = K.bias_add(y, self.conv2d_b2)
y = K.relu(y)
return y
def concat_t(self, x, t):
new_shape = tf.concat([tf.shape(x)[:-1], tf.constant([1],dtype="int32",shape=(1,))], axis=0)
t = tf.ones(shape=new_shape) * tf.reshape(t, (1, 1, 1, 1))
return tf.concat([x, t], axis=-1)
#############################################################