Skip to content

Commit

Permalink
Conform to recent changes in pymc (#194)
Browse files Browse the repository at this point in the history
* conform to recent changes in pymc

* update version

* fix shapes
  • Loading branch information
aloctavodia authored Nov 7, 2024
1 parent 1741d7d commit b9f4567
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 20 deletions.
2 changes: 1 addition & 1 deletion pymc_bart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"plot_pdp",
"plot_variable_importance",
]
__version__ = "0.7.0"
__version__ = "0.7.1"


pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART]
14 changes: 7 additions & 7 deletions pymc_bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,22 @@ class BARTRV(RandomVariable):
"""Base class for BART."""

name: str = "BART"
ndim_supp = 1
ndims_params: List[int] = [2, 1, 0, 0, 0, 1]
signature = "(m,n),(m),(),(),() -> (m)"
dtype: str = "floatX"
_print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}")
all_trees = List[List[List[Tree]]]

def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): # pylint: disable=arguments-renamed
return dist_params[0].shape[:1]
idx = dist_params[0].ndim - 2
return [dist_params[0].shape[idx]]

@classmethod
def rng_fn( # pylint: disable=W0237
cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, split_prior=None, size=None
cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, size=None
):
if not size:
size = None

if not cls.all_trees:
if size is not None:
return np.full((size[0], cls.Y.shape[0]), cls.Y.mean())
Expand Down Expand Up @@ -96,9 +99,6 @@ class BART(Distribution):
List of SplitRule objects, one per column in input data.
Allows using different split rules for different columns. Default is ContinuousSplitRule.
Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables.
shape: : Optional[Tuple], default None
Specify the output shape. If shape is different from (len(X)) (the default), train a
separate tree for each value in other dimensions.
separate_trees : Optional[bool], default False
When training multiple trees (by setting a shape parameter), the default behavior is to
learn a joint tree structure and only have different leaf values for each.
Expand Down
7 changes: 5 additions & 2 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,10 @@ class PGBART(ArrayStepShared):
name = "pgbart"
default_blocked = False
generates_stats = True
stats_dtypes = [{"variable_inclusion": object, "tune": bool}]
stats_dtypes_shapes: dict[str, tuple[type, list]] = {
"variable_inclusion": (object, []),
"tune": (bool, []),
}

def __init__( # noqa: PLR0915
self,
Expand Down Expand Up @@ -227,7 +230,7 @@ def __init__( # noqa: PLR0915
def astep(self, _):
variable_inclusion = np.zeros(self.num_variates, dtype="int")

upper = min(self.lower + self.batch[~self.tune], self.m)
upper = min(self.lower + self.batch[not self.tune], self.m)
tree_ids = range(self.lower, upper)
self.lower = upper if upper < self.m else 0

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pymc<=5.16.2
pymc>=5.16.2, <=5.18
arviz>=0.18.0
numba
matplotlib
Expand Down
18 changes: 9 additions & 9 deletions tests/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from numpy.testing import assert_almost_equal, assert_array_equal
from pymc.initial_point import make_initial_point_fn
from pymc.logprob.basic import joint_logp
from pymc.logprob.basic import transformed_conditional_logp

import pymc_bart as pmb

Expand All @@ -12,7 +12,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True):
fn = make_initial_point_fn(
model=model,
return_transformed=False,
default_strategy="moment",
default_strategy="support_point",
)
moment = fn(0)["x"]
expected = np.asarray(expected)
Expand All @@ -27,7 +27,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True):

if check_finite_logp:
logp_moment = (
joint_logp(
transformed_conditional_logp(
(model["x"],),
rvs_to_values={model["x"]: pm.math.constant(moment)},
rvs_to_transforms={},
Expand All @@ -53,7 +53,7 @@ def test_bart_vi(response):
mu = pmb.BART("mu", X, Y, m=10, response=response)
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu, sigma, observed=Y)
idata = pm.sample(random_seed=3415)
idata = pm.sample(tune=200, draws=200, random_seed=3415)
var_imp = (
idata.sample_stats["variable_inclusion"]
.stack(samples=("chain", "draw"))
Expand All @@ -77,8 +77,8 @@ def test_missing_data(response):
with pm.Model() as model:
mu = pmb.BART("mu", X, Y, m=10, response=response)
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu, sigma, observed=Y)
idata = pm.sample(tune=100, draws=100, chains=1, random_seed=3415)
pm.Normal("y", mu, sigma, observed=Y)
pm.sample(tune=100, draws=100, chains=1, random_seed=3415)


@pytest.mark.parametrize(
Expand All @@ -91,7 +91,7 @@ def test_shared_variable(response):
Y = np.random.normal(0, 1, size=50)

with pm.Model() as model:
data_X = pm.MutableData("data_X", X)
data_X = pm.Data("data_X", X)
mu = pmb.BART("mu", data_X, Y, m=2, response=response)
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu, sigma, observed=Y, shape=mu.shape)
Expand All @@ -116,7 +116,7 @@ def test_shape(response):
with pm.Model() as model:
w = pmb.BART("w", X, Y, m=2, response=response, shape=(2, 250))
y = pm.Normal("y", w[0], pm.math.abs(w[1]), observed=Y)
idata = pm.sample(random_seed=3415)
idata = pm.sample(tune=50, draws=10, random_seed=3415)

assert model.initial_point()["w"].shape == (2, 250)
assert idata.posterior.coords["w_dim_0"].data.size == 2
Expand All @@ -133,7 +133,7 @@ class TestUtils:
mu = pmb.BART("mu", X, Y, m=10)
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu, sigma, observed=Y)
idata = pm.sample(random_seed=3415)
idata = pm.sample(tune=200, draws=200, random_seed=3415)

def test_sample_posterior(self):
all_trees = self.mu.owner.op.all_trees
Expand Down

0 comments on commit b9f4567

Please sign in to comment.