diff --git a/multi_categorical_gans/methods/arae/trainer.py b/multi_categorical_gans/methods/arae/trainer.py index 129abf5..d253cce 100644 --- a/multi_categorical_gans/methods/arae/trainer.py +++ b/multi_categorical_gans/methods/arae/trainer.py @@ -394,7 +394,7 @@ def main(): discriminator = Discriminator( options.code_size, hidden_sizes=parse_int_list(options.discriminator_hidden_sizes), - bn_decay=options.bn_decay, + bn_decay=0, # no batch normalization for the critic critic=True ) diff --git a/multi_categorical_gans/methods/mc_wgan_gp/trainer.py b/multi_categorical_gans/methods/mc_wgan_gp/trainer.py index e3a2e1d..49e1134 100644 --- a/multi_categorical_gans/methods/mc_wgan_gp/trainer.py +++ b/multi_categorical_gans/methods/mc_wgan_gp/trainer.py @@ -276,7 +276,7 @@ def main(): discriminator = Discriminator( features.shape[1], hidden_sizes=parse_int_list(options.discriminator_hidden_sizes), - bn_decay=options.bn_decay, + bn_decay=0, # no batch normalization for the critic critic=True )