-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Obscure NotImplementedError for Categorical #545
Comments
@tbsexton This is a duplicated issue of #542. |
The latent site Btw, we are targeting supporting discrete latent variables for the next release. But it requires a ton of work, which is undergoing. |
Ahh. Ok interesting. Definitely looking forward to that release! Numpyro has been a joy to use, coming from a long-time pymc3 user. I will have to think this out a bit more...unfortunately estimating the "patient-zero" is a part of the inference problem, for each observation in the plate. For now I might simply use the dirichlet and manually set the patient-zero via some kind of soft-argmax. |
@fehiepsi It seems as though the classic work-around continuous approximation to a categorical sample is the gumbel-softmax. Are there plans to implement the Gumbel distribution w/ mixin from pyro? |
@fehiepsi am I correct that this model will be enabled by enumeration after your Funsor integration? |
@tbsexton That would be a great "good first issue". :) @fritzo Yes, that is my purpose for the integration with funsor. |
Closed as a duplication of #542. |
@tbsexton With #572, it is possible to marginalize discrete latent variables so I think that your original model should work. Currently, we have some tests for the Gaussian mixture or latent Bernoulli models. Personally, I would like to turn your model into an example to illustrate this new functionality of NumPyro. If you agree, could you suggest me a dataset to run your model? |
@fehiepsi I would love that. Actually, I've been working on a paper that may include that model, but I honestly haven't been able to test it out until now! So I have code that would easily synthesize data for it (it's a form of network "backboning", like this work but with diffusion dynamics baked in ). Would it be possible for me to get set up on the appropriate branch and submit a pull request with the example? As a Notebook or jupytext script? |
Awesome!! I can't wait for your contribution. :D We used notebooks and put them in this folder. Please feel free to fork #572 and create a PR. |
Somewhat new to numpyro, though more familiar with Jax, so apologies if this is a known issue.
Modelling the boilerplate off of the baseball and time-series forcasting examples, working on a network inference problem (see here for an older jax version with discussion)
Setup looks like:
Where
infections
is an array with columns as nodes (0=susceptible, 1=infected) and rows as unique observations, simulated from a "ground-truth" network and different source nodes.Running based on documentation examples results in the following error that I'm having quite a hard time parsing (sorry for the wall of text):
The text was updated successfully, but these errors were encountered: