Skip to content

Latest commit

 

History

History
37 lines (30 loc) · 2.62 KB

File metadata and controls

37 lines (30 loc) · 2.62 KB

DRESS: Disentangled Representation-based Self-Supervised Meta-Learning for Diverse Tasks [arXiv]

Authors: Wei Cui, Tongzi Wu, Jesse C. Cresswell, Yi Sui, Keyvan Golestan

Summary

This repository contains the official implementation of the paper DRESS: Disentangled Representation-based Self-Supervised Meta-Learning for Diverse Tasks. It includes both training and evaluation code.

Repository Structure

The code files within the repository are organized as follows:

  • main.py: the main entrance point of the program.
  • partition_generators.py: implementation of generating supervised and self-supervised partitions on each dataset.
  • task_generator.py: implementation of generating few-shot learning tasks from any given partition.
  • utils.py: implementation of helper functions.

The sub-folders within the repository are as follows:

  • scripts/: the folder including the scripts to train, evaluate, and obtain visulizations.
  • encoders/: the folder containing classes of encoders for obtaining the latent spaces.
  • dataset_loaders/: the folder containing scripts for loading each of the dataset for experiments.
  • baselines/: the folder containing implementations of baseline methods.
  • analyze_results/: the folder containing scripts for post-processing results.
  • visualization_results/: the folder containing visualizations on constructed tasks via DRESS.

Dataset

Create a folder named data/ under the main directory to house the raw data. The datasets experimented are loaded from their respective dataset loader script under dataset_loaders/. The source data preparations are as follows:

  • smallNORB: automatically downloaded within our script via the tensorflow_datasets package.
  • shapes3D: download 3dshapes.h5 from Google Cloud Storage and place it under data/shapes3d/.
  • causal3D: download trainset.tar.gz and testset.tar.gz from the dataset homepage and extract them under data/causal3d/train/ and data/causal3d/test/ resectively.
  • MPI3D: download mpi3d_toy.npz from this link and place it under data/mpi3d/.
  • CelebA: automatically downloaded within our script via the torchvision package.

Running Environment

Simply install an anaconda environment using the environment.yml file under this repository.