diff --git a/python_scripts/trees_sol_01.py b/python_scripts/trees_sol_01.py index e97b7e8b2..15eb35204 100644 --- a/python_scripts/trees_sol_01.py +++ b/python_scripts/trees_sol_01.py @@ -150,46 +150,38 @@ # %% tags=["solution"] import numpy as np +from matplotlib import cm -xx = np.linspace(30, 60, 100) -yy = np.linspace(10, 23, 100) -xx, yy = np.meshgrid(xx, yy) -Xfull = pd.DataFrame( - {"Culmen Length (mm)": xx.ravel(), "Culmen Depth (mm)": yy.ravel()} -) - -probas = tree.predict_proba(Xfull) -n_classes = len(np.unique(tree.classes_)) - +classes = np.unique(tree.classes_) _, axs = plt.subplots(ncols=3, nrows=1, sharey=True, figsize=(12, 5)) -plt.suptitle("Predicted probabilities for decision tree model", y=0.8) - -for class_of_interest in range(n_classes): - axs[class_of_interest].set_title( - f"Class {tree.classes_[class_of_interest]}" - ) - imshow_handle = axs[class_of_interest].imshow( - probas[:, class_of_interest].reshape((100, 100)), - extent=(30, 60, 10, 23), - vmin=0.0, - vmax=1.0, - origin="lower", - cmap="viridis", +plt.suptitle("Predicted probabilities for decision tree model", y=1.05) +plt.subplots_adjust(bottom=0.45) + +for idx, class_of_interest in enumerate(classes): + axs[idx].set_title(f"Class {class_of_interest}") + disp = DecisionBoundaryDisplay.from_estimator( + tree, + data_test, + response_method="predict_proba", + class_of_interest=class_of_interest, + ax=axs[idx], + vmin=0, + vmax=1, ) - axs[class_of_interest].set_xlabel("Culmen Length (mm)") - if class_of_interest == 0: - axs[class_of_interest].set_ylabel("Culmen Depth (mm)") - idx = target_test == tree.classes_[class_of_interest] - axs[class_of_interest].scatter( - data_test["Culmen Length (mm)"].loc[idx], - data_test["Culmen Depth (mm)"].loc[idx], + axs[idx].scatter( + data_test["Culmen Length (mm)"].loc[target_test == class_of_interest], + data_test["Culmen Depth (mm)"].loc[target_test == class_of_interest], marker="o", c="w", edgecolor="k", ) + axs[idx].set_xlabel("Culmen Length (mm)") + axs[idx].set_ylabel("Culmen Depth (mm)" if idx == 0 else None) -ax = plt.axes([0.15, 0.04, 0.7, 0.05]) -plt.colorbar(imshow_handle, cax=ax, orientation="horizontal") +ax = plt.axes([0.15, 0.14, 0.7, 0.05]) +plt.colorbar( + cm.ScalarMappable(cmap="viridis"), cax=ax, orientation="horizontal" +) _ = plt.title("Probability") # %% [markdown] tags=["solution"] @@ -201,9 +193,9 @@ # the certainty. # ``` # -# In future versions of scikit-learn `DecisionBoundaryDisplay` will support a -# `class_of_interest` parameter that will allow in particular for a -# visualization of `predict_proba` in multi-class settings. +# Since scikit-learn v1.4, `DecisionBoundaryDisplay` supports a +# `class_of_interest` parameter that allows in particular for a visualization of +# `predict_proba` in multi-class settings. # # We also plan to make it possible to visualize the `predict_proba` values for # the class with the maximum predicted probability (without having to pass a