diff --git a/pymc_experimental/inference/smc/sampling.py b/pymc_experimental/inference/smc/sampling.py index 5f8dcee6..c6d7a818 100644 --- a/pymc_experimental/inference/smc/sampling.py +++ b/pymc_experimental/inference/smc/sampling.py @@ -166,7 +166,7 @@ def sample_smc_blackjax( num_mcmc_steps, kernel, diagnosis, - total_iterations, + int(total_iterations), iterations_to_diagnose, inner_kernel_params, running_time, @@ -198,13 +198,13 @@ def arviz_from_particles(model, particles): ------- """ n_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0] - by_varname = {k.name: v.squeeze()[np.newaxis, :] for k, v in zip(model.value_vars, particles)} + by_varname = {k.name: v for k, v in zip(model.value_vars, particles)} varnames = [v.name for v in model.value_vars] with model: strace = NDArray(name=model.name) strace.setup(n_particles, 0) for particle_index in range(0, n_particles): - strace.record(point={k: by_varname[k][0][particle_index] for k in varnames}) + strace.record(point={k: by_varname[k][particle_index] for k in varnames}) multitrace = MultiTrace((strace,)) return to_inference_data(multitrace, log_likelihood=False) @@ -295,14 +295,7 @@ def blackjax_particles_from_pymc_population(model, pymc_population): order_of_vars = model.value_vars - def _format(var): - variable = pymc_population[var.name] - if len(variable.shape) == 1: - return variable[:, np.newaxis] - else: - return variable - - return [_format(var) for var in order_of_vars] + return [pymc_population[var.name] for var in order_of_vars] def add_to_inference_data( @@ -384,7 +377,7 @@ def get_jaxified_particles_fn(model, graph_outputs): logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[graph_outputs]) def logp_fn_wrap(particles): - return logp_fn(*[p.squeeze() for p in particles])[0] + return logp_fn(*particles)[0] return logp_fn_wrap diff --git a/pymc_experimental/tests/test_blackjax_smc.py b/pymc_experimental/tests/test_blackjax_smc.py index 2cdcf067..4aa0027c 100644 --- a/pymc_experimental/tests/test_blackjax_smc.py +++ b/pymc_experimental/tests/test_blackjax_smc.py @@ -133,7 +133,7 @@ def test_blackjax_particles_from_pymc_population_univariate(): model = fast_model() population = {"x": np.array([2, 3, 4])} blackjax_particles = blackjax_particles_from_pymc_population(model, population) - jax.tree.map(np.testing.assert_allclose, blackjax_particles, [np.array([[2], [3], [4]])]) + jax.tree.map(np.testing.assert_allclose, blackjax_particles, [np.array([2, 3, 4])]) def test_blackjax_particles_from_pymc_population_multivariate(): @@ -147,7 +147,7 @@ def test_blackjax_particles_from_pymc_population_multivariate(): jax.tree.map( np.testing.assert_allclose, blackjax_particles, - [np.array([[0.34614613], [1.09163261], [-0.44526825]]), np.array([[1], [2], [3]])], + [np.array([0.34614613, 1.09163261, -0.44526825]), np.array([1, 2, 3])], ) @@ -168,7 +168,7 @@ def test_blackjax_particles_from_pymc_population_multivariable(): population = {"x": np.array([[2, 3], [5, 6], [7, 9]]), "z": np.array([11, 12, 13])} blackjax_particles = blackjax_particles_from_pymc_population(model, population) - jax.tree.map( + jax.tree_map( np.testing.assert_allclose, blackjax_particles, [np.array([[2, 3], [5, 6], [7, 9]]), np.array([[11], [12], [13]])], @@ -181,7 +181,7 @@ def test_arviz_from_particles(): with model: inference_data = arviz_from_particles(model, particles) - assert inference_data.posterior.sizes == Frozen({"chain": 1, "draw": 3, "x_dim_0": 2}) + assert inference_data.posterior.dims == Frozen({"chain": 1, "draw": 3, "x_dim_0": 2}) assert inference_data.posterior.data_vars.dtypes == Frozen( {"x": dtype("float64"), "z": dtype("float64")} ) @@ -196,7 +196,7 @@ def test_get_jaxified_logprior(): """ logprior = get_jaxified_logprior(fast_model()) for point in [-0.5, 0.0, 0.5]: - jax.tree.map( + jax.tree_map( np.testing.assert_allclose, jax.vmap(logprior)([np.array([point])]), np.log(scipy.stats.norm(0, 1).pdf(point)), @@ -212,7 +212,7 @@ def test_get_jaxified_loglikelihood(): """ loglikelihood = get_jaxified_loglikelihood(fast_model()) for point in [-0.5, 0.0, 0.5]: - jax.tree.map( + jax.tree_map( np.testing.assert_allclose, jax.vmap(loglikelihood)([np.array([point])]), np.log(scipy.stats.norm(point, 1).pdf(0)),