Skip to content

Commit

Permalink
Code style fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
rcamino committed Jul 9, 2018
1 parent aec1147 commit 09e6083
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 37 deletions.
6 changes: 1 addition & 5 deletions multi_categorical_gans/methods/arae/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,7 @@ def sample(autoencoder, generator, num_samples, num_features, batch_size=100, no
noise = Variable(torch.FloatTensor(batch_size, noise_size).normal_())
noise = to_cuda_if_available(noise)
batch_code = generator(noise)

batch_samples = autoencoder.decode(batch_code,
training=False,
temperature=temperature)

batch_samples = autoencoder.decode(batch_code, training=False, temperature=temperature)
batch_samples = to_cpu_if_available(batch_samples)
batch_samples = batch_samples.data.numpy()

Expand Down
7 changes: 2 additions & 5 deletions multi_categorical_gans/methods/arae/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,12 @@ def regularize_ae_grad(grad):

batch_original = Variable(torch.from_numpy(batch))
batch_original = to_cuda_if_available(batch_original)
batch_code = autoencoder.encode(batch_original,
normalize_code=normalize_code)
batch_code = autoencoder.encode(batch_original, normalize_code=normalize_code)
batch_code = add_noise_to_code(batch_code, ae_noise_radius)
if regularization_penalty > 0:
batch_code.register_hook(update_last_ae_grad_norm)

batch_reconstructed = autoencoder.decode(batch_code,
training=True,
temperature=temperature)
batch_reconstructed = autoencoder.decode(batch_code, training=True, temperature=temperature)

ae_loss = categorical_variable_loss(batch_reconstructed, batch_original, variable_sizes)
ae_loss.backward()
Expand Down
6 changes: 1 addition & 5 deletions multi_categorical_gans/methods/mc_wgan_gp/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,7 @@ def train(generator,
fake_loss.backward()

# this is the magic from WGAN-GP
gradient_penalty = calculate_gradient_penalty(discriminator,
penalty,
real_features,
fake_features)

gradient_penalty = calculate_gradient_penalty(discriminator, penalty, real_features, fake_features)
gradient_penalty.backward()

# finally update the discriminator weights
Expand Down
5 changes: 1 addition & 4 deletions multi_categorical_gans/methods/medgan/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,9 @@ def batch_norm_train(self, mode=True):
batch_norm.train(mode=mode)

def forward(self, noise):
"""
This sums are called "shortcut connections"
"""
outputs = noise

for module in self.modules:
# Cannot write "outputs += module(outputs)" because it is an inplace operation (no differentiable)
outputs = module(outputs) + outputs
outputs = module(outputs) + outputs # shortcut connection
return outputs
5 changes: 1 addition & 4 deletions multi_categorical_gans/methods/medgan/pre_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,7 @@ def pre_train_epoch(autoencoder, data, batch_size, optim=None, variable_sizes=No
batch = Variable(torch.from_numpy(batch))
batch = to_cuda_if_available(batch)

_, batch_reconstructed = autoencoder(batch,
training=training,
temperature=temperature,
normalize_code=False)
_, batch_reconstructed = autoencoder(batch, training=training, temperature=temperature, normalize_code=False)

loss = categorical_variable_loss(batch_reconstructed, batch, variable_sizes)
loss.backward()
Expand Down
6 changes: 1 addition & 5 deletions multi_categorical_gans/methods/medgan/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,7 @@ def sample(autoencoder, generator, num_samples, num_features, batch_size=100, co
noise = Variable(torch.FloatTensor(batch_size, code_size).normal_())
noise = to_cuda_if_available(noise)
batch_code = generator(noise)

batch_samples = autoencoder.decode(batch_code,
training=False,
temperature=temperature)

batch_samples = autoencoder.decode(batch_code, training=False, temperature=temperature)
batch_samples = to_cpu_if_available(batch_samples)
batch_samples = batch_samples.data.numpy()

Expand Down
12 changes: 3 additions & 9 deletions multi_categorical_gans/methods/medgan/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,7 @@ def train(autoencoder,
noise = Variable(torch.FloatTensor(len(batch), code_size).normal_())
noise = to_cuda_if_available(noise)
fake_code = generator(noise)
fake_features = autoencoder.decode(fake_code,
training=True,
temperature=temperature)
fake_features = autoencoder.decode(fake_code, training=True, temperature=temperature)
fake_features = fake_features.detach() # do not propagate to the generator
fake_pred = discriminator(fake_features)
fake_loss = criterion(fake_pred, label_zeros)
Expand All @@ -127,9 +125,7 @@ def train(autoencoder,
noise = Variable(torch.FloatTensor(len(batch), code_size).normal_())
noise = to_cuda_if_available(noise)
gen_code = generator(noise)
gen_features = autoencoder.decode(gen_code,
training=True,
temperature=temperature)
gen_features = autoencoder.decode(gen_code, training=True, temperature=temperature)
gen_pred = discriminator(gen_features)

smooth_label_ones = Variable(torch.FloatTensor(len(batch)).uniform_(0.9, 1))
Expand Down Expand Up @@ -167,9 +163,7 @@ def train(autoencoder,
noise = Variable(torch.FloatTensor(len(batch), code_size).normal_())
noise = to_cuda_if_available(noise)
fake_code = generator(noise)
fake_features = autoencoder.decode(fake_code,
training=False,
temperature=temperature)
fake_features = autoencoder.decode(fake_code, training=False, temperature=temperature)
fake_pred = discriminator(fake_features)
fake_pred = to_cpu_if_available(fake_pred)
correct += (fake_pred.data.numpy().ravel() < .5).sum()
Expand Down

0 comments on commit 09e6083

Please sign in to comment.