This repository implements training and inference methods of GAN with just fc layers on MNIST.
![GAN Tutorial](https://private-user-images.githubusercontent.com/144267687/302553167-3fe95fbd-1340-4396-ae92-66c59b3f0013.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk3NjU4NzIsIm5iZiI6MTczOTc2NTU3MiwicGF0aCI6Ii8xNDQyNjc2ODcvMzAyNTUzMTY3LTNmZTk1ZmJkLTEzNDAtNDM5Ni1hZTkyLTY2YzU5YjNmMDAxMy5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjE3JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIxN1QwNDEyNTJaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT1iZWE0YWY1ODcxNDA2ODk3OTlmNGQ5ODE2YTkzMWNhYTZmNjRmMGNhNGNiNzcwNzhlYzVmMGU2MTZmY2I1MzQ0JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.Xvm9ze3H2pX9IlfQFRb_k7PTRJGDjGT57WgYyVIqqLU)
![](https://private-user-images.githubusercontent.com/144267687/292857810-4e1fd994-6ec0-4e21-aeee-6b054e72ddab.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk3NjU4NzIsIm5iZiI6MTczOTc2NTU3MiwicGF0aCI6Ii8xNDQyNjc2ODcvMjkyODU3ODEwLTRlMWZkOTk0LTZlYzAtNGUyMS1hZWVlLTZiMDU0ZTcyZGRhYi5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjE3JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIxN1QwNDEyNTJaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT1hZjc1NzAzZDViOGE2MTljMTgyOTZmYzFmNDE3NjZlZGI3N2U3NzM3NjMyNWU3YjJmNjNiZWE4ZDU4NjdmZTRmJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.domeJU2M6pnTgKatpf1WyvJs9llhe-HAwn2IdMK-uac)
![](https://private-user-images.githubusercontent.com/144267687/292857831-f4bbdafa-a8e2-4a8f-b063-4a4bc00c76fa.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk3NjU4NzIsIm5iZiI6MTczOTc2NTU3MiwicGF0aCI6Ii8xNDQyNjc2ODcvMjkyODU3ODMxLWY0YmJkYWZhLWE4ZTItNGE4Zi1iMDYzLTRhNGJjMDBjNzZmYS5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjE3JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIxN1QwNDEyNTJaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT1jNDk0ZmE3MzcwMzQ2ZGI2Njk0NzFmY2MzNWE5YmNlMzhiYTAwZmE4Mjg0MTA4OTYxY2FjMjYwNTk3MjJhM2QwJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.XxkKdrrh8ha7PRF6yglkYp9qvOpZQtg-Ow2dBHDEbIE)
For setting up the mnist dataset:
Follow - https://github.com/explainingai-code/Pytorch-VAE#data-preparation
The directory structure should look like this
$REPO_ROOT
-> data
-> train
-> images
-> 0
*.png
-> 1
...
-> 9
*.png
-> test
-> images
-> 0
*.png
...
-> dataset
-> tools
- Create a new conda environment with python 3.8 then run below commands
git clone https://github.com/explainingai-code/GANs-Pytorch.git
cd GANs-Pytorch
pip install -r requirements.txt
python -m tools.train_gan
for training and saving inference samples
- Ensure dataset is prepared according to the Data Preparation instructions
- Change the
IM_CHANNELS
field to 3 intrain_gan.py
- Uncomment lines 56-59 in the
dataset/mnist_dataset.py
file
- Dump all *.png files(or whatever format images you have) in the path
data/train/images
- Comment https://github.com/explainingai-code/GANs-Pytorch/blob/main/dataset/mnist_dataset.py#L43
- Directory structure should be following:
data
-> train
-> images
*.png
- Change the
IM_PATH
field todata/train
intrain_gan.py
- Change the channels and image sizes accordingly
Outputs will be saved every 50 steps in samples
directory .
During training of GAN the following output will be saved
- Latest Model checkpoints for generator and discriminator in
$REPO_ROOT
directory
During inference every 50 steps the following output will be saved
- Sampled image grid for in
samples/*.png
@misc{goodfellow2014generative,
title={Generative Adversarial Networks},
author={Ian J. Goodfellow and Jean Pouget-Abadie and Mehdi Mirza and Bing Xu and David Warde-Farley and Sherjil Ozair and Aaron Courville and Yoshua Bengio},
year={2014},
eprint={1406.2661},
archivePrefix={arXiv},
primaryClass={stat.ML}
}