Skip to content

This repository contains a TensorFlow implementation of a Generative Adversarial Network (GAN) designed to generate images based on the CIFAR-100 dataset. The model includes both a generator and discriminator, and it uses hinge loss for training.

Notifications You must be signed in to change notification settings

Banji575/CIFAR-100-GAN-Model

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 

Repository files navigation

GAN Training using CIFAR-100 Dataset

This project demonstrates the implementation of a Generative Adversarial Network (GAN) to generate images similar to those found in the CIFAR-100 dataset. This is a learning-focused project intended to help understand the core concepts of GANs, including generator and discriminator design, loss functions, and training loops.

Project Overview

The project consists of two main components:

  • Generator Model: A neural network that takes random noise as input and generates realistic images.
  • Discriminator Model: A neural network that tries to distinguish between real images from the CIFAR-100 dataset and those generated by the generator.

The models are trained together in an adversarial manner, where the generator aims to fool the discriminator, and the discriminator aims to correctly classify images as real or fake.

Features

  • Uses TensorFlow and Keras to create and train the GAN.
  • Utilizes CIFAR-100, a dataset consisting of 50,000 training images, scaled to the range [-1, 1] for better performance.
  • The generator progressively upsamples from a random noise vector to create a 32x32 RGB image.
  • The discriminator uses convolutional layers to classify images as real or fake.
  • Hinge loss is used for both generator and discriminator, and the Adam optimizer is employed to optimize the training process.

Dependencies

The following dependencies are required for this project:

  • TensorFlow: Machine learning library used for building and training the models.
  • Matplotlib: For visualization of generated images during training.
  • IPython: To display images during training.

Install the dependencies using the following command:

pip install tensorflow matplotlib ipython

Running the Project

1. Clone the Repository

Clone this repository to your local machine:

git clone <repository_url>

2. Open the Notebook

Open the Jupyter notebook file using Jupyter:

jupyter notebook <notebook_name>.ipynb

3. Train the GAN

Run all cells in the notebook to start training the GAN. The training process will visualize the progress by displaying generated images after each epoch.

4. Customize Parameters

The script includes several parameters that can be customized:

  • EPOCHS: Number of training epochs. Default is set to 50.
  • BATCH_SIZE: Number of images in each batch during training. Default is 256.
  • NOISE_DIM: Dimensionality of the random noise vector input to the generator. Default is 100.

Code Overview

Generator Model

The generator starts with a dense layer that reshapes the input noise vector into a low-resolution activation map, which is then progressively upsampled using transposed convolutional layers to generate a 32x32 image.

Discriminator Model

The discriminator is a convolutional neural network that classifies whether an image is real or generated. It employs Leaky ReLU activations and dropout for regularization.

Loss Functions

  • Generator Loss: Uses hinge loss to push the discriminator to classify generated images as real.
  • Discriminator Loss: Uses hinge loss to correctly classify real images as real and generated images as fake.

Training Process

The training loop involves generating images, calculating losses for both the generator and discriminator, and using gradient descent to update their weights.

Best Practices

  • GPU Acceleration: Use a GPU to accelerate the training process, as training GANs is computationally intensive.
  • Regular Checkpoints: Save model checkpoints periodically to avoid losing progress due to unexpected interruptions.
  • Visualize Progress: Use the provided visualization functions to observe the evolution of generated images during training.
  • Batch Normalization: Ensure proper use of batch normalization in the generator to stabilize training.

Results

Generated images are saved at each epoch, allowing you to observe the improvement in quality over time. You can view these images in the working directory.

License

This project is licensed under the MIT License.

Acknowledgements

This project is inspired by the TensorFlow DCGAN tutorial and serves as an educational resource for those looking to learn more about generative adversarial networks.

About

This repository contains a TensorFlow implementation of a Generative Adversarial Network (GAN) designed to generate images based on the CIFAR-100 dataset. The model includes both a generator and discriminator, and it uses hinge loss for training.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published