Our team of 5 worked on the task of generating novel fashion images from an existing dataset called DeepFashion using the popular generative adversarial network (or GAN) architecture. We experimented with several different architectures, losses, learning rates, training techniques (which are a particular stress point when it comes to building GANs.).
The architecture we have used for the most part is that of the Deep Convolutional GAN or DCGAN. Although we have experimented with things such as adding dropout, changing the dimensionality of some layers, changing the loss from BCEWithLogitsLoss to MSE (as outlined in the LSGAN architecture) etc., the basic concept of ConvTranspose2d (Upsampling layers) and Conv2d (simple Conv. layers) has been maintained.
In addition, we also added an element of giving the generator some meaningful input to start with rather than simple random noise. While the generator is conventionally given some random noise sampled from a normal distribution, we have passed the input image through a ResNet architecture, extracted an encoded vector (which will contain some compressed information about the image) and then concatenated this encoding with some random noise. This is finally then given to the generator.
The idea behind this method is to give the generator some help in that we are giving it noise which already has some rhyme and reason, which will hopefully make the generator's job easier when it comes to mapping this noise to a proper output which can fool the discriminator.
We have trained on both local GPUs as well as the Colab platform to varying degrees of success. Although Colab often provides unstable access to GPUs, we were able to train for 8-9 hours without any disturbances. However, Colab does change the GPU being used on each runtime based on the stress on their remote PCs and hence training is not always consistent even with the same code. In the future, we hope to gain access to more stable and powerful GPUs which may help render better results.
The model did not seem to converge, and the quality of outputs did not increase significantly after this point.
-
Pre-requisites:
- torch
- torchvision
- tqdm
- matplotlib
- time
-
Clone the repository
git clone https://github.com/Data-Science-Community-SRM/FashionGen
Abhishek Saxena |
Aditya Shukla |
Aradhya Tripathi |
Made with ❤️ by DS Community SRM