Skip to content

This repository contains the implementation of the explanation invariance and equivariance metrics, a framework to evaluate the robustness of interpretability methods.

Notifications You must be signed in to change notification settings

JonathanCrabbe/RobustXAI

Repository files navigation

Explanation Invariance and Equivariance

image

Code Author: Jonathan Crabbé ([email protected])

This repository contains the implementation of the explanation invariance and equivariance metrics, a framework to evaluate the robustness of interpretability methods. For more details, please read our paper: 'Evaluating the Robustness of Interpretability Methods through Explanation Invariance and Equivariance'.

1. Installation

From repository:

  1. Clone the repository.
  2. Install Fortran on your machines (this is a prerequisite for the py3nj package).
  3. Install the required packages from the environment.yml file with conda.
conda env create -f environment.yml
  1. Activate the created environment.
conda activate robustxai

When the packages are installed, you are ready to evaluate interpretability methods!

2. Toy example

Bellow, you can find a toy demonstration where we evaluate the explanation equivariance score for Integrated Gradients explanations of a Fashion-MNIST classifier. The relevant code can be found in the folder interpretability.

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import FashionMNIST
from torchvision.transforms import transforms
from models.images import AllCNN
from pathlib import Path
from utils.misc import set_random_seed
from captum.attr import IntegratedGradients
from utils.symmetries import Translation2D
from interpretability.robustness import explanation_equivariance_exact
from interpretability.feature import FeatureImportance
from torch.utils.data import Subset

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
data_dir = Path.cwd()/"datasets/fashion_mnist"
set_random_seed(42)

# Load the data
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Pad(10)]
)
test_set = FashionMNIST(data_dir, train=False, transform=transform, download=True)
small_test_set = Subset(test_set, torch.randperm(len(test_set))[:10])
test_loader = DataLoader(small_test_set, batch_size=10, shuffle=False)

# Load the model
model = AllCNN(latent_dim=50)
model.to(device).eval()

# Define symmetry group
translation = Translation2D(max_dispacement=5)

# Define interpretability method
interpretability_method = FeatureImportance(IntegratedGradients(model))

# Compute equivariance score
explanation_equiv = explanation_equivariance_exact(
        interpretability_method, translation, test_loader, device
    ).mean().item()
print(f'{explanation_equiv=:.2f}')

3. Reproducing the paper results

3.A. ECG Dataset

Our script automatically downloads the ECG dataset from Kaggle. To do so, one has to create a Kaggle token as explained here. Once the token is properly set-up, one can run our experiments with the script

python -m experiments.ecg --name experiment_name --train --plot

where the --train option should only be used one time to fit a model for all the experiments. The experiment_name parameter can take the following values:

experiment_name description
feature_importance Feature importance equivariance Section 3.1
example_importance Example importance invariance Section 3.1
concept_importance Concept-based invariance Section 3.1
enforce_invariance Improving robustness Section 3.2
sensitivity_comparison Relaxing invariance Section 3.2

The resulting plots and data are saved here.

3.B. Mutagenicity Dataset

One can run our experiments with the script

python -m experiments.mut --name experiment_name --train --plot

where the --train option should only be used one time to fit a model for all the experiments. The experiment_name parameter can take the following values:

experiment_name description
feature_importance Feature importance equivariance Section 3.1
example_importance Example importance invariance Section 3.1
concept_importance Concept-based invariance Section 3.1

The resulting plots and data are saved here.

3.C. ModelNet40 Dataset

Our script automatically downloads the ModelNet40 dataset from Kaggle. To do so, one has to create a Kaggle token as explained here. Once the token is properly set-up, one can run our experiments with the script

python -m experiments.mnet --name experiment_name --train --plot

where the --train option should only be used one time to fit a model for all the experiments. The experiment_name parameter can take the following values:

experiment_name description
feature_importance Feature importance equivariance Section 3.1
example_importance Example importance invariance Section 3.1
concept_importance Concept-based invariance Section 3.1

The resulting plots and data are saved here.

3.D. FashionMNIST Dataset

One can run our experiments with the script

python -m experiments.fashion_mnist --name experiment_name --train --plot

where the --train option should only be used one time to fit a model for all the experiments. The experiment_name parameter can take the following values:

experiment_name description
feature_importance Feature importance equivariance Section 3.1
example_importance Example importance invariance Section 3.1
concept_importance Concept-based invariance Section 3.1
enforce_invariance Improving robustness Section 3.2

The resulting plots and data are saved here.

3.E. CIFAR100 Dataset

One can run our experiments with the script

python -m experiments.cifar100 --name experiment_name --train --plot

where the --train option should only be used one time to fit a model for all the experiments. If you have WandB set-up, you can also use the --use_wandb option to log the results to your WandB account. The experiment_name parameter can take the following values:

experiment_name description
feature_importance Feature importance equivariance Section 3.3
example_importance Example importance invariance Section 3.3
concept_importance Concept-based invariance Section 3.3

3.F. STL10 Dataset

One can run our experiments with the script

python -m experiments.stl10 --name experiment_name --train --plot

where the --train option should only be used one time to fit a model for all the experiments. If you have WandB set-up, you can also use the --use_wandb option to log the results to your WandB account. The experiment_name parameter can take the following values:

experiment_name description
feature_importance Feature importance equivariance Section 3.3
example_importance Example importance invariance Section 3.3
concept_importance Concept-based invariance Section 3.3

The resulting plots and data are saved here.

3.G. IMDb Dataset

One can run our experiments with the script

python -m experiments.imdb --name experiment_name --train --plot

where the --train option should only be used one time to fit a model for all the experiments. If you have WandB set-up, you can also use the --use_wandb option to log the results to your WandB account. The experiment_name parameter can take the following values:

experiment_name description
feature_importance Feature importance equivariance Section 3.3
example_importance Example importance invariance Section 3.3
concept_importance Concept-based invariance Section 3.3

The resulting plots and data are saved here.

3.H. Failure Modes

To visualize failure modes of various interpretability methods, one can use the Jupyter notebook here. Note that this assumes that models have been trained for the FashionMNIST and the STL10 datasets. All the parameters from this notebook can be adapted to explore more failure modes.

4. Citing

If you use this code, please cite the associated paper:

@misc{crabbé2023evaluating,
      title={Evaluating the Robustness of Interpretability Methods through Explanation Invariance and Equivariance}, 
      author={Jonathan Crabbé and Mihaela van der Schaar},
      year={2023},
      eprint={2304.06715},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

About

This repository contains the implementation of the explanation invariance and equivariance metrics, a framework to evaluate the robustness of interpretability methods.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published