diff --git a/multi_categorical_gans/methods/__init__.py b/multi_categorical_gans/methods/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/multi_categorical_gans/methods/arae/__init__.py b/multi_categorical_gans/methods/arae/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/multi_categorical_gans/methods/arae/sampler.py b/multi_categorical_gans/methods/arae/sampler.py new file mode 100644 index 0000000..820661b --- /dev/null +++ b/multi_categorical_gans/methods/arae/sampler.py @@ -0,0 +1,167 @@ +from __future__ import print_function + +import argparse +import torch + +import numpy as np + +from torch.autograd.variable import Variable + +from multi_categorical_gans.methods.general.autoencoder import AutoEncoder +from multi_categorical_gans.methods.general.generator import Generator + +from multi_categorical_gans.utils.categorical import load_variable_sizes_from_metadata +from multi_categorical_gans.utils.commandline import parse_int_list +from multi_categorical_gans.utils.cuda import to_cuda_if_available, to_cpu_if_available, load_without_cuda + + +def sample(autoencoder, generator, num_samples, num_features, batch_size=100, noise_size=128, temperature=None, + round_features=False): + + autoencoder, generator = to_cuda_if_available(autoencoder, generator) + + autoencoder.train(mode=False) + generator.train(mode=False) + + samples = np.zeros((num_samples, num_features), dtype=np.float32) + + start = 0 + while start < num_samples: + with torch.no_grad(): + noise = Variable(torch.FloatTensor(batch_size, noise_size).normal_()) + noise = to_cuda_if_available(noise) + batch_code = generator(noise) + + batch_samples = autoencoder.decode(batch_code, + training=False, + temperature=temperature) + + batch_samples = to_cpu_if_available(batch_samples) + batch_samples = batch_samples.data.numpy() + + # if rounding is activated (for ARAE with binary outputs) + if round_features: + batch_samples = np.round(batch_samples) + + # do not go further than the desired number of samples + end = min(start + batch_size, num_samples) + # limit the samples taken from the batch based on what is missing + samples[start:end, :] = batch_samples[:min(batch_size, end - start), :] + + # move to next batch + start = end + return samples + + +def main(): + options_parser = argparse.ArgumentParser(description="Sample data with ARAE.") + + options_parser.add_argument("autoencoder", type=str, help="Autoencoder input file.") + options_parser.add_argument("generator", type=str, help="Generator input file.") + options_parser.add_argument("num_samples", type=int, help="Number of output samples.") + options_parser.add_argument("num_features", type=int, help="Number of output features.") + options_parser.add_argument("data", type=str, help="Output data.") + + options_parser.add_argument("--metadata", type=str, + help="Information about the categorical variables in json format.") + + options_parser.add_argument( + "--code_size", + type=int, + default=128, + help="Dimension of the autoencoder latent space." + ) + + options_parser.add_argument( + "--noise_size", + type=int, + default=128, + help="Dimension of the generator input noise." + ) + + options_parser.add_argument( + "--encoder_hidden_sizes", + type=str, + default="", + help="Size of each hidden layer in the encoder separated by commas (no spaces)." + ) + + options_parser.add_argument( + "--decoder_hidden_sizes", + type=str, + default="", + help="Size of each hidden layer in the decoder separated by commas (no spaces)." + ) + + options_parser.add_argument( + "--batch_size", + type=int, + default=100, + help="Amount of samples per batch." + ) + + options_parser.add_argument( + "--generator_hidden_sizes", + type=str, + default="256,128", + help="Size of each hidden layer in the generator separated by commas (no spaces)." + ) + + options_parser.add_argument( + "--generator_bn_decay", + type=float, + default=0.01, + help="Generator batch normalization decay." + ) + + options_parser.add_argument( + "--temperature", + type=float, + default=None, + help="Gumbel-Softmax temperature." + ) + + options = options_parser.parse_args() + + if options.metadata is not None and options.temperature is not None: + variable_sizes = load_variable_sizes_from_metadata(options.metadata) + temperature = options.temperature + else: + variable_sizes = None + temperature = None + + autoencoder = AutoEncoder( + options.num_features, + code_size=options.code_size, + encoder_hidden_sizes=parse_int_list(options.encoder_hidden_sizes), + decoder_hidden_sizes=parse_int_list(options.decoder_hidden_sizes), + variable_sizes=variable_sizes + ) + + load_without_cuda(autoencoder, options.autoencoder) + + generator = Generator( + options.noise_size, + options.code_size, + hidden_sizes=parse_int_list(options.generator_hidden_sizes), + bn_decay=options.generator_bn_decay + ) + + load_without_cuda(generator, options.generator) + + data = sample( + autoencoder, + generator, + options.num_samples, + options.num_features, + batch_size=options.batch_size, + noise_size=options.noise_size, + temperature=temperature, + round_features=(temperature is None) + ) + + np.save(options.data, data) + + +if __name__ == "__main__": + main() diff --git a/multi_categorical_gans/methods/arae/trainer.py b/multi_categorical_gans/methods/arae/trainer.py new file mode 100644 index 0000000..48061b9 --- /dev/null +++ b/multi_categorical_gans/methods/arae/trainer.py @@ -0,0 +1,451 @@ +from __future__ import print_function + +import argparse +import torch + +import numpy as np + +from torch.autograd.variable import Variable +from torch.optim import Adam + +from multi_categorical_gans.datasets.dataset import Dataset +from multi_categorical_gans.datasets.formats import data_formats, loaders + +from multi_categorical_gans.methods.general.autoencoder import AutoEncoder +from multi_categorical_gans.methods.general.generator import Generator +from multi_categorical_gans.methods.general.discriminator import Discriminator + +from multi_categorical_gans.utils.categorical import load_variable_sizes_from_metadata, categorical_variable_loss +from multi_categorical_gans.utils.commandline import DelayedKeyboardInterrupt, parse_int_list +from multi_categorical_gans.utils.cuda import to_cuda_if_available, to_cpu_if_available +from multi_categorical_gans.utils.initialization import load_or_initialize +from multi_categorical_gans.utils.logger import Logger + + +def add_noise_to_code(code, noise_radius): + if noise_radius > 0: + means = torch.zeros_like(code) + gauss_noise = torch.normal(means, noise_radius) + return code + to_cuda_if_available(Variable(gauss_noise)) + else: + return code + + +def train(autoencoder, + generator, + discriminator, + train_data, + val_data, + output_ae_path, + output_gen_path, + output_disc_path, + output_loss_path, + batch_size=1000, + start_epoch=0, + num_epochs=1000, + num_ae_steps=1, + num_disc_steps=2, + num_gen_steps=1, + noise_size=128, + l2_regularization=0.001, + learning_rate=0.001, + discriminator_clamp=0.01, + ae_noise_radius=0.2, + ae_noise_anneal=0.995, + normalize_code=True, + variable_sizes=None, + temperature=None, + regularization_penalty=0.01 + ): + autoencoder, generator, discriminator = to_cuda_if_available(autoencoder, generator, discriminator) + + optim_ae = Adam(autoencoder.parameters(), weight_decay=l2_regularization, lr=learning_rate) + optim_gen = Adam(generator.parameters(), weight_decay=l2_regularization, lr=learning_rate) + optim_disc = Adam(discriminator.parameters(), weight_decay=l2_regularization, lr=learning_rate) + + logger = Logger(output_loss_path) + + for epoch_index in range(start_epoch, num_epochs): + logger.start_timer() + + # train + autoencoder.train(mode=True) + generator.train(mode=True) + discriminator.train(mode=True) + + ae_losses = [] + disc_losses = [] + gen_losses = [] + + more_batches = True + train_data_iterator = train_data.batch_iterator(batch_size) + + last_ae_grad_norm = None + + def update_last_ae_grad_norm(grad): + global last_ae_grad_norm + norm = torch.norm(grad, 2, 1) + last_ae_grad_norm = norm.detach().data.mean() + return grad + + def regularize_ae_grad(grad): + global last_ae_grad_norm + gan_norm = torch.norm(grad, 2, 1).detach().data.mean() + return (grad * last_ae_grad_norm / gan_norm) * (-regularization_penalty) + + while more_batches: + # train autoencoder + for _ in range(num_ae_steps): + try: + batch = next(train_data_iterator) + except StopIteration: + more_batches = False + break + + autoencoder.zero_grad() + + batch_original = Variable(torch.from_numpy(batch)) + batch_original = to_cuda_if_available(batch_original) + batch_code = autoencoder.encode(batch_original, + normalize_code=normalize_code) + batch_code = add_noise_to_code(batch_code, ae_noise_radius) + if regularization_penalty > 0: + batch_code.register_hook(update_last_ae_grad_norm) + + batch_reconstructed = autoencoder.decode(batch_code, + training=True, + temperature=temperature) + + ae_loss = categorical_variable_loss(batch_reconstructed, batch_original, variable_sizes) + ae_loss.backward() + + optim_ae.step() + + ae_loss = to_cpu_if_available(ae_loss) + ae_losses.append(ae_loss.data.numpy()) + + # train discriminator + for _ in range(num_disc_steps): + # clamp parameters to a cube + for discriminator_parameter in discriminator.parameters(): + discriminator_parameter.data.clamp_(-discriminator_clamp, discriminator_clamp) + + try: + batch = next(train_data_iterator) + except StopIteration: + more_batches = False + break + + discriminator.zero_grad() + autoencoder.zero_grad() + + # first train the discriminator only with real data + real_features = Variable(torch.from_numpy(batch)) + real_features = to_cuda_if_available(real_features) + real_code = autoencoder.encode(real_features, + normalize_code=normalize_code) + real_code = add_noise_to_code(real_code, ae_noise_radius) + if regularization_penalty > 0: + real_code.register_hook(regularize_ae_grad) + real_pred = discriminator(real_code) + real_loss = - real_pred.mean(0).view(1) + real_loss.backward() + + # then train the discriminator only with fake data + noise = Variable(torch.FloatTensor(len(batch), noise_size).normal_()) + noise = to_cuda_if_available(noise) + fake_code = generator(noise) + fake_code = fake_code.detach() # do not propagate to the generator + fake_pred = discriminator(fake_code) + fake_loss = fake_pred.mean(0).view(1) + fake_loss.backward() + + optim_ae.step() + optim_disc.step() + + disc_loss = real_loss + fake_loss + disc_loss = to_cpu_if_available(disc_loss) + disc_losses.append(disc_loss.data.numpy()) + + del disc_loss + del fake_loss + del real_loss + + # train generator + for _ in range(num_gen_steps): + generator.zero_grad() + + noise = Variable(torch.FloatTensor(len(batch), noise_size).normal_()) + noise = to_cuda_if_available(noise) + gen_code = generator(noise) + fake_pred = discriminator(gen_code) + fake_loss = - fake_pred.mean(0).view(1) + fake_loss.backward() + + optim_gen.step() + + fake_loss = to_cpu_if_available(fake_loss) + gen_losses.append(fake_loss.data.numpy()[0]) + + del fake_loss + + # log epoch metrics for current class + logger.log(epoch_index, num_epochs, "autoencoder", "train_mean_loss", np.mean(ae_losses)) + logger.log(epoch_index, num_epochs, "discriminator", "train_mean_loss", np.mean(disc_losses)) + logger.log(epoch_index, num_epochs, "generator", "train_mean_loss", np.mean(gen_losses)) + + # save models for the epoch + with DelayedKeyboardInterrupt(): + torch.save(autoencoder.state_dict(), output_ae_path) + torch.save(generator.state_dict(), output_gen_path) + torch.save(discriminator.state_dict(), output_disc_path) + logger.flush() + + ae_noise_radius *= ae_noise_anneal + + logger.close() + + +def main(): + options_parser = argparse.ArgumentParser(description="Train ARAE or MC-ARAE. " + + "Define 'metadata' and 'temperature' to use MC-ARAE.") + + options_parser.add_argument("data", type=str, help="Training data. See 'data_format' parameter.") + + options_parser.add_argument("output_autoencoder", type=str, help="Autoencoder output file.") + options_parser.add_argument("output_generator", type=str, help="Generator output file.") + options_parser.add_argument("output_discriminator", type=str, help="Discriminator output file.") + options_parser.add_argument("output_loss", type=str, help="Loss output file.") + + options_parser.add_argument("--input_autoencoder", type=str, help="Autoencoder input file.", default=None) + options_parser.add_argument("--input_generator", type=str, help="Generator input file.", default=None) + options_parser.add_argument("--input_discriminator", type=str, help="Discriminator input file.", default=None) + + options_parser.add_argument("--metadata", type=str, + help="Information about the categorical variables in json format.") + + options_parser.add_argument( + "--validation_proportion", type=float, + default=.1, + help="Ratio of data for validation." + ) + + options_parser.add_argument( + "--data_format", + type=str, + default="sparse", + choices=data_formats, + help="Either a dense numpy array or a sparse csr matrix." + ) + + options_parser.add_argument( + "--code_size", + type=int, + default=128, + help="Dimension of the autoencoder latent space." + ) + + options_parser.add_argument( + "--noise_size", + type=int, + default=128, + help="Dimension of the generator input noise." + ) + + options_parser.add_argument( + "--encoder_hidden_sizes", + type=str, + default="", + help="Size of each hidden layer in the encoder separated by commas (no spaces)." + ) + + options_parser.add_argument( + "--decoder_hidden_sizes", + type=str, + default="", + help="Size of each hidden layer in the decoder separated by commas (no spaces)." + ) + + options_parser.add_argument( + "--batch_size", + type=int, + default=100, + help="Amount of samples per batch." + ) + + options_parser.add_argument( + "--start_epoch", + type=int, + default=0, + help="Starting epoch." + ) + + options_parser.add_argument( + "--num_epochs", + type=int, + default=5000, + help="Number of epochs." + ) + + options_parser.add_argument( + "--l2_regularization", + type=float, + default=0, + help="L2 regularization weight for every parameter." + ) + + options_parser.add_argument( + "--learning_rate", + type=float, + default=1e-5, + help="Adam learning rate." + ) + + options_parser.add_argument( + "--generator_hidden_sizes", + type=str, + default="100,100,100", + help="Size of each hidden layer in the generator separated by commas (no spaces)." + ) + + options_parser.add_argument( + "--bn_decay", + type=float, + default=0.9, + help="Batch normalization decay for the generator and discriminator." + ) + + options_parser.add_argument( + "--discriminator_hidden_sizes", + type=str, + default="100", + help="Size of each hidden layer in the discriminator separated by commas (no spaces)." + ) + + options_parser.add_argument( + "--num_autoencoder_steps", + type=int, + default=1, + help="Number of successive training steps for the autoencoder." + ) + + options_parser.add_argument( + "--num_discriminator_steps", + type=int, + default=1, + help="Number of successive training steps for the discriminator." + ) + + options_parser.add_argument( + "--num_generator_steps", + type=int, + default=1, + help="Number of successive training steps for the generator." + ) + + options_parser.add_argument( + "--discriminator_clamp", + type=float, + default=0.01, + help="WGAN clamp." + ) + + options_parser.add_argument( + "--autoencoder_noise_radius", + type=float, + default=0, + help="Gaussian noise standard deviation for the latent code (autoencoder regularization)." + ) + + options_parser.add_argument( + "--autoencoder_noise_anneal", + type=float, + default=0.995, + help="Anneal the noise radius by this value after every epoch." + ) + + options_parser.add_argument( + "--temperature", + type=float, + default=None, + help="Gumbel-Softmax temperature." + ) + + options_parser.add_argument("--seed", type=int, help="Random number generator seed.", default=42) + + options = options_parser.parse_args() + + if options.seed is not None: + np.random.seed(options.seed) + torch.manual_seed(options.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(options.seed) + + features = loaders[options.data_format](options.data) + data = Dataset(features) + train_data, val_data = data.split(1.0 - options.validation_proportion) + + if options.metadata is not None and options.temperature is not None: + variable_sizes = load_variable_sizes_from_metadata(options.metadata) + temperature = options.temperature + else: + variable_sizes = None + temperature = None + + autoencoder = AutoEncoder( + features.shape[1], + code_size=options.code_size, + encoder_hidden_sizes=parse_int_list(options.encoder_hidden_sizes), + decoder_hidden_sizes=parse_int_list(options.decoder_hidden_sizes), + variable_sizes=variable_sizes + ) + + load_or_initialize(autoencoder, options.input_autoencoder) + + generator = Generator( + options.noise_size, + options.code_size, + hidden_sizes=parse_int_list(options.generator_hidden_sizes), + bn_decay=options.bn_decay + ) + + load_or_initialize(generator, options.input_generator) + + discriminator = Discriminator( + options.code_size, + hidden_sizes=parse_int_list(options.discriminator_hidden_sizes), + bn_decay=options.bn_decay, + critic=True + ) + + load_or_initialize(discriminator, options.input_discriminator) + + train( + autoencoder, + generator, + discriminator, + train_data, + val_data, + options.output_autoencoder, + options.output_generator, + options.output_discriminator, + options.output_loss, + batch_size=options.batch_size, + start_epoch=options.start_epoch, + num_epochs=options.num_epochs, + num_ae_steps=options.num_autoencoder_steps, + num_disc_steps=options.num_discriminator_steps, + num_gen_steps=options.num_generator_steps, + noise_size=options.noise_size, + l2_regularization=options.l2_regularization, + learning_rate=options.learning_rate, + discriminator_clamp=options.discriminator_clamp, + ae_noise_radius=options.autoencoder_noise_radius, + ae_noise_anneal=options.autoencoder_noise_anneal, + variable_sizes=variable_sizes, + temperature=temperature + ) + + +if __name__ == "__main__": + main() diff --git a/multi_categorical_gans/methods/general/__init__.py b/multi_categorical_gans/methods/general/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/multi_categorical_gans/methods/general/autoencoder.py b/multi_categorical_gans/methods/general/autoencoder.py new file mode 100644 index 0000000..c8d3992 --- /dev/null +++ b/multi_categorical_gans/methods/general/autoencoder.py @@ -0,0 +1,40 @@ +from __future__ import print_function + +import torch +import torch.nn as nn + +from multi_categorical_gans.methods.general.decoder import Decoder +from multi_categorical_gans.methods.general.encoder import Encoder + + +class AutoEncoder(nn.Module): + + def __init__(self, input_size, code_size=128, encoder_hidden_sizes=[], decoder_hidden_sizes=[], + variable_sizes=None): + + super(AutoEncoder, self).__init__() + + self.encoder = Encoder(input_size, + code_size, + hidden_sizes=encoder_hidden_sizes) + + self.decoder = Decoder(code_size, + (input_size if variable_sizes is None else variable_sizes), + hidden_sizes=decoder_hidden_sizes) + + def forward(self, inputs, normalize_code=False, training=False, temperature=None): + code = self.encode(inputs, normalize_code=normalize_code) + reconstructed = self.decode(code, training=training, temperature=temperature) + return code, reconstructed + + def encode(self, inputs, normalize_code=False): + code = self.encoder(inputs) + + if normalize_code: + norms = torch.norm(code, 2, 1) + code = torch.div(code, norms.unsqueeze(1).expand_as(code)) + + return code + + def decode(self, code, training=False, temperature=None): + return self.decoder(code, training=training, temperature=temperature) diff --git a/multi_categorical_gans/methods/general/decoder.py b/multi_categorical_gans/methods/general/decoder.py new file mode 100644 index 0000000..8ea13ab --- /dev/null +++ b/multi_categorical_gans/methods/general/decoder.py @@ -0,0 +1,42 @@ +from __future__ import print_function + +import torch.nn as nn + +from multi_categorical_gans.methods.general.multi_categorical import MultiCategorical +from multi_categorical_gans.methods.general.single_output import SingleOutput + + +class Decoder(nn.Module): + + def __init__(self, code_size, output_size, hidden_sizes=[]): + super(Decoder, self).__init__() + + hidden_activation = nn.Tanh() + + previous_layer_size = code_size + hidden_layers = [] + + for layer_size in hidden_sizes: + hidden_layers.append(nn.Linear(previous_layer_size, layer_size)) + hidden_layers.append(hidden_activation) + previous_layer_size = layer_size + + if len(hidden_layers) > 0: + self.hidden_layers = nn.Sequential(*hidden_layers) + else: + self.hidden_layers = None + + if type(output_size) is int: + self.output_layer = SingleOutput(previous_layer_size, output_size, activation=nn.Sigmoid()) + elif type(output_size) is list: + self.output_layer = MultiCategorical(previous_layer_size, output_size) + else: + raise Exception("Invalid output size.") + + def forward(self, code, training=False, temperature=None): + if self.hidden_layers is None: + hidden = code + else: + hidden = self.hidden_layers(code) + + return self.output_layer(hidden, training=training, temperature=temperature) diff --git a/multi_categorical_gans/methods/general/discriminator.py b/multi_categorical_gans/methods/general/discriminator.py new file mode 100644 index 0000000..c12f20d --- /dev/null +++ b/multi_categorical_gans/methods/general/discriminator.py @@ -0,0 +1,32 @@ +from __future__ import print_function + +import torch.nn as nn + + +class Discriminator(nn.Module): + + def __init__(self, input_size, hidden_sizes=(256, 128), bn_decay=0.01, critic=False): + super(Discriminator, self).__init__() + + hidden_activation = nn.LeakyReLU(0.2) + + previous_layer_size = input_size + layers = [] + + for layer_number, layer_size in enumerate(hidden_sizes): + layers.append(nn.Linear(previous_layer_size, layer_size)) + if layer_number > 0 and bn_decay > 0: + layers.append(nn.BatchNorm1d(layer_size, momentum=(1 - bn_decay))) + layers.append(hidden_activation) + previous_layer_size = layer_size + + layers.append(nn.Linear(previous_layer_size, 1)) + + # the critic has a linear output + if not critic: + layers.append(nn.Sigmoid()) + + self.model = nn.Sequential(*layers) + + def forward(self, inputs): + return self.model(inputs).view(-1) diff --git a/multi_categorical_gans/methods/general/encoder.py b/multi_categorical_gans/methods/general/encoder.py new file mode 100644 index 0000000..e777439 --- /dev/null +++ b/multi_categorical_gans/methods/general/encoder.py @@ -0,0 +1,26 @@ +from __future__ import print_function + +import torch.nn as nn + + +class Encoder(nn.Module): + + def __init__(self, input_size, code_size, hidden_sizes=[]): + super(Encoder, self).__init__() + + hidden_activation = nn.Tanh() + + previous_layer_size = input_size + + layer_sizes = list(hidden_sizes) + [code_size] + layers = [] + + for layer_size in layer_sizes: + layers.append(nn.Linear(previous_layer_size, layer_size)) + layers.append(hidden_activation) + previous_layer_size = layer_size + + self.hidden_layers = nn.Sequential(*layers) + + def forward(self, inputs): + return self.hidden_layers(inputs) diff --git a/multi_categorical_gans/methods/general/generator.py b/multi_categorical_gans/methods/general/generator.py new file mode 100644 index 0000000..bbe3a3a --- /dev/null +++ b/multi_categorical_gans/methods/general/generator.py @@ -0,0 +1,43 @@ +from __future__ import print_function + +import torch.nn as nn + +from multi_categorical_gans.methods.general.multi_categorical import MultiCategorical +from multi_categorical_gans.methods.general.single_output import SingleOutput + + +class Generator(nn.Module): + + def __init__(self, noise_size, output_size, hidden_sizes=[], bn_decay=0.01): + super(Generator, self).__init__() + + hidden_activation = nn.ReLU() + + previous_layer_size = noise_size + hidden_layers = [] + + for layer_size in hidden_sizes: + hidden_layers.append(nn.Linear(previous_layer_size, layer_size)) + hidden_layers.append(nn.BatchNorm1d(layer_size, momentum=(1 - bn_decay))) + hidden_layers.append(hidden_activation) + previous_layer_size = layer_size + + if len(hidden_layers) > 0: + self.hidden_layers = nn.Sequential(*hidden_layers) + else: + self.hidden_layers = None + + if type(output_size) is int: + self.output = SingleOutput(previous_layer_size, output_size) + elif type(output_size) is list: + self.output = MultiCategorical(previous_layer_size, output_size) + else: + raise Exception("Invalid output size.") + + def forward(self, noise, training=False, temperature=None): + if self.hidden_layers is None: + hidden = noise + else: + hidden = self.hidden_layers(noise) + + return self.output(hidden, training=training, temperature=temperature) diff --git a/multi_categorical_gans/methods/general/multi_categorical.py b/multi_categorical_gans/methods/general/multi_categorical.py new file mode 100644 index 0000000..24412ff --- /dev/null +++ b/multi_categorical_gans/methods/general/multi_categorical.py @@ -0,0 +1,75 @@ +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch.distributions.one_hot_categorical import OneHotCategorical + + +class MultiCategorical(nn.Module): + + def __init__(self, input_size, variable_sizes): + super(MultiCategorical, self).__init__() + + self.output_layers = nn.ModuleList() + self.output_activations = nn.ModuleList() + + continuous_size = 0 + for i, variable_size in enumerate(variable_sizes): + # if it is a categorical variable + if variable_size > 1: + # first create the accumulated continuous layer + if continuous_size > 0: + self.output_layers.append(nn.Linear(input_size, continuous_size)) + self.output_activations.append(ContinuousActivation()) + continuous_size = 0 + # create the categorical layer + self.output_layers.append(nn.Linear(input_size, variable_size)) + self.output_activations.append(CategoricalActivation()) + # if not, accumulate continuous variables + else: + continuous_size += 1 + + # create the remaining accumulated continuous layer + if continuous_size > 0: + self.output_layers.append(nn.Linear(input_size, continuous_size)) + self.output_activations.append(ContinuousActivation()) + + def forward(self, inputs, training=True, temperature=None, concat=True): + outputs = [] + for output_layer, output_activation in zip(self.output_layers, self.output_activations): + logits = output_layer(inputs) + output = output_activation(logits, training=training, temperature=temperature) + outputs.append(output) + + if concat: + return torch.cat(outputs, dim=1) + else: + return outputs + + +class CategoricalActivation(nn.Module): + + def __init__(self): + super(CategoricalActivation, self).__init__() + + def forward(self, logits, training=True, temperature=None): + # gumbel-softmax (training and evaluation) + if temperature is not None: + return F.gumbel_softmax(logits, hard=not training, tau=temperature) + # softmax training + elif training: + return F.softmax(logits, dim=1) + # softmax evaluation + else: + return OneHotCategorical(logits=logits).sample() + + +class ContinuousActivation(nn.Module): + + def __init__(self): + super(ContinuousActivation, self).__init__() + + def forward(self, logits, training=True, temperature=None): + return F.sigmoid(logits) diff --git a/multi_categorical_gans/methods/general/single_output.py b/multi_categorical_gans/methods/general/single_output.py new file mode 100644 index 0000000..e255c98 --- /dev/null +++ b/multi_categorical_gans/methods/general/single_output.py @@ -0,0 +1,16 @@ +from __future__ import print_function + +import torch.nn as nn + + +class SingleOutput(nn.Module): + + def __init__(self, previous_layer_size, output_size, activation=None): + super(SingleOutput, self).__init__() + if activation is None: + self.model = nn.Linear(previous_layer_size, output_size) + else: + self.model = nn.Sequential(nn.Linear(previous_layer_size, output_size), activation) + + def forward(self, hidden, training=False, temperature=None): + return self.model(hidden) diff --git a/multi_categorical_gans/methods/mc_gumbel/__init__.py b/multi_categorical_gans/methods/mc_gumbel/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/multi_categorical_gans/methods/mc_gumbel/sampler.py b/multi_categorical_gans/methods/mc_gumbel/sampler.py new file mode 100644 index 0000000..e69f452 --- /dev/null +++ b/multi_categorical_gans/methods/mc_gumbel/sampler.py @@ -0,0 +1,114 @@ +from __future__ import print_function + +import argparse +import torch + +import numpy as np + +from torch.autograd.variable import Variable + +from multi_categorical_gans.methods.general.generator import Generator + +from multi_categorical_gans.utils.categorical import load_variable_sizes_from_metadata +from multi_categorical_gans.utils.commandline import parse_int_list +from multi_categorical_gans.utils.cuda import to_cuda_if_available, to_cpu_if_available, load_without_cuda + + +def sample(generator, temperature, num_samples, num_features, batch_size=100, noise_size=128): + generator = to_cuda_if_available(generator) + + generator.train(mode=False) + + samples = np.zeros((num_samples, num_features), dtype=np.float32) + + start = 0 + while start < num_samples: + with torch.no_grad(): + noise = Variable(torch.FloatTensor(batch_size, noise_size).normal_()) + noise = to_cuda_if_available(noise) + batch_samples = generator(noise, training=False, temperature=temperature) + batch_samples = to_cpu_if_available(batch_samples) + batch_samples = batch_samples.data.numpy() + + # do not go further than the desired number of samples + end = min(start + batch_size, num_samples) + # limit the samples taken from the batch based on what is missing + samples[start:end, :] = batch_samples[:min(batch_size, end - start), :] + + # move to next batch + start = end + return samples + + +def main(): + options_parser = argparse.ArgumentParser(description="Sample data with MedGAN.") + + options_parser.add_argument("generator", type=str, help="Generator input file.") + + options_parser.add_argument("metadata", type=str, + help="Information about the categorical variables in json format.") + + options_parser.add_argument("num_samples", type=int, help="Number of output samples.") + options_parser.add_argument("num_features", type=int, help="Number of output features.") + options_parser.add_argument("data", type=str, help="Output data.") + + options_parser.add_argument( + "--noise_size", + type=int, + default=128, + help="Dimension of the generator input noise." + ) + + options_parser.add_argument( + "--batch_size", + type=int, + default=100, + help="Amount of samples per batch." + ) + + options_parser.add_argument( + "--generator_hidden_sizes", + type=str, + default="256,128", + help="Size of each hidden layer in the generator separated by commas (no spaces)." + ) + + options_parser.add_argument( + "--generator_bn_decay", + type=float, + default=0.01, + help="Generator batch normalization decay." + ) + + options_parser.add_argument( + "--temperature", + type=float, + default=0.666, + help="Gumbel-Softmax temperature." + ) + + options = options_parser.parse_args() + + generator = Generator( + options.noise_size, + load_variable_sizes_from_metadata(options.metadata), + hidden_sizes=parse_int_list(options.generator_hidden_sizes), + bn_decay=options.generator_bn_decay + ) + + load_without_cuda(generator, options.generator) + + data = sample( + generator, + options.temperature, + options.num_samples, + options.num_features, + batch_size=options.batch_size, + noise_size=options.noise_size, + ) + + np.save(options.data, data) + + +if __name__ == "__main__": + main() diff --git a/multi_categorical_gans/methods/mc_gumbel/trainer.py b/multi_categorical_gans/methods/mc_gumbel/trainer.py new file mode 100644 index 0000000..ffb72b4 --- /dev/null +++ b/multi_categorical_gans/methods/mc_gumbel/trainer.py @@ -0,0 +1,340 @@ +from __future__ import print_function + +import argparse +import torch + +import numpy as np + +from torch.autograd.variable import Variable +from torch.nn import BCELoss +from torch.optim import Adam + +from multi_categorical_gans.datasets.dataset import Dataset +from multi_categorical_gans.datasets.formats import data_formats, loaders + +from multi_categorical_gans.methods.general.discriminator import Discriminator +from multi_categorical_gans.methods.general.generator import Generator + +from multi_categorical_gans.utils.categorical import load_variable_sizes_from_metadata +from multi_categorical_gans.utils.commandline import DelayedKeyboardInterrupt, parse_int_list +from multi_categorical_gans.utils.cuda import to_cuda_if_available, to_cpu_if_available +from multi_categorical_gans.utils.initialization import load_or_initialize +from multi_categorical_gans.utils.logger import Logger + + +def train(generator, + discriminator, + train_data, + val_data, + output_gen_path, + output_disc_path, + output_loss_path, + batch_size=1000, + start_epoch=0, + num_epochs=1000, + num_disc_steps=2, + num_gen_steps=1, + noise_size=128, + l2_regularization=0.001, + learning_rate=0.001, + temperature=0.666 + ): + generator, discriminator = to_cuda_if_available(generator, discriminator) + + optim_gen = Adam(generator.parameters(), weight_decay=l2_regularization, lr=learning_rate) + optim_disc = Adam(discriminator.parameters(), weight_decay=l2_regularization, lr=learning_rate) + + criterion = BCELoss() + + logger = Logger(output_loss_path) + + for epoch_index in range(start_epoch, num_epochs): + logger.start_timer() + + # train + generator.train(mode=True) + discriminator.train(mode=True) + + disc_losses = [] + gen_losses = [] + + more_batches = True + train_data_iterator = train_data.batch_iterator(batch_size) + + while more_batches: + # train discriminator + for _ in range(num_disc_steps): + # next batch + try: + batch = next(train_data_iterator) + except StopIteration: + more_batches = False + break + + # using "one sided smooth labels" is one trick to improve GAN training + label_zeros = Variable(torch.zeros(len(batch))) + smooth_label_ones = Variable(torch.FloatTensor(len(batch)).uniform_(0.9, 1)) + + label_zeros, smooth_label_ones = to_cuda_if_available(label_zeros, smooth_label_ones) + + optim_disc.zero_grad() + + # first train the discriminator only with real data + real_features = Variable(torch.from_numpy(batch)) + real_features = to_cuda_if_available(real_features) + real_pred = discriminator(real_features) + real_loss = criterion(real_pred, smooth_label_ones) + real_loss.backward() + + # then train the discriminator only with fake data + noise = Variable(torch.FloatTensor(len(batch), noise_size).normal_()) + noise = to_cuda_if_available(noise) + fake_features = generator(noise, training=True, temperature=temperature) + fake_features = fake_features.detach() # do not propagate to the generator + fake_pred = discriminator(fake_features) + fake_loss = criterion(fake_pred, label_zeros) + fake_loss.backward() + + # finally update the discriminator weights + # using two separated batches is another trick to improve GAN training + optim_disc.step() + + disc_loss = real_loss + fake_loss + disc_loss = to_cpu_if_available(disc_loss) + disc_losses.append(disc_loss.data.numpy()) + + del disc_loss + del fake_loss + del real_loss + + # train generator + for _ in range(num_gen_steps): + optim_gen.zero_grad() + + noise = Variable(torch.FloatTensor(len(batch), noise_size).normal_()) + noise = to_cuda_if_available(noise) + gen_features = generator(noise, training=True, temperature=temperature) + gen_pred = discriminator(gen_features) + + smooth_label_ones = Variable(torch.FloatTensor(len(batch)).uniform_(0.9, 1)) + smooth_label_ones = to_cuda_if_available(smooth_label_ones) + + gen_loss = criterion(gen_pred, smooth_label_ones) + gen_loss.backward() + + optim_gen.step() + + gen_loss = to_cpu_if_available(gen_loss) + gen_losses.append(gen_loss.data.numpy()) + + del gen_loss + + # validate discriminator + generator.train(mode=False) + discriminator.train(mode=False) + + correct = 0.0 + total = 0.0 + for batch in val_data.batch_iterator(batch_size): + # real data discriminator accuracy + with torch.no_grad(): + real_features = Variable(torch.from_numpy(batch)) + real_features = to_cuda_if_available(real_features) + real_pred = discriminator(real_features) + real_pred = to_cpu_if_available(real_pred) + correct += (real_pred.data.numpy().ravel() > .5).sum() + total += len(real_pred) + + # fake data discriminator accuracy + with torch.no_grad(): + noise = Variable(torch.FloatTensor(len(batch), noise_size).normal_()) + noise = to_cuda_if_available(noise) + fake_features = generator(noise, training=False, temperature=temperature) + fake_pred = discriminator(fake_features) + fake_pred = to_cpu_if_available(fake_pred) + correct += (fake_pred.data.numpy().ravel() < .5).sum() + total += len(fake_pred) + + # log epoch metrics for current class + logger.log(epoch_index, num_epochs, "discriminator", "train_mean_loss", np.mean(disc_losses)) + logger.log(epoch_index, num_epochs, "generator", "train_mean_loss", np.mean(gen_losses)) + logger.log(epoch_index, num_epochs, "discriminator", "validation_accuracy", correct / total) + + # save models for the epoch + with DelayedKeyboardInterrupt(): + torch.save(generator.state_dict(), output_gen_path) + torch.save(discriminator.state_dict(), output_disc_path) + logger.flush() + + logger.close() + + +def main(): + options_parser = argparse.ArgumentParser(description="Train MC-Gumbel.") + + options_parser.add_argument("data", type=str, help="Training data. See 'data_format' parameter.") + + options_parser.add_argument("metadata", type=str, + help="Information about the categorical variables in json format.") + + options_parser.add_argument("output_generator", type=str, help="Generator output file.") + options_parser.add_argument("output_discriminator", type=str, help="Discriminator output file.") + options_parser.add_argument("output_loss", type=str, help="Loss output file.") + + options_parser.add_argument("--input_generator", type=str, help="Generator input file.", default=None) + options_parser.add_argument("--input_discriminator", type=str, help="Discriminator input file.", default=None) + + options_parser.add_argument( + "--validation_proportion", type=float, + default=.1, + help="Ratio of data for validation." + ) + + options_parser.add_argument( + "--data_format", + type=str, + default="sparse", + choices=data_formats, + help="Either a dense numpy array or a sparse csr matrix." + ) + + options_parser.add_argument( + "--noise_size", + type=int, + default=128, + help="" + ) + + options_parser.add_argument( + "--batch_size", + type=int, + default=1000, + help="Amount of samples per batch." + ) + + options_parser.add_argument( + "--start_epoch", + type=int, + default=0, + help="Starting epoch." + ) + + options_parser.add_argument( + "--num_epochs", + type=int, + default=1000, + help="Number of epochs." + ) + + options_parser.add_argument( + "--l2_regularization", + type=float, + default=0.001, + help="L2 regularization weight for every parameter." + ) + + options_parser.add_argument( + "--learning_rate", + type=float, + default=0.001, + help="Adam learning rate." + ) + + options_parser.add_argument( + "--generator_hidden_sizes", + type=str, + default="256,128", + help="Size of each hidden layer in the generator separated by commas (no spaces)." + ) + + options_parser.add_argument( + "--bn_decay", + type=float, + default=0.9, + help="Batch normalization decay for the generator and discriminator." + ) + + options_parser.add_argument( + "--discriminator_hidden_sizes", + type=str, + default="256,128", + help="Size of each hidden layer in the discriminator separated by commas (no spaces)." + ) + + options_parser.add_argument( + "--num_discriminator_steps", + type=int, + default=2, + help="Number of successive training steps for the discriminator." + ) + + options_parser.add_argument( + "--num_generator_steps", + type=int, + default=1, + help="Number of successive training steps for the generator." + ) + + options_parser.add_argument( + "--temperature", + type=float, + default=0.666, + help="Gumbel-Softmax temperature." + ) + + options_parser.add_argument("--seed", type=int, help="Random number generator seed.", default=42) + + options = options_parser.parse_args() + + if options.seed is not None: + np.random.seed(options.seed) + torch.manual_seed(options.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(options.seed) + + features = loaders[options.data_format](options.data) + data = Dataset(features) + train_data, val_data = data.split(1.0 - options.validation_proportion) + + variable_sizes = load_variable_sizes_from_metadata(options.metadata) + + generator = Generator( + options.noise_size, + variable_sizes, + hidden_sizes=parse_int_list(options.generator_hidden_sizes), + bn_decay=options.bn_decay + ) + + load_or_initialize(generator, options.input_generator) + + discriminator = Discriminator( + features.shape[1], + hidden_sizes=parse_int_list(options.discriminator_hidden_sizes), + bn_decay=options.bn_decay, + critic=False + ) + + load_or_initialize(discriminator, options.input_discriminator) + + train( + generator, + discriminator, + train_data, + val_data, + options.output_generator, + options.output_discriminator, + options.output_loss, + batch_size=options.batch_size, + start_epoch=options.start_epoch, + num_epochs=options.num_epochs, + num_disc_steps=options.num_discriminator_steps, + num_gen_steps=options.num_generator_steps, + noise_size=options.noise_size, + l2_regularization=options.l2_regularization, + learning_rate=options.learning_rate, + temperature=options.temperature + ) + + +if __name__ == "__main__": + main() diff --git a/multi_categorical_gans/methods/mc_wgan_gp/__init__.py b/multi_categorical_gans/methods/mc_wgan_gp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/multi_categorical_gans/methods/mc_wgan_gp/sampler.py b/multi_categorical_gans/methods/mc_wgan_gp/sampler.py new file mode 100644 index 0000000..579387e --- /dev/null +++ b/multi_categorical_gans/methods/mc_wgan_gp/sampler.py @@ -0,0 +1,106 @@ +from __future__ import print_function + +import argparse +import torch + +import numpy as np + +from torch.autograd.variable import Variable + +from multi_categorical_gans.methods.general.generator import Generator + +from multi_categorical_gans.utils.categorical import load_variable_sizes_from_metadata +from multi_categorical_gans.utils.commandline import parse_int_list +from multi_categorical_gans.utils.cuda import to_cuda_if_available, to_cpu_if_available, load_without_cuda + + +def sample(generator, num_samples, num_features, batch_size=100, noise_size=128): + generator = to_cuda_if_available(generator) + + generator.train(mode=False) + + samples = np.zeros((num_samples, num_features), dtype=np.float32) + + start = 0 + while start < num_samples: + with torch.no_grad(): + noise = Variable(torch.FloatTensor(batch_size, noise_size).normal_()) + noise = to_cuda_if_available(noise) + batch_samples = generator(noise, training=False) + batch_samples = to_cpu_if_available(batch_samples) + batch_samples = batch_samples.data.numpy() + + # do not go further than the desired number of samples + end = min(start + batch_size, num_samples) + # limit the samples taken from the batch based on what is missing + samples[start:end, :] = batch_samples[:min(batch_size, end - start), :] + + # move to next batch + start = end + return samples + + +def main(): + options_parser = argparse.ArgumentParser(description="Sample data with MedGAN.") + + options_parser.add_argument("generator", type=str, help="Generator input file.") + + options_parser.add_argument("metadata", type=str, + help="Information about the categorical variables in json format.") + + options_parser.add_argument("num_samples", type=int, help="Number of output samples.") + options_parser.add_argument("num_features", type=int, help="Number of output features.") + options_parser.add_argument("data", type=str, help="Output data.") + + options_parser.add_argument( + "--noise_size", + type=int, + default=128, + help="Dimension of the generator input noise." + ) + + options_parser.add_argument( + "--batch_size", + type=int, + default=100, + help="Amount of samples per batch." + ) + + options_parser.add_argument( + "--generator_hidden_sizes", + type=str, + default="256,128", + help="Size of each hidden layer in the generator separated by commas (no spaces)." + ) + + options_parser.add_argument( + "--generator_bn_decay", + type=float, + default=0.01, + help="Generator batch normalization decay." + ) + + options = options_parser.parse_args() + + generator = Generator( + options.noise_size, + load_variable_sizes_from_metadata(options.metadata), + hidden_sizes=parse_int_list(options.generator_hidden_sizes), + bn_decay=options.generator_bn_decay + ) + + load_without_cuda(generator, options.generator) + + data = sample( + generator, + options.num_samples, + options.num_features, + batch_size=options.batch_size, + noise_size=options.noise_size + ) + + np.save(options.data, data) + + +if __name__ == "__main__": + main() diff --git a/multi_categorical_gans/methods/mc_wgan_gp/trainer.py b/multi_categorical_gans/methods/mc_wgan_gp/trainer.py new file mode 100644 index 0000000..eec95a6 --- /dev/null +++ b/multi_categorical_gans/methods/mc_wgan_gp/trainer.py @@ -0,0 +1,329 @@ +from __future__ import print_function + +import argparse +import torch + +import numpy as np + +from torch.autograd.variable import Variable +from torch.optim import Adam + +from multi_categorical_gans.datasets.dataset import Dataset +from multi_categorical_gans.datasets.formats import data_formats, loaders + +from multi_categorical_gans.methods.general.discriminator import Discriminator +from multi_categorical_gans.methods.general.generator import Generator + +from multi_categorical_gans.utils.categorical import load_variable_sizes_from_metadata +from multi_categorical_gans.utils.commandline import DelayedKeyboardInterrupt, parse_int_list +from multi_categorical_gans.utils.cuda import to_cuda_if_available, to_cpu_if_available +from multi_categorical_gans.utils.initialization import load_or_initialize +from multi_categorical_gans.utils.logger import Logger + + +def calculate_gradient_penalty(discriminator, penalty, real_data, fake_data): + real_data = real_data.data + fake_data = fake_data.data + + alpha = torch.rand(len(real_data), 1) + alpha = alpha.expand(real_data.size()) + alpha = to_cuda_if_available(alpha) + + interpolates = alpha * real_data + ((1 - alpha) * fake_data) + interpolates = Variable(interpolates, requires_grad=True) + discriminator_interpolates = discriminator(interpolates) + + gradients = torch.autograd.grad(outputs=discriminator_interpolates, + inputs=interpolates, + grad_outputs=to_cuda_if_available(torch.ones_like(discriminator_interpolates)), + create_graph=True, retain_graph=True, only_inputs=True)[0] + + return ((gradients.norm(2, dim=1) - 1) ** 2).mean() * penalty + + +def train(generator, + discriminator, + train_data, + val_data, + output_gen_path, + output_disc_path, + output_loss_path, + batch_size=1000, + start_epoch=0, + num_epochs=1000, + num_disc_steps=2, + num_gen_steps=1, + noise_size=128, + l2_regularization=0.001, + learning_rate=0.001, + penalty=0.1 + ): + generator, discriminator = to_cuda_if_available(generator, discriminator) + + optim_gen = Adam(generator.parameters(), weight_decay=l2_regularization, lr=learning_rate) + optim_disc = Adam(discriminator.parameters(), weight_decay=l2_regularization, lr=learning_rate) + + logger = Logger(output_loss_path) + + for epoch_index in range(start_epoch, num_epochs): + logger.start_timer() + + # train + generator.train(mode=True) + discriminator.train(mode=True) + + disc_losses = [] + gen_losses = [] + + more_batches = True + train_data_iterator = train_data.batch_iterator(batch_size) + + while more_batches: + # train discriminator + for _ in range(num_disc_steps): + # next batch + try: + batch = next(train_data_iterator) + except StopIteration: + more_batches = False + break + + optim_disc.zero_grad() + + # first train the discriminator only with real data + real_features = Variable(torch.from_numpy(batch)) + real_features = to_cuda_if_available(real_features) + real_pred = discriminator(real_features) + real_loss = - real_pred.mean(0).view(1) + real_loss.backward() + + # then train the discriminator only with fake data + noise = Variable(torch.FloatTensor(len(batch), noise_size).normal_()) + noise = to_cuda_if_available(noise) + fake_features = generator(noise, training=True) + fake_features = fake_features.detach() # do not propagate to the generator + fake_pred = discriminator(fake_features) + fake_loss = fake_pred.mean(0).view(1) + fake_loss.backward() + + # this is the magic from WGAN-GP + gradient_penalty = calculate_gradient_penalty(discriminator, + penalty, + real_features, + fake_features) + + gradient_penalty.backward() + + # finally update the discriminator weights + # using two separated batches is another trick to improve GAN training + optim_disc.step() + + disc_loss = real_loss + fake_loss + gradient_penalty + disc_loss = to_cpu_if_available(disc_loss) + disc_losses.append(disc_loss.data.numpy()) + + del disc_loss + del gradient_penalty + del fake_loss + del real_loss + + # train generator + for _ in range(num_gen_steps): + optim_gen.zero_grad() + + noise = Variable(torch.FloatTensor(len(batch), noise_size).normal_()) + noise = to_cuda_if_available(noise) + gen_features = generator(noise, training=True) + fake_pred = discriminator(gen_features) + fake_loss = - fake_pred.mean(0).view(1) + fake_loss.backward() + + optim_gen.step() + + fake_loss = to_cpu_if_available(fake_loss) + gen_losses.append(fake_loss.data.numpy()) + + del fake_loss + + # log epoch metrics for current class + logger.log(epoch_index, num_epochs, "discriminator", "train_mean_loss", np.mean(disc_losses)) + logger.log(epoch_index, num_epochs, "generator", "train_mean_loss", np.mean(gen_losses)) + + # save models for the epoch + with DelayedKeyboardInterrupt(): + torch.save(generator.state_dict(), output_gen_path) + torch.save(discriminator.state_dict(), output_disc_path) + logger.flush() + + logger.close() + + +def main(): + options_parser = argparse.ArgumentParser(description="Train Gumbel generator and discriminator.") + + options_parser.add_argument("data", type=str, help="Training data. See 'data_format' parameter.") + + options_parser.add_argument("metadata", type=str, + help="Information about the categorical variables in json format.") + + options_parser.add_argument("output_generator", type=str, help="Generator output file.") + options_parser.add_argument("output_discriminator", type=str, help="Discriminator output file.") + options_parser.add_argument("output_loss", type=str, help="Loss output file.") + + options_parser.add_argument("--input_generator", type=str, help="Generator input file.", default=None) + options_parser.add_argument("--input_discriminator", type=str, help="Discriminator input file.", default=None) + + options_parser.add_argument( + "--validation_proportion", type=float, + default=.1, + help="Ratio of data for validation." + ) + + options_parser.add_argument( + "--data_format", + type=str, + default="sparse", + choices=data_formats, + help="Either a dense numpy array, a sparse csr matrix or any of those formats in split into several files." + ) + + options_parser.add_argument( + "--noise_size", + type=int, + default=128, + help="" + ) + + options_parser.add_argument( + "--batch_size", + type=int, + default=1000, + help="Amount of samples per batch." + ) + + options_parser.add_argument( + "--start_epoch", + type=int, + default=0, + help="Starting epoch." + ) + + options_parser.add_argument( + "--num_epochs", + type=int, + default=1000, + help="Number of epochs." + ) + + options_parser.add_argument( + "--l2_regularization", + type=float, + default=0.001, + help="L2 regularization weight for every parameter." + ) + + options_parser.add_argument( + "--learning_rate", + type=float, + default=0.001, + help="Adam learning rate." + ) + + options_parser.add_argument( + "--generator_hidden_sizes", + type=str, + default="256,128", + help="Size of each hidden layer in the generator separated by commas (no spaces)." + ) + + options_parser.add_argument( + "--bn_decay", + type=float, + default=0.9, + help="Batch normalization decay for the generator and discriminator." + ) + + options_parser.add_argument( + "--discriminator_hidden_sizes", + type=str, + default="256,128", + help="Size of each hidden layer in the discriminator separated by commas (no spaces)." + ) + + options_parser.add_argument( + "--num_discriminator_steps", + type=int, + default=2, + help="Number of successive training steps for the discriminator." + ) + + options_parser.add_argument( + "--num_generator_steps", + type=int, + default=1, + help="Number of successive training steps for the generator." + ) + + options_parser.add_argument( + "--penalty", + type=float, + default=0.1, + help="WGAN-GP gradient penalty lambda." + ) + + options_parser.add_argument("--seed", type=int, help="Random number generator seed.", default=42) + + options = options_parser.parse_args() + + if options.seed is not None: + np.random.seed(options.seed) + torch.manual_seed(options.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(options.seed) + + features = loaders[options.data_format](options.data) + data = Dataset(features) + train_data, val_data = data.split(1.0 - options.validation_proportion) + + variable_sizes = load_variable_sizes_from_metadata(options.metadata) + + generator = Generator( + options.noise_size, + variable_sizes, + hidden_sizes=parse_int_list(options.generator_hidden_sizes), + bn_decay=options.bn_decay + ) + + load_or_initialize(generator, options.input_generator) + + discriminator = Discriminator( + features.shape[1], + hidden_sizes=parse_int_list(options.discriminator_hidden_sizes), + bn_decay=options.bn_decay, + critic=True + ) + + load_or_initialize(discriminator, options.input_discriminator) + + train( + generator, + discriminator, + train_data, + val_data, + options.output_generator, + options.output_discriminator, + options.output_loss, + batch_size=options.batch_size, + start_epoch=options.start_epoch, + num_epochs=options.num_epochs, + num_disc_steps=options.num_discriminator_steps, + num_gen_steps=options.num_generator_steps, + noise_size=options.noise_size, + l2_regularization=options.l2_regularization, + learning_rate=options.learning_rate, + penalty=options.penalty + ) + + +if __name__ == "__main__": + main() diff --git a/multi_categorical_gans/methods/medgan/__init__.py b/multi_categorical_gans/methods/medgan/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/multi_categorical_gans/methods/medgan/discriminator.py b/multi_categorical_gans/methods/medgan/discriminator.py new file mode 100644 index 0000000..89e5264 --- /dev/null +++ b/multi_categorical_gans/methods/medgan/discriminator.py @@ -0,0 +1,37 @@ +from __future__ import print_function + +import torch +import torch.nn as nn + + +class Discriminator(nn.Module): + + def __init__(self, input_size, hidden_sizes=(256, 128)): + super(Discriminator, self).__init__() + + hidden_activation = nn.LeakyReLU() + + previous_layer_size = input_size * 2 + layers = [] + + for layer_size in hidden_sizes: + layers.append(nn.Linear(previous_layer_size, layer_size)) + layers.append(hidden_activation) + previous_layer_size = layer_size + + layers.append(nn.Linear(previous_layer_size, 1)) + layers.append(nn.Sigmoid()) + + self.model = nn.Sequential(*layers) + + def minibatch_averaging(self, inputs): + """ + This method is explained in the MedGAN paper. + """ + mean_per_feature = torch.mean(inputs, 0) + mean_per_feature_repeated = mean_per_feature.repeat(len(inputs), 1) + return torch.cat((inputs, mean_per_feature_repeated), 1) + + def forward(self, inputs): + inputs = self.minibatch_averaging(inputs) + return self.model(inputs).view(-1) diff --git a/multi_categorical_gans/methods/medgan/generator.py b/multi_categorical_gans/methods/medgan/generator.py new file mode 100644 index 0000000..f62c55f --- /dev/null +++ b/multi_categorical_gans/methods/medgan/generator.py @@ -0,0 +1,42 @@ +from __future__ import print_function + +import torch.nn as nn + + +class Generator(nn.Module): + + def __init__(self, code_size=128, num_hidden_layers=2, bn_decay=0.01): + super(Generator, self).__init__() + + self.modules = [] + self.batch_norms = [] + + for layer_number in range(num_hidden_layers): + self.add_generator_module("hidden_{:d}".format(layer_number + 1), code_size, nn.ReLU(), bn_decay) + self.add_generator_module("output", code_size, nn.Tanh(), bn_decay) + + def add_generator_module(self, name, code_size, activation, bn_decay): + batch_norm = nn.BatchNorm1d(code_size, momentum=(1 - bn_decay)) + module = nn.Sequential( + nn.Linear(code_size, code_size, bias=False), # bias is not necessary because of the batch normalization + batch_norm, + activation + ) + self.modules.append(module) + self.add_module(name, module) + self.batch_norms.append(batch_norm) + + def batch_norm_train(self, mode=True): + for batch_norm in self.batch_norms: + batch_norm.train(mode=mode) + + def forward(self, noise): + """ + This sums are called "shortcut connections" + """ + outputs = noise + + for module in self.modules: + # Cannot write "outputs += module(outputs)" because it is an inplace operation (no differentiable) + outputs = module(outputs) + outputs + return outputs diff --git a/multi_categorical_gans/methods/medgan/pre_trainer.py b/multi_categorical_gans/methods/medgan/pre_trainer.py new file mode 100644 index 0000000..01ac5fe --- /dev/null +++ b/multi_categorical_gans/methods/medgan/pre_trainer.py @@ -0,0 +1,226 @@ +from __future__ import print_function + +import argparse +import torch + +import numpy as np + +from torch.autograd.variable import Variable +from torch.optim import Adam + +from multi_categorical_gans.datasets.dataset import Dataset +from multi_categorical_gans.datasets.formats import data_formats, loaders + +from multi_categorical_gans.methods.general.autoencoder import AutoEncoder + +from multi_categorical_gans.utils.categorical import load_variable_sizes_from_metadata, categorical_variable_loss +from multi_categorical_gans.utils.commandline import DelayedKeyboardInterrupt, parse_int_list +from multi_categorical_gans.utils.cuda import to_cuda_if_available, to_cpu_if_available +from multi_categorical_gans.utils.initialization import load_or_initialize +from multi_categorical_gans.utils.logger import Logger + + +def pre_train(autoencoder, + train_data, + val_data, + output_path, + output_loss_path, + batch_size=100, + start_epoch=0, + num_epochs=100, + l2_regularization=0.001, + learning_rate=0.001, + variable_sizes=None, + temperature=None + ): + autoencoder = to_cuda_if_available(autoencoder) + + optim = Adam(autoencoder.parameters(), weight_decay=l2_regularization, lr=learning_rate) + + logger = Logger(output_loss_path) + + for epoch_index in range(start_epoch, num_epochs): + logger.start_timer() + train_loss = pre_train_epoch(autoencoder, train_data, batch_size, optim, variable_sizes, temperature) + logger.log(epoch_index, num_epochs, "autoencoder", "train_mean_loss", np.mean(train_loss)) + + logger.start_timer() + val_loss = pre_train_epoch(autoencoder, val_data, batch_size, None, variable_sizes, temperature) + logger.log(epoch_index, num_epochs, "autoencoder", "validation_mean_loss", np.mean(val_loss)) + + # save models for the epoch + with DelayedKeyboardInterrupt(): + torch.save(autoencoder.state_dict(), output_path) + logger.flush() + + logger.close() + + +def pre_train_epoch(autoencoder, data, batch_size, optim=None, variable_sizes=None, temperature=None): + autoencoder.train(mode=(optim is not None)) + + training = optim is not None + + losses = [] + for batch in data.batch_iterator(batch_size): + if optim is not None: + optim.zero_grad() + + batch = Variable(torch.from_numpy(batch)) + batch = to_cuda_if_available(batch) + + _, batch_reconstructed = autoencoder(batch, + training=training, + temperature=temperature, + normalize_code=False) + + loss = categorical_variable_loss(batch_reconstructed, batch, variable_sizes) + loss.backward() + + if training: + optim.step() + + loss = to_cpu_if_available(loss) + losses.append(loss.data.numpy()) + del loss + return losses + + +def losses_by_class_to_string(losses_by_class): + return ", ".join(["{:.5f}".format(np.mean(losses)) for losses in losses_by_class]) + + +def main(): + options_parser = argparse.ArgumentParser(description="Pre-train MedGAN or MC-MedGAN. " + + "Define 'metadata' and 'temperature' to use MC-MedGAN.") + + options_parser.add_argument("data", type=str, help="Training data. See 'data_format' parameter.") + + options_parser.add_argument("output_model", type=str, help="Model output file.") + options_parser.add_argument("output_loss", type=str, help="Loss output file.") + + options_parser.add_argument("--input_model", type=str, help="Model input file.", default=None) + + options_parser.add_argument("--metadata", type=str, + help="Information about the categorical variables in json format." + + " Only used if temperature is also provided.") + + options_parser.add_argument( + "--validation_proportion", + type=float, + default=.1, + help="Ratio of data for validation." + ) + + options_parser.add_argument( + "--data_format", + type=str, + default="sparse", + choices=data_formats, + help="Either a dense numpy array, a sparse csr matrix or any of those formats in split into several files." + ) + + options_parser.add_argument( + "--code_size", + type=int, + default=128, + help="Dimension of the autoencoder latent space." + ) + + options_parser.add_argument( + "--encoder_hidden_sizes", + type=str, + default="", + help="Size of each hidden layer in the encoder separated by commas (no spaces)." + ) + + options_parser.add_argument( + "--decoder_hidden_sizes", + type=str, + default="", + help="Size of each hidden layer in the decoder separated by commas (no spaces)." + ) + + options_parser.add_argument( + "--batch_size", + type=int, + default=100, + help="Amount of samples per batch." + ) + + options_parser.add_argument( + "--num_epochs", + type=int, + default=100, + help="Number of epochs." + ) + + options_parser.add_argument( + "--l2_regularization", + type=float, + default=0.001, + help="L2 regularization weight for every parameter." + ) + + options_parser.add_argument( + "--learning_rate", + type=float, + default=0.001, + help="Adam learning rate." + ) + + options_parser.add_argument( + "--temperature", + type=float, + default=None, + help="Gumbel-Softmax temperature. Only used if metadata is also provided." + ) + + options_parser.add_argument("--seed", type=int, help="Random number generator seed.", default=42) + + options = options_parser.parse_args() + + if options.seed is not None: + np.random.seed(options.seed) + torch.manual_seed(options.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(options.seed) + + features = loaders[options.data_format](options.data) + data = Dataset(features) + train_data, val_data = data.split(1.0 - options.validation_proportion) + + if options.metadata is not None and options.temperature is not None: + variable_sizes = load_variable_sizes_from_metadata(options.metadata) + temperature = options.temperature + else: + variable_sizes = None + temperature = None + + autoencoder = AutoEncoder( + features.shape[1], + code_size=options.code_size, + encoder_hidden_sizes=parse_int_list(options.encoder_hidden_sizes), + decoder_hidden_sizes=parse_int_list(options.decoder_hidden_sizes), + variable_sizes=variable_sizes + ) + + load_or_initialize(autoencoder, options.input_model) + + pre_train( + autoencoder, + train_data, + val_data, + options.output_model, + options.output_loss, + batch_size=options.batch_size, + num_epochs=options.num_epochs, + l2_regularization=options.l2_regularization, + learning_rate=options.learning_rate, + variable_sizes=variable_sizes, + temperature=temperature + ) + + +if __name__ == "__main__": + main() diff --git a/multi_categorical_gans/methods/medgan/sampler.py b/multi_categorical_gans/methods/medgan/sampler.py new file mode 100644 index 0000000..433e096 --- /dev/null +++ b/multi_categorical_gans/methods/medgan/sampler.py @@ -0,0 +1,159 @@ +from __future__ import print_function + +import argparse +import torch + +import numpy as np + +from torch.autograd.variable import Variable + +from multi_categorical_gans.methods.general.autoencoder import AutoEncoder +from multi_categorical_gans.methods.medgan.generator import Generator + +from multi_categorical_gans.utils.categorical import load_variable_sizes_from_metadata +from multi_categorical_gans.utils.commandline import parse_int_list +from multi_categorical_gans.utils.cuda import to_cuda_if_available, to_cpu_if_available, load_without_cuda + + +def sample(autoencoder, generator, num_samples, num_features, batch_size=100, code_size=128, temperature=None, + round_features=False): + + autoencoder, generator = to_cuda_if_available(autoencoder, generator) + + autoencoder.train(mode=False) + generator.train(mode=False) + + samples = np.zeros((num_samples, num_features), dtype=np.float32) + + start = 0 + while start < num_samples: + with torch.no_grad(): + noise = Variable(torch.FloatTensor(batch_size, code_size).normal_()) + noise = to_cuda_if_available(noise) + batch_code = generator(noise) + + batch_samples = autoencoder.decode(batch_code, + training=False, + temperature=temperature) + + batch_samples = to_cpu_if_available(batch_samples) + batch_samples = batch_samples.data.numpy() + + # if rounding is activated (for MedGAN with binary outputs) + if round_features: + batch_samples = np.round(batch_samples) + + # do not go further than the desired number of samples + end = min(start + batch_size, num_samples) + # limit the samples taken from the batch based on what is missing + samples[start:end, :] = batch_samples[:min(batch_size, end - start), :] + + # move to next batch + start = end + return samples + + +def main(): + options_parser = argparse.ArgumentParser(description="Sample data with MedGAN.") + + options_parser.add_argument("autoencoder", type=str, help="Autoencoder input file.") + options_parser.add_argument("generator", type=str, help="Generator input file.") + options_parser.add_argument("num_samples", type=int, help="Number of output samples.") + options_parser.add_argument("num_features", type=int, help="Number of output features.") + options_parser.add_argument("data", type=str, help="Output data.") + + options_parser.add_argument("--metadata", type=str, + help="Information about the categorical variables in json format.") + + options_parser.add_argument( + "--code_size", + type=int, + default=128, + help="Dimension of the autoencoder latent space." + ) + + options_parser.add_argument( + "--encoder_hidden_sizes", + type=str, + default="", + help="Size of each hidden layer in the encoder separated by commas (no spaces)." + ) + + options_parser.add_argument( + "--decoder_hidden_sizes", + type=str, + default="", + help="Size of each hidden layer in the decoder separated by commas (no spaces)." + ) + + options_parser.add_argument( + "--batch_size", + type=int, + default=100, + help="Amount of samples per batch." + ) + + options_parser.add_argument( + "--generator_hidden_layers", + type=int, + default=2, + help="Number of hidden layers in the generator." + ) + + options_parser.add_argument( + "--generator_bn_decay", + type=float, + default=0.01, + help="Generator batch normalization decay." + ) + + options_parser.add_argument( + "--temperature", + type=float, + default=None, + help="Gumbel-Softmax temperature." + ) + + options = options_parser.parse_args() + + if options.metadata is not None and options.temperature is not None: + variable_sizes = load_variable_sizes_from_metadata(options.metadata) + temperature = options.temperature + else: + variable_sizes = None + temperature = None + + autoencoder = AutoEncoder( + options.num_features, + code_size=options.code_size, + encoder_hidden_sizes=parse_int_list(options.encoder_hidden_sizes), + decoder_hidden_sizes=parse_int_list(options.decoder_hidden_sizes), + variable_sizes=variable_sizes + ) + + autoencoder.load_state_dict(torch.load(options.autoencoder)) + + generator = Generator( + code_size=options.code_size, + num_hidden_layers=options.generator_hidden_layers, + bn_decay=options.generator_bn_decay + ) + + load_without_cuda(generator, options.generator) + + data = sample( + autoencoder, + generator, + options.num_samples, + options.num_features, + batch_size=options.batch_size, + code_size=options.code_size, + temperature=temperature, + round_features=(temperature is None) + ) + + np.save(options.data, data) + + +if __name__ == "__main__": + main() diff --git a/multi_categorical_gans/methods/medgan/trainer.py b/multi_categorical_gans/methods/medgan/trainer.py new file mode 100644 index 0000000..cbc5d87 --- /dev/null +++ b/multi_categorical_gans/methods/medgan/trainer.py @@ -0,0 +1,393 @@ +from __future__ import print_function + +import argparse +import torch + +import numpy as np + +from torch.autograd.variable import Variable +from torch.optim import Adam +from torch.nn import BCELoss + +from multi_categorical_gans.datasets.dataset import Dataset +from multi_categorical_gans.datasets.formats import data_formats, loaders + +from multi_categorical_gans.methods.general.autoencoder import AutoEncoder +from multi_categorical_gans.methods.medgan.discriminator import Discriminator +from multi_categorical_gans.methods.medgan.generator import Generator + +from multi_categorical_gans.utils.categorical import load_variable_sizes_from_metadata +from multi_categorical_gans.utils.commandline import DelayedKeyboardInterrupt, parse_int_list +from multi_categorical_gans.utils.cuda import to_cuda_if_available, to_cpu_if_available, load_without_cuda +from multi_categorical_gans.utils.initialization import load_or_initialize +from multi_categorical_gans.utils.logger import Logger + + +def train(autoencoder, + generator, + discriminator, + train_data, + val_data, + output_ae_path, + output_gen_path, + output_disc_path, + output_loss_path, + batch_size=1000, + start_epoch=0, + num_epochs=1000, + num_disc_steps=2, + num_gen_steps=1, + code_size=128, + l2_regularization=0.001, + learning_rate=0.001, + temperature=None + ): + autoencoder, generator, discriminator = to_cuda_if_available(autoencoder, generator, discriminator) + + optim_gen = Adam(list(generator.parameters()) + list(autoencoder.decoder.parameters()), + weight_decay=l2_regularization, lr=learning_rate) + + optim_disc = Adam(discriminator.parameters(), weight_decay=l2_regularization, lr=learning_rate) + + criterion = BCELoss() + + logger = Logger(output_loss_path) + + for epoch_index in range(start_epoch, num_epochs): + logger.start_timer() + + # train + autoencoder.train(mode=True) + generator.train(mode=True) + discriminator.train(mode=True) + + disc_losses = [] + gen_losses = [] + + more_batches = True + train_data_iterator = train_data.batch_iterator(batch_size) + + while more_batches: + # train discriminator + generator.batch_norm_train(mode=False) + + for _ in range(num_disc_steps): + # next batch + try: + batch = next(train_data_iterator) + except StopIteration: + more_batches = False + break + + # using "one sided smooth labels" is one trick to improve GAN training + label_zeros = Variable(torch.zeros(len(batch))) + smooth_label_ones = Variable(torch.FloatTensor(len(batch)).uniform_(0.9, 1)) + + label_zeros, smooth_label_ones = to_cuda_if_available(label_zeros, smooth_label_ones) + + optim_disc.zero_grad() + + # first train the discriminator only with real data + real_features = Variable(torch.from_numpy(batch)) + real_features = to_cuda_if_available(real_features) + real_pred = discriminator(real_features) + real_loss = criterion(real_pred, smooth_label_ones) + real_loss.backward() + + # then train the discriminator only with fake data + noise = Variable(torch.FloatTensor(len(batch), code_size).normal_()) + noise = to_cuda_if_available(noise) + fake_code = generator(noise) + fake_features = autoencoder.decode(fake_code, + training=True, + temperature=temperature) + fake_features = fake_features.detach() # do not propagate to the generator + fake_pred = discriminator(fake_features) + fake_loss = criterion(fake_pred, label_zeros) + fake_loss.backward() + + # finally update the discriminator weights + # using two separated batches is another trick to improve GAN training + optim_disc.step() + + disc_loss = real_loss + fake_loss + disc_loss = to_cpu_if_available(disc_loss) + disc_losses.append(disc_loss.data.numpy()) + + del disc_loss + del fake_loss + del real_loss + + # train generator + generator.batch_norm_train(mode=True) + + for _ in range(num_gen_steps): + optim_gen.zero_grad() + + noise = Variable(torch.FloatTensor(len(batch), code_size).normal_()) + noise = to_cuda_if_available(noise) + gen_code = generator(noise) + gen_features = autoencoder.decode(gen_code, + training=True, + temperature=temperature) + gen_pred = discriminator(gen_features) + + smooth_label_ones = Variable(torch.FloatTensor(len(batch)).uniform_(0.9, 1)) + smooth_label_ones = to_cuda_if_available(smooth_label_ones) + + gen_loss = criterion(gen_pred, smooth_label_ones) + gen_loss.backward() + + optim_gen.step() + + gen_loss = to_cpu_if_available(gen_loss) + gen_losses.append(gen_loss.data.numpy()) + + del gen_loss + + # validate discriminator + autoencoder.train(mode=False) + generator.train(mode=False) + discriminator.train(mode=False) + + correct = 0.0 + total = 0.0 + for batch in val_data.batch_iterator(batch_size): + # real data discriminator accuracy + with torch.no_grad(): + real_features = Variable(torch.from_numpy(batch)) + real_features = to_cuda_if_available(real_features) + real_pred = discriminator(real_features) + real_pred = to_cpu_if_available(real_pred) + correct += (real_pred.data.numpy().ravel() > .5).sum() + total += len(real_pred) + + # fake data discriminator accuracy + with torch.no_grad(): + noise = Variable(torch.FloatTensor(len(batch), code_size).normal_()) + noise = to_cuda_if_available(noise) + fake_code = generator(noise) + fake_features = autoencoder.decode(fake_code, + training=False, + temperature=temperature) + fake_pred = discriminator(fake_features) + fake_pred = to_cpu_if_available(fake_pred) + correct += (fake_pred.data.numpy().ravel() < .5).sum() + total += len(fake_pred) + + # log epoch metrics for current class + logger.log(epoch_index, num_epochs, "discriminator", "train_mean_loss", np.mean(disc_losses)) + logger.log(epoch_index, num_epochs, "generator", "train_mean_loss", np.mean(gen_losses)) + logger.log(epoch_index, num_epochs, "discriminator", "validation_accuracy", correct / total) + + # save models for the epoch + with DelayedKeyboardInterrupt(): + torch.save(autoencoder.state_dict(), output_ae_path) + torch.save(generator.state_dict(), output_gen_path) + torch.save(discriminator.state_dict(), output_disc_path) + logger.flush() + + logger.close() + + +def main(): + options_parser = argparse.ArgumentParser(description="Train MedGAN or MC-MedGAN. " + + "Define 'metadata' and 'temperature' to use MC-MedGAN.") + + options_parser.add_argument("data", type=str, help="Training data. See 'data_format' parameter.") + + options_parser.add_argument("input_autoencoder", type=str, help="Autoencoder input file.") + options_parser.add_argument("output_autoencoder", type=str, help="Autoencoder output file.") + options_parser.add_argument("output_generator", type=str, help="Generator output file.") + options_parser.add_argument("output_discriminator", type=str, help="Discriminator output file.") + options_parser.add_argument("output_loss", type=str, help="Loss output file.") + + options_parser.add_argument("--input_generator", type=str, help="Generator input file.", default=None) + options_parser.add_argument("--input_discriminator", type=str, help="Discriminator input file.", default=None) + + options_parser.add_argument("--metadata", type=str, + help="Information about the categorical variables in json format." + + " Only used if temperature is also provided.") + + options_parser.add_argument( + "--validation_proportion", type=float, + default=.1, + help="Ratio of data for validation." + ) + + options_parser.add_argument( + "--data_format", + type=str, + default="sparse", + choices=data_formats, + help="Either a dense numpy array, a sparse csr matrix or any of those formats in split into several files." + ) + + options_parser.add_argument( + "--code_size", + type=int, + default=128, + help="Dimension of the autoencoder latent space." + ) + + options_parser.add_argument( + "--encoder_hidden_sizes", + type=str, + default="", + help="Size of each hidden layer in the encoder separated by commas (no spaces)." + ) + + options_parser.add_argument( + "--decoder_hidden_sizes", + type=str, + default="", + help="Size of each hidden layer in the decoder separated by commas (no spaces)." + ) + + options_parser.add_argument( + "--batch_size", + type=int, + default=1000, + help="Amount of samples per batch." + ) + + options_parser.add_argument( + "--start_epoch", + type=int, + default=0, + help="Starting epoch." + ) + + options_parser.add_argument( + "--num_epochs", + type=int, + default=1000, + help="Number of epochs." + ) + + options_parser.add_argument( + "--l2_regularization", + type=float, + default=0.001, + help="L2 regularization weight for every parameter." + ) + + options_parser.add_argument( + "--learning_rate", + type=float, + default=0.001, + help="Adam learning rate." + ) + + options_parser.add_argument( + "--generator_hidden_layers", + type=int, + default=2, + help="Number of hidden layers in the generator." + ) + + options_parser.add_argument( + "--generator_bn_decay", + type=float, + default=0.99, + help="Generator batch normalization decay." + ) + + options_parser.add_argument( + "--discriminator_hidden_sizes", + type=str, + default="256,128", + help="Size of each hidden layer in the discriminator separated by commas (no spaces)." + ) + + options_parser.add_argument( + "--num_discriminator_steps", + type=int, + default=2, + help="Number of successive training steps for the discriminator." + ) + + options_parser.add_argument( + "--num_generator_steps", + type=int, + default=1, + help="Number of successive training steps for the generator." + ) + + options_parser.add_argument( + "--temperature", + type=float, + default=None, + help="Gumbel-Softmax temperature. Only used if metadata is also provided." + ) + + options_parser.add_argument("--seed", type=int, help="Random number generator seed.", default=42) + + options = options_parser.parse_args() + + if options.seed is not None: + np.random.seed(options.seed) + torch.manual_seed(options.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(options.seed) + + features = loaders[options.data_format](options.data) + data = Dataset(features) + train_data, val_data = data.split(1.0 - options.validation_proportion) + + if options.metadata is not None and options.temperature is not None: + variable_sizes = load_variable_sizes_from_metadata(options.metadata) + temperature = options.temperature + else: + variable_sizes = None + temperature = None + + autoencoder = AutoEncoder( + features.shape[1], + code_size=options.code_size, + encoder_hidden_sizes=parse_int_list(options.encoder_hidden_sizes), + decoder_hidden_sizes=parse_int_list(options.decoder_hidden_sizes), + variable_sizes=variable_sizes + ) + + load_without_cuda(autoencoder, options.input_autoencoder) + + generator = Generator( + code_size=options.code_size, + num_hidden_layers=options.generator_hidden_layers, + bn_decay=options.generator_bn_decay + ) + + load_or_initialize(generator, options.input_generator) + + discriminator = Discriminator( + features.shape[1], + hidden_sizes=parse_int_list(options.discriminator_hidden_sizes) + ) + + load_or_initialize(discriminator, options.input_discriminator) + + train( + autoencoder, + generator, + discriminator, + train_data, + val_data, + options.output_autoencoder, + options.output_generator, + options.output_discriminator, + options.output_loss, + batch_size=options.batch_size, + start_epoch=options.start_epoch, + num_epochs=options.num_epochs, + num_disc_steps=options.num_discriminator_steps, + num_gen_steps=options.num_generator_steps, + code_size=options.code_size, + l2_regularization=options.l2_regularization, + learning_rate=options.learning_rate, + temperature=temperature + ) + + +if __name__ == "__main__": + main() diff --git a/multi_categorical_gans/utils/__init__.py b/multi_categorical_gans/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/multi_categorical_gans/utils/categorical.py b/multi_categorical_gans/utils/categorical.py new file mode 100644 index 0000000..ca5ec24 --- /dev/null +++ b/multi_categorical_gans/utils/categorical.py @@ -0,0 +1,50 @@ +import json + +import torch +import torch.nn.functional as F + + +def load_variable_sizes_from_metadata(metadata_path): + with open(metadata_path, "r") as metadata_file: + metadata = json.load(metadata_file) + return metadata["variable_sizes"] + + +def categorical_variable_loss(reconstructed, original, variable_sizes): + # by default use loss for binary variables + if variable_sizes is None: + return F.binary_cross_entropy(reconstructed, original) + # use the variable sizes when available + else: + loss = 0 + start = 0 + continuous_size = 0 + for variable_size in variable_sizes: + # if it is a categorical variable + if variable_size > 1: + # add loss from the accumulated continuous variables + if continuous_size > 0: + end = start + continuous_size + batch_reconstructed_variable = reconstructed[:, start:end] + batch_target = original[:, start:end] + loss += F.mse_loss(batch_reconstructed_variable, batch_target) + start = end + continuous_size = 0 + # add loss from categorical variable + end = start + variable_size + batch_reconstructed_variable = reconstructed[:, start:end] + batch_target = torch.argmax(original[:, start:end], dim=1) + loss += F.cross_entropy(batch_reconstructed_variable, batch_target) + start = end + # if not, accumulate continuous variables + else: + continuous_size += 1 + + # add loss from the remaining accumulated continuous variables + if continuous_size > 0: + end = start + continuous_size + batch_reconstructed_variable = reconstructed[:, start:end] + batch_target = original[:, start:end] + loss += F.mse_loss(batch_reconstructed_variable, batch_target) + + return loss diff --git a/multi_categorical_gans/utils/commandline.py b/multi_categorical_gans/utils/commandline.py new file mode 100644 index 0000000..6656d33 --- /dev/null +++ b/multi_categorical_gans/utils/commandline.py @@ -0,0 +1,42 @@ +from __future__ import print_function + +import signal +import sys + + +def parse_int_list(comma_separated_ints): + if comma_separated_ints is None or comma_separated_ints == "": + return [] + return [int(i) for i in comma_separated_ints.split(",")] + + +class DelayedKeyboardInterrupt(object): + + SIGNALS = [signal.SIGINT, signal.SIGTERM] + + def __init__(self): + self.signal_received = {} + self.old_handler = {} + + def __enter__(self): + self.signal_received = {} + self.old_handler = {} + for sig in self.SIGNALS: + self.old_handler[sig] = signal.signal(sig, self.handler) + + def handler(self, sig, frame): + self.signal_received[sig] = frame + print('Delaying received signal', sig) + + def __exit__(self, type, value, traceback): + for sig in self.SIGNALS: + signal.signal(sig, self.old_handler[sig]) + for sig, frame in self.signal_received.items(): + old_handler = self.old_handler[sig] + print('Resuming received signal', sig) + if callable(old_handler): + old_handler(sig, frame) + elif old_handler == signal.SIG_DFL: + sys.exit(0) + self.signal_received = {} + self.old_handler = {} diff --git a/multi_categorical_gans/utils/cuda.py b/multi_categorical_gans/utils/cuda.py new file mode 100644 index 0000000..5be1907 --- /dev/null +++ b/multi_categorical_gans/utils/cuda.py @@ -0,0 +1,21 @@ +import torch + + +def to_cuda_if_available(*tensors): + if torch.cuda.is_available(): + tensors = [tensor.cuda() if tensor is not None else None for tensor in tensors] + if len(tensors) == 1: + return tensors[0] + return tensors + + +def to_cpu_if_available(*tensors): + if torch.cuda.is_available(): + tensors = [tensor.cpu() if tensor is not None else None for tensor in tensors] + if len(tensors) == 1: + return tensors[0] + return tensors + + +def load_without_cuda(model, state_dict_path): + model.load_state_dict(torch.load(state_dict_path, map_location=lambda storage, loc: storage)) diff --git a/multi_categorical_gans/utils/initialization.py b/multi_categorical_gans/utils/initialization.py new file mode 100644 index 0000000..f36c259 --- /dev/null +++ b/multi_categorical_gans/utils/initialization.py @@ -0,0 +1,20 @@ +import torch.nn as nn + +from multi_categorical_gans.utils.cuda import load_without_cuda + + +def initialize_weights(module): + if type(module) == nn.Linear: + nn.init.xavier_normal_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0.0) + elif type(module) == nn.BatchNorm1d: + module.weight.data.normal_(1.0, 0.02) + module.bias.data.fill_(0) + + +def load_or_initialize(module, state_dict_path): + if state_dict_path is not None: + load_without_cuda(module, state_dict_path) + else: + module.apply(initialize_weights) diff --git a/multi_categorical_gans/utils/logger.py b/multi_categorical_gans/utils/logger.py new file mode 100644 index 0000000..f6a9afa --- /dev/null +++ b/multi_categorical_gans/utils/logger.py @@ -0,0 +1,54 @@ +import csv +import os +import time + + +class Logger(object): + + PRINT_FORMAT = "epoch {:d}/{:d} {}-{}: {:.05f} Time: {:.2f} s" + CSV_COLUMNS = ["epoch", "model", "metric_name", "metric_value", "time"] + + start_time = None + + def __init__(self, output_path): + if os.path.exists(output_path) and os.path.getsize(output_path) > 0: + self.output_file = open(output_path, "a") + self.output_writer = csv.DictWriter(self.output_file, fieldnames=self.CSV_COLUMNS) + else: + self.output_file = open(output_path, "w") + self.output_writer = csv.DictWriter(self.output_file, fieldnames=self.CSV_COLUMNS) + self.output_writer.writeheader() + + self.start_timer() + + def start_timer(self): + self.start_time = time.time() + + def log(self, epoch_index, num_epochs, model_name, metric_name, metric_value): + elapsed_time = time.time() - self.start_time + + self.output_writer.writerow({ + "epoch": epoch_index + 1, + "model": model_name, + "metric_name": metric_name, + "metric_value": metric_value, + "time": elapsed_time + }) + + print(self.PRINT_FORMAT + .format(epoch_index + 1, + num_epochs, + model_name, + metric_name, + metric_value, + elapsed_time + )) + + def flush(self): + self.output_file.flush() + + def close(self): + self.output_file.close() + + self.output_file = None + self.output_writer = None