From d4e8cadaf84cde6417390dc18c06d83c9b9114c4 Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Tue, 17 Dec 2024 10:39:00 -0300 Subject: [PATCH] bump release (#205) * bump release * fix zip and new args --- pymc_bart/__init__.py | 2 +- pymc_bart/bart.py | 2 +- pymc_bart/pgbart.py | 12 ++++++++---- requirements.txt | 2 +- tests/test_bart.py | 6 ++++-- 5 files changed, 15 insertions(+), 9 deletions(-) diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index 18fe054..eee1881 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -42,7 +42,7 @@ "plot_variable_importance", "plot_variable_inclusion", ] -__version__ = "0.7.1" +__version__ = "0.8.0" pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART] diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index 94b91c3..eb869d2 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -175,7 +175,7 @@ def get_moment(rv, size, *rv_inputs): return cls.get_moment(rv, size, *rv_inputs) cls.rv_op = bart_op - params = [X, Y, m, alpha, beta, split_prior] + params = [X, Y, m, alpha, beta] return super().__new__(cls, name, *params, **kwargs) @classmethod diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 6a7e26e..1505f15 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -17,6 +17,7 @@ import numpy as np import numpy.typing as npt from numba import njit +from pymc.initial_point import PointType from pymc.model import Model, modelcontext from pymc.pytensorf import inputvars, join_nonshared_inputs, make_shared_replacements from pymc.step_methods.arraystep import ArrayStepShared @@ -125,9 +126,12 @@ def __init__( # noqa: PLR0915 num_particles: int = 10, batch: tuple[float, float] = (0.1, 0.1), model: Optional[Model] = None, + initial_point: PointType | None = None, + compile_kwargs: dict | None = None, # pylint: disable=unused-argument ): model = modelcontext(model) - initial_values = model.initial_point() + if initial_point is None: + initial_point = model.initial_point() if vars is None: vars = model.value_vars else: @@ -150,7 +154,7 @@ def __init__( # noqa: PLR0915 self.m = self.bart.m self.response = self.bart.response - shape = initial_values[value_bart.name].shape + shape = initial_point[value_bart.name].shape self.shape = 1 if len(shape) == 1 else shape[0] @@ -217,8 +221,8 @@ def __init__( # noqa: PLR0915 self.num_particles = num_particles self.indices = list(range(1, num_particles)) - shared = make_shared_replacements(initial_values, vars, model) - self.likelihood_logp = logp(initial_values, [model.datalogp], vars, shared) + shared = make_shared_replacements(initial_point, vars, model) + self.likelihood_logp = logp(initial_point, [model.datalogp], vars, shared) self.all_particles = [ [ParticleTree(self.a_tree) for _ in range(self.m)] for _ in range(self.trees_shape) ] diff --git a/requirements.txt b/requirements.txt index ac9bd07..da634d4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -pymc>=5.16.2, <=5.18 +pymc>=5.16.2, <=5.19.1 arviz>=0.18.0 numba matplotlib diff --git a/tests/test_bart.py b/tests/test_bart.py index c64811a..a003363 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -248,8 +248,10 @@ def test_categorical_model(separate_trees, split_rule): separate_trees=separate_trees, ) y = pm.Categorical("y", p=pm.math.softmax(lo.T, axis=-1), observed=Y) - idata = pm.sample(random_seed=3415, tune=300, draws=300) - idata = pm.sample_posterior_predictive(idata, predictions=True, extend_inferencedata=True) + idata = pm.sample(tune=300, draws=300, random_seed=3415) + idata = pm.sample_posterior_predictive( + idata, predictions=True, extend_inferencedata=True, random_seed=3415 + ) # Fit should be good enough so right category is selected over 50% of time assert (idata.predictions.y.median(["chain", "draw"]) == Y).all()