diff --git a/multi_categorical_gans/methods/arae/sampler.py b/multi_categorical_gans/methods/arae/sampler.py index 820661b..0039fe5 100644 --- a/multi_categorical_gans/methods/arae/sampler.py +++ b/multi_categorical_gans/methods/arae/sampler.py @@ -31,11 +31,7 @@ def sample(autoencoder, generator, num_samples, num_features, batch_size=100, no noise = Variable(torch.FloatTensor(batch_size, noise_size).normal_()) noise = to_cuda_if_available(noise) batch_code = generator(noise) - - batch_samples = autoencoder.decode(batch_code, - training=False, - temperature=temperature) - + batch_samples = autoencoder.decode(batch_code, training=False, temperature=temperature) batch_samples = to_cpu_if_available(batch_samples) batch_samples = batch_samples.data.numpy() diff --git a/multi_categorical_gans/methods/arae/trainer.py b/multi_categorical_gans/methods/arae/trainer.py index 48061b9..7fa6300 100644 --- a/multi_categorical_gans/methods/arae/trainer.py +++ b/multi_categorical_gans/methods/arae/trainer.py @@ -106,15 +106,12 @@ def regularize_ae_grad(grad): batch_original = Variable(torch.from_numpy(batch)) batch_original = to_cuda_if_available(batch_original) - batch_code = autoencoder.encode(batch_original, - normalize_code=normalize_code) + batch_code = autoencoder.encode(batch_original, normalize_code=normalize_code) batch_code = add_noise_to_code(batch_code, ae_noise_radius) if regularization_penalty > 0: batch_code.register_hook(update_last_ae_grad_norm) - batch_reconstructed = autoencoder.decode(batch_code, - training=True, - temperature=temperature) + batch_reconstructed = autoencoder.decode(batch_code, training=True, temperature=temperature) ae_loss = categorical_variable_loss(batch_reconstructed, batch_original, variable_sizes) ae_loss.backward() diff --git a/multi_categorical_gans/methods/mc_wgan_gp/trainer.py b/multi_categorical_gans/methods/mc_wgan_gp/trainer.py index eec95a6..f800f54 100644 --- a/multi_categorical_gans/methods/mc_wgan_gp/trainer.py +++ b/multi_categorical_gans/methods/mc_wgan_gp/trainer.py @@ -107,11 +107,7 @@ def train(generator, fake_loss.backward() # this is the magic from WGAN-GP - gradient_penalty = calculate_gradient_penalty(discriminator, - penalty, - real_features, - fake_features) - + gradient_penalty = calculate_gradient_penalty(discriminator, penalty, real_features, fake_features) gradient_penalty.backward() # finally update the discriminator weights diff --git a/multi_categorical_gans/methods/medgan/generator.py b/multi_categorical_gans/methods/medgan/generator.py index f62c55f..ab7008f 100644 --- a/multi_categorical_gans/methods/medgan/generator.py +++ b/multi_categorical_gans/methods/medgan/generator.py @@ -31,12 +31,9 @@ def batch_norm_train(self, mode=True): batch_norm.train(mode=mode) def forward(self, noise): - """ - This sums are called "shortcut connections" - """ outputs = noise for module in self.modules: # Cannot write "outputs += module(outputs)" because it is an inplace operation (no differentiable) - outputs = module(outputs) + outputs + outputs = module(outputs) + outputs # shortcut connection return outputs diff --git a/multi_categorical_gans/methods/medgan/pre_trainer.py b/multi_categorical_gans/methods/medgan/pre_trainer.py index 01ac5fe..d460ca1 100644 --- a/multi_categorical_gans/methods/medgan/pre_trainer.py +++ b/multi_categorical_gans/methods/medgan/pre_trainer.py @@ -69,10 +69,7 @@ def pre_train_epoch(autoencoder, data, batch_size, optim=None, variable_sizes=No batch = Variable(torch.from_numpy(batch)) batch = to_cuda_if_available(batch) - _, batch_reconstructed = autoencoder(batch, - training=training, - temperature=temperature, - normalize_code=False) + _, batch_reconstructed = autoencoder(batch, training=training, temperature=temperature, normalize_code=False) loss = categorical_variable_loss(batch_reconstructed, batch, variable_sizes) loss.backward() diff --git a/multi_categorical_gans/methods/medgan/sampler.py b/multi_categorical_gans/methods/medgan/sampler.py index 433e096..8cb7efe 100644 --- a/multi_categorical_gans/methods/medgan/sampler.py +++ b/multi_categorical_gans/methods/medgan/sampler.py @@ -31,11 +31,7 @@ def sample(autoencoder, generator, num_samples, num_features, batch_size=100, co noise = Variable(torch.FloatTensor(batch_size, code_size).normal_()) noise = to_cuda_if_available(noise) batch_code = generator(noise) - - batch_samples = autoencoder.decode(batch_code, - training=False, - temperature=temperature) - + batch_samples = autoencoder.decode(batch_code, training=False, temperature=temperature) batch_samples = to_cpu_if_available(batch_samples) batch_samples = batch_samples.data.numpy() diff --git a/multi_categorical_gans/methods/medgan/trainer.py b/multi_categorical_gans/methods/medgan/trainer.py index cbc5d87..a6c6212 100644 --- a/multi_categorical_gans/methods/medgan/trainer.py +++ b/multi_categorical_gans/methods/medgan/trainer.py @@ -98,9 +98,7 @@ def train(autoencoder, noise = Variable(torch.FloatTensor(len(batch), code_size).normal_()) noise = to_cuda_if_available(noise) fake_code = generator(noise) - fake_features = autoencoder.decode(fake_code, - training=True, - temperature=temperature) + fake_features = autoencoder.decode(fake_code, training=True, temperature=temperature) fake_features = fake_features.detach() # do not propagate to the generator fake_pred = discriminator(fake_features) fake_loss = criterion(fake_pred, label_zeros) @@ -127,9 +125,7 @@ def train(autoencoder, noise = Variable(torch.FloatTensor(len(batch), code_size).normal_()) noise = to_cuda_if_available(noise) gen_code = generator(noise) - gen_features = autoencoder.decode(gen_code, - training=True, - temperature=temperature) + gen_features = autoencoder.decode(gen_code, training=True, temperature=temperature) gen_pred = discriminator(gen_features) smooth_label_ones = Variable(torch.FloatTensor(len(batch)).uniform_(0.9, 1)) @@ -167,9 +163,7 @@ def train(autoencoder, noise = Variable(torch.FloatTensor(len(batch), code_size).normal_()) noise = to_cuda_if_available(noise) fake_code = generator(noise) - fake_features = autoencoder.decode(fake_code, - training=False, - temperature=temperature) + fake_features = autoencoder.decode(fake_code, training=False, temperature=temperature) fake_pred = discriminator(fake_features) fake_pred = to_cpu_if_available(fake_pred) correct += (fake_pred.data.numpy().ravel() < .5).sum()