-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixing issue Samples are outside the support for DiscreteUniform dist… #1835
base: master
Are you sure you want to change the base?
Conversation
numpyro/infer/mixed_hmc.py
Outdated
lambda idx, support: support[idx], | ||
z_discrete, | ||
self._support_enumerates, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doing this might return in-support values but I worry that the algorithms are wrong. To compute potential energy correctly in the algorithm, we need to work with in-support values. I think you can pass support_enumerates
into self._discrete_proposal_fn
and change the proposal logic there.
proposal = random.randint(rng_proposal, (), minval=0, maxval=support_size)
# z_new_flat = z_discrete_flat.at[idx].set(proposal)
z_new_flat = z_discrete_flat.at[idx].set(support_enumerate[proposal])
or for modified rw proposal
i = random.randint(rng_proposal, (), minval=0, maxval=support_size - 1)
# proposal = jnp.where(i >= z_discrete_flat[idx], i + 1, i)
# proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, proposal)
proposal_index = jnp.where(support_size[i] == z_discrete_flat[idx], support_size - 1, i)
proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, support_size[proposal_index])
z_new_flat = z_discrete_flat.at[idx].set(proposal)
or at discrete gibbs proposal
proposal_index = jnp.where(support_enumerate[i] == z_init_flat[idx], support_size - 1, i)
z_new_flat = z_init_flat.at[idx].set(support_enumerate[proposal_index])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, thank you for the feedback. I will try this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fehiepsi how do you debug in numpyro? I tried jax.debug.
but nothing happens.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I use print
most of the time. When actual values are needed, I sometimes use jax.disable_jit()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fehiepsi I have issues with passing enumerate supports and traced values as the support arrays can have different sizes. I was thinking maybe to just pass the "lower bound of the support" as offset and combined with support_sizes
it should make the trick. Are there discrete variables where the support is not a simple discrete range with step 1 between values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for modified_rw_proposal
I think you used support_size
in place of support_enumerate
, shouldn't it be:
i = random.randint(rng_proposal, (), minval=0, maxval=support_size - 1)
# proposal = jnp.where(i >= z_discrete_flat[idx], i + 1, i)
# proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, proposal)
proposal_index = jnp.where(support_enumerate[i] == z_discrete_flat[idx], support_size - 1, i)
proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, support_enumerate[proposal_index])
z_new_flat = z_discrete_flat.at[idx].set(proposal)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! your solutions are super cool! I haven't thought of different support sizes previously.
numpyro/infer/hmc_gibbs.py
Outdated
self._support_enumerates = np.zeros( | ||
(len(self._support_sizes), max_length_support_enumerates), dtype=int | ||
) | ||
for i, (name, site) in enumerate(self._prototype_trace.items()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great solution! I just have a couple of comments:
- it might be better to loop over names in support_sizes and get site via
site = self._prototype_trace[name]
- we use
ravel_pytree
to flatten support_sizes. so we might want to keep the same behavior here. I don't have a great solution for this, maybe
support_enumerates = {}
for name, support_size in self._support_sizes.items():
site = self._prototype_trace[name]
enumerate_support = site["fn"].enumerate_support(False)
padded_enumerate_support = np.pad(enumerate_support, (0, max_length_support_enumerates - enumerate_support.shape[0]))
padded_enumerate_support = np.broadcast_to(padded_enumerate_support, support_size.shape + (max_length_support_enumerates,))
support_enumerates[name] = padded_enumerate_support
self._support_enumerates = jax.vmap(lambda x: ravel_pytree(x)[0], in_axes=1, out_axes=1)(support_enumerates)
@fehiepsi it worked fine with |
I think we need to ravel along the first axis. The second axis (corresponds to |
numpyro/infer/hmc_gibbs.py
Outdated
for site in self._prototype_trace.values() | ||
if site["type"] == "sample" | ||
and site["fn"].has_enumerate_support | ||
and not site["is_observed"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it is better to loop over support_sizes
: for name, site in self._prototype_trace.items() if name in support_sizes
the first axis is |
we vmap over the batch axis, which is the second axis, i.e. in_axes=1 |
Could you also add a simple test (as in the issue) for this? you can run |
I applied the lint/format and I added a test.
ok, but the So the following line: support_size.shape + (max_length_support_enumerates,), is just equivalent to Maybe you have an example where |
That is a good point. I thought support sizes contain flatten arrays. Sorry for the confusion. I guess we need to move the enumerate dimension to the first axis before vmapping like you did
|
I tried the following direction: max_length_support_enumerates = np.max(
[size for size in self._support_sizes.values()]
)
support_enumerates = {}
for name, support_size in self._support_sizes.items():
site = self._prototype_trace[name]
enumerate_support = site["fn"].enumerate_support(True).T
# Only the last dimension that corresponds to support size is padded
pad_width = [(0, 0) for _ in range(len(enumerate_support.shape) - 1)] + [
(0, max_length_support_enumerates - enumerate_support.shape[-1])
]
padded_enumerate_support = np.pad(enumerate_support, pad_width)
support_enumerates[name] = padded_enumerate_support
self._support_enumerates = jax.vmap(
lambda x: ravel_pytree(x)[0], in_axes=len(support_size.shape), out_axes=1
)(support_enumerates) which work with the following cases: def model_1():
numpyro.sample("x0", dist.DiscreteUniform(10, 12))
numpyro.sample("x1", dist.Categorical(np.asarray([0.25, 0.25, 0.25, 0.25])))
def model_2():
numpyro.sample("x0", dist.Categorical(0.25 * jnp.ones((4,))))
numpyro.sample("x1", dist.Categorical(0.1 * jnp.ones((10,))))
def model_3():
numpyro.sample("x0", dist.Categorical(0.25 * jnp.ones((3, 4))))
numpyro.sample("x1", dist.Categorical(0.1 * jnp.ones((3, 10)))) But fails when I try to batch def model_4():
numpyro.sample("x1", dist.DiscreteUniform(10 * jnp.ones((3,)), 19 * jnp.ones((3,)))) with the following exception which comes before the code I added (when the
|
The By the way, maybe we need to use
|
Hmm, there seems to have a bug at DiscreteUniform.enumerate_support. |
@fehiepsi sorry for the delay... other things happened I couldn't follow up. Yes, let me test this now! |
…tests are passing when using changes from PR pyro-ppl#1859
This fixes issue #1834 for
MixedHMC
sampling withDiscreteUniform
distribution sampling outside the support without using theenumerate_support
.