Skip to content

Commit

Permalink
improve variable importance computation add backward method (#125)
Browse files Browse the repository at this point in the history
* improve variable importance

* fix tests
  • Loading branch information
aloctavodia authored Nov 18, 2023
1 parent 2a8b12d commit d7bbfb4
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 82 deletions.
5 changes: 2 additions & 3 deletions pymc_bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,8 @@ class BART(Distribution):
Controls the prior probability over the number of leaves of the trees.
Should be positive.
split_prior : Optional[List[float]], default None.
Each element of split_prior should be in the [0, 1] interval and the elements should sum to
1. Otherwise they will be normalized.
Defaults to 0, i.e. all covariates have the same prior probability to be selected.
List of positive numbers, one per column in input data.
Defaults to None, all covariates have the same prior probability to be selected.
split_rules : Optional[List[SplitRule]], default None
List of SplitRule objects, one per column in input data.
Allows using different split rules for different columns. Default is ContinuousSplitRule.
Expand Down
5 changes: 5 additions & 0 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ def __init__(
else:
self.split_rules = [ContinuousSplitRule] * self.X.shape[1]

jittered = np.random.normal(self.X, self.X.std(axis=0) / 12)
min_values = np.min(self.X, axis=0)
max_values = np.max(self.X, axis=0)
self.X = np.clip(jittered, min_values, max_values)

init_mean = self.bart.Y.mean()
self.num_observations = self.X.shape[0]
self.num_variates = self.X.shape[1]
Expand Down
189 changes: 121 additions & 68 deletions pymc_bart/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utility function for variable selection and bart interpretability."""

from itertools import combinations
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

Expand All @@ -12,7 +13,6 @@
from scipy.interpolate import griddata
from scipy.signal import savgol_filter
from scipy.stats import norm, pearsonr
from xarray import concat

from .tree import Tree

Expand Down Expand Up @@ -71,7 +71,6 @@ def _sample_posterior(
for tree in odim_trees:
p[odim] += tree.predict(x=X, excluded=excluded, shape=leaves_shape)

# pred.reshape((*size_iter, shape, -1))
return pred.transpose((0, 3, 1, 2)).reshape((*size_iter, -1, shape))


Expand Down Expand Up @@ -714,11 +713,12 @@ def plot_variable_importance(
bartrv: Variable,
X: npt.NDArray[np.float_],
labels: Optional[List[str]] = None,
sort_vars: bool = True,
method: str = "VI",
figsize: Optional[Tuple[float, float]] = None,
xlabel_angle: float = 0,
samples: int = 100,
random_seed: Optional[int] = None,
) -> Tuple[npt.NDArray[np.int_], List[plt.Axes]]:
) -> Tuple[List[int], List[plt.Axes]]:
"""
Estimates variable importance from the BART-posterior.
Expand All @@ -733,10 +733,17 @@ def plot_variable_importance(
labels : Optional[List[str]]
List of the names of the covariates. If X is a DataFrame the names of the covariables will
be taken from it and this argument will be ignored.
sort_vars : bool
Whether to sort the variables according to their variable importance. Defaults to True.
method : str
Method used to rank variables. Available options are "VI" (default) and "backward".
The R squared will be computed following this ranking.
"VI" counts how many times each variable is included in the posterior distribution
of trees. "backward" uses a backward search based on the R squared.
VI requieres less computation time.
figsize : tuple
Figure size. If None it will be defined automatically.
xlabel_angle : float
rotation angle of the x-axis labels. Defaults to 0. Use values like 45 for
long labels and/or many variables.
samples : int
Number of predictions used to compute correlation for subsets of variables. Defaults to 100
random_seed : Optional[int]
Expand All @@ -747,7 +754,9 @@ def plot_variable_importance(
idxs: indexes of the covariates from higher to lower relative importance
axes: matplotlib axes
"""
_, axes = plt.subplots(2, 1, figsize=figsize)
rng = np.random.default_rng(random_seed)

all_trees = bartrv.owner.op.all_trees

if bartrv.ndim == 1: # type: ignore
shape = 1
Expand All @@ -758,80 +767,124 @@ def plot_variable_importance(
labels = X.columns
X = X.values

n_draws = idata["posterior"].dims["draw"]
half = n_draws // 2
f_half = idata["sample_stats"]["variable_inclusion"].sel(draw=slice(0, half - 1))
s_half = idata["sample_stats"]["variable_inclusion"].sel(draw=slice(half, n_draws))
n_vars = X.shape[1]

if figsize is None:
figsize = (8, 3)

_, ax = plt.subplots(1, 1, figsize=figsize)

var_imp_chains = concat([f_half, s_half], dim="chain", join="override").mean(("draw")).values
var_imp = idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values
if labels is None:
labels_ary = np.arange(len(var_imp))
labels_ary = np.arange(n_vars).astype(str)
else:
labels_ary = np.array(labels)

rng = np.random.default_rng(random_seed)
ticks = np.arange(n_vars, dtype=int)

ticks = np.arange(len(var_imp), dtype=int)
idxs = np.argsort(var_imp)
subsets = [idxs[:-i].tolist() for i in range(1, len(idxs))]
subsets.append(None) # type: ignore
predicted_all = _sample_posterior(
all_trees, X=X, rng=rng, size=samples, excluded=None, shape=shape
)

if sort_vars:
indices = idxs[::-1]
else:
indices = np.arange(len(var_imp))
if method == "VI":
idxs = np.argsort(
idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values
)
subsets = [idxs[:-i].tolist() for i in range(1, len(idxs))]
subsets.append(None) # type: ignore

indices: List[int] = list(idxs[::-1])

r2_mean = np.zeros(n_vars)
r2_hdi = np.zeros((n_vars, 2))
for idx, subset in enumerate(subsets):
predicted_subset = _sample_posterior(
all_trees=all_trees,
X=X,
rng=rng,
size=samples,
excluded=subset,
shape=shape,
)
pearson = np.zeros(samples)
for j in range(samples):
pearson[j] = (
pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0]
) ** 2
r2_mean[idx] = np.mean(pearson)
r2_hdi[idx] = az.hdi(pearson)

elif method == "backward":
r2_mean = np.zeros(n_vars)
r2_hdi = np.zeros((n_vars, 2))

variables = set(range(n_vars))
excluded: List[int] = []
indices = []

for i_var in range(0, n_vars):
subsets = _generate_combinations(variables, excluded)
max_pearson = -np.inf
for subset in subsets:
predicted_subset = _sample_posterior(
all_trees=all_trees,
X=X,
rng=rng,
size=samples,
excluded=subset,
shape=shape,
)
pearson = np.zeros(samples)
for j in range(samples):
pearson[j] = (
pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0]
) ** 2
mean_pearson = np.mean(pearson, dtype=float)
if mean_pearson > max_pearson:
max_pearson = mean_pearson
best_subset = subset
best_pearson = pearson

chains_mean = (var_imp / var_imp.sum())[indices]
chains_hdi = az.hdi((var_imp_chains.T / var_imp_chains.sum(axis=1)).T)[indices]
r2_mean[i_var] = max_pearson
r2_hdi[i_var] = az.hdi(best_pearson)

axes[0].errorbar(
indices.extend((set(best_subset) - set(indices)))

excluded.append(best_subset)

indices.extend((set(variables) - set(indices)))

indices = indices[::-1]
r2_mean = r2_mean[::-1]
r2_hdi = r2_hdi[::-1]

new_labels = [
"+ " + ele if index != 0 else ele for index, ele in enumerate(labels_ary[indices])
]

r2_yerr_min = np.clip(r2_mean - r2_hdi[:, 0], 0, None)
r2_yerr_max = np.clip(r2_hdi[:, 1] - r2_mean, 0, None)
ax.errorbar(
ticks,
chains_mean,
np.array((chains_mean - chains_hdi[:, 0], chains_hdi[:, 1] - chains_mean)),
r2_mean,
np.array((r2_yerr_min, r2_yerr_max)),
color="C0",
)
axes[0].set_xticks(ticks)
axes[0].set_xticklabels(labels_ary[indices])
axes[0].set_xlabel("covariables")
axes[0].set_ylabel("importance")

all_trees = bartrv.owner.op.all_trees
ax.axhline(r2_mean[-1], ls="--", color="0.5")
ax.set_xticks(ticks, new_labels, rotation=xlabel_angle)
ax.set_ylabel("R²", rotation=0, labelpad=12)
ax.set_ylim(0, 1)
ax.set_xlim(-0.5, n_vars - 0.5)

predicted_all = _sample_posterior(
all_trees, X=X, rng=rng, size=samples, excluded=None, shape=shape
)
return indices, ax

ev_mean = np.zeros(len(var_imp))
ev_hdi = np.zeros((len(var_imp), 2))
for idx, subset in enumerate(subsets):
predicted_subset = _sample_posterior(
all_trees=all_trees,
X=X,
rng=rng,
size=samples,
excluded=subset,
shape=shape,
)
pearson = np.zeros(samples)
for j in range(samples):
pearson[j] = (
pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0]
) ** 2
ev_mean[idx] = np.mean(pearson)
ev_hdi[idx] = az.hdi(pearson)

axes[1].errorbar(
ticks, ev_mean, np.array((ev_mean - ev_hdi[:, 0], ev_hdi[:, 1] - ev_mean)), color="C0"
)
axes[1].axhline(ev_mean[-1], ls="--", color="0.5")
axes[1].set_xticks(ticks)
axes[1].set_xticklabels(ticks + 1)
axes[1].set_xlabel("number of covariables")
axes[1].set_ylabel("R²", rotation=0, labelpad=12)
axes[1].set_ylim(0, 1)

axes[0].set_xlim(-0.5, len(var_imp) - 0.5)
axes[1].set_xlim(-0.5, len(var_imp) - 0.5)
def _generate_combinations(variables, excluded):
"""
Generate all possible combinations of variables.
"""
all_combinations = combinations(variables, len(excluded))
valid_combinations = [
com for com in all_combinations if not any(ele in com for ele in excluded)
]

return idxs[::-1], axes
return valid_combinations
28 changes: 17 additions & 11 deletions tests/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_shared_variable(response):
y = pm.Normal("y", mu, sigma, observed=Y, shape=mu.shape)
idata = pm.sample(tune=100, draws=100, chains=2, random_seed=3415)
ppc = pm.sample_posterior_predictive(idata)
new_X = pm.set_data({"data_X": X[:3]})
pm.set_data({"data_X": X[:3]})
ppc2 = pm.sample_posterior_predictive(idata)

assert ppc.posterior_predictive["y"].shape == (2, 100, 50)
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_sample_posterior(self):
{"instances": 2},
{"var_idx": [0], "smooth": False, "color": "k"},
{"grid": (1, 2), "sharey": "none", "alpha": 1},
{"var_discrete": [0]}
{"var_discrete": [0]},
],
)
def test_ice(self, kwargs):
Expand All @@ -178,7 +178,7 @@ def test_ice(self, kwargs):
},
{"var_idx": [0], "smooth": False, "color": "k"},
{"grid": (1, 2), "sharey": "none", "alpha": 1},
{"var_discrete": [0]}
{"var_discrete": [0]},
],
)
def test_pdp(self, kwargs):
Expand Down Expand Up @@ -224,22 +224,28 @@ def test_bart_moment(size, expected):
@pytest.mark.parametrize(
argnames="separate_trees,split_rule",
argvalues=[
(False,pmb.ContinuousSplitRule),
(False,pmb.OneHotSplitRule),
(False,pmb.SubsetSplitRule),
(True,pmb.ContinuousSplitRule)
(False, pmb.ContinuousSplitRule),
(False, pmb.OneHotSplitRule),
(False, pmb.SubsetSplitRule),
(True, pmb.ContinuousSplitRule),
],
ids=["continuous", "one-hot", "subset", "separate-trees"],
)
def test_categorical_model(separate_trees,split_rule):
def test_categorical_model(separate_trees, split_rule):

Y = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2])
X = np.concatenate([Y[:, None], np.random.randint(0, 6, size=(9, 4))], axis=1)

with pm.Model() as model:
lo = pmb.BART("logodds", X, Y, m=2, shape=(3, 9),
split_rules=[split_rule]*5,
separate_trees=separate_trees)
lo = pmb.BART(
"logodds",
X,
Y,
m=2,
shape=(3, 9),
split_rules=[split_rule] * 5,
separate_trees=separate_trees,
)
y = pm.Categorical("y", p=pm.math.softmax(lo.T, axis=-1), observed=Y)
idata = pm.sample(random_seed=3415, tune=300, draws=300)
idata = pm.sample_posterior_predictive(idata, predictions=True, extend_inferencedata=True)
Expand Down

0 comments on commit d7bbfb4

Please sign in to comment.