Skip to content

The repository contains software library for Data Augmentation Services

Notifications You must be signed in to change notification settings

sverma88/ADMA_Demo_2018

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

19 Commits
 
 
 
 

Repository files navigation

ADMA Demo 2018

The repository contains software library for Data Augmentation Services

Requirements

python 3.x
numpy > 1.13
scipy > 0.19
pillow > 5.2
tensorflow (gpu) > 1.3
MATLAB

Stage 1

Contains two parts (i) Training GAN and (ii) Training Ensemble Classifier

Train GAN

Execute the python script main.py located at DAS/Stage-1/Train_GAN/ by typing the command on the terminal
python main.py

We provide modified version of DCGAN taken from https://github.com/carpedm20/DCGAN-tensorflow The discriminator and the generator are conditioned on the labels of the images and have less layers than the DCGAN model available at the link above. The code DCGAN_Modified.py is located at DAS/Stage-1/Train_GAN/

This will download CIFAR-10 dataset automatically to the path specified. The GAN will be trained for category 0 against all with the default parameters specified in the file main.py. Synthetically Generated Images and associated Labels will be saved in DAS/Stage-1/Train_GAN/Geneated_Data folder.
Example files already exists in this path, Images_10_1_2.mat and Labels_10_0_1.mat

Naming Convention Images_alpha_CategoryA_CategoryB and similarly Labels_alpha_CategoryA_CategoryB

If you wish to change the categories on which GAN is trained then please edit file DAS/Stage-1/Train_GAN/DCGAN_Modified.py
Line 162, fixed_label = 0 and Line 163, iter_labels = np.arange((fixed_label + 1), 10)

If you want to specify the split ratio of training dataset while training GAN then edit
Line 159, alpha = 0.1

Train Ensemble Classifieris

To train Ensemble classifier execute the main MATLAB file TrainEnsemble.m found at DAS/Stage-1/TrainEnsemClass/

We train SVM, k-NN and naive Bayes available at MATLAB-R2018a The codes and necessary functions are in the folder DAS/Stage-1/Train_EnsemClass/

Trained parameters of the classifier's will get save in a mat file DAS/Stage-1/TrainEnsemClass, example file exists in the folder name MODEL_X_Y.mat
X : Category 1
Y : Category 2

You can specify on which labels you want to train your ensemble classifier then edit file TrainEnsemble.m
Line 14, fixed_label = 1
Line 15, selected_labels = [(fixed_label+1):10]

Stage 2

Once training of then GAN and ensemble classifier is finished and outputs are saved in their corresponding locations. Move to DAS/Stage-2/ for filtering synthetic Images and obtaining performance measuer of CNN trained on augmented datasets.

Filter Unbiased Images

Execute the main MATLAB file Filter_Images.m found at DAS/Stage-2/Filter_Unbiased_Images/ to filter the synthetic images generated by the GAN.

Path of the training data, the saved model, and the generated data is required. They are already set in the Filter_Images.m file. However you can modify them to any locations, the details are as below:
Line 16, path of the trained ensemble classifier's model
Line 17, path of the training data
Line 78, path of the generated data

Once the code terminates output file named Batches_alpha_CategoryA_CategoryB will be saved in
DAS/Stage-2/Filter_Unbiased_Images/Filtered_Images/
This file contains the test data and its labels, batches of training and filter images for 3-fold cross-validation.

Train CNN on Augmented datasets

Execute python main.py on the terminal to train VGG-style CNN adopted from https://github.com/soumith/DeepLearningFrameworks/blob/master/Tensorflow_CIFAR.ipynb. The file is available at DAS/Stage-2/Train_CNN/

This will train the CNN on Augmenetd dataset obtained from the filtering stage i.e. Stage - 2 of Data Augmentation Services
The outputs will get saved in DAS/Stage-2/Train_CNN/results folder with name Accuracy_alpha_CategoryA_CategoryB and Pred_Labels_alpha_CategoryA_CategoryB.

If you wish to train the CNN on true training CIFAR dataset then edit the script main.py
Line 5, change from VGG_CNN_CIFAR import VGG to from VGG_CNN_Baseline import VGG

Calculate Performace Measures

Execute the main MATLAB script DAS/Stage-2/Calculate_Bias_Variance/Calculate_Performace.m to obtain the measures of the bias, the variance and, the accuracy of our model after training on the augmenetd dataset.

The script requires path to the directory where results after saved by the CNN model and the path to the directory where the augmeneted data is stored. Both the paramters are at the following Line numbers in the script

Line 5, path to the CNN results directory
Line 6, path to the augmented dataset

About

The repository contains software library for Data Augmentation Services

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published