Skip to content

Commit

Permalink
fix bug, clean code and add comments (#132)
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia authored Dec 20, 2023
1 parent 9deb502 commit 83f2409
Showing 1 changed file with 46 additions and 34 deletions.
80 changes: 46 additions & 34 deletions pymc_bart/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Utility function for variable selection and bart interpretability."""

from itertools import combinations
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -805,26 +804,35 @@ def plot_variable_importance(
excluded=subset,
shape=shape,
)
pearson = np.zeros(samples)
for j in range(samples):
pearson[j] = (
pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0]
) ** 2
r2_mean[idx] = np.mean(pearson)
r2_hdi[idx] = az.hdi(pearson)
r_2 = np.array(
[
pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0] ** 2
for j in range(samples)
]
)
r2_mean[idx] = np.mean(r_2)
r2_hdi[idx] = az.hdi(r_2)

elif method == "backward":
r2_mean = np.zeros(n_vars)
r2_hdi = np.zeros((n_vars, 2))

variables = set(range(n_vars))
excluded: List[int] = []
least_important_vars: List[int] = []
indices = []

for i_var in range(0, n_vars):
subsets = _generate_combinations(variables, excluded)
max_pearson = -np.inf
# Iterate over each variable to determine its contribution
# least_important_vars tracks the variable with the lowest contribution
# at the current stage. One new varible is added at each iteration.
for i_var in range(n_vars):
# Generate all possible subsets by adding one variable at a time to
# least_important_vars
subsets = generate_sequences(n_vars, i_var, least_important_vars)
max_r_2 = -np.inf

# Iterate over each subset to find the one with the maximum Pearson correlation
for subset in subsets:
# Sample posterior predictions excluding a subset of variables
predicted_subset = _sample_posterior(
all_trees=all_trees,
X=X,
Expand All @@ -833,25 +841,32 @@ def plot_variable_importance(
excluded=subset,
shape=shape,
)
pearson = np.zeros(samples)
# Calculate Pearson correlation for each sample and find the mean
r_2 = np.zeros(samples)
for j in range(samples):
pearson[j] = (
r_2[j] = (
pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0]
) ** 2
mean_pearson = np.mean(pearson, dtype=float)
if mean_pearson > max_pearson:
max_pearson = mean_pearson
best_subset = subset
best_pearson = pearson
mean_r_2 = np.mean(r_2, dtype=float)
# Identify the least important combination of variables
# based on the maximum mean squared Pearson correlation
if mean_r_2 > max_r_2:
max_r_2 = mean_r_2
least_important_subset = subset
r_2_without_least_important_vars = r_2

r2_mean[i_var] = max_pearson
r2_hdi[i_var] = az.hdi(best_pearson)
# Save values for plotting later
r2_mean[i_var] = max_r_2
r2_hdi[i_var] = az.hdi(r_2_without_least_important_vars)

indices.extend((set(best_subset) - set(indices)))
# extend current list of least important variable
least_important_vars += least_important_subset

excluded.append(best_subset)
# add index of removed variable
indices += list(set(least_important_subset) - set(indices))

indices.extend((set(variables) - set(indices)))
# add remaining index
indices += list(set(variables) - set(least_important_vars))

indices = indices[::-1]
r2_mean = r2_mean[::-1]
Expand All @@ -878,13 +893,10 @@ def plot_variable_importance(
return indices, ax


def _generate_combinations(variables, excluded):
"""
Generate all possible combinations of variables.
"""
all_combinations = combinations(variables, len(excluded))
valid_combinations = [
com for com in all_combinations if not any(ele in com for ele in excluded)
]

return valid_combinations
def generate_sequences(n_vars, i_var, include):
"""Generate combinations of variables"""
if i_var:
sequences = [tuple(include + [i]) for i in range(n_vars) if i not in include]
else:
sequences = [()]
return sequences

0 comments on commit 83f2409

Please sign in to comment.