Skip to content

A conditional DCGAN, in Tensorflow, for generating hand-written digits from the MNIST dataset.

Notifications You must be signed in to change notification settings

Snag9311/conditional-DCGAN-for-MNIST

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

43 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Conditional DCGAN for MNIST

This is a generative model for the hand-written digits of the MNIST dataset. It combines the DCGAN architecture recommended by Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks (Radford et al) with the inputting of labels suggested in Conditional Generative Adversarial Nets (Mirza).

Why a Conditional GAN?

In my last project, I used a DCGAN to generate MNIST digits in an unsupervised fashion - although MNIST is a labeled dataset, I threw away the labels at the beginning and did not use them. This worked, but of course those labels held a great deal of useful information. It would have been nice to allow the GAN to benefit from that additional input, and it would have also been nice to be able to specify which digit I wanted the trained generator to create.

Conditional GANs tackle these shortcomings by feeding the labels into both the Generator and Discriminator.

This has a couple of effects. For example, in the unsupervised DCGAN, the random vector z input controlled everything about the resulting digit - including which digit it was. Since that role is taken over by the labels in a conditional GAN, the z input here encodes all the other features (rotation, style, and so on).

Feeding in the labels also affected training. I found that the architecture that had worked in my last project quickly suffered from mode collapse when I used the corresponding version here. Apparently, the labels made it easier for the Discriminator to do its job, allowing the Discriminator to "win" the minimax game prematurely. The generator lost the gradients it needed to learn and started outputting identical black images.

Using fewer layers and larger filters stabilized training. See trainer/architecture.py for details.

Results

Once I used a suitable architecture, the cDCGAN converged relatively quickly. Below are four randomly sampled digits from each category (0 - 9) that were generated by the finished model:

Trained Model

To use:

  1. Download the trained model here.

  2. Unzip it and drag into the project directory.

  3. Navigate into the project directory, and run python -m trainer.task --sample [NUM_SAMPLES_PER_CLASS]. The results will be saved to the samples/all_samples folder by default.

If you want to store the trained model somewhere else, just include --checkpoint-dir [YOUR_PATH] in the command.

If you want to output the samples to another location, just include --sample-dir [YOUR_PATH] in the command.

Train Your Own (MNIST)

If you want to tweak this code and train your own version from scratch, you can find the main code in trainer/task.py. To train, you will need to:

  1. Download the MNIST data here.
  2. cd into the project directory
  3. Run python -m trainer.task --data-dir [YOUR_PATH_TO_MNIST_DATA] to start training.

Train Your Own (Other Dataset)

If you have a dataset of low resolution, categorically labeled images and want to generate new ones with this code, you should only have to:

  1. Edit the trainer/architecture.py file for your desired input image size, number of label categories, and architecture. DCGANs are very sensitive to architecture, so you may need to try multiple configurations.

  2. Edit the _load_data method in trainer/dataset_loader.py file to unwrap your dataset and shape it into the given format.

  3. Edit trainer/train_config.py to set your preferred training configurations (batch size, num epochs, output filepaths, etc.). I have a separate set of filepath defaults for local and remote training, since I tend to train in the cloud, so hopefully this is useful to you as well. Use the TrainConfig.is_local = True/False property to toggle between local and remote modes.

I hope this is helpful!

To start training, run python -m trainer.task from the project directory.

Acknowledgements

About

A conditional DCGAN, in Tensorflow, for generating hand-written digits from the MNIST dataset.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%