Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more tests for collapse handler #809

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
address expanded distribution
fehiepsi committed Nov 25, 2020
commit 8f8719f9b1a04073e6d639eb528300a3855e0d39
4 changes: 3 additions & 1 deletion numpyro/handlers.py
Original file line number Diff line number Diff line change
@@ -83,7 +83,7 @@
import jax.numpy as jnp

import numpyro
from numpyro.distributions.distribution import COERCIONS
from numpyro.distributions.distribution import COERCIONS, ExpandedDistribution
from numpyro.primitives import _PYRO_STACK, Messenger, apply_stack, plate
from numpyro.util import not_jax_tracer

@@ -268,6 +268,8 @@ def process_message(self, msg):
if msg["type"] == "sample":
if msg["value"] is None:
msg["value"] = msg["name"]
if isinstance(msg["fn"], ExpandedDistribution):
msg["fn"] = msg["fn"].base_dist

if isinstance(msg["fn"], Funsor) or isinstance(msg["value"], (str, Funsor)):
msg["stop"] = True
1 change: 0 additions & 1 deletion test/test_handlers.py
Original file line number Diff line number Diff line change
@@ -636,7 +636,6 @@ def model():
x = numpyro.sample("x", dist.Normal(0, 1))
with handlers.collapse():
with handlers.plate("data", len(data)):
# TODO: address expanded distribution
y = numpyro.sample("y", dist.Normal(x, 1.))
numpyro.sample("z", dist.Normal(y, 1.), obs=data)