Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds get_variable_inclusion function #214

Merged
merged 2 commits into from
Dec 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ methods in the current release of PyMC-BART.
=============================

.. automodule:: pymc_bart
:members: BART, PGBART, plot_pdp, plot_ice, plot_variable_importance, plot_convergence, ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule
:members: BART, PGBART, compute_variable_importance, get_variable_inclusion, plot_convergence, plot_ice, plot_pdp, plot_scatter_submodels, plot_variable_importance, plot_variable_inclusion, ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule
2 changes: 2 additions & 0 deletions pymc_bart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pymc_bart.split_rules import ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule
from pymc_bart.utils import (
compute_variable_importance,
get_variable_inclusion,
plot_convergence,
plot_ice,
plot_pdp,
Expand All @@ -33,6 +34,7 @@
"OneHotSplitRule",
"SubsetSplitRule",
"compute_variable_importance",
"get_variable_inclusion",
"plot_convergence",
"plot_ice",
"plot_pdp",
Expand Down
68 changes: 50 additions & 18 deletions pymc_bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,50 @@ def _smooth_mean(
return x_data, y_data


def get_variable_inclusion(idata, X, labels=None, to_kulprit=False):
"""
Get the normalized variable inclusion from BART model.

Parameters
----------
idata : InferenceData
InferenceData containing a collection of BART_trees in sample_stats group
X : npt.NDArray
The covariate matrix.
labels : Optional[list[str]]
List of the names of the covariates. If X is a DataFrame the names of the covariables will
be taken from it and this argument will be ignored.
to_kulprit : bool
If True, the function will return a list of list with the variables names.
This list can be passed as a path to Kulprit's project method. Defaults to False.
Returns
-------
VI_norm : npt.NDArray
Normalized variable inclusion.
labels : list[str]
List of the names of the covariates.
"""
VIs = idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values
VI_norm = VIs / VIs.sum()
idxs = np.argsort(VI_norm)

indices = idxs[::-1]
n_vars = len(indices)

if hasattr(X, "columns") and hasattr(X, "to_numpy"):
labels = X.columns

if labels is None:
labels = np.arange(n_vars).astype(str)

label_list = labels.to_list()

if to_kulprit:
return [label_list[:idx] for idx in range(n_vars)]
else:
return VI_norm[indices], label_list


def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=None, ax=None):
"""
Plot normalized variable inclusion from BART model.
Expand Down Expand Up @@ -720,26 +764,15 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non

Returns
-------
idxs: indexes of the covariates from higher to lower relative importance
axes: matplotlib axes
"""
if plot_kwargs is None:
plot_kwargs = {}

VIs = idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values
VIs = VIs / VIs.sum()
idxs = np.argsort(VIs)

indices = idxs[::-1]
n_vars = len(indices)

if hasattr(X, "columns") and hasattr(X, "to_numpy"):
labels = X.columns
VI_norm, labels = get_variable_inclusion(idata, X, labels)
n_vars = len(labels)

if labels is None:
labels = np.arange(n_vars).astype(str)

new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])]
new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)]

ticks = np.arange(n_vars, dtype=int)

Expand All @@ -749,19 +782,18 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non
if ax is None:
_, ax = plt.subplots(1, 1, figsize=figsize)

ax.axhline(1 / n_vars, color="0.5", linestyle="--")
ax.plot(
VIs[indices],
VI_norm,
color=plot_kwargs.get("color", "k"),
marker=plot_kwargs.get("marker", "o"),
ls=plot_kwargs.get("ls", "-"),
)

ax.set_xticks(ticks, new_labels, rotation=plot_kwargs.get("rotation", 0))

ax.axhline(1 / n_vars, color="0.5", linestyle="--")
ax.set_ylim(0, 1)

return idxs, ax
return ax


def compute_variable_importance( # noqa: PLR0915 PLR0912
Expand Down
Loading