Skip to content

Commit

Permalink
fix bug with shapes (#208)
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia authored Dec 23, 2024
1 parent 1ec251b commit 2f0b3aa
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pymc_bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912

r2_mean = np.zeros(n_vars)
r2_hdi = np.zeros((n_vars, 2))
preds = np.zeros((n_vars, samples, bartrv.eval().shape[0]))
preds = np.zeros((n_vars, samples, *bartrv.eval().T.shape))

if method == "backward_VI":
if fixed >= n_vars:
Expand Down
1 change: 1 addition & 0 deletions tests/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,4 @@ def test_categorical_model(separate_trees, split_rule):

# Fit should be good enough so right category is selected over 50% of time
assert (idata.predictions.y.median(["chain", "draw"]) == Y).all()
assert pmb.compute_variable_importance(idata, bartrv=lo, X=X)["preds"].shape == (5, 50, 9, 3)

0 comments on commit 2f0b3aa

Please sign in to comment.