This project implements a Deep Convolutional GAN (DCGAN) to generate handwritten digit images similar to those in the MNIST dataset. The project includes the model design, training script, and utilities for managing data and generating images, all set up to run in a self-contained environment.
This project uses a DCGAN, a type of GAN (Generative Adversarial Network), where:
- The Generator learns to generate new digit images by taking random noise as input.
- The Discriminator learns to distinguish between real MNIST images and fake images generated by the Generator.
- The two models are trained together in an adversarial process, where the Generator improves at producing realistic images, and the Discriminator improves at detecting fakes.
-
Clone the repository
git clone https://github.com/letsdoitbycode/GAN-model-implementation.git cd GAN-model-implementation
-
Create a virtual environment and activate it:
python -m venv venv source venv/bin/activate # On Windows use `venv\Scripts\activate`
-
Install the required packages:
pip install torch torchvision numpy Pillow pip install requirements.txt #else you can do this directly
-
Run the Training Script:
- The training script will download the MNIST dataset (if not already present) and begin training the GAN.
python train.py
GAN-model-implementation/
├── data/ # Directory for downloaded MNIST dataset
├── generated_images/ # Directory to save generated images during training
├── models/
│ ├── generator.py # Generator model code
│ └── discriminator.py # Discriminator model code
├── train.py # Script to train the GAN
├── utils.py # Utility functions (e.g., image saving, initialization)
├── requirements.txt # Python dependencies
└── README.md # Project description and setup instructions
plaintext models/generator.py
Defines the architecture of the Generator network, which creates fake MNIST images from random noise.
plaintext models/discriminator.py
Defines the architecture of the Discriminator network, which evaluates if images are real or fake.
plaintexttrain.py
Script for training the GAN. It sets up the data, trains the Generator and Discriminator, and saves generated images at intervals.
plaintext utils.py
Contains utility functions, such as save_generated_images to save generated images during training, and initialize_weights to initialize model weights.
plaintext requirements.txt
Lists the Python dependencies required for this project.
plaintext README.md
Provides an overview, setup instructions, and details on the project structure.
Generated images are saved in the plaintext generated_images/
directory. As training progresses, you can observe the images evolving to resemble handwritten digits more closely.
You may want to include a few sample images here from plaintext generated_images/
after running the project.
- Training Duration: Depending on your hardware, training might take some time. The script is set to run for 50 epochs, but you may experiment with different settings.
- Noise Dimension: The input noise dimension for the Generator is set to 100. You can experiment with this to see how it impacts the quality of generated images.
Contributions are welcome! Please open an issue or submit a pull request for any changes or improvements.