Skip to content

Commit

Permalink
add submodels arguments to plot subsets (#200)
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia authored Nov 29, 2024
1 parent 9ec4de8 commit b20d074
Showing 1 changed file with 47 additions and 29 deletions.
76 changes: 47 additions & 29 deletions pymc_bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non
Parameters
----------
idata: InferenceData
idata : InferenceData
InferenceData containing a collection of BART_trees in sample_stats group
X : npt.NDArray[np.float64]
The covariate matrix.
Expand Down Expand Up @@ -784,7 +784,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
Parameters
----------
idata: InferenceData
idata : InferenceData
InferenceData containing a collection of BART_trees in sample_stats group
bartrv : BART Random Variable
BART variable once the model that include it has been fitted.
Expand Down Expand Up @@ -949,8 +949,10 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912

indices = least_important_vars[::-1]

labels = np.array(["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)])

vi_results = {
"indices": indices,
"indices": np.asarray(indices),
"labels": labels[indices],
"r2_mean": r2_mean,
"r2_hdi": r2_hdi,
Expand All @@ -962,8 +964,9 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912

def plot_variable_importance(
vi_results: dict,
labels=None,
figsize=None,
submodels: Optional[Union[list[int], np.ndarray, tuple[int, ...]]] = None,
labels: Optional[list[str]] = None,
figsize: Optional[tuple[float, float]] = None,
plot_kwargs: Optional[dict[str, Any]] = None,
ax: Optional[plt.Axes] = None,
):
Expand All @@ -974,8 +977,11 @@ def plot_variable_importance(
----------
vi_results: Dictionary
Dictionary computed with `compute_variable_importance`
X : npt.NDArray[np.float64]
The covariate matrix.
submodels : Optional[Union[list[int], np.ndarray]]
List of the indices of the submodels to plot. Defaults to None, all variables are ploted.
The indices correspond to order computed by `compute_variable_importance`.
For example `submodels=[0,1]` will plot the two most important variables.
`submodels=[1,0]` is equivalent as values are sorted before use.
labels : Optional[list[str]]
List of the names of the covariates. If X is a DataFrame the names of the covariables will
be taken from it and this argument will be ignored.
Expand All @@ -995,11 +1001,15 @@ def plot_variable_importance(
-------
axes: matplotlib axes
"""
if submodels is None:
submodels = np.sort(vi_results["indices"])
else:
submodels = np.sort(submodels)

indices = vi_results["indices"]
r2_mean = vi_results["r2_mean"]
r2_hdi = vi_results["r2_hdi"]
preds = vi_results["preds"]
indices = vi_results["indices"][submodels]
r2_mean = vi_results["r2_mean"][submodels]
r2_hdi = vi_results["r2_hdi"][submodels]
preds = vi_results["preds"][submodels]
preds_all = vi_results["preds_all"]
samples = preds.shape[1]

Expand All @@ -1016,9 +1026,7 @@ def plot_variable_importance(
_, ax = plt.subplots(1, 1, figsize=figsize)

if labels is None:
labels = vi_results["labels"]

labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)]
labels = vi_results["labels"][submodels]

r_2_ref = np.array([pearsonr2(preds_all[j], preds_all[j + 1]) for j in range(samples - 1)])

Expand Down Expand Up @@ -1059,21 +1067,27 @@ def plot_variable_importance(
def plot_scatter_submodels(
vi_results: dict,
func: Optional[Callable] = None,
submodels: Optional[Union[list[int], np.ndarray]] = None,
grid: str = "long",
labels=None,
labels: Optional[list[str]] = None,
figsize: Optional[tuple[float, float]] = None,
plot_kwargs: Optional[dict[str, Any]] = None,
axes: Optional[plt.Axes] = None,
):
ax: Optional[plt.Axes] = None,
) -> list[plt.Axes]:
"""
Plot submodel's predictions against reference-model's predictions.
Parameters
----------
vi_results: Dictionary
vi_results : Dictionary
Dictionary computed with `compute_variable_importance`
func : Optional[Callable], by default None.
Arbitrary function to apply to the predictions. Defaults to the identity function.
submodels : Optional[Union[list[int], np.ndarray]]
List of the indices of the submodels to plot. Defaults to None, all variables are ploted.
The indices correspond to order computed by `compute_variable_importance`.
For example `submodels=[0,1]` will plot the two most important variables.
`submodels=[1,0]` is equivalent as values are sorted before use.
grid : str or tuple
How to arrange the subplots. Defaults to "long", one subplot below the other.
Other options are "wide", one subplot next to each other or a tuple indicating the number
Expand All @@ -1092,20 +1106,23 @@ def plot_scatter_submodels(
-------
axes: matplotlib axes
"""
indices = vi_results["indices"]
preds = vi_results["preds"]
if submodels is None:
submodels = np.sort(vi_results["indices"])
else:
submodels = np.sort(submodels)

indices = vi_results["indices"][submodels]
preds = vi_results["preds"][submodels]
preds_all = vi_results["preds_all"]

if axes is None:
_, axes = _get_axes(grid, len(indices), True, True, figsize)
if ax is None:
_, ax = _get_axes(grid, len(indices), True, True, figsize)

if plot_kwargs is None:
plot_kwargs = {}

if labels is None:
labels = vi_results["labels"]

labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)]
labels = vi_results["labels"][submodels]

if func is not None:
preds = func(preds)
Expand All @@ -1114,22 +1131,23 @@ def plot_scatter_submodels(
min_ = min(np.min(preds), np.min(preds_all))
max_ = max(np.max(preds), np.max(preds_all))

for pred, x_label, ax in zip(preds, labels, axes.ravel()):
ax.plot(
for pred, x_label, axi in zip(preds, labels, ax.ravel()):
axi.plot(
pred,
preds_all,
marker=plot_kwargs.get("marker_scatter", "."),
ls="",
color=plot_kwargs.get("color_scatter", "C0"),
alpha=plot_kwargs.get("alpha_scatter", 0.1),
)
ax.set_xlabel(x_label)
ax.axline(
axi.set_xlabel(x_label)
axi.axline(
[min_, min_],
[max_, max_],
color=plot_kwargs.get("color_ref", "0.5"),
ls=plot_kwargs.get("ls_ref", "--"),
)
return ax


def generate_sequences(n_vars, i_var, include):
Expand Down

0 comments on commit b20d074

Please sign in to comment.