-
Notifications
You must be signed in to change notification settings - Fork 17
/
ae_deep.py
325 lines (242 loc) · 12.6 KB
/
ae_deep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
'''Encoder, Decoder and Auto-encoder based on the VGG-16 with batch normalization
Written By: Anders Ohrn, September 2020
'''
import torch
from torch import nn
from torchvision import models
class EncoderVGG(nn.Module):
'''Encoder of image based on the architecture of VGG-16 with batch normalization.
Args:
pretrained_params (bool, optional): If the network should be populated with pre-trained VGG parameters.
Defaults to True.
'''
channels_in = 3
channels_code = 512
def __init__(self, pretrained_params=True):
super(EncoderVGG, self).__init__()
vgg = models.vgg16_bn(pretrained=pretrained_params)
del vgg.classifier
del vgg.avgpool
self.encoder = self._encodify_(vgg)
def forward(self, x):
'''Execute the encoder on the image input
Args:
x (Tensor): image tensor
Returns:
x_code (Tensor): code tensor
pool_indices (list): Pool indices tensors in order of the pooling modules
'''
pool_indices = []
x_current = x
for module_encode in self.encoder:
output = module_encode(x_current)
# If the module is pooling, there are two outputs, the second the pool indices
if isinstance(output, tuple) and len(output) == 2:
x_current = output[0]
pool_indices.append(output[1])
else:
x_current = output
return x_current, pool_indices
@staticmethod
def dim_code(img_dim):
'''Convenience function to provide dimension of code given a square image of specified size. The transformation
is defined by the details of the VGG method. The aim should be to resize the image to produce an integer
code dimension.
Args:
img_dim (int): Height/width dimension of the tentative square image to input to the auto-encoder
Returns:
code_dim (float): Height/width dimension of the code
int_value (bool): If False, the tentative image dimension will not produce an integer dimension for the
code. If True it will. For actual applications, this value should be True.
'''
value = img_dim / 2**5
int_value = img_dim % 2**5 == 0
return value, int_value
def _encodify_(self, encoder):
'''Create list of modules for encoder based on the architecture in VGG template model.
In the encoder-decoder architecture, the unpooling operations in the decoder require pooling
indices from the corresponding pooling operation in the encoder. In VGG template, these indices
are not returned. Hence the need for this method to extent the pooling operations.
Args:
encoder : the template VGG model
Returns:
modules : the list of modules that define the encoder corresponding to the VGG model
'''
modules = nn.ModuleList()
for module in encoder.features:
if isinstance(module, nn.MaxPool2d):
module_add = nn.MaxPool2d(kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
return_indices=True)
modules.append(module_add)
else:
modules.append(module)
return modules
class DecoderVGG(nn.Module):
'''Decoder of code based on the architecture of VGG-16 with batch normalization.
The decoder is created from a pseudo-inversion of the encoder based on VGG-16 with batch normalization. The
pesudo-inversion is obtained by (1) replacing max pooling layers in the encoder with max un-pooling layers with
pooling indices from the mirror image max pooling layer, and by (2) replacing 2D convolutions with transposed
2D convolutions. The ReLU and batch normalization layers are the same as in the encoder, that is subsequent to
the convolution layer.
Args:
encoder: The encoder instance of `EncoderVGG` that is to be inverted into a decoder
'''
channels_in = EncoderVGG.channels_code
channels_out = 3
def __init__(self, encoder):
super(DecoderVGG, self).__init__()
self.decoder = self._invert_(encoder)
def forward(self, x, pool_indices):
'''Execute the decoder on the code tensor input
Args:
x (Tensor): code tensor obtained from encoder
pool_indices (list): Pool indices Pytorch tensors in order the pooling modules in the encoder
Returns:
x (Tensor): decoded image tensor
'''
x_current = x
k_pool = 0
reversed_pool_indices = list(reversed(pool_indices))
for module_decode in self.decoder:
# If the module is unpooling, collect the appropriate pooling indices
if isinstance(module_decode, nn.MaxUnpool2d):
x_current = module_decode(x_current, indices=reversed_pool_indices[k_pool])
k_pool += 1
else:
x_current = module_decode(x_current)
return x_current
def _invert_(self, encoder):
'''Invert the encoder in order to create the decoder as a (more or less) mirror image of the encoder
The decoder is comprised of two principal types: the 2D transpose convolution and the 2D unpooling. The 2D transpose
convolution is followed by batch normalization and activation. Therefore as the module list of the encoder
is iterated over in reverse, a convolution in encoder is turned into transposed convolution plus normalization
and activation, and a maxpooling in encoder is turned into unpooling.
Args:
encoder (ModuleList): the encoder
Returns:
decoder (ModuleList): the decoder obtained by "inversion" of encoder
'''
modules_transpose = []
for module in reversed(encoder):
if isinstance(module, nn.Conv2d):
kwargs = {'in_channels' : module.out_channels, 'out_channels' : module.in_channels,
'kernel_size' : module.kernel_size, 'stride' : module.stride,
'padding' : module.padding}
module_transpose = nn.ConvTranspose2d(**kwargs)
module_norm = nn.BatchNorm2d(module.in_channels)
module_act = nn.ReLU(inplace=True)
modules_transpose += [module_transpose, module_norm, module_act]
elif isinstance(module, nn.MaxPool2d):
kwargs = {'kernel_size' : module.kernel_size, 'stride' : module.stride,
'padding' : module.padding}
module_transpose = nn.MaxUnpool2d(**kwargs)
modules_transpose += [module_transpose]
# Discard the final normalization and activation, so final module is convolution with bias
modules_transpose = modules_transpose[:-2]
return nn.ModuleList(modules_transpose)
class AutoEncoderVGG(nn.Module):
'''Auto-Encoder based on the VGG-16 with batch normalization template model. The class is comprised of
an encoder and a decoder.
Args:
pretrained_params (bool, optional): If the network should be populated with pre-trained VGG parameters.
Defaults to True.
'''
channels_in = EncoderVGG.channels_in
channels_code = EncoderVGG.channels_code
channels_out = DecoderVGG.channels_out
def __init__(self, pretrained_params=True):
super(AutoEncoderVGG, self).__init__()
self.encoder = EncoderVGG(pretrained_params=pretrained_params)
self.decoder = DecoderVGG(self.encoder.encoder)
@staticmethod
def dim_code(img_dim):
'''Convenience function to provide dimension of code given a square image of specified size. The transformation
is defined by the details of the VGG method. The aim should be to resize the image to produce an integer
code dimension.
Args:
img_dim (int): Height/width dimension of the tentative square image to input to the auto-encoder
Returns:
code_dim (float): Height/width dimension of the code
int_value (bool): If False, the tentative image dimension will not produce an integer dimension for the
code. If True it will. For actual applications, this value should be True.
'''
return EncoderVGG.dim_code(img_dim)
@staticmethod
def state_dict_mutate(encoder_or_decoder, ae_state_dict):
'''Mutate an auto-encoder state dictionary into a pure encoder or decoder state dictionary
The method depends on the naming of the encoder and decoder attribute names as defined in the auto-encoder
initialization. Currently these names are "encoder" and "decoder".
The state dictionary that is returned can be loaded into a pure EncoderVGG or DecoderVGG instance.
Args:
encoder_or_decoder (str): Specification if mutation should be to an encoder state dictionary or decoder
state dictionary, where the former is denoted with "encoder" and the latter "decoder"
ae_state_dict (OrderedDict): The auto-encoder state dictionary to mutate
Returns:
state_dict (OrderedDict): The mutated state dictionary that can be loaded into either an EncoderVGG
or DecoderVGG instance
Raises:
RuntimeError : if state dictionary contains keys that cannot be attributed to either encoder or decoder
ValueError : if specified mutation is neither "encoder" or "decoder"
'''
if not (encoder_or_decoder == 'encoder' or encoder_or_decoder == 'decoder'):
raise ValueError('State dictionary mutation only for "encoder" or "decoder", not {}'.format(encoder_or_decoder))
keys = list(ae_state_dict)
for key in keys:
if 'encoder' in key or 'decoder' in key:
if encoder_or_decoder in key:
key_new = key[len(encoder_or_decoder) + 1:]
ae_state_dict[key_new] = ae_state_dict[key]
del ae_state_dict[key]
else:
del ae_state_dict[key]
else:
raise RuntimeError('State dictionary key {} is neither part of encoder or decoder'.format(key))
return ae_state_dict
def forward(self, x):
'''Forward the autoencoder for image input
Args:
x (Tensor): image tensor
Returns:
x_prime (Tensor): image tensor following encoding and decoding
'''
code, pool_indices = self.encoder(x)
x_prime = self.decoder(code, pool_indices)
return x_prime
class EncoderVGGMerged(EncoderVGG):
'''Special case of the VGG Encoder wherein the code is merged along the height/width dimension. This is a thin child
class of `EncoderVGG`.
Args:
merger_type (str, optional): Defines how the code is merged. If `None`, there is no merger and the identical
functionality to the parent class `EncoderVGG` is obtained. If "mean", the channels for the height/width
code cells are averaged; the number of channels are identical between input and output. If "flatten", the
channels for the height/width code cells are stacked on top each other; the number of channels of the output
is the number of channels of the input multiplied by number of height cells and multiplied by the number
of width cells.
pretrained_params (bool, optional): If the network should be populated with pre-trained VGG parameters.
Defaults to True.
'''
def __init__(self, merger_type=None, pretrained_params=True):
super(EncoderVGGMerged, self).__init__(pretrained_params=pretrained_params)
if merger_type is None:
self.code_post_process = lambda x: x
self.code_post_process_kwargs = {}
elif merger_type == 'mean':
self.code_post_process = torch.mean
self.code_post_process_kwargs = {'dim' : (-2, -1)}
elif merger_type == 'flatten':
self.code_post_process = torch.flatten
self.code_post_process_kwargs = {'start_dim' : 1, 'end_dim' : -1}
else:
raise ValueError('Unknown merger type for the encoder code: {}'.format(merger_type))
def forward(self, x):
'''Execute the encoder on the image input
Args:
x (Tensor): image tensor
Returns:
x_code (Tensor): merged code tensor
'''
x_current, _ = super().forward(x)
x_code = self.code_post_process(x_current, **self.code_post_process_kwargs)
return x_code