-
Notifications
You must be signed in to change notification settings - Fork 3
/
lenet.py
569 lines (474 loc) · 22.8 KB
/
lenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
#!/usr/bin/env python3
# Copyright 2020 Christian Henning
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# @title :mnets/lenet.py
# @author :ch
# @contact :[email protected]
# @created :02/28/2020
# @version :1.0
# @python_version :3.6.10
"""
LeNet
-----
This module contains a general classifier template and a LeNet-like network
to classify either MNIST or CIFAR-10 images. The network is implemented in a
way that it might not have trainable parameters. Instead, the network weights
would have to be passed to the ``forward`` method. This makes the usage of a
hypernetwork (a network that generates the weights of another network)
particularly easy.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from mnets.classifier_interface import Classifier
from mnets.mnet_interface import MainNetInterface
from utils.torch_utils import init_params
class LeNet(Classifier):
"""The network consists of two convolutional layers followed by two fully-
connected layers. See implementation for details.
LeNet was originally introduced in
"Gradient-based learning applied to document recognition", LeCun et
al., 1998.
Though, the implementation provided here has several difference compared
to the original LeNet architecture (e.g., the LeNet-5 architecture):
- There is no special connectivity map before the second convolutional
layer as described by table 1 in the original paper.
- The dimensions of layers and their activation functions are dfferent.
- The original LeNet-5 has a third fully connected layer with 1x1 kernels.
We mainly use this modified LeNet architecture for MNIST:
- A small architecture with only 21,840 weights.
- A larger architecture with 431,080 weights.
Both of these architectures are typically used for MNIST nowadays.
Note, a variant of this architecture is also used for CIFAR-10, e.g. in
"Bayesian Convolutional Neural Networks with Bernoulli Approximate
Variational Inference", Gal et al., 2015.
and
"Multiplicative Normalizing Flows for Variational Bayesian Neural
Networks", Louizos et al., 2017.
In both these works, the dimensions of the weight parameters are:
.. code-block:: python
main_dims=[[192,3,5,5],[192],[192,192,5,5],[192],[1000,4800],
[1000],[10,1000],[10]],
which is an architecture with 5,747,394 weights. Note, the authors used
dropout in different configurations, e.g., after each layer, only after
the fully-connected layer or no dropout at all.
Args:
in_shape (tuple or list): The shape of an input sample.
.. note::
We assume the Tensorflow format, where the last entry
denotes the number of channels.
num_classes (int): The number of output neurons.
verbose (bool): Allow printing of general information about the
generated network (such as number of weights).
arch (str): The architecture to be employed. The following options are
available:
- ``'mnist_small'``: A small LeNet with 21,840 weights suitable
for MNIST
- ``'mnist_large'``: A larger LeNet with 431,080 weights suitable
for MNIST
- ``'cifar'``: A huge LeNet with 5,747,394 weights designed for
CIFAR-10.
no_weights (bool): If set to ``True``, no trainable parameters will be
constructed, i.e., weights are assumed to be produced ad-hoc
by a hypernetwork and passed to the :meth:`forward` method.
init_weights (optional): This option is for convinience reasons.
The option expects a list of parameter values that are used to
initialize the network weights. As such, it provides a
convinient way of initializing a network with a weight draw
produced by the hypernetwork.
dropout_rate (float): If ``-1``, no dropout will be applied. Otherwise a
number between 0 and 1 is expected, denoting the dropout rate.
Dropout will be applied after the convolutional layers
(before pooling) and after the first fully-connected layer
(after the activation function).
**kwargs: Keyword arguments regarding context modulation. This class
can process the same context-modulation related arguments as class
:class:`mnets.mlp.MLP`. One may additionally specify the argument
``context_mod_apply_pixel_wise`` (see class
:class:`mnets.resnet.ResNet`).
"""
_ARCHITECTURES = {
'mnist_small': [[10,1,5,5],[10],[20,10,5,5],[20],
[50,320],[50],[10,50],[10]],
'mnist_large': [[20,1,5,5],[20],[50,20,5,5],[50],[500,800],[500],
[10,500],[10]],
'cifar': [[192,3,5,5],[192],[192,192,5,5],[192],[1000,4800],[1000],
[10,1000],[10]]
}
def __init__(self, in_shape=(28, 28, 1), num_classes=10, verbose=True,
arch='mnist_large', no_weights=False, init_weights=None,
dropout_rate=-1, #0.5
**kwargs):
super(LeNet, self).__init__(num_classes, verbose)
self._in_shape = in_shape
assert arch in LeNet._ARCHITECTURES.keys()
self._chosen_arch = LeNet._ARCHITECTURES[arch]
if num_classes != 10:
self._chosen_arch[-2][0] = num_classes
self._chosen_arch[-1][0] = num_classes
# Sanity check, given current implementation.
if arch.startswith('mnist'):
if not in_shape[0] == in_shape[1] == 28:
raise ValueError('MNIST LeNet architectures expect input ' +
'images of size 28x28.')
else:
if not in_shape[0] == in_shape[1] == 32:
raise ValueError('CIFAR LeNet architectures expect input ' +
'images of size 32x32.')
### Parse or set context-mod arguments ###
rem_kwargs = MainNetInterface._parse_context_mod_args(kwargs)
if len(rem_kwargs) > 0:
raise ValueError('Keyword arguments %s unknown.' % str(rem_kwargs))
# Since this is a conv-net, we may also want to add the following.
if 'context_mod_apply_pixel_wise' not in kwargs.keys():
kwargs['context_mod_apply_pixel_wise'] = False
self._use_context_mod = kwargs['use_context_mod']
self._context_mod_inputs = kwargs['context_mod_inputs']
self._no_last_layer_context_mod = kwargs['no_last_layer_context_mod']
self._context_mod_no_weights = kwargs['context_mod_no_weights']
self._context_mod_post_activation = \
kwargs['context_mod_post_activation']
self._context_mod_gain_offset = kwargs['context_mod_gain_offset']
self._context_mod_gain_softplus = kwargs['context_mod_gain_softplus']
self._context_mod_apply_pixel_wise = \
kwargs['context_mod_apply_pixel_wise']
### Setup class attributes ###
assert(init_weights is None or \
(not no_weights or not self._context_mod_no_weights))
self._no_weights = no_weights
self._dropout_rate = dropout_rate
self._has_bias = True
self._has_fc_out = True
# We need to make sure that the last 2 entries of `weights` correspond
# to the weight matrix and bias vector of the last layer!
self._mask_fc_out = True
self._has_linear_out = True
self._param_shapes = []
self._param_shapes_meta = []
self._internal_params = None if no_weights and \
self._context_mod_no_weights else nn.ParameterList()
self._hyper_shapes_learned = None \
if not no_weights and not self._context_mod_no_weights else []
self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \
is None else []
self._layer_weight_tensors = nn.ParameterList()
self._layer_bias_vectors = nn.ParameterList()
if dropout_rate != -1:
assert(dropout_rate >= 0. and dropout_rate <= 1.)
# FIXME `nn.Dropout2d` zeroes out whole feature maps. Is that really
# desired here?
self._drop_conv1 = nn.Dropout2d(p=dropout_rate)
self._drop_conv2 = nn.Dropout2d(p=dropout_rate)
self._drop_fc1 = nn.Dropout(p=dropout_rate)
### Define and initialize context mod layers/weights ###
self._context_mod_layers = nn.ModuleList() if self._use_context_mod \
else None
if self._use_context_mod:
cm_layer_inds = []
cm_shapes = [] # Output shape of all context-mod layers.
if self._context_mod_inputs:
cm_shapes.append([in_shape[2], *in_shape[:2]])
# We reserve layer zero for input context-mod. Otherwise, there
# is no layer zero.
cm_layer_inds.append(0)
layer_out_shapes = self._compute_layer_out_sizes()
# Context-modulation is applied after the pooling layers.
# So we delete the shapes of the conv-layer outputs and keep the
# ones of the pooling layer outputs.
del layer_out_shapes[2]
del layer_out_shapes[0]
cm_shapes.extend(layer_out_shapes)
cm_layer_inds.extend(range(2, 2*len(layer_out_shapes)+1, 2))
if self._no_last_layer_context_mod:
cm_shapes = cm_shapes[:-1]
cm_layer_inds = cm_layer_inds[:-1]
if not self._context_mod_apply_pixel_wise:
# Only scalar gain and shift per feature map!
for i, s in enumerate(cm_shapes):
if len(s) == 3:
cm_shapes[i] = [s[0], 1, 1]
self._add_context_mod_layers(cm_shapes, cm_layers=cm_layer_inds)
### Define and add conv- and fc-layer weights.
for i, s in enumerate(self._chosen_arch):
if not no_weights:
self._internal_params.append(nn.Parameter(torch.Tensor(*s),
requires_grad=True))
if len(s) == 1:
self._layer_bias_vectors.append(self._internal_params[-1])
else:
self._layer_weight_tensors.append(self._internal_params[-1])
else:
self._hyper_shapes_learned.append(s)
self._hyper_shapes_learned_ref.append(len(self.param_shapes))
self._param_shapes.append(s)
self._param_shapes_meta.append({
'name': 'weight' if len(s) != 1 else 'bias',
'index': -1 if no_weights else len(self._internal_params)-1,
'layer': 2 * (i // 2) + 1
})
### Initialize weights.
if init_weights is not None:
assert len(init_weights) == len(self.weights)
for i in range(len(init_weights)):
assert np.all(np.equal(list(init_weights[i].shape),
self.weights[i].shape))
self.weights[i].data = init_weights[i]
else:
for i in range(len(self._layer_weight_tensors)):
init_params(self._layer_weight_tensors[i],
self._layer_bias_vectors[i])
### Print user info.
if verbose:
if self._use_context_mod:
cm_param_shapes = []
for cm_layer in self.context_mod_layers:
cm_param_shapes.extend(cm_layer.param_shapes)
cm_num_weights = \
MainNetInterface.shapes_to_num_weights(cm_param_shapes)
print('Creating a LeNet with %d weights' % self.num_params
+ (' (including %d weights associated with-' % cm_num_weights
+ 'context modulation)' if self._use_context_mod else '')
+ '.'
+ (' The network uses dropout.' if dropout_rate != -1 \
else ''))
self._is_properly_setup()
def forward(self, x, weights=None, distilled_params=None, condition=None):
"""Compute the output :math:`y` of this network given the input
:math:`x`.
Args:
(....): See docstring of method
:meth:`mnets.mnet_interface.MainNetInterface.forward`. We
provide some more specific information below.
weights (list or dict): See argument ``weights`` of method
:meth:`mnets.mlp.MLP.forward`.
condition (int, optional): If provided, then this argument will be
passed as argument ``ckpt_id`` to the method
:meth:`utils.context_mod_layer.ContextModLayer.forward`.
Returns:
(torch.Tensor): The output of the network.
"""
if ((not self._use_context_mod and self._no_weights) or \
(self._no_weights or self._context_mod_no_weights)) and \
weights is None:
raise Exception('Network was generated without weights. ' +
'Hence, "weights" option may not be None.')
############################################
### Extract which weights should be used ###
############################################
# FIXME Code copied from MLP its `forward` method.
# I.e., are we using internally maintained weights or externally given
# ones or are we even mixing between these groups.
n_cm = self._num_context_mod_shapes()
if weights is None:
weights = self.weights
if self._use_context_mod:
cm_weights = weights[:n_cm]
int_weights = weights[n_cm:]
else:
cm_weights = None
int_weights = weights
else:
int_weights = None
cm_weights = None
if isinstance(weights, dict):
assert('internal_weights' in weights.keys() or \
'mod_weights' in weights.keys())
if 'internal_weights' in weights.keys():
int_weights = weights['internal_weights']
if 'mod_weights' in weights.keys():
cm_weights = weights['mod_weights']
else:
if self._use_context_mod and \
len(weights) == n_cm:
cm_weights = weights
else:
assert len(weights) == len(self.param_shapes)
if self._use_context_mod:
cm_weights = weights[:n_cm]
int_weights = weights[n_cm:]
else:
int_weights = weights
if self._use_context_mod and cm_weights is None:
if self._context_mod_no_weights:
raise Exception('Network was generated without weights ' +
'for context-mod layers. Hence, they must be passed ' +
'via the "weights" option.')
cm_weights = self.weights[:n_cm]
if int_weights is None:
if self._no_weights:
raise Exception('Network was generated without internal ' +
'weights. Hence, they must be passed via the ' +
'"weights" option.')
if self._context_mod_no_weights:
int_weights = self.weights
else:
int_weights = self.weights[n_cm:]
# Note, context-mod weights might have different shapes, as they
# may be parametrized on a per-sample basis.
if self._use_context_mod:
assert len(cm_weights) == n_cm
int_shapes = self.param_shapes[n_cm:]
assert len(int_weights) == len(int_shapes)
for i, s in enumerate(int_shapes):
assert np.all(np.equal(s, list(int_weights[i].shape)))
cm_ind = 0
# Split context-mod weights per context-mod layer.
if cm_weights is not None:
cm_weights_layer = []
cm_start = 0
for cm_layer in self.context_mod_layers:
cm_end = cm_start + len(cm_layer.param_shapes)
cm_weights_layer.append(cm_weights[cm_start:cm_end])
cm_start = cm_end
#######################
### Parse condition ###
#######################
cmod_cond = None
if condition is not None:
assert isinstance(condition, int)
cmod_cond = condition
# FIXME We always require context-mod weight above, but
# we can't pass both (a condition and weights) to the
# context-mod layers.
# An unelegant solution would be, to just set all
# context-mod weights to None.
raise NotImplementedError('CM-conditions not implemented!')
cm_weights_layer = [None] * len(cm_weights_layer)
###########################
### Forward Computation ###
###########################
### Helper function to handle context-mod and non-linearities.
def modulate_layer(h):
"""Compute context-modulation and non-linearity.
The order if the following:
context-mod (if pre-activation) -> non-linearity ->
context-mod (if post-activation)
This method increments the index ``cm_ind``.
Args:
h: Input activity.
Returns:
Output of layer.
"""
nonlocal cm_ind
# Context-dependent modulation (pre-activation).
if self._use_context_mod and \
not self._context_mod_post_activation:
h = self._context_mod_layers[cm_ind].forward(h,
weights=cm_weights_layer[cm_ind], ckpt_id=cmod_cond)
cm_ind += 1
# Non-linearity
h = F.relu(h)
# Context-dependent modulation (post-activation).
if self._use_context_mod and self._context_mod_post_activation:
h = self._context_mod_layers[cm_ind].forward(h,
weights=cm_weights_layer[cm_ind], ckpt_id=cmod_cond)
cm_ind += 1
return h
x = x.view(-1, *self._in_shape)
x = x.permute(0, 3, 1, 2)
h = x
# Context-dependent modulation of inputs directly.
if self._use_context_mod and self._context_mod_inputs:
h = self._context_mod_layers[cm_ind].forward(h,
weights=cm_weights_layer[cm_ind], ckpt_id=cmod_cond)
cm_ind += 1
h = F.conv2d(h, int_weights[0], bias=int_weights[1])
if self._dropout_rate != -1:
h = self._drop_conv1(h)
h = F.max_pool2d(h, 2)
h = modulate_layer(h)
h = F.conv2d(h, int_weights[2], bias=int_weights[3])
if self._dropout_rate != -1:
h = self._drop_conv2(h)
h = F.max_pool2d(h, 2)
h = modulate_layer(h)
h = h.reshape(-1, int_weights[4].size()[1])
h = F.linear(h, int_weights[4], bias=int_weights[5])
h = modulate_layer(h)
# FIXME Before we applied context-modulation after dropout, since
# dropout was before the non-linearity and not after.
if self._dropout_rate != -1:
h = self._drop_fc1(h)
h = F.linear(h, int_weights[6], bias=int_weights[7])
# Context-dependent modulation in output layer.
if self._use_context_mod and not self._no_last_layer_context_mod:
h = self._context_mod_layers[cm_ind].forward(h,
weights=cm_weights_layer[cm_ind], ckpt_id=cmod_cond)
return h
def distillation_targets(self):
"""Targets to be distilled after training.
See docstring of abstract super method
:meth:`mnets.mnet_interface.MainNetInterface.distillation_targets`.
This network does not have any distillation targets.
Returns:
``None``
"""
return None
def _compute_layer_out_sizes(self):
"""Compute the output shapes of all layers in this network.
This method will compute the output shape of each layer in this network,
including the output layer, which just corresponds to the number of
classes.
Returns:
(list): A list of shapes (lists of integers). The first entry will
correspond to the shape of the output of the first convolutional
layer. The last entry will correspond to the output shape.
.. note::
Output shapes of convolutional layers will adhere PyTorch
convention, i.e., ``[C, H, W]``, where ``C`` denotes the channel
dimension.
.. note::
Pooling layers are considered as individual layers.
"""
in_shape = self._in_shape
assert len(self._chosen_arch) == 8
fs = [self._chosen_arch[0][0], self._chosen_arch[2][0]]
ks = [self._chosen_arch[0][2], 2, self._chosen_arch[2][2], 2]
assert self._chosen_arch[0][2] == self._chosen_arch[0][3] and \
self._chosen_arch[2][2] == self._chosen_arch[2][3]
pd = 0 # all paddings are 0.
# Note, `in_shape` is in Tensorflow layout.
assert(len(in_shape) == 3)
in_shape = [in_shape[2], *in_shape[:2]]
ret = []
C, H, W = in_shape
# First conv layer (stride 1).
C = fs[0]
H = (H - ks[0] + 2*pd) // 1 + 1
W = (W - ks[0] + 2*pd) // 1 + 1
ret.append([C, H, W])
# First pooling layer (stride == kernel size)
H = (H - ks[1] + 2*pd) // ks[1] + 1
W = (W - ks[1] + 2*pd) // ks[1] + 1
ret.append([C, H, W])
# Second conv layer (stride 1).
C = fs[1]
H = (H - ks[2] + 2*pd) // 1 + 1
W = (W - ks[2] + 2*pd) // 1 + 1
ret.append([C, H, W])
# Second pooling layer (stride == kernel size)
H = (H - ks[3] + 2*pd) // ks[3] + 1
W = (W - ks[3] + 2*pd) // ks[3] + 1
ret.append([C, H, W])
assert C * H * W == self._chosen_arch[4][1]
# First fully-connected layer.
ret.append([self._chosen_arch[6][1]])
# Output layer.
assert self._num_classes == self._chosen_arch[6][0]
ret.append([self._num_classes])
return ret
if __name__ == '__main__':
pass