diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 337eadcab..d43748b9a 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -77,6 +77,7 @@ """ from collections import OrderedDict +from functools import reduce import warnings import numpy as np @@ -267,6 +268,20 @@ def process_message(self, msg): msg["stop"] = True +def _eager_expand_fn(fn): + if isinstance(fn, Independent): + reinterpreted_batch_ndims = fn.reinterpreted_batch_ndims + fn = fn.base_dist + else: + reinterpreted_batch_ndims = 0 # no-op for to_event method + if isinstance(fn, ExpandedDistribution): + batch_shape = fn.batch_shape + base_batch_shape = fn.base_dist.batch_shape + appended_shape = batch_shape[:len(batch_shape) - len(base_batch_shape)] + fn = tree_map(lambda x: jnp.broadcast_to(x, appended_shape + jnp.shape(x)), fn.base_dist) + return fn.to_event(reinterpreted_batch_ndims) + + class collapse(trace): """ EXPERIMENTAL Collapses all sites in the context by lazily sampling and @@ -287,14 +302,24 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def process_message(self, msg): - from funsor.terms import Funsor + if msg["type"] != "sample": + return - if msg["type"] == "sample": - if msg["value"] is None: - msg["value"] = msg["name"] + import funsor + + # Eagerly convert fn and value to Funsor. + dim_to_name = {f.dim: f.name for f in msg["cond_indep_stack"]} + dim_to_name.update(self.preserved_plates) + if isinstance(msg["fn"], (Independent, ExpandedDistribution)): + msg["fn"] = _eager_expand_fn(msg["fn"]) + msg["fn"] = funsor.to_funsor(msg["fn"], funsor.Real, dim_to_name) + domain = msg["fn"].inputs["value"] + if msg["value"] is None: + msg["value"] = funsor.Variable(msg["name"], domain) + else: + msg["value"] = funsor.to_funsor(msg["value"], domain, dim_to_name) - if isinstance(msg["fn"], Funsor) or isinstance(msg["value"], (str, Funsor)): - msg["stop"] = True + msg["stop"] = True def __enter__(self): self.preserved_plates = frozenset( @@ -304,15 +329,22 @@ def __enter__(self): return super().__enter__() def __exit__(self, exc_type, exc_value, traceback): - import funsor - _coerce = COERCIONS.pop() assert _coerce is self._coerce super().__exit__(exc_type, exc_value, traceback) if exc_type is not None: + self.trace.clear() + self.preserved_plates.clear() return + if any(site["type"] == "sample" for site in self.trace.values()): + name, log_prob, _, _ = self._get_log_prob() + numpyro.factor(name, log_prob.data) + + def _get_log_prob(self): + import funsor + # Convert delayed statements to pyro.factor() reduced_vars = [] log_prob_terms = [] @@ -322,24 +354,28 @@ def __exit__(self, exc_type, exc_value, traceback): continue if not site["is_observed"]: reduced_vars.append(name) - dim_to_name = {f.dim: f.name for f in site["cond_indep_stack"]} - fn = funsor.to_funsor(site["fn"], funsor.Real, dim_to_name) - value = site["value"] - if not isinstance(value, str): - value = funsor.to_funsor(site["value"], fn.inputs["value"], dim_to_name) - log_prob_terms.append(fn(value=value)) + log_prob_terms.append(site["fn"](value=site["value"])) plates |= frozenset(f.name for f in site["cond_indep_stack"]) - assert log_prob_terms, "nothing to collapse" - reduced_plates = plates - self.preserved_plates - log_prob = funsor.sum_product.sum_product( - funsor.ops.logaddexp, - funsor.ops.add, - log_prob_terms, - eliminate=frozenset(reduced_vars) | reduced_plates, - plates=plates, - ) name = reduced_vars[0] - numpyro.factor(name, log_prob.data) + reduced_vars = frozenset(reduced_vars) + assert log_prob_terms, "nothing to collapse" + reduced_plates = plates - frozenset(self.preserved_plates.values()) + self.trace.clear() + self.preserved_plates.clear() + if reduced_plates: + log_prob = funsor.sum_product.sum_product( + funsor.ops.logaddexp, + funsor.ops.add, + log_prob_terms, + eliminate=frozenset(reduced_vars) | reduced_plates, + plates=plates, + ) + log_joint = NotImplemented + else: + log_joint = reduce(funsor.ops.add, log_prob_terms) + log_prob = log_joint.reduce(funsor.ops.logaddexp, reduced_vars) + + return name, log_prob, log_joint, reduced_vars class condition(Messenger): diff --git a/numpyro/primitives.py b/numpyro/primitives.py index 670564a9d..6a1cbbfef 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -11,6 +11,7 @@ import jax.numpy as jnp import numpyro +from numpyro.distributions.distribution import Distribution from numpyro.util import identity _PYRO_STACK = [] @@ -501,7 +502,8 @@ def process_message(self, msg): cond_indep_stack = msg["cond_indep_stack"] frame = CondIndepStackFrame(self.name, self.dim, self.subsample_size) cond_indep_stack.append(frame) - if msg["type"] == "sample": + # only expand if fn is Distribution, not a Funsor + if msg['type'] == 'sample' and isinstance(msg['fn'], Distribution): expected_shape = self._get_batch_shape(cond_indep_stack) dist_batch_shape = msg["fn"].batch_shape if "sample_shape" in msg["kwargs"]: diff --git a/test/test_handlers.py b/test/test_handlers.py index 745781549..381b19dbc 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -713,7 +713,6 @@ def guide(): svi.update(svi_state) -@pytest.mark.xfail(reason="missing pattern in Funsor") def test_collapse_beta_binomial_plate(): data = np.array([0.0, 1.0, 5.0, 5.0]) @@ -734,6 +733,115 @@ def guide(): svi.update(svi_state) +def test_collapse_normal_normal(): + data = np.array(0.) + + def model(): + x = numpyro.sample("x", dist.Normal(0, 1)) + with handlers.collapse(): + y = numpyro.sample("y", dist.Normal(x, 1.)) + numpyro.sample("z", dist.Normal(y, 1.), obs=data) + + def guide(): + loc = numpyro.param("loc", 0.) + scale = numpyro.param("scale", 1., constraint=constraints.positive) + numpyro.sample("x", dist.Normal(loc, scale)) + + svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO()) + svi_state = svi.init(random.PRNGKey(0)) + svi.update(svi_state) + + +def test_collapse_normal_normal_plate(): + data = np.arange(5.) + + def model(): + x = numpyro.sample("x", dist.Normal(0, 1)) + with handlers.collapse(): + y = numpyro.sample("y", dist.Normal(x, 1.)) + with handlers.plate("data", len(data)): + numpyro.sample("z", dist.Normal(y, 1.), obs=data) + + def guide(): + loc = numpyro.param("loc", 0.) + scale = numpyro.param("scale", 1., constraint=constraints.positive) + numpyro.sample("x", dist.Normal(loc, scale)) + + svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO()) + svi_state = svi.init(random.PRNGKey(0)) + svi.update(svi_state) + + +def test_collapse_normal_plate_normal(): + data = np.arange(5.) + + def model(): + x = numpyro.sample("x", dist.Normal(0, 1)) + with handlers.collapse(): + with handlers.plate("data", len(data)): + y = numpyro.sample("y", dist.Normal(x, 1.)) + numpyro.sample("z", dist.Normal(y, 1.), obs=data) + + def guide(): + loc = numpyro.param("loc", 0.) + scale = numpyro.param("scale", 1., constraint=constraints.positive) + numpyro.sample("x", dist.Normal(loc, scale)) + + svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO()) + svi_state = svi.init(random.PRNGKey(0)) + svi.update(svi_state) + + +@pytest.mark.xfail(reason="missing pattern in Funsor") +def test_collapse_diag_normal_plate_normal(): + d = 3 + data = np.ones((5, d)) + + def model(): + x = numpyro.sample("x", dist.Normal(0, 1)) + with handlers.collapse(): + with handlers.plate("data", len(data)): + y = numpyro.sample("y", dist.Normal(x, 1.).expand([d]).to_event(1)) + numpyro.sample("z", dist.Normal(y, 1.).to_event(1), obs=data) + + def guide(): + loc = numpyro.param("loc", 0.) + scale = numpyro.param("scale", 1., constraint=constraints.positive) + numpyro.sample("x", dist.Normal(loc, scale)) + + svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO()) + svi_state = svi.init(random.PRNGKey(0)) + svi.update(svi_state) + + +@pytest.mark.xfail(reason="missing pattern in Funsor") +def test_collapse_normal_mvn_mvn(): + T, d, S = 5, 2, 3 + data = jnp.ones((T, S)) + + def model(): + x = numpyro.sample("x", dist.Normal(0, 1)) + with handlers.collapse(): + with numpyro.plate("d", d, dim=-1): + beta0 = numpyro.sample("beta0", dist.Normal(x, 1.).expand([d, S]).to_event(1)) + beta = numpyro.sample( + "beta", dist.MultivariateNormal(beta0, scale_tril=jnp.eye(S))) + + # this fails because beta shape is (3,) while it should be (2, 3) + mean = jnp.ones((T, d)) @ beta + with numpyro.plate("data", T, dim=-1): + numpyro.sample("obs", dist.MultivariateNormal(mean, scale_tril=jnp.eye(S)), obs=data) + + def guide(): + loc = numpyro.param("loc", 0.) + scale = numpyro.param("scale", 1., constraint=constraints.positive) + numpyro.sample("x", dist.Normal(loc, scale)) + + svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO()) + svi_state = svi.init(random.PRNGKey(0)) + svi.update(svi_state) + + def test_prng_key(): assert numpyro.prng_key() is None