Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing issue Samples are outside the support for DiscreteUniform dist… #1835

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
7 changes: 7 additions & 0 deletions numpyro/infer/hmc_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,13 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
and site["fn"].has_enumerate_support
and not site["is_observed"]
}
self._support_enumerates = {
name: site["fn"].enumerate_support(False)
for name, site in self._prototype_trace.items()
if site["type"] == "sample"
and site["fn"].has_enumerate_support
and not site["is_observed"]
}
self._gibbs_sites = [
name
for name, site in self._prototype_trace.items()
Expand Down
6 changes: 6 additions & 0 deletions numpyro/infer/mixed_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from jax import grad, jacfwd, lax, random
from jax.flatten_util import ravel_pytree
import jax
import jax.numpy as jnp

from numpyro.infer.hmc import momentum_generator
Expand Down Expand Up @@ -301,6 +302,11 @@ def body_fn(i, vals):
adapt_state=adapt_state,
)

z_discrete = jax.tree.map(
lambda idx, support: support[idx],
z_discrete,
self._support_enumerates,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing this might return in-support values but I worry that the algorithms are wrong. To compute potential energy correctly in the algorithm, we need to work with in-support values. I think you can pass support_enumerates into self._discrete_proposal_fn and change the proposal logic there.

    proposal = random.randint(rng_proposal, (), minval=0, maxval=support_size)
    # z_new_flat = z_discrete_flat.at[idx].set(proposal)
    z_new_flat = z_discrete_flat.at[idx].set(support_enumerate[proposal])

or for modified rw proposal

    i = random.randint(rng_proposal, (), minval=0, maxval=support_size - 1)
    # proposal = jnp.where(i >= z_discrete_flat[idx], i + 1, i)
    # proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, proposal)
    proposal_index = jnp.where(support_size[i] == z_discrete_flat[idx], support_size - 1, i)
    proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, support_size[proposal_index])
    z_new_flat = z_discrete_flat.at[idx].set(proposal)

or at discrete gibbs proposal

    proposal_index = jnp.where(support_enumerate[i] == z_init_flat[idx], support_size - 1, i)
    z_new_flat = z_init_flat.at[idx].set(support_enumerate[proposal_index])

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, thank you for the feedback. I will try this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fehiepsi how do you debug in numpyro? I tried jax.debug. but nothing happens.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use print most of the time. When actual values are needed, I sometimes use jax.disable_jit()

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fehiepsi I have issues with passing enumerate supports and traced values as the support arrays can have different sizes. I was thinking maybe to just pass the "lower bound of the support" as offset and combined with support_sizes it should make the trick. Are there discrete variables where the support is not a simple discrete range with step 1 between values?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for modified_rw_proposal I think you used support_size in place of support_enumerate, shouldn't it be:

    i = random.randint(rng_proposal, (), minval=0, maxval=support_size - 1)
    # proposal = jnp.where(i >= z_discrete_flat[idx], i + 1, i)
    # proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, proposal)
    proposal_index = jnp.where(support_enumerate[i] == z_discrete_flat[idx], support_size - 1, i)
    proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, support_enumerate[proposal_index])
    z_new_flat = z_discrete_flat.at[idx].set(proposal)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! your solutions are super cool! I haven't thought of different support sizes previously.

z = {**z_discrete, **hmc_state.z}
return MixedHMCState(z, hmc_state, rng_key, accept_prob)

Expand Down