Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing issue Samples are outside the support for DiscreteUniform dist… #1835

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
6 changes: 3 additions & 3 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,9 +469,9 @@ def enumerate_support(self, expand=True):
raise NotImplementedError(
"Inhomogeneous `high` not supported by `enumerate_support`."
)
values = (self.low + jnp.arange(np.amax(self.high - self.low) + 1)).reshape(
(-1,) + (1,) * len(self.batch_shape)
)
low = jnp.reshape(self.low, -1)[0]
high = jnp.reshape(self.high, -1)[0]
values = jnp.arange(low, high + 1).reshape((-1,) + (1,) * len(self.batch_shape))
if expand:
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values
Expand Down
82 changes: 72 additions & 10 deletions numpyro/infer/hmc_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np

import jax
from jax import device_put, grad, jacfwd, random, value_and_grad
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
Expand Down Expand Up @@ -192,12 +193,22 @@ def __getstate__(self):


def _discrete_gibbs_proposal_body_fn(
z_init_flat, unravel_fn, pe_init, potential_fn, idx, i, val
z_init_flat,
unravel_fn,
pe_init,
potential_fn,
idx,
i,
val,
support_size,
support_enumerate,
):
rng_key, z, pe, log_weight_sum = val
rng_key, rng_transition = random.split(rng_key)
proposal = jnp.where(i >= z_init_flat[idx], i + 1, i)
z_new_flat = z_init_flat.at[idx].set(proposal)
proposal_index = jnp.where(
support_enumerate[i] == z_init_flat[idx], support_size - 1, i
)
z_new_flat = z_init_flat.at[idx].set(support_enumerate[proposal_index])
z_new = unravel_fn(z_new_flat)
pe_new = potential_fn(z_new)
log_weight_new = pe_init - pe_new
Expand All @@ -216,7 +227,9 @@ def _discrete_gibbs_proposal_body_fn(
return rng_key, z, pe, log_weight_sum


def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size):
def _discrete_gibbs_proposal(
rng_key, z_discrete, pe, potential_fn, idx, support_size, support_enumerate
):
# idx: current index of `z_discrete_flat` to update
# support_size: support size of z_discrete at the index idx

Expand All @@ -234,6 +247,8 @@ def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support
pe,
potential_fn,
idx,
support_size=support_size,
support_enumerate=support_enumerate,
)
init_val = (rng_key, z_discrete, pe, jnp.array(0.0))
rng_key, z_new, pe_new, _ = fori_loop(0, support_size - 1, body_fn, init_val)
Expand All @@ -242,7 +257,14 @@ def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support


def _discrete_modified_gibbs_proposal(
rng_key, z_discrete, pe, potential_fn, idx, support_size, stay_prob=0.0
rng_key,
z_discrete,
pe,
potential_fn,
idx,
support_size,
support_enumerate,
stay_prob=0.0,
):
assert isinstance(stay_prob, float) and stay_prob >= 0.0 and stay_prob < 1
z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)
Expand All @@ -253,6 +275,8 @@ def _discrete_modified_gibbs_proposal(
pe,
potential_fn,
idx,
support_size=support_size,
support_enumerate=support_enumerate,
)
# like gibbs_step but here, weight of the current value is 0
init_val = (rng_key, z_discrete, pe, jnp.array(-jnp.inf))
Expand All @@ -276,28 +300,41 @@ def _discrete_modified_gibbs_proposal(
return rng_key, z_new, pe_new, log_accept_ratio


def _discrete_rw_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size):
def _discrete_rw_proposal(
rng_key, z_discrete, pe, potential_fn, idx, support_size, support_enumerate
):
rng_key, rng_proposal = random.split(rng_key, 2)
z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)

proposal = random.randint(rng_proposal, (), minval=0, maxval=support_size)
z_new_flat = z_discrete_flat.at[idx].set(proposal)
z_new_flat = z_discrete_flat.at[idx].set(support_enumerate[proposal])
z_new = unravel_fn(z_new_flat)
pe_new = potential_fn(z_new)
log_accept_ratio = pe - pe_new
return rng_key, z_new, pe_new, log_accept_ratio


def _discrete_modified_rw_proposal(
rng_key, z_discrete, pe, potential_fn, idx, support_size, stay_prob=0.0
rng_key,
z_discrete,
pe,
potential_fn,
idx,
support_size,
support_enumerate,
stay_prob=0.0,
):
assert isinstance(stay_prob, float) and stay_prob >= 0.0 and stay_prob < 1
rng_key, rng_proposal, rng_stay = random.split(rng_key, 3)
z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)

i = random.randint(rng_proposal, (), minval=0, maxval=support_size - 1)
proposal = jnp.where(i >= z_discrete_flat[idx], i + 1, i)
proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, proposal)
proposal_index = jnp.where(
support_enumerate[i] == z_discrete_flat[idx], support_size - 1, i
)
proposal = jnp.where(
random.bernoulli(rng_stay, stay_prob), idx, support_enumerate[proposal_index]
)
z_new_flat = z_discrete_flat.at[idx].set(proposal)
z_new = unravel_fn(z_new_flat)
pe_new = potential_fn(z_new)
Expand Down Expand Up @@ -434,6 +471,31 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
and site["fn"].has_enumerate_support
and not site["is_observed"]
}

