Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing issue Samples are outside the support for DiscreteUniform dist… #1835

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
82 changes: 72 additions & 10 deletions numpyro/infer/hmc_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,22 @@ def __getstate__(self):


def _discrete_gibbs_proposal_body_fn(
z_init_flat, unravel_fn, pe_init, potential_fn, idx, i, val
z_init_flat,
unravel_fn,
pe_init,
potential_fn,
idx,
i,
val,
support_size,
support_enumerate,
):
rng_key, z, pe, log_weight_sum = val
rng_key, rng_transition = random.split(rng_key)
proposal = jnp.where(i >= z_init_flat[idx], i + 1, i)
z_new_flat = z_init_flat.at[idx].set(proposal)
proposal_index = jnp.where(
support_enumerate[i] == z_init_flat[idx], support_size - 1, i
)
z_new_flat = z_init_flat.at[idx].set(support_enumerate[proposal_index])
z_new = unravel_fn(z_new_flat)
pe_new = potential_fn(z_new)
log_weight_new = pe_init - pe_new
Expand All @@ -216,7 +226,9 @@ def _discrete_gibbs_proposal_body_fn(
return rng_key, z, pe, log_weight_sum


def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size):
def _discrete_gibbs_proposal(
rng_key, z_discrete, pe, potential_fn, idx, support_size, support_enumerate
):
# idx: current index of `z_discrete_flat` to update
# support_size: support size of z_discrete at the index idx

Expand All @@ -234,6 +246,8 @@ def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support
pe,
potential_fn,
idx,
support_size=support_size,
support_enumerate=support_enumerate,
)
init_val = (rng_key, z_discrete, pe, jnp.array(0.0))
rng_key, z_new, pe_new, _ = fori_loop(0, support_size - 1, body_fn, init_val)
Expand All @@ -242,7 +256,14 @@ def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support


def _discrete_modified_gibbs_proposal(
rng_key, z_discrete, pe, potential_fn, idx, support_size, stay_prob=0.0
rng_key,
z_discrete,
pe,
potential_fn,
idx,
support_size,
support_enumerate,
stay_prob=0.0,
):
assert isinstance(stay_prob, float) and stay_prob >= 0.0 and stay_prob < 1
z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)
Expand All @@ -253,6 +274,8 @@ def _discrete_modified_gibbs_proposal(
pe,
potential_fn,
idx,
support_size=support_size,
support_enumerate=support_enumerate,
)
# like gibbs_step but here, weight of the current value is 0
init_val = (rng_key, z_discrete, pe, jnp.array(-jnp.inf))
Expand All @@ -276,28 +299,41 @@ def _discrete_modified_gibbs_proposal(
return rng_key, z_new, pe_new, log_accept_ratio


def _discrete_rw_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size):
def _discrete_rw_proposal(
rng_key, z_discrete, pe, potential_fn, idx, support_size, support_enumerate
):
rng_key, rng_proposal = random.split(rng_key, 2)
z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)

proposal = random.randint(rng_proposal, (), minval=0, maxval=support_size)
z_new_flat = z_discrete_flat.at[idx].set(proposal)
z_new_flat = z_discrete_flat.at[idx].set(support_enumerate[proposal])
z_new = unravel_fn(z_new_flat)
pe_new = potential_fn(z_new)
log_accept_ratio = pe - pe_new
return rng_key, z_new, pe_new, log_accept_ratio


def _discrete_modified_rw_proposal(
rng_key, z_discrete, pe, potential_fn, idx, support_size, stay_prob=0.0
rng_key,
z_discrete,
pe,
potential_fn,
idx,
support_size,
support_enumerate,
stay_prob=0.0,
):
assert isinstance(stay_prob, float) and stay_prob >= 0.0 and stay_prob < 1
rng_key, rng_proposal, rng_stay = random.split(rng_key, 3)
z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)

i = random.randint(rng_proposal, (), minval=0, maxval=support_size - 1)
proposal = jnp.where(i >= z_discrete_flat[idx], i + 1, i)
proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, proposal)
proposal_index = jnp.where(
support_enumerate[i] == z_discrete_flat[idx], support_size - 1, i
)
proposal = jnp.where(
random.bernoulli(rng_stay, stay_prob), idx, support_enumerate[proposal_index]
)
z_new_flat = z_discrete_flat.at[idx].set(proposal)
z_new = unravel_fn(z_new_flat)
pe_new = potential_fn(z_new)
Expand Down Expand Up @@ -434,6 +470,32 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
and site["fn"].has_enumerate_support
and not site["is_observed"]
}
max_length_support_enumerates = max(
(
site["fn"].enumerate_support(False).shape[0]
for site in self._prototype_trace.values()
if site["type"] == "sample"
and site["fn"].has_enumerate_support
and not site["is_observed"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it is better to loop over support_sizes: for name, site in self._prototype_trace.items() if name in support_sizes

)
)
# All support_enumerates should have the same length to be used in the loop
# Each support is padded with zeros to have the same length
self._support_enumerates = np.zeros(
(len(self._support_sizes), max_length_support_enumerates), dtype=int
)
for i, (name, site) in enumerate(self._prototype_trace.items()):
Copy link
Member

@fehiepsi fehiepsi Jul 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great solution! I just have a couple of comments:

  • it might be better to loop over names in support_sizes and get site via site = self._prototype_trace[name]
  • we use ravel_pytree to flatten support_sizes. so we might want to keep the same behavior here. I don't have a great solution for this, maybe
support_enumerates = {}
for name, support_size in self._support_sizes.items():
    site = self._prototype_trace[name]
    enumerate_support = site["fn"].enumerate_support(False)
    padded_enumerate_support = np.pad(enumerate_support, (0, max_length_support_enumerates - enumerate_support.shape[0]))
    padded_enumerate_support = np.broadcast_to(padded_enumerate_support, support_size.shape + (max_length_support_enumerates,))
    support_enumerates[name] = padded_enumerate_support

self._support_enumerates = jax.vmap(lambda x: ravel_pytree(x)[0], in_axes=1, out_axes=1)(support_enumerates)

if (
site["type"] == "sample"
and site["fn"].has_enumerate_support
and not site["is_observed"]
):
self._support_enumerates[
i, : site["fn"].enumerate_support(False).shape[0]
] = site["fn"].enumerate_support(False)
self._support_enumerates = jnp.asarray(
self._support_enumerates, dtype=jnp.int32
)
self._gibbs_sites = [
name
for name, site in self._prototype_trace.items()
Expand Down
2 changes: 2 additions & 0 deletions numpyro/infer/mixed_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from jax import grad, jacfwd, lax, random
from jax.flatten_util import ravel_pytree
import jax
import jax.numpy as jnp

from numpyro.infer.hmc import momentum_generator
Expand Down Expand Up @@ -138,6 +139,7 @@ def update_discrete(
partial(potential_fn, z_hmc=hmc_state.z),
idx,
self._support_sizes_flat[idx],
self._support_enumerates[idx],
)
# Algo 1, line 20: depending on reject or refract, we will update
# the discrete variable and its corresponding kinetic energy. In case of
Expand Down
Loading