-
Notifications
You must be signed in to change notification settings - Fork 3
/
train_utils.py
1428 lines (1205 loc) · 58.3 KB
/
train_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
# Copyright 2020 Christian Henning
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# @title :probabilistic/prob_cifar/train_utils.py
# @author :ch
# @contact :[email protected]
# @created :01/30/2020
# @version :1.0
# @python_version :3.6.9
"""
Training utilities
------------------
A collection of helper functions for training scripts of this subpackage.
"""
from argparse import Namespace
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from torch.nn import functional as F
from warnings import warn
from data.special.permuted_mnist import PermutedMNIST
from hnets.chunked_mlp_hnet import ChunkedHMLP
from hnets.hnet_helpers import init_conditional_embeddings
from hnets.hnet_perturbation_wrapper import HPerturbWrapper
from hnets.mlp_hnet import HMLP
from hnets.structured_hmlp_examples import resnet_chunking, wrn_chunking
from hnets.structured_mlp_hnet import StructuredHMLP
from probabilistic import GaussianBNNWrapper
from probabilistic import prob_utils as putils
from probabilistic.regression import train_utils as rtu
from probabilistic.prob_cifar import hpsearch_config_resnet_avb as hpresnetavb
from probabilistic.prob_cifar import hpsearch_config_resnet_avb_pf as \
hpresnetavbpf
from probabilistic.prob_cifar import hpsearch_config_zenke_avb as hpzenkeavb
from probabilistic.prob_cifar import hpsearch_config_zenke_avb_pf as \
hpzenkeavbpf
from probabilistic.prob_cifar import hpsearch_config_zenke_bbb as hpzenkebbb
from probabilistic.prob_cifar import hpsearch_config_resnet_bbb as hpresnetbbb
from probabilistic.prob_cifar import hpsearch_config_resnet_ewc as hpresnetewc
from probabilistic.prob_cifar import hpsearch_config_resnet_mt as hpresnetmt
from probabilistic.prob_cifar import hpsearch_config_resnet_ssge as hpresnetssge
from probabilistic.prob_cifar import hpsearch_config_resnet_ssge_pf as \
hpresnetssgepf
from probabilistic.prob_gmm import hpsearch_config_gmm_bbb as hpgmmbbb
from probabilistic.prob_gmm import hpsearch_config_gmm_ewc as hpgmmewc
from probabilistic.prob_gmm import hpsearch_config_gmm_avb as hpgmmavb
from probabilistic.prob_gmm import hpsearch_config_gmm_avb_pf as hpgmmavbpf
from probabilistic.prob_gmm import hpsearch_config_gmm_ssge as hpgmmssge
from probabilistic.prob_gmm import hpsearch_config_gmm_ssge_pf as hpgmmssgepf
from probabilistic.prob_mnist import train_utils as pmutils
from probabilistic.prob_mnist import hpsearch_config_split_avb as hpsplitavb
from probabilistic.prob_mnist import hpsearch_config_split_avb_pf as \
hpsplitavbpf
from probabilistic.prob_mnist import hpsearch_config_perm_avb as hppermavb
from probabilistic.prob_mnist import hpsearch_config_perm_avb_pf as \
hppermavbpf
from probabilistic.prob_mnist import hpsearch_config_perm_bbb as hppermbbb
from probabilistic.prob_mnist import hpsearch_config_perm_ewc as hppermewc
from probabilistic.prob_mnist import hpsearch_config_perm_mt as hppermmt
from probabilistic.prob_mnist import hpsearch_config_split_bbb as hpsplitbbb
from probabilistic.prob_mnist import hpsearch_config_split_ewc as hpsplitewc
from probabilistic.prob_mnist import hpsearch_config_split_mt as hpsplitmt
from probabilistic.prob_mnist import hpsearch_config_split_ssge as \
hpsplitssge
from probabilistic.prob_mnist import hpsearch_config_split_ssge_pf as \
hpsplitssgepf
from utils import gan_helpers as gan
from utils import sim_utils as sutils
from utils import torch_utils as tutils
def generate_networks(config, shared, logger, data_handlers, device,
create_mnet=True, create_hnet=True, create_hhnet=True,
create_dis=True):
"""Create the networks required for training with implicit distributions.
This function will create networks based on user configuration.
This function also takes care of weight initialization.
Args:
config (argparse.Namespace): Command-line arguments.
shared (argparse.Namespace): Miscellaneous data shared among training
functions.
logger: Console (and file) logger.
data_handlers: List of data handlers, one for each task. Needed to
extract the number of inputs/outputs of the main network. And to
infer the number of tasks.
device: Torch device.
create_mnet (bool, optional): If ``False``, the user can force that no
main network is generated.
create_hnet (bool, optional): If ``False``, the user can force that no
hypernet ``hnet`` is generated.
Note:
Even if ``True``, the ``hnet`` is only generated if the user
configuration ``config`` requests it.
create_hhnet (bool, optional): If ``False``, the user can force that no
hyper-hypernet ``hhnet`` is generated.
Note:
Even if ``True``, the ``hhnet`` is only generated if the user
configuration ``config`` requests it.
create_dis (bool, optional): If ``False``, the user can force that no
discriminator ``dis`` is generated.
Note:
Even if ``True``, the ``dis`` is only generated if the user
configuration ``config`` requests it.
Returns:
(tuple): Tuple containing:
- **mnet**: Main network instance.
- **hnet** (optional): Hypernetwork instance. This return value is
``None`` if no hypernetwork should be constructed.
**hhnet** (optional): Hyper-hypernetwork instance. This return value
is ``None`` if no hyper-hypernetwork should be constructed.
- **dis** (optional): Discriminator instance. This return value is
``None`` if no discriminator should be constructed.
"""
num_tasks = len(data_handlers)
if hasattr(config, 'cl_scenario'):
num_heads = 1 if config.cl_scenario == 2 else num_tasks
else:
assert hasattr(config, 'multi_head')
num_heads = num_tasks if config.multi_head else 1
# Sanity check!
for i in range(1, num_tasks):
assert np.prod(data_handlers[i].in_shape) == \
np.prod(data_handlers[0].in_shape)
if data_handlers[0].classification:
assert data_handlers[i].num_classes == data_handlers[0].num_classes
else:
assert np.prod(data_handlers[i].out_shape) == \
np.prod(data_handlers[0].out_shape)
# Parse user "wishes".
use_hnet = False
use_hhnet = False
use_dis = False
no_mnet_weights = False
if hasattr(config, 'mnet_only'):
use_hnet = not config.mnet_only
use_hhnet = not config.mnet_only and not shared.prior_focused and \
not config.no_hhnet
# Note, without the hypernet, there is no weight distribution and therefore
# no discriminator needed.
use_dis = use_hnet and not config.no_dis
no_mnet_weights = not config.mnet_only
if hasattr(config, 'distill_iter'):
# Note, if distillation is used, the hnet is first trained independent
# of a hyper-hypernetwork, which is why it needs its own weights.
no_hnet_weights = use_hhnet and config.distill_iter == -1
else:
no_hnet_weights = use_hhnet
####################
### Main network ###
####################
if 'gmm' in shared.experiment_type or \
'regression' in shared.experiment_type:
mnet_type = 'mlp'
in_shape = data_handlers[0].in_shape
elif 'mnist' in shared.experiment_type:
if hasattr(config, 'net_type'):
logger.debug('Main network will be of type: %s.' % config.net_type)
mnet_type = config.net_type
else:
logger.debug('Main network will be an MLP.')
mnet_type = 'mlp'
assert len(data_handlers[0].in_shape) == 3 # MNIST
in_shape = data_handlers[0].in_shape
# Note, that padding is currently only applied when transforming the
# image to a torch tensor.
if isinstance(data_handlers[0], PermutedMNIST):
assert len(data_handlers[0].torch_in_shape) == 3 # MNIST
in_shape = data_handlers[0].torch_in_shape
else:
assert 'cifar' in shared.experiment_type
in_shape = [32, 32, 3]
if 'zenke' in shared.experiment_type:
assert not hasattr(config, 'net_type')
mnet_type = 'zenke'
else:
assert 'resnet' in shared.experiment_type
mnet_type = config.net_type
if mnet_type == 'mlp':
if len(in_shape) > 1:
n_x = np.prod(in_shape)
in_shape = [n_x]
else:
assert len(in_shape) == 3
assert mnet_type in ['lenet', 'resnet', 'wrn', 'iresnet', 'zenke']
if data_handlers[0].classification:
out_shape = [data_handlers[0].num_classes * num_heads]
else:
assert len(data_handlers[0].out_shape) == 1
out_shape = [data_handlers[0].out_shape[0] * num_heads]
if not create_mnet:
# FIXME We would need to allow the passing of old `mnet`s.
raise NotImplementedError('This function doesn\'t support yet to ' +
'construct networks without constructing ' +
'a main network first.')
logger.info('Creating main network ...')
mnet_kwargs = {}
if mnet_type == 'iresnet':
mnet_kwargs['cutout_mod'] = True
mnet = sutils.get_mnet_model(config, mnet_type, in_shape, out_shape,
device, no_weights=no_mnet_weights,
**mnet_kwargs)
# Initialize main net weights, if any.
assert not hasattr(config, 'custom_network_init')
if hasattr(config, 'normal_init'):
mnet.custom_init(normal_init=config.normal_init,
normal_std=config.std_normal_init, zero_bias=True)
else:
mnet.custom_init(zero_bias=True)
#####################
### Discriminator ###
#####################
dis = None
if use_dis and create_dis:
logger.info('Creating discriminator ...')
if config.use_batchstats:
in_shape = [mnet.num_params * 2]
else:
in_shape = [mnet.num_params]
dis = sutils.get_mnet_model(config, config.dis_net_type, in_shape, [1],
device, cprefix='dis_', no_weights=False)
dis.custom_init(normal_init=config.normal_init,
normal_std=config.std_normal_init, zero_bias=True)
#####################
### Hypernetwork ###
#####################
def _hyperfan_init(net, mnet, cond_var, uncond_var):
if isinstance(net, HMLP):
net.apply_hyperfan_init(method='in', use_xavier=False,
uncond_var=uncond_var, cond_var=cond_var,
mnet=mnet)
elif isinstance(net, ChunkedHMLP):
net.apply_chunked_hyperfan_init(method='in', use_xavier=False,
uncond_var=uncond_var, cond_var=cond_var, mnet=mnet, eps=1e-5,
cemb_normal_init=False)
elif isinstance(net, StructuredHMLP):
# FIXME We should adapt `uncond_var`, as chunk embeddings are
# additionally inputted as unconditional inputs.
# FIXME We should provide further instructions on what individual
# chunks represent (e.g., batchnorm scales and shifts should be
# initialized differently).
for int_hnet in net.internal_hnets:
net.apply_hyperfan_init(method='in', use_xavier=False,
uncond_var=uncond_var, cond_var=cond_var, mnet=None)
else:
raise NotImplementedError('No hyperfan-init implemented for ' +
'hypernetwork of type %s.' % type(net))
hnet = None
if use_hnet and create_hnet:
logger.info('Creating hypernetwork ...')
# For now, we either produce all or no weights with the hypernet.
# Note, it can be that the mnet was produced with internal weights.
assert mnet.hyper_shapes_learned is None or \
len(mnet.param_shapes) == len(mnet.hyper_shapes_learned)
chunk_shapes = None
num_per_chunk = None
assembly_fct = None
if config.imp_hnet_type == 'structured_hmlp':
if mnet_type == 'resnet':
chunk_shapes, num_per_chunk, assembly_fct = \
resnet_chunking(mnet,
gcd_chunking=config.imp_shmlp_gcd_chunking)
elif mnet_type == 'wrn':
chunk_shapes, num_per_chunk, assembly_fct = \
wrn_chunking(mnet,
gcd_chunking=config.imp_shmlp_gcd_chunking,
ignore_bn_weights=False, ignore_out_weights=False)
else:
raise NotImplementedError('"structured_hmlp" not implemented ' +
'for network of type %s.' % mnet_type)
# The hypernet is an implicit distribution, that only receives noise
# as input, which are unconditional inputs.
hnet = sutils.get_hypernet(config, device, config.imp_hnet_type,
mnet.param_shapes, 0, cprefix='imp_',
no_uncond_weights=no_hnet_weights, no_cond_weights=True,
uncond_in_size=config.latent_dim, shmlp_chunk_shapes=chunk_shapes,
shmlp_num_per_chunk=num_per_chunk, shmlp_assembly_fct=assembly_fct)
#if isinstance(hnet, StructuredHMLP):
# print(num_per_chunk)
# for ii, int_hnet in enumerate(hnet.internal_hnets):
# print(' Internal hnet %d with %d outputs.' % \
# (ii, int_hnet.num_outputs))
### Initialize hypernetwork.
if not no_hnet_weights:
if not config.hyper_fan_init:
rtu.apply_custom_hnet_init(config, logger, hnet)
else:
_hyperfan_init(hnet, mnet, -1, config.latent_std**2)
### Apply noise trick if requested by user.
if config.full_support_perturbation != -1:
hnet = HPerturbWrapper(hnet, hnet_uncond_in_size=config.latent_dim,
sigma_noise=config.full_support_perturbation)
shared.noise_dim = hnet.num_outputs
else:
shared.noise_dim = config.latent_dim
##########################
### Hyper-hypernetwork ###
##########################
hhnet = None
if use_hhnet and create_hhnet:
if not create_hnet:
# FIXME We require an existing hnet to do this.
raise NotImplementedError('This function doesn\'t allow yet the ' +
'creation of a hyper-hypernet without ' +
'first creating a hypernetwork.')
logger.info('Creating hyper-hypernetwork ...')
assert hnet is not None
assert len(hnet.unconditional_param_shapes) == len(hnet.param_shapes)
hhnet = sutils.get_hypernet(config, device, config.hh_hnet_type,
hnet.unconditional_param_shapes, num_tasks,
cprefix='hh_')
### Initialize hypernetwork.
if not config.hyper_fan_init:
rtu.apply_custom_hnet_init(config, logger, hhnet)
else:
# Note, hyperfan-init doesn't take care of task-embedding
# intialization.
init_conditional_embeddings(hhnet,
normal_std=config.std_normal_temb)
_hyperfan_init(hhnet, hnet, config.std_normal_temb**2, -1)
return mnet, hnet, hhnet, dis
def setup_summary_dict(config, shared, experiment, mnet, hnet=None,
hhnet=None, dis=None):
"""Setup the summary dictionary that is written to the performance
summary file (in the result folder).
This method adds the keyword "summary" to ``shared``.
Args:
config (argparse.Namespace): Command-line arguments.
shared (argparse.Namespace): Miscellaneous data shared among training
functions (summary dict will be added).
experiment: Type of experiment. See argument `experiment` of method
:func:`probabilistic.prob_cifar.train_avb.run`.
mnet: Main network.
hnet (optional): Implicit Hypernetwork.
hhnet (optional): Hyper-Hypernetwork.
dis (optional): Discriminator.
"""
assert experiment in ['gmm_bbb', 'gmm_avb', 'gmm_avb_pf',
'split_bbb', 'perm_bbb',
'cifar_zenke_bbb', 'cifar_resnet_bbb',
'split_mnist_avb', 'split_mnist_avb_pf',
'perm_mnist_avb', 'perm_mnist_avb_pf',
'cifar_zenke_avb', 'cifar_zenke_avb_pf',
'cifar_resnet_avb', 'cifar_resnet_avb_pf',
'gmm_ssge', 'gmm_ssge_pf',
'split_mnist_ssge', 'split_mnist_ssge_pf',
'perm_mnist_ssge', 'perm_mnist_ssge_pf',
'cifar_resnet_ssge', 'cifar_resnet_ssge_pf',
'gmm_ewc', 'split_mnist_ewc', 'perm_mnist_ewc',
'cifar_resnet_ewc',
'gmm_mt', 'split_mnist_mt', 'perm_mnist_mt',
'cifar_resnet_mt']
summary = dict()
mnum = mnet.num_params
hnum = -1
hhnum = -1
dnum = -1
hm_ratio = -1
hhm_ratio = -1
dm_ratio = -1
if hnet is not None:
hnum = hnet.num_params
hm_ratio = hnum / mnum
if hhnet is not None:
hhnum = hhnet.num_params
hhm_ratio = hhnum / mnum
if dis is not None:
dnum = dis.num_params
dm_ratio = dnum / mnum
if experiment == 'gmm_bbb':
summary_keys = hpgmmbbb._SUMMARY_KEYWORDS
elif experiment == 'split_bbb':
summary_keys = hpsplitbbb._SUMMARY_KEYWORDS
elif experiment == 'perm_bbb':
summary_keys = hppermbbb._SUMMARY_KEYWORDS
elif experiment == 'cifar_zenke_bbb':
summary_keys = hpzenkebbb._SUMMARY_KEYWORDS
elif experiment == 'cifar_resnet_bbb':
summary_keys = hpresnetbbb._SUMMARY_KEYWORDS
elif experiment == 'gmm_avb':
summary_keys = hpgmmavb._SUMMARY_KEYWORDS
elif experiment == 'gmm_avb_pf':
summary_keys = hpgmmavbpf._SUMMARY_KEYWORDS
elif experiment == 'split_mnist_avb':
summary_keys = hpsplitavb._SUMMARY_KEYWORDS
elif experiment == 'split_mnist_avb_pf':
summary_keys = hpsplitavbpf._SUMMARY_KEYWORDS
elif experiment == 'perm_mnist_avb':
summary_keys = hppermavb._SUMMARY_KEYWORDS
elif experiment == 'perm_mnist_avb_pf':
summary_keys = hppermavbpf._SUMMARY_KEYWORDS
elif experiment == 'cifar_resnet_avb':
summary_keys = hpresnetavb._SUMMARY_KEYWORDS
elif experiment == 'cifar_resnet_avb_pf':
summary_keys = hpresnetavbpf._SUMMARY_KEYWORDS
elif experiment == 'cifar_zenke_avb':
summary_keys = hpzenkeavb._SUMMARY_KEYWORDS
elif experiment == 'gmm_ssge':
summary_keys = hpgmmssge._SUMMARY_KEYWORDS
elif experiment == 'gmm_ssge_pf':
summary_keys = hpgmmssgepf._SUMMARY_KEYWORDS
elif experiment == 'split_mnist_ssge':
summary_keys = hpsplitssge._SUMMARY_KEYWORDS
elif experiment == 'split_mnist_ssge_pf':
summary_keys = hpsplitssgepf._SUMMARY_KEYWORDS
elif experiment == 'perm_mnist_ssge':
summary_keys = hpsplitssgepf._SUMMARY_KEYWORDS
elif experiment == 'perm_mnist_ssge_pf':
summary_keys = hpsplitssgepf._SUMMARY_KEYWORDS
elif experiment == 'cifar_resnet_ssge':
summary_keys = hpresnetssge._SUMMARY_KEYWORDS
elif experiment == 'cifar_resnet_ssge_pf':
summary_keys = hpresnetssgepf._SUMMARY_KEYWORDS
elif experiment == 'cifar_zenke_avb_pf':
summary_keys = hpzenkeavbpf._SUMMARY_KEYWORDS
elif experiment == 'gmm_ewc':
summary_keys = hpgmmewc._SUMMARY_KEYWORDS
elif experiment == 'split_mnist_ewc':
summary_keys = hpsplitewc._SUMMARY_KEYWORDS
elif experiment == 'perm_mnist_ewc':
summary_keys = hppermewc._SUMMARY_KEYWORDS
elif experiment == 'cifar_resnet_ewc':
summary_keys = hpresnetewc._SUMMARY_KEYWORDS
elif experiment == 'gmm_mt':
summary_keys = hpgmmmt._SUMMARY_KEYWORDS
elif experiment == 'split_mnist_mt':
summary_keys = hpsplitmt._SUMMARY_KEYWORDS
elif experiment == 'perm_mnist_mt':
summary_keys = hppermmt._SUMMARY_KEYWORDS
else:
assert experiment == 'cifar_resnet_mt'
summary_keys = hpresnetmt._SUMMARY_KEYWORDS
for k in summary_keys:
if k == 'acc_task_given' or \
k == 'acc_task_given_during' or \
k == 'acc_task_inferred_ent' or \
k == 'acc_task_inferred_ent_during' or \
k == 'acc_dis':
summary[k] = [-1] * config.num_tasks
elif k == 'acc_avg_final' or \
k == 'acc_avg_during' or \
k == 'acc_avg_task_given' or \
k == 'acc_avg_task_given_during' or \
k == 'acc_avg_task_inferred_ent' or \
k == 'acc_avg_task_inferred_ent_during' or \
k == 'avg_task_inference_acc_ent' or \
k == 'acc_avg_task_inferred_conf' or \
k == 'avg_task_inference_acc_conf' or \
k == 'acc_avg_task_inferred_agree' or \
k == 'avg_task_inference_acc_agree' or \
k == 'acc_avg_dis':
summary[k] = -1
elif k == 'num_weights_main':
summary[k] = mnum
elif k == 'num_weights_hyper':
summary[k] = hnum
elif k == 'num_weights_hyper_hyper':
summary[k] = hhnum
elif k == 'num_weights_dis':
summary[k] = dnum
elif k == 'num_weights_hm_ratio':
summary[k] = hm_ratio
elif k == 'num_weights_hhm_ratio':
summary[k] = hhm_ratio
elif k == 'num_weights_dm_ratio':
summary[k] = dm_ratio
elif k == 'finished':
summary[k] = 0
else:
# Implementation must have changed if this exception is
# raised.
raise ValueError('Summary argument %s unknown!' % k)
shared.summary = summary
def set_train_mode(training, mnet, hnet, hhnet, dis):
"""Set mode of all given networks.
Note, all networks be passed as ``None`` and only the provided networks
its mode is set.
Args:
training (bool): If ``True``, training mode will be activated.
Otherwise, evaluation mode is activated.
(....): The remaining arguments refer to network instances.
"""
for net in [mnet, hnet, hhnet, dis]:
if net is not None:
if training:
net.train()
else:
net.eval()
def compute_acc(task_id, data, mnet, hnet, hhnet, device, config, shared,
split_type='test', return_dataset=False, return_entropies=False,
return_confidence=False, return_agreement=False,
return_pred_labels=False, return_labels=False,
return_samples=False, deterministic_sampling=False,
in_samples=None, out_samples=None, num_w_samples=None,
w_samples=None):
"""Compute the accuracy over a specified dataset split.
Note, this function does not explicitly execute the code within a
``torch.no_grad()`` context. This needs to be handled from the outside if
desired.
Note, this function serves the same purpose as function
:func:`probabilistic.prob_mnist.train_utils.compute_acc`.
The ``task_id`` is used only to select the task embedding (if ``hhnet``
is given) and the correct output units depending on the CL scenario.
Args:
(....): See docstring of function
:func:`probabilistic.prob_mnist.train_utils.compute_acc`.
return_samples: If ``True``, the attribute ``samples`` will be
added to the ``return_vals`` Namespace (see return values). This
field will contain all weight samples that have been drawn from
the hypernetwork ``hnet``. If ``hnet`` is not provided,
this field will be ``None``. The field will be filled with a
numpy array.
Returns:
(tuple): Tuple containing:
- **accuracy**: Overall accuracy on dataset split.
- **return_vals**: A namespace object that contains several attributes,
depending on the arguments passed. It will allways contain the
following attribute, denoting the current weights of the implicit
distribution.
- ``theta``: The current output of the ``hhnet`` for ``task_id``.
If no ``hhnet`` is provided but an ``hnet`` is given,
then its weights ``theta`` will be provided. It will be
``None`` if only a main network ``mnet`` is provided.
"""
# FIXME The code is almost a perfect copy from the original function.
assert in_samples is not None or split_type in ['test', 'val', 'train']
assert out_samples is None or in_samples is not None
generator = None
if deterministic_sampling:
generator = torch.Generator()#device=device)
# Note, PyTorch recommends using large random seeds:
# https://tinyurl.com/yx7fwrry
generator.manual_seed(2147483647)
return_vals = Namespace()
allowed_outputs = pmutils.out_units_of_task(config, data, task_id,
shared.num_trained)
ST = shared.softmax_temp[task_id]
if not config.calibrate_temp:
assert ST == 1.
if in_samples is not None:
X = in_samples
T = out_samples
elif split_type == 'train':
X = data.get_train_inputs()
T = data.get_train_outputs()
elif split_type == 'test' or data.num_val_samples == 0:
X = data.get_test_inputs()
T = data.get_test_outputs()
else:
X = data.get_val_inputs()
T = data.get_val_outputs()
num_samples = X.shape[0]
if T is not None:
T = pmutils.fit_targets_to_softmax(config, shared, device, data,
task_id, T)
if return_dataset:
return_vals.inputs = X
return_vals.targets = T
labels = None
if T is not None:
labels = np.argmax(T, axis=1)
if return_labels:
return_vals.labels = labels
X = data.input_to_torch_tensor(X, device)
#if T is not None:
# T = data.output_to_torch_tensor(T, device)
hnet_theta = None
return_vals.theta = None
if hhnet is not None:
assert hnet is not None
hnet_theta = hhnet.forward(cond_id=task_id)
return_vals.theta = hnet_theta
elif hnet is not None:
return_vals.theta = hnet.unconditional_params
# There is no weight sampling without an implicit hypernetwork.
if w_samples is not None:
num_w_samples = len(w_samples)
elif num_w_samples is None:
num_w_samples = 1 if hnet is None else config.val_sample_size
else:
if hnet is None and num_w_samples > 1:
warn('Cannot draw multiple weight samples for deterministic ' +
'network')
num_w_samples = 1
if hasattr(config, 'non_growing_sf_cl3') and config.cl_scenario == 3 \
and config.non_growing_sf_cl3:
softmax_width = config.num_tasks * data.num_classes
elif config.cl_scenario == 3 and not config.split_head_cl3:
softmax_width = len(allowed_outputs)
else:
softmax_width = data.num_classes
softmax_outputs = np.empty((num_w_samples, X.shape[0], softmax_width))
if return_samples:
return_vals.samples = None
# FIXME Note, that a continually learned hypernet (whose weights come from a
# hyper-hypernet) would in principle also require correct argument passing,
# e.g., to choose the correct set of batch statistics.
kwargs = pmutils.mnet_kwargs(config, task_id, mnet)
for j in range(num_w_samples):
weights = None
if w_samples is not None:
weights = w_samples[j]
elif hnet is not None:
z = torch.normal(torch.zeros(1, shared.noise_dim),
config.latent_std, generator=generator).to(device)
weights = hnet.forward(uncond_input=z, weights=hnet_theta)
if weights is not None and return_samples:
if j == 0:
return_vals.samples = np.empty((num_w_samples,
hnet.num_outputs))
return_vals.samples[j, :] = torch.cat([p.detach().flatten() \
for p in weights]).cpu().numpy()
curr_bs = config.val_batch_size
n_processed = 0
while n_processed < num_samples:
if n_processed + curr_bs > num_samples:
curr_bs = num_samples - n_processed
n_processed += curr_bs
sind = n_processed - curr_bs
eind = n_processed
Y = mnet.forward(X[sind:eind, :], weights=weights, **kwargs)
if allowed_outputs is not None:
Y = Y[:, allowed_outputs]
softmax_outputs[j, sind:eind, :] = F.softmax(Y / ST, dim=1). \
detach().cpu().numpy()
# Predictive distribution per sample.
pred_dists = softmax_outputs.mean(axis=0)
pred_labels = np.argmax(pred_dists, axis=1)
# Note, that for CL3 (without split heads) `labels` are already absolute,
# not relative to the head (see post-processing of targets `T` above).
if labels is not None:
accuracy = 100. * np.sum(pred_labels == labels) / num_samples
else:
accuracy = None
if return_pred_labels:
assert pred_labels.size == X.shape[0]
return_vals.pred_labels = pred_labels
if return_entropies:
# We use the "maximum" trick to improve numerical stability.
return_vals.entropies = - np.sum(pred_dists * \
np.log(np.maximum(pred_dists, 1e-5)),
axis=1)
# return_vals.entropies = - np.sum(pred_dists * np.log(pred_dists),
# axis=1)
assert return_vals.entropies.size == X.shape[0]
# Normalize by maximum entropy.
max_ent = - np.log(1.0 / data.num_classes)
return_vals.entropies /= max_ent
if return_confidence:
return_vals.confidence = np.max(pred_dists, axis=1)
assert return_vals.confidence.size == X.shape[0]
if return_agreement:
return_vals.agreement = softmax_outputs.std(axis=0).mean(axis=1)
assert return_vals.agreement.size == X.shape[0]
return accuracy, return_vals
def estimate_implicit_moments(config, shared, task_id, hnet, hhnet, num_samples,
device):
"""Estimate the first two moments of an implicit distribution.
This function takes the implicit distribution represented by ``hnet`` and
estimates the mean and the variances of its outputs.
Args:
config (argparse.Namespace): Command-line arguments.
shared (argparse.Namespace): Miscellaneous data shared among training
functions.
task_id (int): In case ``hhnet`` is provided, this will be used to
select the task embedding.
hnet: The hypernetwork.
hhnet: The hyper-hypernetwork, may be ``None``.
num_samples: The number of samples that should be drawn from the
``hnet`` to estimate the statistics.
device: The PyTorch device.
Returns:
(tuple): Tuple containing:
- **sample_mean** (torch.Tensor): Estimated mean of the implicit
distribution.
- **sample_std** (torch.Tensor): Estimated standard deviation of the
implicit distribution.
"""
theta = None
if hhnet is not None:
theta = hhnet.forward(cond_id=task_id)
samples = torch.empty((num_samples, hnet.num_outputs)).to(device)
for j in range(num_samples):
z = torch.normal(torch.zeros(1, shared.noise_dim), config.latent_std).\
to(device)
weights = hnet.forward(uncond_input=z, weights=theta)
samples[j, :] = torch.cat([p.detach().flatten() for p in weights])
sample_mean = samples.mean(dim=0)
sample_std = samples.std(dim=0)
return sample_mean, sample_std
def process_dis_batch(config, shared, batch_size, device, dis, hnet, hnet_theta,
dist=None):
"""Process a batch of weight samples via the discriminator.
Args:
config (argparse.Namespace): Command-line arguments.
shared (argparse.Namespace): Miscellaneous data shared among training
functions.
batch_size (int): How many samples should be fed through the
discriminator.
device: PyTorch device.
dis: Discriminator.
hnet: The hypernetwork, representing an implicit distribution from
which to sample weights. Is only used to draw samples if
``dist`` is ``None``.
hnet_theta: The weights passed to ``hnet`` when drawing samples.
dist (torch.distributions.normal.Normal): A normal distribution,
from which discriminator inputs can be sampled.
Returns:
(tuple): Tuple containing:
- **dis_out** (torch.Tensor): The discriminator output for the given
batch of samples.
- **dis_input** (torch.Tensor): The samples that have been passed
through the discriminator.
"""
if dist is not None:
samples = dist.sample([batch_size])
if hnet is not None:
assert np.all(np.equal(samples.shape,
[batch_size, hnet.num_outputs]))
else:
assert hnet is not None
z = torch.normal(torch.zeros(batch_size, shared.noise_dim),
config.latent_std).to(device)
samples = hnet.forward(uncond_input=z, weights=hnet_theta,
ret_format='flattened')
if config.use_batchstats:
samples = gan.concat_mean_stats(samples)
return dis.forward(samples), samples
def calc_prior_matching(config, shared, batch_size, device, dis, hnet,
theta_current, dist_prior, dist_ac,
return_current_samples=False):
"""Calculate the prior-matching term.
Args:
config (argparse.Namespace): Command-line arguments.
shared (argparse.Namespace): Miscellaneous data shared among training
functions.
batch_size (int): How many samples should be fed through the
discriminator.
device: PyTorch device.
dis: Discriminator.
hnet: The hypernetwork, representing an implicit distribution from
which to sample weights. Is used to draw samples from the current
implicit distribution ``theta_current`` (which may be ``None`` if
internal weights should be selected).
theta_current: The weights passed to ``hnet`` when drawing samples from
the current implicit distribution that should be matched to the
prior (can be ``None`` if internally maintaned weights of ``hnet``
should be used).
dist_prior (torch.distributions.normal.Normal): A normal distribution,
that represents an explicit prior. Only used if ``dist_ac`` is
not ``None``.
dist_ac (torch.distributions.normal.Normal): A normal distribution,
that can be passed if the adaptive contrast trick is used. If not
``None``, then ``dist_prior`` may not be ``None``.
return_current_samples (bool): If ``True``, the samples collected from
the current implicit distribution are returned.
Returns:
(tuple): Tuple containing:
- **loss_pm**: (torch.Tensor): The unscaled loss value for the
prior-matching term.
- **curr_samples** (list): List of samples drawn from the implicit
distribution ``hnet`` (using ``theta_current``).
"""
assert dist_ac is None or dist_prior is not None
# The following two terms are only required if AC is used.
log_prob_ac = 0
log_prob_prior = 0
if return_current_samples:
curr_samples = []
else:
curr_samples = None
# Translate into samples from the current implicit distribution.
w_samples = torch.empty((batch_size, hnet.num_outputs)).to(device)
# FIXME Create batch of samples rather than looping.
for j in range(batch_size):
z = torch.normal(torch.zeros(1, shared.noise_dim), config.latent_std).\
to(device)
weights = hnet.forward(uncond_input=z, weights=theta_current)
w_samples[j, :] = torch.cat([p.flatten() for p in weights])
if return_current_samples:
curr_samples.append(weights)
if dist_ac is not None:
log_prob_ac = dist_ac.log_prob(w_samples).sum(dim=1).mean()
log_prob_prior = dist_prior.log_prob(w_samples).sum(dim=1).mean()
if config.use_batchstats:
w_samples = gan.concat_mean_stats(w_samples)
value_t = dis.forward(w_samples).mean()
return value_t + log_prob_ac - log_prob_prior, curr_samples
def calc_batch_uncertainty(config, shared, task_id, device, inputs, mnet, hnet,
hhnet, data, num_w_samples, hnet_theta=None,
allowed_outputs=None):
"""Compute the per-sample uncertainties for a given batch of inputs.
Note:
This function is executed inside a ``torch.no_grad()`` context.
Args:
config (argparse.Namespace): Command-line arguments.
shared: Miscellaneous data shared among training functions (softmax
temperature is stored in here).
task_id (int): In case a hypernet ``hnet`` is given, the ``task_id`` is
used to load the corresponding main network ``mnet`` weights.
device: PyTorch device.
inputs (torch.Tensor): A batch of main network ``mnet`` inputs.
mnet: The main network.
hnet (optional): The implicit hypernetwork, can be ``None``.
hhnet (optional): The hyper-hypernetwork, can be ``None``.
data: Dataset loader. Needed to determine the number of classes.
num_w_samples (int): The number of weight samples that should be drawn
to estimate predictive uncertainty.
hnet_theta (tuple, optional): To save computation, one can pass
weights for the implicit hypernetwork ``hnet``, if they have been
computed prior to calling this methods.
allowed_outputs (tuple, optional): The indices of the neurons belonging
to outputs head ``task_id``. Only needs to be specified in a
multi-head setting.
Returns:
(numpy.ndarray): The entropy of the estimated predictive distribution
per input sample.
"""
assert data.classification
assert config.cl_scenario == 2 or allowed_outputs is not None
assert hhnet is None or hnet is not None
# FIXME We calibrate the temperature after training on a task. This function
# is currently only used to track batch uncertainty during training or
# choose coreset samples that have maximum uncertainty on a single model
# (note, relative order of uncertain samples doesn't change due to
# calibration for a single model). Hence, the function is invoked before
# the temperature is optimized.
# Therefore, I throw an assertion if we use the function in the future for
# other purposes, just in case the programmer is unaware.
assert shared.softmax_temp[task_id] == 1.
ST = shared.softmax_temp[task_id]
with torch.no_grad():
if hnet_theta is None and hhnet is not None:
hnet_theta = hhnet.forward(cond_id=task_id)
if allowed_outputs is not None:
num_outs = len(allowed_outputs)
else:
num_outs = data.num_classes
softmax_outputs = np.empty((num_w_samples, inputs.shape[0], num_outs))
kwargs = pmutils.mnet_kwargs(config, task_id, mnet)
for j in range(num_w_samples):
weights = None
if hnet is not None:
z = torch.normal(torch.zeros(1, shared.noise_dim),
config.latent_std).to(device)
weights = hnet.forward(uncond_input=z, weights=hnet_theta)
Y = mnet.forward(inputs, weights=weights, **kwargs)