Skip to content

Files

Latest commit

4c48bcd · Jan 28, 2019

History

History

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
Jan 28, 2019
Jul 9, 2018
Jul 9, 2018
Aug 20, 2018

ARAE and MC-ARAE

Based on the code from jakezhaojb/ARAE.

ARAE

An example using the MIX SMALL dataset:

mkdir -p models/arae/synthetic/mix_small
mkdir -p samples/arae/synthetic/mix_small

Training:

python multi_categorical_gans/methods/arae/trainer.py \
    --data_format=sparse \
    --code_size=65 \
    --noise_size=10 \
    --encoder_hidden_sizes="" \
    --decoder_hidden_sizes="" \
    --batch_size=100 \
    --num_epochs=5000 \
    --l2_regularization=0 \
    --learning_rate=1e-5 \
    --generator_hidden_sizes=100,100,100 \
    --bn_decay=0.9 \
    --discriminator_hidden_sizes=100 \
    --num_autoencoder_steps=1 \
    --num_discriminator_steps=1 \
    --num_generator_steps=1 \
    --autoencoder_noise_radius=0 \
    --seed=123 \
    data/synthetic/mix_small/synthetic-train.features.npz \
    models/arae/synthetic/mix_small/autoencoder.torch \
    models/arae/synthetic/mix_small/generator.torch \
    models/arae/synthetic/mix_small/discriminator.torch \
    models/arae/synthetic/mix_small/loss.csv

Sampling:

python multi_categorical_gans/methods/arae/sampler.py \
    --code_size=65 \
    --noise_size=10 \
    --encoder_hidden_sizes="" \
    --decoder_hidden_sizes="" \
    --batch_size=1000 \
    --generator_hidden_sizes=100,100,100 \
    --generator_bn_decay=0.9 \
    models/arae/synthetic/mix_small/autoencoder.torch \
    models/arae/synthetic/mix_small/generator.torch \
    10000 65 \
    samples/arae/synthetic/mix_small/sample.features.npy

MC-ARAE

Now the metadata (with the variable size information) and the gumbel-softmax temperature is needed.

An example using the MIX SMALL dataset:

mkdir -p models/mc-arae/synthetic/mix_small
mkdir -p samples/mc-arae/synthetic/mix_small

Training:

python multi_categorical_gans/methods/arae/trainer.py \
    --metadata=data/synthetic/mix_small/metadata.json \
    --data_format=sparse \
    --code_size=65 \
    --noise_size=10 \
    --encoder_hidden_sizes="" \
    --decoder_hidden_sizes="" \
    --batch_size=100 \
    --num_epochs=5000 \
    --l2_regularization=0 \
    --learning_rate=1e-5 \
    --generator_hidden_sizes=100,100,100 \
    --bn_decay=0.99 \
    --discriminator_hidden_sizes=100 \
    --num_autoencoder_steps=1 \
    --num_discriminator_steps=1 \
    --num_generator_steps=1 \
    --autoencoder_noise_radius=0 \
    --seed=123 \
    --temperature=0.666 \
    data/synthetic/mix_small/synthetic-train.features.npz \
    models/mc-arae/synthetic/mix_small/autoencoder.torch \
    models/mc-arae/synthetic/mix_small/generator.torch \
    models/mc-arae/synthetic/mix_small/discriminator.torch \
    models/mc-arae/synthetic/mix_small/loss.csv

Sampling:

python multi_categorical_gans/methods/arae/sampler.py \
    --metadata=data/synthetic/mix_small/metadata.json \
    --code_size=65 \
    --noise_size=10 \
    --encoder_hidden_sizes="" \
    --decoder_hidden_sizes="" \
    --batch_size=1000 \
    --generator_hidden_sizes=100,100,100 \
    --generator_bn_decay=0.99 \
    --temperature=0.666 \
    models/mc-arae/synthetic/mix_small/autoencoder.torch \
    models/mc-arae/synthetic/mix_small/generator.torch \
    10000 65 \
    samples/mc-arae/synthetic/mix_small/sample.features.npy