# All support_enumerates should have the same length to be used in the loop
# Each support is padded with zeros to have the same length
# ravel is used to maintain a consistant behaviour with `support_sizes`

max_length_support_enumerates = np.max(
[size for size in self._support_sizes.values()]
)

support_enumerates = {}
for name, support_size in self._support_sizes.items():
site = self._prototype_trace[name]
enumerate_support = site["fn"].enumerate_support(True).T
# Only the last dimension that corresponds to support size is padded
pad_width = [(0, 0) for _ in range(len(enumerate_support.shape) - 1)] + [
(0, max_length_support_enumerates - enumerate_support.shape[-1])
]
padded_enumerate_support = np.pad(enumerate_support, pad_width)

support_enumerates[name] = padded_enumerate_support

self._support_enumerates = jax.vmap(
lambda x: ravel_pytree(x)[0], in_axes=len(support_size.shape), out_axes=1
)(support_enumerates)

self._gibbs_sites = [
name
for name, site in self._prototype_trace.items()
Expand Down
1 change: 1 addition & 0 deletions numpyro/infer/mixed_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def update_discrete(
partial(potential_fn, z_hmc=hmc_state.z),
idx,
self._support_sizes_flat[idx],
self._support_enumerates[idx],
)
# Algo 1, line 20: depending on reject or refract, we will update
# the discrete variable and its corresponding kinetic energy. In case of
Expand Down
89 changes: 89 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3423,3 +3423,92 @@ def test_gaussian_random_walk_linear_recursive_equivalence():
x2 = dist2.sample(random.PRNGKey(7))
assert jnp.allclose(x1, x2.squeeze())
assert jnp.allclose(dist1.log_prob(x1), dist2.log_prob(x2))


def test_discrete_uniform_with_mixedhmc():
import numpyro
import numpyro.distributions as dist
from numpyro.infer import HMC, MCMC, MixedHMC

def sample_mixedhmc(model_fn, num_samples, **kwargs):
kernel = HMC(model_fn, trajectory_length=1.2)
kernel = MixedHMC(kernel, num_discrete_updates=20, **kwargs)
mcmc = MCMC(kernel, num_warmup=100, num_samples=num_samples, progress_bar=False)
key = jax.random.PRNGKey(0)
mcmc.run(key)
samples = mcmc.get_samples()
return samples

num_samples = 1000
mixed_hmc_kwargs = [
{"random_walk": False, "modified": False},
{"random_walk": True, "modified": False},
{"random_walk": True, "modified": True},
{"random_walk": False, "modified": True},
]

# Case 1: one discrete uniform with one categorical
def model_1():
numpyro.sample("x0", dist.DiscreteUniform(10, 12))
numpyro.sample("x1", dist.Categorical(np.asarray([0.25, 0.25, 0.25, 0.25])))

for kwargs in mixed_hmc_kwargs:
samples = sample_mixedhmc(model_1, num_samples, **kwargs)

assert jnp.all(
(samples["x0"] >= 10) & (samples["x0"] <= 12)
), f"Failed with {kwargs=}"
assert jnp.all(
(samples["x1"] >= 0) & (samples["x1"] <= 3)
), f"Failed with {kwargs=}"

def model_2():
numpyro.sample("x0", dist.Categorical(0.25 * jnp.ones((4,))))
numpyro.sample("x1", dist.Categorical(0.1 * jnp.ones((10,))))

# Case 2: 2 categorical with different support lengths
for kwargs in mixed_hmc_kwargs:
samples = sample_mixedhmc(model_2, num_samples, **kwargs)

assert jnp.all(
(samples["x0"] >= 0) & (samples["x0"] <= 3)
), f"Failed with {kwargs=}"
assert jnp.all(
(samples["x1"] >= 0) & (samples["x1"] <= 9)
), f"Failed with {kwargs=}"

def model_3():
numpyro.sample("x0", dist.Categorical(0.25 * jnp.ones((3, 4))))
numpyro.sample("x1", dist.Categorical(0.1 * jnp.ones((3, 10))))

# Case 3: 2 categorical with different support lengths and batched by 3
for kwargs in mixed_hmc_kwargs:
samples = sample_mixedhmc(model_3, num_samples, **kwargs)

assert jnp.all(
(samples["x0"] >= 0) & (samples["x0"] <= 3)
), f"Failed with {kwargs=}"
assert jnp.all(
(samples["x1"] >= 0) & (samples["x1"] <= 9)
), f"Failed with {kwargs=}"

def model_4():
dist0 = dist.Categorical(0.25 * jnp.ones((3, 4)))
numpyro.sample("x0", dist0)
dist1 = dist.DiscreteUniform(10 * jnp.ones((3,)), 19 * jnp.ones((3,)))
numpyro.sample("x1", dist1)

# Case 4: 1 categorical with different support lengths and batched by 3
for kwargs in mixed_hmc_kwargs:
samples = sample_mixedhmc(model_4, num_samples, **kwargs)

assert jnp.all(
(samples["x0"] >= 0) & (samples["x0"] <= 3)
), f"Failed with {kwargs=}"
assert jnp.all(
(samples["x1"] >= 10) & (samples["x1"] <= 20)
), f"Failed with {kwargs=}"


if __name__ == "__main__":
test_discrete_uniform_with_mixedhmc()
Loading