Skip to content

Official code for the paper, "This Probably Looks Exactly Like That: An Invertible Prototypical Network"

License

Notifications You must be signed in to change notification settings

craymichael/ProtoFlow

Repository files navigation

ProtoFlow: An Invertible Prototypical Neural Network

This repository contains code for the paper, "This Probably Looks Exactly Like That: An Invertible Prototypical Network," which was accepted to ECCV 2024. The proposed architecture, ProtoFlow, represents prototypical distributions as Gaussians in the latent space of a normalizing flow. The approach enables rich interpretation, effective uncertainty estimation, and a promising research path forward for intrinsically interpretable neural networks.

Abstract

We combine concept-based neural networks with generative, flow-based classifiers into a novel, intrinsically explainable, exactly invertible approach to supervised learning. Prototypical neural networks, a type of concept-based neural network, represent an exciting way forward in realizing human-comprehensible machine learning without concept annotations, but a human-machine semantic gap continues to haunt current approaches. We find that reliance on indirect interpretation functions for prototypical explanations imposes a severe limit on prototypes' informative power. From this, we posit that invertibly learning prototypes as distributions over the latent space provides more robust, expressive, and interpretable modeling. We propose one such model, called ProtoFlow, by composing a normalizing flow with Gaussian mixture models. ProtoFlow (1) sets a new state-of-the-art in joint generative and predictive modeling and (2) achieves predictive performance comparable to existing prototypical neural networks while enabling richer interpretation.

Installation

Install the requirements in requirements.txt as follows:

pip install requirements.txt

Alternatively, the exact environment that was used in this research can be reproduced using conda. After installing conda, create a new environment using the provided environment.yml:

conda env create -f environment.yml

Training

To train an instance of ProtoFlow, the train.py script should be used. Run python train.py --help for usage details. You will want the DenseFlow pretrained checkpoints to initialize the model, which can be downloaded following the instructions here. For example, ProtoFlow can be trained on CIFAR-10 as follows:

# Optionally enter your dataset root here
#export DATASET_ROOT='/mnt/data/ml_datasets/'
python train.py --flow_ckpt checkpoints/denseflow/imn32/imagenet32/ \
  --img_size 32 \
  --dataset cifar10 \
  --extra my_test_run \
  -e 10 \
  --batch_steps 32 \
  --batch_size 256 \
  --trainable all \
  --lr 2e-4 \
  --gmm_lr 2e-3 \
  --consistency_loss \
  --protos_per_class 5 \
  --elbo_loss2

To run using PyTorch DDP (distributed/parallel training), you can use the following:

torchrun --nproc_per_node=2 train.py ...

Testing

To train an instance of ProtoFlow, the test.py script should be used. Run python test.py --help for usage details. A pre-trained model can be downloaded (see proceeding section) and be evaluated using this script. For example:

# Optionally enter your dataset root here
#export DATASET_ROOT='/mnt/data/ml_datasets/'
python test.py --resume checkpoints/cifar10/checkpoint.pt \
  --tta \
  --tta_num 5 \
  --num_samples 5 \
  --proto_scores

If you run out of GPU VRAM, adjust the --batch_size.

Pre-trained Models

All checkpoints and configurations (including hyperparameters) for trained models are available here.

License

This repository is distributed under the GNU GPL v2.0 License.

This repository contains code from the following projects:

Citation

@inproceedings{protoflowECCV2024,
    author    = {Carmichael, Zachariah and
                 Redgrave, Timothy and
                 Gonzalez Cedre, Daniel and
                 Scheirer, Walter J.},
    title     = {This Probably Looks Exactly Like That: An Invertible Prototypical Network},
    booktitle = {European Conference on Computer Vision},
    year      = {2024},
    publisher = {Springer Nature},
}

obligatory miata

miat

About

Official code for the paper, "This Probably Looks Exactly Like That: An Invertible Prototypical Network"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages