-
Notifications
You must be signed in to change notification settings - Fork 1
/
residual_network.py
68 lines (48 loc) · 2.28 KB
/
residual_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""residual_network.py
Utilities for the residual autoencoder network
"""
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, Activation, Add
from tensorflow.keras.initializers import glorot_uniform
def ResidualAutoencoder_V2_64(batch_size=32):
""" returns a keras model of the architecture described by this paper:
https://web.stanford.edu/class/cs331b/2016/projects/zhao.pdf
The required input shape is (64, 64, 1)
"""
input_shape = (64, 64, 1)
img_input = Input(shape=input_shape)
## Downsampling
x = img_input
x = Conv2D(filters=64, kernel_size=(4, 4), padding='same', kernel_initializer=glorot_uniform(seed=0))(x)
x = Activation('relu')(x)
shortcut_1 = x
x = Conv2D(filters=64, kernel_size=(8, 8), strides=(2, 2), padding='same',
kernel_initializer=glorot_uniform(seed=0))(x)
x = Activation('relu')(x)
shortcut_2 = x
x = Conv2D(filters=128, kernel_size=(8, 8), strides=(2, 2), padding='same',
kernel_initializer=glorot_uniform(seed=0))(x)
x = Activation('relu')(x)
shortcut_3 = x
x = Conv2D(filters=256, kernel_size=(8, 8), strides=(2, 2), padding='same',
kernel_initializer=glorot_uniform(seed=0))(x)
x = Activation('relu')(x)
shortcut_4 = x
x = Conv2D(filters=518, kernel_size=(4, 4), strides=(2, 2), padding='same',
kernel_initializer=glorot_uniform(seed=0))(x)
x = Activation('relu')(x)
## Upsampling
x = Conv2DTranspose(filters=256, kernel_size=4, strides=(2, 2), padding='same')(x)
x = Activation('relu')(x)
x = Add()([x, shortcut_4])
x = Conv2DTranspose(filters=128, kernel_size=8, strides=(2, 2), padding='same')(x)
x = Activation('relu')(x)
x = Add()([x, shortcut_3])
x = Conv2DTranspose(filters=64, kernel_size=16, strides=(2, 2), padding='same')(x)
x = Activation('relu')(x)
x = Add()([x, shortcut_2])
x = Conv2DTranspose(filters=64, kernel_size=16, strides=(2, 2), padding='same')(x)
x = Activation('relu')(x)
x = Add()([x, shortcut_1])
x = Conv2DTranspose(filters=1, kernel_size=64, padding='same')(x)
return tf.keras.Model(img_input, x, name='res_autoencoder_v2')