Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add gputree support #291

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions explainerdashboard/explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,12 @@ def __init__(
"sklearn-compatible NeuralNet wrapper are supported for now! "
"See https://github.com/skorch-dev/skorch"
)
assert shap in ["tree", "linear", "deep", "kernel", "skorch"], (
"ERROR! Only shap='guess', 'tree', 'linear', ' kernel' or 'skorch' are "
" supported for now!"
assert shap in ["tree", "linear", "deep", "kernel", "skorch", "gputree"], (
"ERROR! Only shap='guess', 'tree', 'linear', ' kernel', 'skorch' "
"or 'gputree' are supported for now!"
)
self.shap = shap
if self.shap in {"kernel", "skorch", "linear"}:
if self.shap in {"kernel", "skorch", "linear", "gputree"}:
print(
f"WARNING: For shap='{self.shap}', shap interaction values can unfortunately "
"not be calculated!"
Expand Down Expand Up @@ -1123,6 +1123,13 @@ def model_predict(data_asarray):
if self.X_background is not None
else shap.sample(self.X, 50),
)
elif self.shap == "gputree":
print(
"Generating self.shap_explainer = shap.explainer.GPUTree(model, X)."
"Make sure you have a cuda enabled GPU and followed installation"
"instructions at https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/explainers/GPUTree.html#" # noqa: E501
)
self._shap_explainer = shap.explainers.GPUTree(self.model, self.X)
return self._shap_explainer

@insert_pos_label
Expand Down