FastSHAP is an amortized approach for calculating Shapley value estimates for many examples. It involves training an explainer model to output Shapley value estimates in a single forward pass, using an objective function inspired by KernelSHAP.
FastSHAP was introduced in this paper [1]. This repository provides a PyTorch implementation, and see here for a similar TensorFlow implementation (also provided by the paper authors).
The easiest way to get started is to clone this repository and install the package in your Python environment:
pip install .
We often want to explain how a model makes its predictions. Given a convention for removing subsets of features (e.g., marginalizing them out), one of the most popular ways of generating explanations is Shapley values. Shapley values essentially consider every possible subset of features, as well as the corresponding predictions when different features are made available to the model, and they summarize the contribution that each feature makes to the prediction. (See this blog post for an introduction to Shapley values.)
The challenge with Shapley values is that they are computationally costly to calculate. Several approximations exist, such as KernelSHAP for the model-agnostic case and TreeSHAP for the case of tree-based models. FastSHAP provides a different kind of approximation: rather than running an algorithm separately for each example to be explained, FastSHAP trains a model (a neural network) to output Shapley value estimates in a single forward pass.
To run FastSHAP, you need the following ingredients:
- A predictive model to explain. The original model can be anything, including a simple model for tabular data (e.g., XGBoost) or a large CNN for image data (e.g., ResNet).
- A convention for holding out features. Shapley values require holding out different groups of features to observe how the predictions change (see [2] for a discussion of many approaches). Our experiments typically use a surrogate model trained to replicate the original model's predictions; this approach was introduced by [3] and is also discussed in [2, 4]. An alternative approach is to train a model that directly accommodates missing features (discussed in [2] and demonstrated in this notebook).
- An explainer model architecture. The explanations are generated by a model (a neural network) rather than a Monte Carlo estimation algorithm (e.g., KernelSHAP), so an appropriate architecture must be chosen. The explainer should generally be a MLP for tabular data and a CNN for image data, and the output must be the same size as the input. Our experiments use an explainer that outputs explanations for all classes simultaneously, but the class variable can alternatively be provided as a model input.
For more details about the method, see the paper [1].
For usage examples, see the following notebooks:
- Census: this notebook shows how to train FastSHAP using the census income (or adult) dataset, a common tabular dataset. The original model is LGBM, and explanations are generated by first training a surrogate model (MLP), and then training a FastSHAP explainer model (MLP).
- CIFAR-10: this notebook shows how to train FastSHAP for image models using CIFAR-10. The original model are the surrogate are both ResNet18's, and the FastSHAP explainer model is a UNet (because its outputs must be image-sized). Similarly, this notebook shows how to train FastSHAP using a single model trained to accommodate missing features.
- Neil Jethani ([email protected])
- Mukund Sudarshan ([email protected])
- Ian Covert ([email protected])
- Su-In Lee
- Rajesh Ranganath
[1] Neil Jethani*, Mukund Sudarshan*, Ian Covert*, Su-In Lee, Rajesh Ranganath. "FastSHAP: Real-Time Shapley Value Estimation."
[2] Ian Covert, Scott Lundberg, Su-In Lee. "Explaining by Removing: A Unified Framework For Model Explanation." arXiv preprint:2011.14878
[3] Christopher Frye, Damien de Mijolla, Tom Begley, Laurence Cowton, Megan Stanley, Ilya Feige. "Shapley Explainability on the Data Manifold." ICLR 2021
[4] Neil Jethani, Mukund Sudarshan, Yindalon Aphinyanaphongs, Rajesh Ranganath. "Have We Learned to Explain?: How Interpretability Methods Can Learn to Encode Predictions in their Interpretations." AISTATS 2021