-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Example to marginalize discrete latent variables (WIP) #646
base: enum-messengers
Are you sure you want to change the base?
Conversation
@tbsexton Wonderful tutorial!!! I am looking for its final version. ;) About the shape error, I think you need to declare some dimensions as event dimensions. Alternatively, you can use # beta hyperpriors
u = ny.sample("u", dist.Uniform(0, 1).expand([n_edges]).to_event(1))
v = ny.sample("v", dist.Gamma(1, 20).expand([n_edges]).to_event(1))
Λ = ny.sample("Λ", dist.Beta(u*v, (1-u)*v).to_event(1))
s_ij = squareform(Λ) # adjacency matrix to recover via inference
with ny.plate("n_cascades", n_cascades, dim=-1):
# infer infection source node
x0 = ny.sample("x0", dist.Categorical(probs=ϕ))
src = one_hot(x0, n_nodes)
# simulate ode and realize
p_infectious = batched_diffuse(s_ij, 5, src)
print(s_ij.shape, x0, src.shape, p_infectious.shape)
# infectious = spread_jax(s_ij, one_hot(x0, n_nodes),0, 5)
real = dist.Bernoulli(probs=p_infectious).to_event(1)
return ny.sample("obs", real, obs=infections) This still does not work yet because |
@tbsexton If you also modify
However, I prefer to do
by
and use However, MCMC can't find valid initial parameters for your model. This is probably an issue of our Bernoulli implementation so I will try to debug it.
No, you don't need to |
Interesting tricks, thanks! Implementing them at the moment, though, I'm getting a
|
With |
Oops... sorry, you will need to install the master branch of
|
hmmm different error now:
re: forcing to 1., yes, in the limit all nodes would go to 1, though shouldnt happen after a single iteration. There are two fixes for this
A bit more insight into the "how many time-steps should we assume" problem... i.e how far we should propagate the ODE in each plate before stopping to "measure" the infections?
Still working on this, but I was hoping to avoid the problem for a simpler model at first ^_^ |
@fehiepsi Oops! a couple things from your suggestions I had not implemented right... missed the I also forgot to change the But it runs now! Going to be a bit I guess, since it only seems to be able to process 1-2 it/sec.... |
Yeah, the distribution shape in Pyro (which corresponds Pyro/NumPyro 1-1 with the plate graph - hence enabling advanced inference mechanism) is flexible but it is not straightforward to keep things synced. I think the best resource is this tensor shape tutorial.
Probably it will be better if the chain moves to some useful domain. Currently, it takes 1000 leapfrog steps per sample. I am not sure if GPU helps because I can't access GPU in a few weeks. |
@fehiepsi a few updates:
|
@tbsexton That error is a bug! I'll push a fix soon. About the slowness, that is because your model has many latent variables. But I believe you can use
I think it is compatible with plate because, under the hood, we use a Unit distribution to store that
Does that solve your usage case? |
@tbsexton Unfortunately, we don't support enumerate sites with with ny.plate("n_edges", n_edges, dim=-1):
u = ny.sample("u", dist.Uniform(0, 1).expand([n_edges]))
v = ny.sample("v", dist.Gamma(1, 20).expand([n_edges]))
ρ = ny.sample("ρ", dist.Beta(u * v, (1 - u) * v))
A = ny.sample("A", dist.Bernoulli(probs=ρ))
# resolve the issue: `squareform` does not support batching
s_ij = squareform(ρ * A) You will need to make def squareform(edgelist):
"""edgelist to adj. matrix"""
from numpyro.distributions.util import vec_to_tril_matrix
half = vec_to_tril_matrix(edgelist, diagonal=-1)
full = half + np.swapaxes(half, -2, -1)
return full However, for the later code, it is a bit complicated to handle batch dimensions. It took me a while to realize that with ny.plate("n_cascades", n_cascades, dim=-2):
x0 = ny.sample("x0", dist.Categorical(probs=ϕ))
src = one_hot(x0, n_nodes)
# we can broadcast src using shapes of `src.shape[:-2]` and `s_ij.shape[:-2]`
# but I am lazy to do that so I use `matmul` here
src = (np.broadcast_to(np.eye(s_ij.shape[-1]), s_ij.shape) @ np.swapaxes(src, -2, -1)).squeeze(-1)
p_infectious = diffuse(s_ij, 1, src)
p_infectious = np.clip(p_infectious, a_max=1 - 1e-6, a_min=1e-6)
with ny.plate("n_nodes", n_nodes, dim=-1):
real = dist.Bernoulli(probs=p_infectious)
ny.sample("obs", real, obs=infections) MCMC seems to be pretty fast with the above code and gives high |
@tbsexton FYI, we just release 0.3 version, which supports enumeration over discrete latent variables. I really like this topic and I also have time now so I will look at your tutorial in more details (mainly to study and to make it work with NumPyro). If you have further reference to go through the notebook, please let me know, I would greatly appreciate! |
Awesome! I will get this pr onto v 0.3 soon. I took a bit of a break on this upon realizing the inference was not passing some basic benchmarking tests to recover network structure correctly. I just finished getting a pure optimization version working via pure Jax, and will be getting an example notebook live on one of my own repo's soon. Happy to link it here once it's up to help with this! |
Initial structure and outline, with lead up to inference. Currently shows some kind of broadcasting error.