Skip to content

Latest commit

 

History

History
142 lines (118 loc) · 22.9 KB

overview.md

File metadata and controls

142 lines (118 loc) · 22.9 KB

Overview of Code Structure

To help users better understand and use our codebase, we briefly overview the functionality and implementation of each package and each module. Please see the documentation in each file for more details.

Scripts

  • assemble.py is an auxiliary script to assemble two MUNIT generators to a new one for inference. Specificly, it will export a new MUNIT generator which consists of the content encoder and the decoder in generator A and the style encoder in generator B.
  • distill.py is a script for distillation. The distiller supports ResNet and SPADE models (with option --distiller: e.g. spade, resnet). You could specify the teacher model with options --teacher_netG and --teacher_ngf, and load the pretrained teacher weight with --restore_teacher_G_path. Similarly, You could specify the student model with options --student_netG and --student_ngf, and load the pretrained student weight with --restore_student_G_path. We also support pruning before distillation. You just need to specify the model you would like to prune with options --pretrained_netG and --pretrained_ngf, and load the weight with restore_pretrain_G_path.
  • evolution_search.py is a script for evolution searching. Once you have get your supernet weight, you can use this script to search for the best performed subnet. It will load a saved supernet model from --resotre_G_path and save the searching results to --output_dir.
  • export.py is an auxiliary script to extract a specific subnet for a supernet and export it. You need specify the supernet model with--model and --ngf and the model weight with --input_path. To extract the specific subnet, you need to provide the subnet configuration with --config_str and the exported model will be saved to --output_path.
  • get_real_stat.py is an auxiliary script to get the statistical information of the ground-truth images to compute FID. You need to specify the dataset with options --dataroot, --dataset_mode and the direction you would like to train with --direction.
  • latency.py is a general-purpose test script to measure the latency of the models. The usage is almost the same as test.py.
  • merge.py is an auxiliary script to merge multiple searching results. It is usually used in manually-split parallel searching.
  • remove_spectral_norm.py is an auxiliary script to remove the spectral normalization of the GauGAN model.
  • search.py is a script for evaluating all candidate subnets. Once you have get your supernet weight, you can use this script to evaluate the performance of candidate subnets. It will load a saved supernet model from --resotre_G_path and save the evaluation results to --output_path. See the our training tutorials of Fast GAN Compression and GAN Compression for more details.
  • select_arch.py is an auxiliary script to parse the output pickle by the search.py and select the architecture configurations you want.
  • test.py is a general-purpose test script. Once you have get your model weight, you can use this script to test your model. It will load a saved model from --restore_G_path and save the results to --results_dir.
  • train.py is a general-purpose original model training script. It works for various models (with option --model: e.g., pix2pix, cycle_gan) and different datasets (with option --dataset_mode: e.g., aligned, unaligned). See the our training tutorials of Fast GAN Compression and GAN Compression for more details.
  • train_supernet.py is a script for the "once-for-all" network training and finetuning. The "once-for-all" network supports ResNet and SPADE models (with option --supernet: e.g. spade, resnet). Like distillation, you could specify the teacher model with options --teacher_netG and --teacher_ngf, and load the pre-trained teacher weight with --restore_teacher_G_path. Similarly, You could specify the student model with options --student_netG and --student_ngf, and load the pre-trained student weight with --restore_student_G_path. Moreover, you need to specify the candidate subnet set with option --config_set when training a supernet. When we are fine-tuning a specific subnet, you need to specify the chosen subnet configuration with option --config_str.
  • trainer.py is a module that implements the training logic for train.py, distill.py and train_supernet.py.

Directories

configs directory contains modules related to the search space configuration used in "once-for-all" network training.

  • __init__.py contains an encoding and a decoding function of the configuration description string.
  • channel_configs.py a module that implements a configuration search space class that used in training "once-for-all" network.
  • munit_configs.py is a module that defines some search space of the MUNIT "once-for-all" network.
  • resnet_configs.py is a module that defines some search space of the ResNet-based "once-for-all" network.
  • single_configs.py is a module that implements a configuration set class that only contains a single configuration. Usually, it is used in fine-tuning and testing.
  • spade_configs.py is a module that defines some search space of the SPADE-based "once-for-all" network.

