From fbaef4869da2c5fcfe45beb7e8202af2b9a7ae85 Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Fri, 5 Apr 2024 16:53:04 -0300 Subject: [PATCH 1/5] fixing serialization to netcdf --- pymc_experimental/inference/smc/sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_experimental/inference/smc/sampling.py b/pymc_experimental/inference/smc/sampling.py index 5f8dcee6..ae866c9a 100644 --- a/pymc_experimental/inference/smc/sampling.py +++ b/pymc_experimental/inference/smc/sampling.py @@ -166,7 +166,7 @@ def sample_smc_blackjax( num_mcmc_steps, kernel, diagnosis, - total_iterations, + int(total_iterations), iterations_to_diagnose, inner_kernel_params, running_time, From d99b860ec9b2e79215228f5f6641210b269f7953 Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Fri, 5 Apr 2024 17:53:55 -0300 Subject: [PATCH 2/5] potential fix --- notebooks/example2.py | 47 ++++++++++++++++++++ pymc_experimental/inference/smc/sampling.py | 14 +++--- pymc_experimental/tests/test_blackjax_smc.py | 21 +++++---- 3 files changed, 66 insertions(+), 16 deletions(-) create mode 100644 notebooks/example2.py diff --git a/notebooks/example2.py b/notebooks/example2.py new file mode 100644 index 00000000..34a5bec0 --- /dev/null +++ b/notebooks/example2.py @@ -0,0 +1,47 @@ +import pymc as pm + +from pymc_experimental.inference.smc.sampling import sample_smc_blackjax +import numpy as np + +with pm.Model() as model: + a = pm.Normal("a", mu=10, sigma=10) + b = pm.Normal("b", mu=10, sigma=10) + # either of the following lines produces an error + d = pm.Dirichlet("d", [1, 1]) + + trace = sample_smc_blackjax( + n_particles=1000, + kernel="HMC", + inner_kernel_params={ + "step_size": 0.01, + "integration_steps": 20, + }, + iterations_to_diagnose=10, + target_essn=0.5, + num_mcmc_steps=10, + ) + + +real_a = 0.2 +real_b = 2 +x = np.linspace(1, 100) +y = real_a * x + real_b + np.random.normal(0, 2, len(x)) + + +with pm.Model() as model: + a = pm.Normal("a", mu=10, sigma=10) + b = pm.Normal("b", mu=10, sigma=10) + # either of the following lines produces an error + c = pm.Normal("c", mu=10, sigma=10, shape=(1,)) + + trace = sample_smc_blackjax( + n_particles=1000, + kernel="HMC", + inner_kernel_params={ + "step_size": 0.01, + "integration_steps": 20, + }, + iterations_to_diagnose=10, + target_essn=0.5, + num_mcmc_steps=10, + ) \ No newline at end of file diff --git a/pymc_experimental/inference/smc/sampling.py b/pymc_experimental/inference/smc/sampling.py index ae866c9a..fe634ed0 100644 --- a/pymc_experimental/inference/smc/sampling.py +++ b/pymc_experimental/inference/smc/sampling.py @@ -198,13 +198,13 @@ def arviz_from_particles(model, particles): ------- """ n_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0] - by_varname = {k.name: v.squeeze()[np.newaxis, :] for k, v in zip(model.value_vars, particles)} + by_varname = {k.name: v for k, v in zip(model.value_vars, particles)} varnames = [v.name for v in model.value_vars] with model: strace = NDArray(name=model.name) strace.setup(n_particles, 0) for particle_index in range(0, n_particles): - strace.record(point={k: by_varname[k][0][particle_index] for k in varnames}) + strace.record(point={k: by_varname[k][particle_index] for k in varnames}) multitrace = MultiTrace((strace,)) return to_inference_data(multitrace, log_likelihood=False) @@ -297,10 +297,10 @@ def blackjax_particles_from_pymc_population(model, pymc_population): def _format(var): variable = pymc_population[var.name] - if len(variable.shape) == 1: - return variable[:, np.newaxis] - else: - return variable + #if len(variable.shape) == 1: + # return variable[:, np.newaxis] + #else: + return variable return [_format(var) for var in order_of_vars] @@ -384,7 +384,7 @@ def get_jaxified_particles_fn(model, graph_outputs): logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[graph_outputs]) def logp_fn_wrap(particles): - return logp_fn(*[p.squeeze() for p in particles])[0] + return logp_fn(*particles)[0] return logp_fn_wrap diff --git a/pymc_experimental/tests/test_blackjax_smc.py b/pymc_experimental/tests/test_blackjax_smc.py index ebb71f13..927d52b6 100644 --- a/pymc_experimental/tests/test_blackjax_smc.py +++ b/pymc_experimental/tests/test_blackjax_smc.py @@ -133,7 +133,7 @@ def test_blackjax_particles_from_pymc_population_univariate(): model = fast_model() population = {"x": np.array([2, 3, 4])} blackjax_particles = blackjax_particles_from_pymc_population(model, population) - jax.tree_map(np.testing.assert_allclose, blackjax_particles, [np.array([[2], [3], [4]])]) + jax.tree_map(np.testing.assert_allclose, blackjax_particles, [np.array([2, 3, 4])]) def test_blackjax_particles_from_pymc_population_multivariate(): @@ -142,12 +142,13 @@ def test_blackjax_particles_from_pymc_population_multivariate(): z = pm.Normal("z", 0, 1) y = pm.Normal("y", x + z, 1, observed=0) - population = {"x": np.array([0.34614613, 1.09163261, -0.44526825]), "z": np.array([1, 2, 3])} + population = {"x": np.array([0.34614613, 1.09163261, -0.44526825]), + "z": np.array([1, 2, 3])} blackjax_particles = blackjax_particles_from_pymc_population(model, population) jax.tree_map( np.testing.assert_allclose, blackjax_particles, - [np.array([[0.34614613], [1.09163261], [-0.44526825]]), np.array([[1], [2], [3]])], + [np.array([0.34614613, 1.09163261, -0.44526825]), np.array([1, 2, 3])], ) @@ -158,30 +159,32 @@ def simple_multivariable_model(): """ with pm.Model() as model: x = pm.Normal("x", 0, 1, shape=2) - z = pm.Normal("z", 0, 1) - y = pm.Normal("y", z, 1, observed=0) + z = pm.Normal("z", 0, 1, shape=(1,)) + y = pm.Normal("y", z, np.array([1,]), observed=np.array([0,])) + return model def test_blackjax_particles_from_pymc_population_multivariable(): model = simple_multivariable_model() - population = {"x": np.array([[2, 3], [5, 6], [7, 9]]), "z": np.array([11, 12, 13])} + population = {"x": np.array([[2, 3], [5, 6], [7, 9]]), + "z": np.array([[11,], [12,], [13,]])} blackjax_particles = blackjax_particles_from_pymc_population(model, population) jax.tree_map( np.testing.assert_allclose, blackjax_particles, - [np.array([[2, 3], [5, 6], [7, 9]]), np.array([[11], [12], [13]])], + [np.array([[2, 3], [5, 6], [7, 9]]), np.array([[11,], [12,], [13,]])], ) def test_arviz_from_particles(): model = simple_multivariable_model() - particles = [np.array([[2, 3], [5, 6], [7, 9]]), np.array([[11], [12], [13]])] + particles = [np.array([[2, 3], [5, 6], [7, 9]]), np.array([[11,],[12,],[13]])] with model: inference_data = arviz_from_particles(model, particles) - assert inference_data.posterior.dims == Frozen({"chain": 1, "draw": 3, "x_dim_0": 2}) + assert inference_data.posterior.dims == Frozen({"chain": 1, "draw": 3, "x_dim_0": 2, "z_dim_0":1}) assert inference_data.posterior.data_vars.dtypes == Frozen( {"x": dtype("float64"), "z": dtype("float64")} ) From da3af5ef5f2a95c0ac71d9840c2ddefbe0e41418 Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Fri, 5 Apr 2024 17:59:06 -0300 Subject: [PATCH 3/5] removing comment --- pymc_experimental/inference/smc/sampling.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/pymc_experimental/inference/smc/sampling.py b/pymc_experimental/inference/smc/sampling.py index fe634ed0..3325e0b1 100644 --- a/pymc_experimental/inference/smc/sampling.py +++ b/pymc_experimental/inference/smc/sampling.py @@ -295,14 +295,7 @@ def blackjax_particles_from_pymc_population(model, pymc_population): order_of_vars = model.value_vars - def _format(var): - variable = pymc_population[var.name] - #if len(variable.shape) == 1: - # return variable[:, np.newaxis] - #else: - return variable - - return [_format(var) for var in order_of_vars] + return [pymc_population[var.name]for var in order_of_vars] def add_to_inference_data( From a838a491c95d662f467ce3bc44d5bd1cf1c1ff8c Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Wed, 17 Apr 2024 09:26:09 -0300 Subject: [PATCH 4/5] removing unused file --- notebooks/example2.py | 47 ------------------------------------------- 1 file changed, 47 deletions(-) delete mode 100644 notebooks/example2.py diff --git a/notebooks/example2.py b/notebooks/example2.py deleted file mode 100644 index 34a5bec0..00000000 --- a/notebooks/example2.py +++ /dev/null @@ -1,47 +0,0 @@ -import pymc as pm - -from pymc_experimental.inference.smc.sampling import sample_smc_blackjax -import numpy as np - -with pm.Model() as model: - a = pm.Normal("a", mu=10, sigma=10) - b = pm.Normal("b", mu=10, sigma=10) - # either of the following lines produces an error - d = pm.Dirichlet("d", [1, 1]) - - trace = sample_smc_blackjax( - n_particles=1000, - kernel="HMC", - inner_kernel_params={ - "step_size": 0.01, - "integration_steps": 20, - }, - iterations_to_diagnose=10, - target_essn=0.5, - num_mcmc_steps=10, - ) - - -real_a = 0.2 -real_b = 2 -x = np.linspace(1, 100) -y = real_a * x + real_b + np.random.normal(0, 2, len(x)) - - -with pm.Model() as model: - a = pm.Normal("a", mu=10, sigma=10) - b = pm.Normal("b", mu=10, sigma=10) - # either of the following lines produces an error - c = pm.Normal("c", mu=10, sigma=10, shape=(1,)) - - trace = sample_smc_blackjax( - n_particles=1000, - kernel="HMC", - inner_kernel_params={ - "step_size": 0.01, - "integration_steps": 20, - }, - iterations_to_diagnose=10, - target_essn=0.5, - num_mcmc_steps=10, - ) \ No newline at end of file From 633b7706023855959accaa3997848b44b90618cb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Apr 2024 12:29:09 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pymc_experimental/inference/smc/sampling.py | 2 +- pymc_experimental/tests/test_blackjax_smc.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pymc_experimental/inference/smc/sampling.py b/pymc_experimental/inference/smc/sampling.py index 3325e0b1..c6d7a818 100644 --- a/pymc_experimental/inference/smc/sampling.py +++ b/pymc_experimental/inference/smc/sampling.py @@ -295,7 +295,7 @@ def blackjax_particles_from_pymc_population(model, pymc_population): order_of_vars = model.value_vars - return [pymc_population[var.name]for var in order_of_vars] + return [pymc_population[var.name] for var in order_of_vars] def add_to_inference_data( diff --git a/pymc_experimental/tests/test_blackjax_smc.py b/pymc_experimental/tests/test_blackjax_smc.py index f817d2f5..4aa0027c 100644 --- a/pymc_experimental/tests/test_blackjax_smc.py +++ b/pymc_experimental/tests/test_blackjax_smc.py @@ -142,8 +142,7 @@ def test_blackjax_particles_from_pymc_population_multivariate(): z = pm.Normal("z", 0, 1) y = pm.Normal("y", x + z, 1, observed=0) - population = {"x": np.array([0.34614613, 1.09163261, -0.44526825]), - "z": np.array([1, 2, 3])} + population = {"x": np.array([0.34614613, 1.09163261, -0.44526825]), "z": np.array([1, 2, 3])} blackjax_particles = blackjax_particles_from_pymc_population(model, population) jax.tree.map( np.testing.assert_allclose,