diff --git a/explainerdashboard/explainers.py b/explainerdashboard/explainers.py index cbc50a2..5c6bbda 100644 --- a/explainerdashboard/explainers.py +++ b/explainerdashboard/explainers.py @@ -321,12 +321,12 @@ def __init__( "sklearn-compatible NeuralNet wrapper are supported for now! " "See https://github.com/skorch-dev/skorch" ) - assert shap in ["tree", "linear", "deep", "kernel", "skorch"], ( - "ERROR! Only shap='guess', 'tree', 'linear', ' kernel' or 'skorch' are " - " supported for now!" + assert shap in ["tree", "linear", "deep", "kernel", "skorch", "gputree"], ( + "ERROR! Only shap='guess', 'tree', 'linear', ' kernel', 'skorch' " + "or 'gputree' are supported for now!" ) self.shap = shap - if self.shap in {"kernel", "skorch", "linear"}: + if self.shap in {"kernel", "skorch", "linear", "gputree"}: print( f"WARNING: For shap='{self.shap}', shap interaction values can unfortunately " "not be calculated!" @@ -1123,6 +1123,13 @@ def model_predict(data_asarray): if self.X_background is not None else shap.sample(self.X, 50), ) + elif self.shap == "gputree": + print( + "Generating self.shap_explainer = shap.explainer.GPUTree(model, X)." + "Make sure you have a cuda enabled GPU and followed installation" + "instructions at https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/explainers/GPUTree.html#" # noqa: E501 + ) + self._shap_explainer = shap.explainers.GPUTree(self.model, self.X) return self._shap_explainer @insert_pos_label