diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index c10b8f8..8774803 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -36,7 +36,7 @@ "plot_pdp", "plot_variable_importance", ] -__version__ = "0.7.0" +__version__ = "0.7.1" pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART] diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index 969baf4..a21bda5 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -37,19 +37,22 @@ class BARTRV(RandomVariable): """Base class for BART.""" name: str = "BART" - ndim_supp = 1 - ndims_params: List[int] = [2, 1, 0, 0, 0, 1] + signature = "(m,n),(m),(),(),() -> (m)" dtype: str = "floatX" _print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}") all_trees = List[List[List[Tree]]] def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): # pylint: disable=arguments-renamed - return dist_params[0].shape[:1] + idx = dist_params[0].ndim - 2 + return [dist_params[0].shape[idx]] @classmethod def rng_fn( # pylint: disable=W0237 - cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, split_prior=None, size=None + cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, size=None ): + if not size: + size = None + if not cls.all_trees: if size is not None: return np.full((size[0], cls.Y.shape[0]), cls.Y.mean()) @@ -96,9 +99,6 @@ class BART(Distribution): List of SplitRule objects, one per column in input data. Allows using different split rules for different columns. Default is ContinuousSplitRule. Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables. - shape: : Optional[Tuple], default None - Specify the output shape. If shape is different from (len(X)) (the default), train a - separate tree for each value in other dimensions. separate_trees : Optional[bool], default False When training multiple trees (by setting a shape parameter), the default behavior is to learn a joint tree structure and only have different leaf values for each. diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 91a9beb..6de7a53 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -114,7 +114,10 @@ class PGBART(ArrayStepShared): name = "pgbart" default_blocked = False generates_stats = True - stats_dtypes = [{"variable_inclusion": object, "tune": bool}] + stats_dtypes_shapes: dict[str, tuple[type, list]] = { + "variable_inclusion": (object, []), + "tune": (bool, []), + } def __init__( # noqa: PLR0915 self, @@ -227,7 +230,7 @@ def __init__( # noqa: PLR0915 def astep(self, _): variable_inclusion = np.zeros(self.num_variates, dtype="int") - upper = min(self.lower + self.batch[~self.tune], self.m) + upper = min(self.lower + self.batch[not self.tune], self.m) tree_ids = range(self.lower, upper) self.lower = upper if upper < self.m else 0 diff --git a/requirements.txt b/requirements.txt index e741cef..ac9bd07 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -pymc<=5.16.2 +pymc>=5.16.2, <=5.18 arviz>=0.18.0 numba matplotlib diff --git a/tests/test_bart.py b/tests/test_bart.py index dfbd86f..e56735e 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -3,7 +3,7 @@ import pytest from numpy.testing import assert_almost_equal, assert_array_equal from pymc.initial_point import make_initial_point_fn -from pymc.logprob.basic import joint_logp +from pymc.logprob.basic import transformed_conditional_logp import pymc_bart as pmb @@ -12,7 +12,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True): fn = make_initial_point_fn( model=model, return_transformed=False, - default_strategy="moment", + default_strategy="support_point", ) moment = fn(0)["x"] expected = np.asarray(expected) @@ -27,7 +27,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True): if check_finite_logp: logp_moment = ( - joint_logp( + transformed_conditional_logp( (model["x"],), rvs_to_values={model["x"]: pm.math.constant(moment)}, rvs_to_transforms={}, @@ -53,7 +53,7 @@ def test_bart_vi(response): mu = pmb.BART("mu", X, Y, m=10, response=response) sigma = pm.HalfNormal("sigma", 1) y = pm.Normal("y", mu, sigma, observed=Y) - idata = pm.sample(random_seed=3415) + idata = pm.sample(tune=200, draws=200, random_seed=3415) var_imp = ( idata.sample_stats["variable_inclusion"] .stack(samples=("chain", "draw")) @@ -77,8 +77,8 @@ def test_missing_data(response): with pm.Model() as model: mu = pmb.BART("mu", X, Y, m=10, response=response) sigma = pm.HalfNormal("sigma", 1) - y = pm.Normal("y", mu, sigma, observed=Y) - idata = pm.sample(tune=100, draws=100, chains=1, random_seed=3415) + pm.Normal("y", mu, sigma, observed=Y) + pm.sample(tune=100, draws=100, chains=1, random_seed=3415) @pytest.mark.parametrize( @@ -91,7 +91,7 @@ def test_shared_variable(response): Y = np.random.normal(0, 1, size=50) with pm.Model() as model: - data_X = pm.MutableData("data_X", X) + data_X = pm.Data("data_X", X) mu = pmb.BART("mu", data_X, Y, m=2, response=response) sigma = pm.HalfNormal("sigma", 1) y = pm.Normal("y", mu, sigma, observed=Y, shape=mu.shape) @@ -116,7 +116,7 @@ def test_shape(response): with pm.Model() as model: w = pmb.BART("w", X, Y, m=2, response=response, shape=(2, 250)) y = pm.Normal("y", w[0], pm.math.abs(w[1]), observed=Y) - idata = pm.sample(random_seed=3415) + idata = pm.sample(tune=50, draws=10, random_seed=3415) assert model.initial_point()["w"].shape == (2, 250) assert idata.posterior.coords["w_dim_0"].data.size == 2 @@ -133,7 +133,7 @@ class TestUtils: mu = pmb.BART("mu", X, Y, m=10) sigma = pm.HalfNormal("sigma", 1) y = pm.Normal("y", mu, sigma, observed=Y) - idata = pm.sample(random_seed=3415) + idata = pm.sample(tune=200, draws=200, random_seed=3415) def test_sample_posterior(self): all_trees = self.mu.owner.op.all_trees