data directory contains all the modules related to data loading and preprocessing. To add a custom dataset class called dummy, you need to add a file called dummy_dataset.py and define a subclass DummyDataset inherited from BaseDataset. You need to implement four functions: __init__ (initialize the class, you need to first call BaseDataset.__init__(self, opt)), __len__ (return the size of dataset), __getitem__ (get a data point), and optionally modify_commandline_options (add dataset-specific options and set default options). Now you can use the dataset class by specifying flag --dataset_mode dummy. Below we explain each file in details.

  • __init__.py implements the interface between this package and training and test scripts. Other scripts will call dataset = create_dataset(opt) to create a dataset for training given the option opt. They can also call dataset = create_eval_dataset(opt) to create a dataset for evaluation given the option opt.
  • aligned_dataset.py includes a dataset class that can load image pairs for pix2pix. It assumes a single image directory /path/to/data/train, which contains image pairs in the form of {A,B}. See here on how to prepare aligned datasets. During test time, you need to prepare a directory /path/to/data/val as test data.
  • base_dataset.py implements an abstract base class (ABC) for datasets. It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
  • cityscapes_dataset.py includes a dataset class that can load cityscapes datasets for GauGAN.
  • coco_dataset.py includes a dataset class that can load Coco-Stuff datasets for GauGAN.
  • image_folder.py implements an image folder class. We modify the official PyTorch image folder code so that this class can load images from both the current directory and its subdirectories.
  • single_dataset.py includes a dataset class that can load a set of single images specified by the path --dataroot /path/to/data. It can be used for generating CycleGAN results only for one side with the model option --model test.
  • spade_dataset.py implements an abstract base class for the datasets for GauGAN model.
  • unaligned_dataset.py includes a dataset class that can load unaligned/unpaired datasets. It assumes that two directories to host training images from domain A /path/to/data/trainA and from domain B /path/to/data/trainB respectively. Then you can train the model with the dataset flag --dataroot /path/to/data. Similarly, you need to prepare two directories /path/to/data/testA and /path/to/data/testB during test time.

datasets directory contains some scripts to prepare the datasets you will use.

distillers directory contains modules related to distillation for different model architectures.

  • __init__.py implements the interface between this package and distill scripts. distill.py calls from distillers import create_distiller and distiller = create_distiller(opt) to create a distiller given the option opt. You also need to call distiller.setup(opt) to properly initialize the model.
  • base_munit_distiller.py implements an abstract base class for the distiller for MUNIT architectures. It also includes commonly used helper functions for intermediate distillation, which can be later used in subclasses.
  • base_resnet_distiller.py implements an abstract base class for the distiller for ResNet architectures. It also includes commonly used helper functions for intermediate distillation, which can be later used in subclasses.
  • base_spade_distiller.py implements an abstract base class for the distiller for SPADE architectures. It also includes commonly used helper functions for intermediate distillation, which can be later used in subclasses.
  • munit_distiller.py is a subclass of base_munit_distiller.py implements an class for the distiller of MUNIT architectures.
  • resnet_distiller.py is a subclass of base_resnet_distiller.py implements an class for the distiller of ResNet architectures.
  • spade_distiller.py is a subclass of base_spade_distiller.py implements an class for the distiller of SPADE architectures.

metric directory contains modules related to evaluation metric.

models directory contains modules related to original model training, testing and network architectures.

supernets directory contains modules related to "once-for-all" network training and fine-tuning for different model architectures (currently only the ResNet architecture).

  • __init__.py implements the interface between this package and "once-for-all" network training scripts. train_supernet.py calls from supernets import create_supernet and supernet = create_supernet(opt) to create a "once-for-all" network given the option opt. You also need to call supernet.setup(opt) to properly initialize the "once-for-all" network.
  • resnet_supernet.py is a subclass of base_resnet_distiller.py implements an class for the "once-for-all" training for ResNet-based architectures.
  • spade_supernet.py is a subclass of base_spade_distiller.py implements an class for the "once-for-all" training for SPADE-based architectures.
  • munit_supernet.py is a subclass of base_munit_distiller.py implements an class for the "once-for-all" training for SPADE-based architectures.

options directory includes our option modules: training options, distill options, search options, "once-for-all" training options, test options, and basic options (the base class of all other options).

  • base_options.py includes options that are used in both training and test. It also implements a few helper functions such as parsing, printing, and saving the options. It also gathers additional options defined in modify_commandline_options functions in both dataset class and model class.
  • distill_options.py includes options that are only used during distillation.
  • evolution_options.py includes options that are only used during evolution search.
  • search_options.py includes options that are only used during search.
  • supernet_options.py includes options that are only used during "once-for-all" network training and fine-tuning.
  • test_options.py includes options that are only used during test time.
  • train_options.py includes options that are only used during the original model training time.

util directory includes a miscellaneous collection of useful helper functions.

  • html.py implements a module that saves images into a single HTML file. It consists of functions such as add_header (add a text header to the HTML file), add_images (add a row of images to the HTML file), save (save the HTML to the disk). It is based on Python library dominate, a Python library for creating and manipulating HTML documents using a DOM API.
  • image_pool.py implements an image buffer that stores previously generated images. This buffer enables us to update discriminators using a history of generated images rather than the ones produced by the latest generators. The original idea was discussed in this paper. The size of the buffer is controlled by the flag --pool_size.
  • logger.py provides a class for logging the training information. It also implements interfaces to the tensorboard.
  • util.py consists of simple helper functions such as tensor2im (convert a tensor array to a numpy image array) and load_network (load a network from a specific checkpoint).
  • weight_transfer.py implements a function to transfer the weights of teacher network to a small student network of the same architecture. It functions as pruning.