-
Notifications
You must be signed in to change notification settings - Fork 17
/
sa_hyper_model.py
806 lines (681 loc) · 34.1 KB
/
sa_hyper_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
#!/usr/bin/env python3
# Copyright 2019 Christian Henning
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# @title :cifar/sa_hyper_model.py
# @author :ch
# @contact :[email protected]
# @created :02/21/2019
# @version :1.0
# @python_version :3.6.6
"""
Convolutional hypernetwork with self-attention layers
-----------------------------------------------------
The module :mod:`cifar.sa_hyper_model` implements a hypernetwork that uses
transpose convolutions (like the generator of a GAN) to generate weights.
Though, as convolutions usually suffer from only capturing local correlations
sufficiently, we incorporate the self-attention mechanism developed by
Zhang et al., "Self-Attention Generative Adversarial Networks", 2018,
https://arxiv.org/abs/1805.08318
See :class:`utils.self_attention_layer.SelfAttnLayerV2` for details on this
layer type.
.. note::
This module has been temporarily moved to this location from the deprecated
package ``classifier``. Once a new hypernetwork interface has been designed,
all hypernets (including this one) will be moved to the subpackage
``hnets``.
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.self_attention_layer import SelfAttnLayerV2
from mnets.mnet_interface import MainNetInterface
from toy_example.hyper_model import HyperNetwork
from utils.module_wrappers import CLHyperNetInterface
from utils.misc import init_params
class SAHnetPart(nn.Module, CLHyperNetInterface):
"""The goal of the network is to produce a chunk of the weights that are
used in a main network. Therefore, the network expects an embedding as
input (additional to the actual hypernet input), which will encode the
chunk of weights of the main network that will be generated by this
network.
This is a convolutional network, employing transpose convolutions. The
network structure is inspired by the DCGAN generator structure, though,
we are additionally using self-attention layers to model global
dependencies.
In general, each transpose convolutional layer will roughly double its
input size. Though, we set the hard constraint that if the input size of
a transpose convolutional layer would be smaller 4, than it doesn't change
the size.
Args:
out_size: A tuple of (width, height), denoting the output shape of
the weights generated by this hypernet. The number of output
channels is assumed to be 1, except if specified otherwise via
(width, height, channels).
num_layers: The number of transpose convolutional layers including
the initial fully-connected layer.
num_filters (optional): List of integers of length num_layers-1.
The number of output channels in each hidden transpose conv
layer. By default, the number of filters in the last hidden
layer will be 128 and doubled in every prior layer. Note, the
output of the first layer (which is fully-connected)) is here
considered to be in the shape of an image tensor.
kernel_size (optional): A single number, a tuple ``(k_x, k_y)`` or
a list of scalars/tuples of length ``num_layers-1``. Specifying the
kernel size in each convolutional layer.
sa_units: List of integers, each representing the index of a layer
in this network after which a self-attention unit should be
inserted. For instance, index 0 represents the
fully-connected layer. The last layer may not be chosen.
input_dim: The dimensionality of the input vectors (comprising
both: chunk embedding and actual hypernet input).
use_batch_norm: If ``True``, batchnorm will be applied to all hidden
layers.
use_spectral_norm: Enable spectral normalization for all layers.
no_theta: If set to ``True``, no trainable parameters ``theta`` will be
constructed, i.e., weights are assumed to be produced ad-hoc
by a hypernetwork and passed to the forward function.
Does not affect task embeddings.
init_theta (optional): This option is for convenience reasons.
The option expects a list of parameter values that are used to
initialize the network weights. As such, it provides a
convenient way of initializing a network with, for instance, a
weight draw produced by the hypernetwork.
The given data has to be in the same shape as the attribute
``theta`` if the network would be constructed with ``theta``.
Does not affect task embeddings.
"""
def __init__(self, out_size, num_layers, num_filters, kernel_size, sa_units,
input_dim, use_batch_norm, use_spectral_norm, no_theta,
init_theta):
# FIXME find a way using super to handle multiple inheritence.
#super(SAHnetPart, self).__init__()
nn.Module.__init__(self)
CLHyperNetInterface.__init__(self)
assert(init_theta is None or not no_theta)
if use_spectral_norm:
raise NotImplementedError('Spectral normalization not yet ' +
'implemented for this hypernetwork type.')
if use_batch_norm:
raise NotImplementedError('Batch normalization not yet ' +
'implemented for this hypernetwork type.')
# FIXME task embeddings are currently maintained outside of this class.
self._target_shapes = out_size
self._task_embs = None
self._size_ext_input = input_dim
self._num_outputs = np.prod(out_size)
if sa_units is None:
sa_units = []
self._sa_units_inds = sa_units
self._use_batch_norm = use_batch_norm
assert(num_layers > 0) # Initial fully-connected layer must exist.
assert(num_filters is None or len(num_filters) == num_layers-1)
assert(len(out_size) == 2 or len(out_size) == 3)
#assert(num_layers-1 not in sa_units)
assert(len(sa_units) == 0 or np.max(sa_units) < num_layers-1)
out_channels = 1 if len(out_size) == 2 else out_size[2]
if num_filters is None:
num_filters = [128] * (num_layers-1)
multipliers = np.power(2, range(num_layers-2, -1, -1)).tolist()
num_filters = [e1 * e2 for e1, e2 in zip(num_filters, multipliers)]
num_filters.append(out_channels)
if kernel_size is None:
kernel_size = 5
if not isinstance(kernel_size, list):
kernel_size = [kernel_size, kernel_size]
if len(kernel_size) == 2:
kernel_size = [kernel_size] * (num_layers-1)
else:
for i, tup in enumerate(kernel_size):
if not isinstance(tup, list):
kernel_size[i] = [tup, tup]
print('Building a self-attention generator with %d layers and an ' % \
(num_layers) + 'output shape of %s.' % str(out_size))
### Compute strides and pads of all transpose conv layers.
# Keep in mind the formula:
# W_o = S * (W_i - 1) - 2 * P + K + P_o
# S - Strides
# P - Padding
# P_o - Output padding
# K - Kernel size
strides = [[2, 2] for _ in range(num_layers-1)]
pads = [[0, 0] for _ in range(num_layers-1)]
out_pads = [[0, 0] for _ in range(num_layers-1)]
# Layer sizes.
sizes = [[out_size[0], out_size[1]]] * (num_layers-1)
w = out_size[0]
h = out_size[1]
def compute_pads(w, k, s):
"""Compute paddings. Given the equation
W_o = S * (W_i - 1) - 2 * P + K + P_o
Paddings and output paddings are chosen such that it holds:
W_o = S * W_i
Args:
w: Size of output dimension.
k: Kernel size.
s: Stride.
Returns:
Padding, output padding.
"""
offset = s
if s == 2 and (w % 2) == 1:
offset = 3
if ((k-offset) % 2) == 0:
p = (k-offset) // 2
p_out = 0
else:
p = int(np.ceil((k-offset) / 2))
p_out = -(k - offset - 2*p)
return p, p_out
for i in range(num_layers-2, -1, -1):
sizes[i] = [w, h]
# This is a condition we set.
# If one of the sizes is too small, we just keep the layer size.
if w <= 4:
strides[i][0] = 1
if h <= 4:
strides[i][1] = 1
pads[i][0], out_pads[i][0] = compute_pads(w, kernel_size[i][0],
strides[i][0])
pads[i][1], out_pads[i][1] = compute_pads(h, kernel_size[i][1],
strides[i][1])
w = w if strides[i][0] == 1 else w // 2
h = h if strides[i][1] == 1 else h // 2
self._fc_out_shape = [num_filters[0], w, h]
if num_layers > 1:
num_filters = num_filters[1:]
# Just a sanity check.
for i, s in enumerate(strides):
w = s[0] * (w-1) + kernel_size[i][0] - 2*pads[i][0] + out_pads[i][0]
h = s[1] * (h-1) + kernel_size[i][1] - 2*pads[i][1] + out_pads[i][1]
assert(w == out_size[0] and h == out_size[1])
# For shapes of self-maintained parameters (underlying modules, like
# self-attention layers, maintain their own weights).
theta_shapes_internal = []
if no_theta:
self._theta = None
else:
self._theta = nn.ParameterList()
if init_theta is not None and len(sa_units) > 0:
num_p = 7 # Number of param tensors per self-attention layer.
num_sa_p = len(sa_units) * num_p
sind = len(init_theta)-num_sa_p
sa_init_weights = []
for i in range(len(sa_units)):
sa_init_weights.append( \
init_theta[sind+i*num_p:sind+(i+1)*num_p])
init_theta = init_theta[:sind]
### Initial fully-connected layer.
num_units = np.prod(self._fc_out_shape)
theta_shapes_internal.extend([[num_units, input_dim], [num_units]])
print('The output shape of the fully-connected layer will be %s' %
(str(self._fc_out_shape)))
### Transpose Convolutional Layers.
self._sa_units = torch.nn.ModuleList()
prev_nfilters = self._fc_out_shape[0]
sa_ind = 0
if 0 in sa_units:
print('A self-attention unit is added after the initial fc layer.')
w_init = None
if init_theta is not None:
w_init = sa_init_weights[sa_ind]
self._sa_units.append(SelfAttnLayerV2(prev_nfilters,
use_spectral_norm, no_weights=no_theta, init_weights=w_init))
sa_ind += 1
# Needed to setup transpose convolutional layers in forward method.
self._strides = strides
self._pads = pads
self._out_pads = out_pads
for i in range(num_layers-1):
theta_shapes_internal.extend([
[prev_nfilters, num_filters[i], *kernel_size[i]],
[num_filters[i]]
])
prev_nfilters = num_filters[i]
msg = 'Transpose convolutional layer %d will have output ' + \
'shape %s. It uses strides=%s, padding=%s and ' \
'output_padding=%s. The kernel size is %s.'
print(msg % (i, str([num_filters[i], *sizes[i]]), str(strides[i]),
str(pads[i]), str(out_pads[i]), str(kernel_size[i])))
if (i+1) in sa_units:
print('A self-attention unit is added after transpose conv ' + \
'layer %d.' % i)
w_init = None
if init_theta is not None:
w_init = sa_init_weights[sa_ind]
self._sa_units.append(SelfAttnLayerV2(num_filters[i],
use_spectral_norm, no_weights=no_theta,
init_weights=w_init))
sa_ind += 1
if not no_theta:
for i, dims in enumerate(theta_shapes_internal):
self._theta.append(nn.Parameter(torch.Tensor(*dims),
requires_grad=True))
if init_theta is not None:
assert(len(init_theta) == len(theta_shapes_internal))
for i in range(len(init_theta)):
assert(np.all(np.equal(list(init_theta[i].shape),
list(self._theta[i].shape))))
self._theta[i].data = init_theta[i]
else:
for i in range(0, len(self._theta), 2):
init_params(self._theta[i], self._theta[i+1])
self._theta_shapes = theta_shapes_internal
for unit in self._sa_units:
self._theta_shapes.extend(unit.weight_shapes)
self._num_weights = np.sum([np.prod(s) for s in self._theta_shapes])
print('Total number of parameters in the self-attention generator: %d' %
self._num_weights)
self._is_properly_setup()
# @override from CLHyperNetInterface
def forward(self, task_id=None, theta=None, dTheta=None, task_emb=None,
ext_inputs=None, squeeze=True):
"""Implementation of abstract super class method.
Note, we currently assume that task embeddings have been concatenated
to ``ext_inputs`` as this class doesn't maintain any class embeddings!
"""
if task_id is not None or task_emb is not None:
raise Exception('This hypernet does not support task embeddings, ' +
'please concatenate them to the external input.')
if ext_inputs is None:
raise ValueError('This hypernet type always expects an external ' +
'input ("ext_inputs" must be set).')
if not self.has_theta and theta is None:
raise Exception('Network was generated without internal weights. ' +
'Hence, "theta" option may not be None.')
if theta is None:
theta = self.theta
else:
assert(len(theta) == len(self.theta_shapes))
for i, s in enumerate(self.theta_shapes):
assert(np.all(np.equal(s, list(theta[i].shape))))
if dTheta is not None:
assert(len(dTheta) == len(self.theta_shapes))
if len(self._sa_units) > 0:
num_p = len(self._sa_units[0].weight_shapes)
num_sa_p = len(self._sa_units) * num_p
sind = len(theta)-num_sa_p
sa_weights = []
sa_dWeights = []
for i in range(len(self._sa_units)):
sa_weights.append(theta[sind+i*num_p:sind+(i+1)*num_p])
if dTheta is not None:
sa_dWeights.append(dTheta[sind+i*num_p:sind+(i+1)*num_p])
else:
sa_dWeights.append(None)
theta = theta[:sind]
if dTheta is not None:
dTheta = dTheta[:sind]
if dTheta is not None:
weights = []
for i, t in enumerate(theta):
weights.append(t + dTheta[i])
else:
weights = theta
### Initial fully-connected layer.
h = ext_inputs
h = F.relu(F.linear(h, weights[0], bias=weights[1]))
if self._use_batch_norm:
raise NotImplementedError()
#h = F.batch_norm(h, bn_stats[ii], bn_stats[ii+1],
# weight=bn_weights[ii], bias=bn_weights[ii+1],
# training=self.training)
h = h.view([-1, *self._fc_out_shape])
### Transpose Convolutional Layers.
sa_ind = 0
if 0 in self._sa_units_inds:
h = self._sa_units[sa_ind].forward(h, weights=sa_weights[sa_ind],
dWeights=sa_dWeights[sa_ind])
sa_ind += 1
num_tc_layers = len(self._strides)
for i in range(num_tc_layers):
h = F.conv_transpose2d(h, weights[2+2*i], bias=weights[3+2*i],
stride=self._strides[i], padding=self._pads[i],
output_padding=self._out_pads[i])
# No activation function and no batchnorm in the last layer.
if i < num_tc_layers - 1:
h = F.relu(h)
if self._use_batch_norm:
raise NotImplementedError()
if (i+1) in self._sa_units_inds:
h = self._sa_units[sa_ind].forward(h,
weights=sa_weights[sa_ind], dWeights=sa_dWeights[sa_ind])
sa_ind += 1
return h
# @override from CLHyperNetInterface
@property
def theta(self):
"""Getter for read-only attribute ``theta``.
Returns:
A :class:`torch.nn.ParameterList` or ``None``, if this network has
no weights.
"""
if self._theta is None:
return None
ret = nn.ParameterList()
ret.extend(self._theta)
# Self-attention units need to be appended to the parameter list.
for unit in self._sa_units:
ret.extend(unit.weights)
return ret
class SAHyperNetwork(nn.Module, CLHyperNetInterface):
"""This class manages an instance of class :class:`SAHnetPart` and most
likely an instance of class :class:`toy_example.hyper_model.HyperNetwork`.
Given a certain output shape, the network will use a transpose convolutional
hypernetwork with self-attention layers (instance of class
:class:`SAHnetPart`) to generate as many weights as possible by running the
network multiple times with different (learned) embeddings as inputs. The
remaining weights will be generated using an instance of class
:class:`toy_example.hyper_model.HyperNetwork` (only necessary if the number
of main network weights is not divisible by the number of
:class:`SAHnetPart` outputs).
Hence, the constructor creates an instance of the class :class:`SAHnetPart`
and, if needed an instance of the class
:class:`toy_example.hyper_model.HyperNetwork`. Additionally, it will create
all embedding vectors (including task embeddings).
Here are some suggested configurations, that have a relative small
number of remaining weights (thus, the bulk of weights is generated
by the SA Hypernet).
**resnet32**:
- out_size = [36, 36], remaining: 21 weights
- out_size = [50, 50], remaining: 193 weights
- out_size = [77, 77], remaining: 231 weights
- out_size = [77, 77, 3], remaining: 231 weights
- out_size = [90, 97], remaining: 3 weights
- out_size = [51, 54], remaining: 21 weights
- out_size = [51, 54, 3], remaining: 21 weights
- out_size = [11, 21], remaining: 0 weights
Attributes:
chunk_embeddings: List of embedding vectors that encode main network
location of the weights to be generated.
Args:
main_dims: A list of lists, each entry denoting the size of a
weight or bias tensor in the hypernet Note, the output of the
:meth:`forward` method will be a list of tensors, each having the
shape of the corresponding list of integers provided as entry via
this argument.
See attribute
:attr:`mnets.mnet_interface.MainNetInterface.param_shapes` for
more information.
num_tasks: Number of task embeddings to be generated.
out_size: See constructor of class :class:`SAHnetPart`.
num_layers: See constructor of class :class:`SAHnetPart`.
num_filters: See constructor of class :class:`SAHnetPart`.
kernel_size: See constructor of class :class:`SAHnetPart`.
sa_units: See constructor of class :class:`SAHnetPart`.
rem_layers: A list of integers, each indicating the size of a hidden
layer in the network :class:`toy_example.hyper_model.HyperNetwork`,
that handles the remaining weights.
te_dim: The dimensionality of task embeddings.
ce_dim: The dimensionality of the chunk embeddings (that should
notify the hypernets which weights of the main network it has
to generate).
.. note::
The fully-connected hypernet for the remaining weights receives
no such embedding.
no_theta: If set to ``True``, no trainable parameters ``theta`` will be
constructed, i.e., weights are assumed to be produced ad-hoc
by a hypernetwork and passed to the forward function.
Does not affect task embeddings.
init_theta (optional): This option is for convenience reasons.
The option expects a list of parameter values that are used to
initialize the network weights. As such, it provides a
convenient way of initializing a network with, for instance, a
weight draw produced by the hypernetwork.
The given data has to be in the same shape as the attribute
``theta`` if the network would be constructed with ``theta``.
Does not affect task embeddings.
use_batch_norm: Enable batchnorm in all subnetworks.
use_spectral_norm: Enable spectral normalization in all subnetworks.
dropout_rate: See constructor of class
:class:`toy_example.hyper_model.HyperNetwork`. Does only apply to
this network type.
discard_remainder: Instead of generating a separate
:class:`toy_example.hyper_model.HyperNetwork`for the remaining
weights, these will be generated by another run of the internal
:class:`SAHnetPart` network, discarding those outputs that are not
needed.
noise_dim: If ``-1``, no noise will be applied.
Otherwise the hypernetwork will receive as additional input
zero-mean Gaussian noise with unit variance during training
(zeroes will be inputted during eval-mode). The same noise
vector is concatenated to all chunk embeddings when generating
one set of weights.
temb_std (optional): If not ``-1``, the task embeddings will be
perturbed by zero-mean Gaussian noise with the given std
(additive noise). The perturbation is only applied if the
network is in training mode. Note, per batch of external inputs,
the perturbation of the task embedding will be shared.
"""
def __init__(self, main_dims, num_tasks, out_size=[64, 64], num_layers=5,
num_filters=None, kernel_size=5, sa_units=[1, 3],
rem_layers=[50,50,50], te_dim=8, ce_dim=8,
no_theta=False, init_theta=None, use_batch_norm=False,
use_spectral_norm=False, dropout_rate=-1,
discard_remainder=False, noise_dim=-1, temb_std=-1):
# FIXME find a way using super to handle multiple inheritence.
#super(SAHyperNetwork, self).__init__()
nn.Module.__init__(self)
CLHyperNetInterface.__init__(self)
if init_theta is not None:
# FIXME I would need to know the number of parameter tensors in each
# hypernet before creating them to split the list init_theta.
raise NotImplementedError('Argument "init_theta" not implemented ' +
'yet!')
assert(init_theta is None or no_theta is False)
self._no_theta = no_theta
self._te_dim = te_dim
self._discard_remainder = discard_remainder
self._target_shapes = main_dims
self._num_outputs = MainNetInterface.shapes_to_num_weights(main_dims)
print('Building a self-attention hypernet for a network with %d '% \
self._num_outputs + 'weights.')
assert(len(out_size) in [2, 3])
self._out_size = out_size
num_outs = np.prod(out_size)
assert(num_outs <= self._num_outputs)
self._noise_dim = noise_dim
self._temb_std = temb_std
num_embs = self._num_outputs // num_outs
rem_weights = self._num_outputs % num_outs
if rem_weights > 0 and not discard_remainder:
print('%d remaining weights (%.2f%%) are generated by a fully-' \
% (rem_weights, 100.0 * rem_weights / self._num_outputs) + \
'connected hypernetwork.')
elif rem_weights > 0:
num_embs += 1
print('%d weights generated by the last chunk of the self-'
% (num_outs - rem_weights) + 'attention hypernet will be ' +
'discarded.')
self._num_embs = num_embs
### Generate Hypernet.
self._hypernet = SAHnetPart(out_size=out_size, num_layers=num_layers,
num_filters=num_filters, kernel_size=kernel_size, sa_units=sa_units,
input_dim=te_dim+ce_dim+(noise_dim if noise_dim != -1 else 0),
use_batch_norm=use_batch_norm, use_spectral_norm=use_spectral_norm,
no_theta=no_theta, init_theta=None)
self._rem_hypernet = None
self._remainder = rem_weights
if rem_weights > 0 and not discard_remainder:
print('A second hypernet for the remainder of the weights has ' +
'to be created, as %d is not dividable by %d ' %
(self._num_outputs, num_outs) + '(remaidner %d)' %
rem_weights)
self._rem_hypernet = HyperNetwork([[rem_weights]], None,
layers=rem_layers, te_dim=te_dim, no_te_embs=True,
no_weights=no_theta,
ce_dim=(noise_dim if noise_dim != -1 else None),
dropout_rate=dropout_rate, use_batch_norm=use_batch_norm,
use_spectral_norm=use_spectral_norm, noise_dim=-1,
temb_std=None)
### Generate embeddings for all weight chunks.
if no_theta:
self._embs = None
else:
self._embs = nn.Parameter(data=torch.Tensor(num_embs, ce_dim),
requires_grad=True)
torch.nn.init.normal_(self._embs, mean=0., std=1.)
# There is no need for a chunk embedding, as this network always
# produces the same chunk.
#if self._remainder > 0 and not discard_remainder:
# self._rem_emb = nn.Parameter(data=torch.Tensor(1, ce_dim),
# requires_grad=True)
# torch.nn.init.normal_(self._rem_emb, mean=0., std=1.)
### Generate task embeddings.
self._task_embs = nn.ParameterList()
# We store individual task embeddings as it makes it easier to pass
# only subsets of task embeddings to an optimizer.
for _ in range(num_tasks):
self._task_embs.append(nn.Parameter(data=torch.Tensor(te_dim),
requires_grad=True))
torch.nn.init.normal_(self._task_embs[-1], mean=0., std=1.)
self._num_weights = 0
for p in list(self.parameters()):
self._num_weights += np.prod(p.shape)
print('Total number of parameters in the hypernetwork: %d' %
self._num_weights)
self._theta_shapes = [[num_embs, ce_dim]] + \
self._hypernet.theta_shapes
if self._rem_hypernet is not None:
self._theta_shapes += self._rem_hypernet.theta_shapes
# @override from CLHyperNetInterface
def forward(self, task_id=None, theta=None, dTheta=None, task_emb=None,
ext_inputs=None, squeeze=True):
"""Implementation of abstract super class method.
Note, this methods can't handle external inputs yet!
The method will iterate through the set of internal chunk embeddings,
calling the internally maintained transpose conv. hypernetwork
(potentially with self-attention layers). If necessary, a small portion
of the chunks will be created by an additional fully-connected network.
"""
if task_id is None and task_emb is None:
raise Exception('The hyper network has to get either a task ID' +
'to choose the learned embedding or directly ' +
'get an embedding as input (e.g. from a task ' +
'recognition model).')
if not self.has_theta and theta is None:
raise Exception('Network was generated without internal weights. ' +
'Hence, "theta" option may not be None.')
if ext_inputs is not None:
# FIXME If this will be implemented, please consider:
# * batch size will have to be multiplied based on num chunk
# embeddings and the number of external inputs -> large batches
# * noise dim must adhere correct behavior (different noise per
# external input).
raise NotImplementedError('This hypernetwork implementation does ' +
'not yet support the passing of external inputs.')
if theta is None:
theta = self.theta
else:
assert(len(theta) == len(self.theta_shapes))
assert(np.all(np.equal(self._embs.shape, list(theta[0].shape))))
nhnet_shapes = len(self._hypernet.theta_shapes)
chunk_embs = theta[0]
hnet_theta = theta[1:1+nhnet_shapes]
if self._rem_hypernet is not None:
rem_hnet_theta = theta[1+nhnet_shapes:]
if dTheta is not None:
assert(len(dTheta) == len(self.theta_shapes))
chunk_embs = chunk_embs + dTheta[0]
hnet_dTheta = dTheta[1:1+nhnet_shapes]
if self._rem_hypernet is not None:
rem_hnet_dTheta = dTheta[1+nhnet_shapes:]
else:
hnet_dTheta = None
rem_hnet_dTheta = None
# Currently, there is no option in the constructor to not generate
# task embeddings, that is why the code below is commented out.
# Select task embeddings.
#if not self.has_task_embs and task_emb is None:
# raise Exception('The network was created with no internal task ' +
# 'embeddings, thus parameter "task_emb" has to ' +
# 'be specified.')
if task_emb is None:
task_emb = self._task_embs[task_id]
if self.training and self._temb_std != -1:
task_emb.add(torch.randn_like(task_emb) * self._temb_std)
# Concatenate the same noise to all chunks, such that it can be
# viewed as if it were an external input.
if self._noise_dim != -1:
if self.training:
eps = torch.randn((1, self._noise_dim))
else:
eps = torch.zeros((1, self._noise_dim))
if self._embs.is_cuda:
eps = eps.to(self._embs.get_device())
# The hypernet input is a concatenation of the task embedding with
# the noise vector and each chunk embedding.
hnet_input = torch.cat([task_emb.view(1, -1), eps], dim=1)
hnet_input = hnet_input.expand(self._num_embs,
self._te_dim + self._noise_dim)
hnet_input = torch.cat([chunk_embs, hnet_input], dim=1)
else:
eps = None
# The hypernet input is a concatenation of the task embedding with
# each chunk embedding.
hnet_input = task_emb.view(1, -1).expand(self._num_embs,
self._te_dim)
hnet_input = torch.cat([chunk_embs, hnet_input], dim=1)
### Gather all generated weights.
weights = self._hypernet.forward(task_id=None, theta=hnet_theta,
dTheta=hnet_dTheta, task_emb=None, ext_inputs=hnet_input)
weights = weights.view(1, -1)
if self._rem_hypernet is not None:
rem_weights = self._rem_hypernet.forward(theta=rem_hnet_theta,
dTheta=rem_hnet_dTheta, task_emb=task_emb, ext_inputs=eps)
weights = torch.cat([weights, rem_weights[0].view(1, -1)], dim=1)
### Reshape weights.
ind = 0
ret = []
for s in self.target_shapes:
num = int(np.prod(s))
W = weights[0][ind:ind+num]
ind += num
ret.append(W.view(*s))
return ret
@property
def chunk_embeddings(self):
"""Getter for read-only attribute chunk_embeddings.
Returns:
A list of all chunk embedding vectors.
"""
# Note, the remainder network has no chunk embedding.
return list(torch.split(self._embs, 1, dim=0))
# @override from CLHyperNetInterface
@property
def theta(self):
"""Getter for read-only attribute ``theta``.
Theta are all learnable parameters of the chunked hypernet including
the chunk embeddings that need to be learned.
Not included are the task embeddings.
.. note::
Chunk embeddings are prepended to the list of weights ``theta`` from
the internal SA hypernetwork (if existing, ``theta`` from the
remainder network will be appended).
Returns:
A list of tensors or ``None``, if ``no_theta`` was set to ``True``
in the constructor of this class.
"""
theta = [self._embs] + list(self._hypernet.theta)
if self._rem_hypernet is not None:
theta += list(self._rem_hypernet.theta)
return theta
# @override from CLHyperNetInterface
@property
def has_theta(self):
"""Getter for read-only attribute ``has_theta``."""
return not self._no_theta
if __name__ == '__main__':
pass