Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SMC: add support for random variables with shape (1,) #336

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions pymc_experimental/inference/smc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def sample_smc_blackjax(
num_mcmc_steps,
kernel,
diagnosis,
total_iterations,
int(total_iterations),
iterations_to_diagnose,
inner_kernel_params,
running_time,
Expand Down Expand Up @@ -198,13 +198,13 @@ def arviz_from_particles(model, particles):
-------
"""
n_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0]
by_varname = {k.name: v.squeeze()[np.newaxis, :] for k, v in zip(model.value_vars, particles)}
by_varname = {k.name: v for k, v in zip(model.value_vars, particles)}
varnames = [v.name for v in model.value_vars]
with model:
strace = NDArray(name=model.name)
strace.setup(n_particles, 0)
for particle_index in range(0, n_particles):
strace.record(point={k: by_varname[k][0][particle_index] for k in varnames})
strace.record(point={k: by_varname[k][particle_index] for k in varnames})
multitrace = MultiTrace((strace,))
return to_inference_data(multitrace, log_likelihood=False)

Expand Down Expand Up @@ -295,14 +295,7 @@ def blackjax_particles_from_pymc_population(model, pymc_population):

order_of_vars = model.value_vars

def _format(var):
variable = pymc_population[var.name]
if len(variable.shape) == 1:
return variable[:, np.newaxis]
else:
return variable

return [_format(var) for var in order_of_vars]
return [pymc_population[var.name] for var in order_of_vars]


def add_to_inference_data(
Expand Down Expand Up @@ -384,7 +377,7 @@ def get_jaxified_particles_fn(model, graph_outputs):
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[graph_outputs])

def logp_fn_wrap(particles):
return logp_fn(*[p.squeeze() for p in particles])[0]
return logp_fn(*particles)[0]

return logp_fn_wrap

Expand Down
12 changes: 6 additions & 6 deletions pymc_experimental/tests/test_blackjax_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_blackjax_particles_from_pymc_population_univariate():
model = fast_model()
population = {"x": np.array([2, 3, 4])}
blackjax_particles = blackjax_particles_from_pymc_population(model, population)
jax.tree.map(np.testing.assert_allclose, blackjax_particles, [np.array([[2], [3], [4]])])
jax.tree.map(np.testing.assert_allclose, blackjax_particles, [np.array([2, 3, 4])])


def test_blackjax_particles_from_pymc_population_multivariate():
Expand All @@ -147,7 +147,7 @@ def test_blackjax_particles_from_pymc_population_multivariate():
jax.tree.map(
np.testing.assert_allclose,
blackjax_particles,
[np.array([[0.34614613], [1.09163261], [-0.44526825]]), np.array([[1], [2], [3]])],
[np.array([0.34614613, 1.09163261, -0.44526825]), np.array([1, 2, 3])],
)


Expand All @@ -168,7 +168,7 @@ def test_blackjax_particles_from_pymc_population_multivariable():
population = {"x": np.array([[2, 3], [5, 6], [7, 9]]), "z": np.array([11, 12, 13])}
blackjax_particles = blackjax_particles_from_pymc_population(model, population)

jax.tree.map(
jax.tree_map(
np.testing.assert_allclose,
blackjax_particles,
[np.array([[2, 3], [5, 6], [7, 9]]), np.array([[11], [12], [13]])],
Expand All @@ -181,7 +181,7 @@ def test_arviz_from_particles():
with model:
inference_data = arviz_from_particles(model, particles)

assert inference_data.posterior.sizes == Frozen({"chain": 1, "draw": 3, "x_dim_0": 2})
assert inference_data.posterior.dims == Frozen({"chain": 1, "draw": 3, "x_dim_0": 2})
assert inference_data.posterior.data_vars.dtypes == Frozen(
{"x": dtype("float64"), "z": dtype("float64")}
)
Expand All @@ -196,7 +196,7 @@ def test_get_jaxified_logprior():
"""
logprior = get_jaxified_logprior(fast_model())
for point in [-0.5, 0.0, 0.5]:
jax.tree.map(
jax.tree_map(
np.testing.assert_allclose,
jax.vmap(logprior)([np.array([point])]),
np.log(scipy.stats.norm(0, 1).pdf(point)),
Expand All @@ -212,7 +212,7 @@ def test_get_jaxified_loglikelihood():
"""
loglikelihood = get_jaxified_loglikelihood(fast_model())
for point in [-0.5, 0.0, 0.5]:
jax.tree.map(
jax.tree_map(
np.testing.assert_allclose,
jax.vmap(loglikelihood)([np.array([point])]),
np.log(scipy.stats.norm(point, 1).pdf(0)),
Expand Down
Loading