Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sample_smc #319

Open
myravian opened this issue Mar 26, 2024 · 7 comments
Open

sample_smc #319

myravian opened this issue Mar 26, 2024 · 7 comments
Assignees

Comments

@myravian
Copy link

myravian commented Mar 26, 2024

Hi there,

I was testing pymc_experimental/inference/smc/sampling.py and noticed the following issues:

  • the inference doesn't seem to like pm.Dirichlet, with a shape error at tmp = logp_fn(*[p.squeeze() for p in particles])[0]
  • arviz_from_particles doesn't seem to like RVs with shape=(1,)
  • the conversion from inferencedata to netCDF fails because the integrations is neither int nor np.array
  • the inferencedata doesn't have the marginal likelihood, do you think it will be implemented in the future or it's just not possible?

Thanks a lot for the SMC blackjax implementation, it's very useful!

Cheers,
VIan

PS: here's some code that produces the error

`
real_a = 0.2
real_b = 2
x = np.linspace(1, 100)
y = real_a * x + real_b + np.random.normal(0, 2, len(x))

with pm.Model() as model:
a = pm.Normal("a", mu=10, sigma=10)
b = pm.Normal("b", mu=10, sigma=10)
# either of the following lines produces an error
# c = pm.Normal("c", mu=10, sigma=10, shape=(1,))
# d = pm.Dirichlet("d", [1, 1])

trace = sample_smc(
    n_particles=1000,
    kernel="HMC",
    inner_kernel_params={
        "step_size": 0.01, 
        "integration_steps": 20,
    },
    iterations_to_diagnose=10,
    target_essn=0.5,
    num_mcmc_steps=10,
)

`

@ciguaran
Copy link
Contributor

Hi, I can tackle this could someone assign the issue to me?

@ciguaran
Copy link
Contributor

ciguaran commented Apr 5, 2024

@myravian could you try your example running it from this branch? I may have a fix https://github.com/ciguaran/pymc-experimental/tree/ciguaran_fix_smc_bj . Also super interested to know what are you using SMC for, it would be great if it would become an example notebook on how to use it!. let me know.

@myravian
Copy link
Author

myravian commented Apr 10, 2024

Unfortunately I still have the same error message:

  File "/local/home/vleboute/work/MULTIGRIS/mgris/sampling_smc_ciguaran.py", line 150, in sample_smc_blackjax
    total_iterations, particles, diagnosis = inference_loop(
                                             ^^^^^^^^^^^^^^^
  File "/local/home/vleboute/work/MULTIGRIS/mgris/sampling_smc_ciguaran.py", line 267, in inference_loop
    n_iter, final_state, _, diagnosis = jax.lax.while_loop(
                                        ^^^^^^^^^^^^^^^^^^^
  File "/local/home/vleboute/work/MULTIGRIS/mgris/sampling_smc_ciguaran.py", line 262, in one_step
    state, info = kernel.step(subk, state)
                  ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/blackjax/smc/adaptive_tempered.py", line 167, in step_fn
    return kernel(
           ^^^^^^^
  File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/blackjax/smc/adaptive_tempered.py", line 101, in kernel
    return tempered_kernel(rng_key, state, num_mcmc_steps, lmbda, mcmc_parameters)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/blackjax/smc/tempered.py", line 143, in kernel
    smc_state, info = smc.base.step(
                      ^^^^^^^^^^^^^^
  File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/blackjax/smc/base.py", line 140, in step
    particles, update_info = update_fn(keys, particles)
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/blackjax/smc/tempered.py", line 131, in mcmc_kernel
    state = mcmc_init_fn(position, tempered_logposterior_fn)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/blackjax/mcmc/hmc.py", line 89, in init
    logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/blackjax/smc/tempered.py", line 126, in tempered_logposterior_fn
    logprior = logprior_fn(position)
               ^^^^^^^^^^^^^^^^^^^^^
  File "/local/home/vleboute/work/MULTIGRIS/mgris/sampling_smc_ciguaran.py", line 380, in logp_fn_wrap
    return logp_fn(*particles)[0]
           ^^^^^^^^^^^^^^^^^^^
  File "/tmp/tmpoc516ktt", line 29, in jax_funcified_fgraph
    tensor_variable_13 = dimshuffle_1(d_simplex_)
                         ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/pytensor/link/jax/dispatch/elemwise.py", line 69, in dimshuffle
    res = jnp.transpose(x, op.transposition)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 681, in transpose
    return lax.transpose(a, axes_)
           ^^^^^^^^^^^^^^^^^^^^^^^
TypeError: transpose permutation isn't a permutation of operand dimensions, got permutation (0,) for operand shape (1000, 1).
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

I don't have a straightforward and simple illustration of the way I use SMC, but the gist of it is that I ran an astrophysical code to compute predictions corresponding to several 100 thousand parameter sets. Based on a set of observables I try to infer the parameters. Of course there are various issues such as regularity/completeness of grid and interpolation, but the main issue is the complex, multi-modal posterior distributions that we expect. From all the proof of concept and validation tests we did, SMC has been a great way to probe the prior space and to handle such difficult posteriors (provided the kernel parameters are well tuned). I'm by no means an expert in statistics and I rely a lot on empirical knowledge so I'm sure I'm not doing everything right though...!

@ciguaran
Copy link
Contributor

Could you share a full python file that reproduces the error? I've run the example you posted at the very beginning and it does work for me 🤔 .

@myravian
Copy link
Author

Here would be the script:
`import pymc as pm

from sampling_smc_ciguaran import sample_smc_blackjax as sample_smc

with pm.Model() as model:
c = pm.Normal("c", mu=10, sigma=10, shape=(1,))
d = pm.Dirichlet("d", [1, 1])

trace = sample_smc(
    n_particles=1000,
    kernel="HMC",
    inner_kernel_params={
        "step_size": 0.01,  # small values better
        "integration_steps": 20,
    },
    iterations_to_diagnose=10,
    target_essn=0.5,
    num_mcmc_steps=10,
)

`
Maybe it has to do with the blackjax/jax versions (1.1.0/0.4.21 in my system)

@ciguaran
Copy link
Contributor

Hi! so I was able to run the example you just shared via installing pymc-experimental from the branch.

pip install git+https://github.com/ciguaran/pymc-experimental@ciguaran_fix_smc_bj

is it possible that you are still using pymc-experimental from master?

@myravian
Copy link
Author

You're right, I was not using the proper versions, just tested it and it seems to work fine, thanks for the modifications!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants