Skip to content

Commit

Permalink
bump release (#205)
Browse files Browse the repository at this point in the history
* bump release

* fix zip and new args
  • Loading branch information
aloctavodia authored Dec 17, 2024
1 parent bcdf77d commit d4e8cad
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pymc_bart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"plot_variable_importance",
"plot_variable_inclusion",
]
__version__ = "0.7.1"
__version__ = "0.8.0"


pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART]
2 changes: 1 addition & 1 deletion pymc_bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def get_moment(rv, size, *rv_inputs):
return cls.get_moment(rv, size, *rv_inputs)

cls.rv_op = bart_op
params = [X, Y, m, alpha, beta, split_prior]
params = [X, Y, m, alpha, beta]
return super().__new__(cls, name, *params, **kwargs)

@classmethod
Expand Down
12 changes: 8 additions & 4 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import numpy.typing as npt
from numba import njit
from pymc.initial_point import PointType
from pymc.model import Model, modelcontext
from pymc.pytensorf import inputvars, join_nonshared_inputs, make_shared_replacements
from pymc.step_methods.arraystep import ArrayStepShared
Expand Down Expand Up @@ -125,9 +126,12 @@ def __init__( # noqa: PLR0915
num_particles: int = 10,
batch: tuple[float, float] = (0.1, 0.1),
model: Optional[Model] = None,
initial_point: PointType | None = None,
compile_kwargs: dict | None = None, # pylint: disable=unused-argument
):
model = modelcontext(model)
initial_values = model.initial_point()
if initial_point is None:
initial_point = model.initial_point()
if vars is None:
vars = model.value_vars
else:
Expand All @@ -150,7 +154,7 @@ def __init__( # noqa: PLR0915
self.m = self.bart.m
self.response = self.bart.response

shape = initial_values[value_bart.name].shape
shape = initial_point[value_bart.name].shape

self.shape = 1 if len(shape) == 1 else shape[0]

Expand Down Expand Up @@ -217,8 +221,8 @@ def __init__( # noqa: PLR0915

self.num_particles = num_particles
self.indices = list(range(1, num_particles))
shared = make_shared_replacements(initial_values, vars, model)
self.likelihood_logp = logp(initial_values, [model.datalogp], vars, shared)
shared = make_shared_replacements(initial_point, vars, model)
self.likelihood_logp = logp(initial_point, [model.datalogp], vars, shared)
self.all_particles = [
[ParticleTree(self.a_tree) for _ in range(self.m)] for _ in range(self.trees_shape)
]
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pymc>=5.16.2, <=5.18
pymc>=5.16.2, <=5.19.1
arviz>=0.18.0
numba
matplotlib
Expand Down
6 changes: 4 additions & 2 deletions tests/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,10 @@ def test_categorical_model(separate_trees, split_rule):
separate_trees=separate_trees,
)
y = pm.Categorical("y", p=pm.math.softmax(lo.T, axis=-1), observed=Y)
idata = pm.sample(random_seed=3415, tune=300, draws=300)
idata = pm.sample_posterior_predictive(idata, predictions=True, extend_inferencedata=True)
idata = pm.sample(tune=300, draws=300, random_seed=3415)
idata = pm.sample_posterior_predictive(
idata, predictions=True, extend_inferencedata=True, random_seed=3415
)

# Fit should be good enough so right category is selected over 50% of time
assert (idata.predictions.y.median(["chain", "draw"]) == Y).all()

0 comments on commit d4e8cad

Please sign in to comment.