From cdef8d18afabddcd866657d6ef5cd3e7813c4cc7 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Thu, 5 Nov 2020 23:57:55 -0600 Subject: [PATCH 01/10] subclass Reparam --- numpyro/contrib/ns.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 numpyro/contrib/ns.py diff --git a/numpyro/contrib/ns.py b/numpyro/contrib/ns.py new file mode 100644 index 000000000..cef9cf400 --- /dev/null +++ b/numpyro/contrib/ns.py @@ -0,0 +1,40 @@ +from numpyro.handlers import reparam +from numpyro.infer.reparam import Reparam +from numpyro.infer import log_likelihood + +from jaxns.nested_sampling import NestedSampler +from jaxns.prior_transforms import PriorChain + + +class UnitCubeReparam(Reparam): + # support distributions with quantiles implementation + def __call__(self, name, fn, obs): + assert obs is None, "TransformReparam does not support observe statements" + fn, batch_shape = self._unexpand(fn) + # if fn is transformed distribution, get transform and base_dist, reparam the base_dist + # TODO: extract quantiles function + + +class NestedSampling: + def __init__(self, model, num_live_points, *, sampler_name='slice', **ns_kwargs): + self.model = model + self.num_live_points = num_live_points + self.sampler_name = sampler_name + self.ns_kwargs = ns_kwargs # max_samples, termination_frac,... + + def run(self, rng_key, *args, **kwargs): + # Step 1: reparam the model so that latent sites have Uniform(0, 1) priors + reparam_model = reparam(self.model, + config=lambda msg: None if msg.get("is_observed") else UnitCubeReparam()) + + # Step 2: compute the likelihood of the model + def ll_fn(*params): + return list(log_likelihood(reparam_model, params, *args, batch_ndims=0, **kwargs).values())[0] + + # Step 3: use NestedSampler with empty prior chain + prior_chain = PriorChain() + ns = NestedSampler(ll_fn, prior_chain, collect_samples=True, sampler_name=self.sampler_name) + results = ns(rng_key, self.num_live_points) + samples = results["samples"] + # TODO: transform samples back or rerun reparam_model to get deterministic sites + return samples From 614e81f970f16f6121b185ac0561781fc818daee Mon Sep 17 00:00:00 2001 From: Du Phan Date: Fri, 6 Nov 2020 18:12:10 -0600 Subject: [PATCH 02/10] support collapse plate and test --- numpyro/handlers.py | 8 ++++++-- test/test_handlers.py | 1 - 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index c1296c759..d283a96b9 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -277,18 +277,22 @@ def __enter__(self): COERCIONS.append(self._coerce) return super().__enter__() - def __exit__(self, *args, **kwargs): + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is not None: + return super().__exit__(exc_type, exc_value, traceback) import funsor _coerce = COERCIONS.pop() assert _coerce is self._coerce - super().__exit__(*args, **kwargs) + super().__exit__(exc_type, exc_value, traceback) # Convert delayed statements to pyro.factor() reduced_vars = [] log_prob_terms = [] plates = frozenset() for name, site in self.trace.items(): + if site["type"] != "sample": + continue if not site["is_observed"]: reduced_vars.append(name) dim_to_name = {f.dim: f.name for f in site["cond_indep_stack"]} diff --git a/test/test_handlers.py b/test/test_handlers.py index 2b32435cf..ece6779cb 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -569,7 +569,6 @@ def guide(): svi.update(svi_state) -@pytest.mark.xfail(reason="missing pattern in Funsor") def test_collapse_beta_binomial_plate(): data = np.array([0., 1., 5., 5.]) From 6d55b3042b06cfe62daae62a475ef2d4f3ae3b11 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Tue, 10 Nov 2020 23:25:26 -0600 Subject: [PATCH 03/10] add some TODO to work on --- numpyro/contrib/ns.py | 40 ---------------------------------------- numpyro/primitives.py | 2 ++ 2 files changed, 2 insertions(+), 40 deletions(-) delete mode 100644 numpyro/contrib/ns.py diff --git a/numpyro/contrib/ns.py b/numpyro/contrib/ns.py deleted file mode 100644 index cef9cf400..000000000 --- a/numpyro/contrib/ns.py +++ /dev/null @@ -1,40 +0,0 @@ -from numpyro.handlers import reparam -from numpyro.infer.reparam import Reparam -from numpyro.infer import log_likelihood - -from jaxns.nested_sampling import NestedSampler -from jaxns.prior_transforms import PriorChain - - -class UnitCubeReparam(Reparam): - # support distributions with quantiles implementation - def __call__(self, name, fn, obs): - assert obs is None, "TransformReparam does not support observe statements" - fn, batch_shape = self._unexpand(fn) - # if fn is transformed distribution, get transform and base_dist, reparam the base_dist - # TODO: extract quantiles function - - -class NestedSampling: - def __init__(self, model, num_live_points, *, sampler_name='slice', **ns_kwargs): - self.model = model - self.num_live_points = num_live_points - self.sampler_name = sampler_name - self.ns_kwargs = ns_kwargs # max_samples, termination_frac,... - - def run(self, rng_key, *args, **kwargs): - # Step 1: reparam the model so that latent sites have Uniform(0, 1) priors - reparam_model = reparam(self.model, - config=lambda msg: None if msg.get("is_observed") else UnitCubeReparam()) - - # Step 2: compute the likelihood of the model - def ll_fn(*params): - return list(log_likelihood(reparam_model, params, *args, batch_ndims=0, **kwargs).values())[0] - - # Step 3: use NestedSampler with empty prior chain - prior_chain = PriorChain() - ns = NestedSampler(ll_fn, prior_chain, collect_samples=True, sampler_name=self.sampler_name) - results = ns(rng_key, self.num_live_points) - samples = results["samples"] - # TODO: transform samples back or rerun reparam_model to get deterministic sites - return samples diff --git a/numpyro/primitives.py b/numpyro/primitives.py index 403be3f2e..d98b857eb 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -307,6 +307,7 @@ def process_message(self, msg): cond_indep_stack.append(frame) if msg['type'] == 'sample': expected_shape = self._get_batch_shape(cond_indep_stack) + # TODO: get `batch_shape` of a Funsor dist_batch_shape = msg['fn'].batch_shape if 'sample_shape' in msg['kwargs']: dist_batch_shape = msg['kwargs']['sample_shape'] + dist_batch_shape @@ -315,6 +316,7 @@ def process_message(self, msg): trailing_shape = expected_shape[overlap_idx:] broadcast_shape = lax.broadcast_shapes(trailing_shape, tuple(dist_batch_shape)) batch_shape = expected_shape[:overlap_idx] + broadcast_shape + # TODO: `expand` a Funsor msg['fn'] = msg['fn'].expand(batch_shape) if self.size != self.subsample_size: scale = 1. if msg['scale'] is None else msg['scale'] From 6d2e0676844b3726e6eae2fe87efe5ae1939d292 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Wed, 11 Nov 2020 22:58:03 -0600 Subject: [PATCH 04/10] only expand if msg[fn] is a NumPyro Distribution --- numpyro/primitives.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/numpyro/primitives.py b/numpyro/primitives.py index d98b857eb..45650b1fd 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -10,6 +10,7 @@ import numpyro from numpyro.distributions.discrete import PRNGIdentity +from numpyro.distributions.distribution import Distribution from numpyro.util import identity _PYRO_STACK = [] @@ -305,9 +306,9 @@ def process_message(self, msg): cond_indep_stack = msg['cond_indep_stack'] frame = CondIndepStackFrame(self.name, self.dim, self.subsample_size) cond_indep_stack.append(frame) - if msg['type'] == 'sample': + # only expand if fn is Distribution, not a Funsor + if msg['type'] == 'sample' and isinstance(msg['fn'], Distribution): expected_shape = self._get_batch_shape(cond_indep_stack) - # TODO: get `batch_shape` of a Funsor dist_batch_shape = msg['fn'].batch_shape if 'sample_shape' in msg['kwargs']: dist_batch_shape = msg['kwargs']['sample_shape'] + dist_batch_shape @@ -316,7 +317,6 @@ def process_message(self, msg): trailing_shape = expected_shape[overlap_idx:] broadcast_shape = lax.broadcast_shapes(trailing_shape, tuple(dist_batch_shape)) batch_shape = expected_shape[:overlap_idx] + broadcast_shape - # TODO: `expand` a Funsor msg['fn'] = msg['fn'].expand(batch_shape) if self.size != self.subsample_size: scale = 1. if msg['scale'] is None else msg['scale'] From 831db772acd3b11afc4cd8d6086e0611fa1aefcb Mon Sep 17 00:00:00 2001 From: Du Phan Date: Wed, 25 Nov 2020 00:17:59 -0600 Subject: [PATCH 05/10] add various mvn tests --- test/test_handlers.py | 85 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/test/test_handlers.py b/test/test_handlers.py index 981dcd051..3a8ae3e24 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -590,6 +590,91 @@ def guide(): svi.update(svi_state) +def test_collapse_normal_normal(): + data = np.array(0.) + + def model(): + x = numpyro.sample("x", dist.Normal(0, 1)) + with handlers.collapse(): + y = numpyro.sample("y", dist.Normal(x, 1.)) + numpyro.sample("z", dist.Normal(y, 1.), obs=data) + + def guide(): + loc = numpyro.param("loc", 0.) + scale = numpyro.param("scale", 1., constraint=constraints.positive) + numpyro.sample("x", dist.Normal(loc, scale)) + + svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO()) + svi_state = svi.init(random.PRNGKey(0)) + svi.update(svi_state) + + +def test_collapse_normal_normal_plate(): + data = np.arange(5.) + + def model(): + x = numpyro.sample("x", dist.Normal(0, 1)) + with handlers.collapse(): + y = numpyro.sample("y", dist.Normal(x, 1.)) + with handlers.plate("data", len(data)): + numpyro.sample("z", dist.Normal(y, 1.), obs=data) + + def guide(): + loc = numpyro.param("loc", 0.) + scale = numpyro.param("scale", 1., constraint=constraints.positive) + numpyro.sample("x", dist.Normal(loc, scale)) + + svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO()) + svi_state = svi.init(random.PRNGKey(0)) + svi.update(svi_state) + + +def test_collapse_normal_plate_normal(): + data = np.arange(5.) + + def model(): + x = numpyro.sample("x", dist.Normal(0, 1)) + with handlers.collapse(): + with handlers.plate("data", len(data)): + # TODO: address expanded distribution + y = numpyro.sample("y", dist.Normal(x, 1.)) + numpyro.sample("z", dist.Normal(y, 1.), obs=data) + + def guide(): + loc = numpyro.param("loc", 0.) + scale = numpyro.param("scale", 1., constraint=constraints.positive) + numpyro.sample("x", dist.Normal(loc, scale)) + + svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO()) + svi_state = svi.init(random.PRNGKey(0)) + svi.update(svi_state) + + +def test_collapse_normal_mvn_mvn(): + T, d, S = 5, 2, 3 + data = jnp.ones((T, S)) + + def model(): + x = numpyro.sample("x", dist.Exponential(1)) + with handlers.collapse(): + with numpyro.plate("d", d): + beta0 = numpyro.sample("beta0", dist.Normal(0, 1).expand([S]).to_event(1)) + # TODO: address beta0 is a str, which cannot do infer_param_domain + beta = numpyro.sample("beta", dist.MultivariateNormal(beta0, jnp.eye(S))) + # FIXME: beta is a string here, how to apply numeric operators + mean = jnp.ones((T, d)) @ beta + with numpyro.plate("data", T, dim=-2): + numpyro.sample("obs", dist.MultivariateNormal(mean, jnp.eye(S)), obs=data) + + def guide(): + rate = numpyro.param("rate", 1., constraint=constraints.positive) + numpyro.sample("x", dist.Exponential(rate)) + + svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO()) + svi_state = svi.init(random.PRNGKey(0)) + svi.update(svi_state) + + def test_prng_key(): assert numpyro.prng_key() is None From 8f8719f9b1a04073e6d639eb528300a3855e0d39 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Wed, 25 Nov 2020 00:25:25 -0600 Subject: [PATCH 06/10] address expanded distribution --- numpyro/handlers.py | 4 +++- test/test_handlers.py | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index c5c00b909..010475bcc 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -83,7 +83,7 @@ import jax.numpy as jnp import numpyro -from numpyro.distributions.distribution import COERCIONS +from numpyro.distributions.distribution import COERCIONS, ExpandedDistribution from numpyro.primitives import _PYRO_STACK, Messenger, apply_stack, plate from numpyro.util import not_jax_tracer @@ -268,6 +268,8 @@ def process_message(self, msg): if msg["type"] == "sample": if msg["value"] is None: msg["value"] = msg["name"] + if isinstance(msg["fn"], ExpandedDistribution): + msg["fn"] = msg["fn"].base_dist if isinstance(msg["fn"], Funsor) or isinstance(msg["value"], (str, Funsor)): msg["stop"] = True diff --git a/test/test_handlers.py b/test/test_handlers.py index 3a8ae3e24..989c121cb 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -636,7 +636,6 @@ def model(): x = numpyro.sample("x", dist.Normal(0, 1)) with handlers.collapse(): with handlers.plate("data", len(data)): - # TODO: address expanded distribution y = numpyro.sample("y", dist.Normal(x, 1.)) numpyro.sample("z", dist.Normal(y, 1.), obs=data) From cfdc30d3aacf0e9fe9d9454f8890f0be1153da7c Mon Sep 17 00:00:00 2001 From: Du Phan Date: Wed, 25 Nov 2020 00:28:03 -0600 Subject: [PATCH 07/10] add more todo --- test/test_handlers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_handlers.py b/test/test_handlers.py index 989c121cb..bb679283b 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -657,6 +657,7 @@ def model(): x = numpyro.sample("x", dist.Exponential(1)) with handlers.collapse(): with numpyro.plate("d", d): + # TODO: verify that to_event works here beta0 = numpyro.sample("beta0", dist.Normal(0, 1).expand([S]).to_event(1)) # TODO: address beta0 is a str, which cannot do infer_param_domain beta = numpyro.sample("beta", dist.MultivariateNormal(beta0, jnp.eye(S))) From ac09d788496dbb3376df9f31deefc05dbe0f2303 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Wed, 25 Nov 2020 21:10:15 -0600 Subject: [PATCH 08/10] make expand works under funsor --- numpyro/handlers.py | 94 +++++++++++++++++++++++++++++-------------- test/test_handlers.py | 15 ++++--- 2 files changed, 71 insertions(+), 38 deletions(-) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 010475bcc..1621ab9df 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -77,13 +77,14 @@ """ from collections import OrderedDict +from functools import reduce import warnings -from jax import lax, random +from jax import lax, random, tree_map import jax.numpy as jnp import numpyro -from numpyro.distributions.distribution import COERCIONS, ExpandedDistribution +from numpyro.distributions.distribution import COERCIONS, ExpandedDistribution, Independent from numpyro.primitives import _PYRO_STACK, Messenger, apply_stack, plate from numpyro.util import not_jax_tracer @@ -245,6 +246,20 @@ def process_message(self, msg): msg['stop'] = True +def _eager_expand_fn(fn): + if isinstance(fn, Independent): + reinterpreted_batch_ndims = fn.reinterpreted_batch_ndims + fn = fn.base_dist + else: + reinterpreted_batch_ndims = 0 # no-op for to_event method + if isinstance(fn, ExpandedDistribution): + batch_shape = fn.batch_shape + base_batch_shape = fn.base_dist.batch_shape + appended_shape = batch_shape[:len(batch_shape) - len(base_batch_shape)] + fn = tree_map(lambda x: jnp.broadcast_to(x, appended_shape + jnp.shape(x)), fn.base_dist) + return fn.to_event(reinterpreted_batch_ndims) + + class collapse(trace): """ EXPERIMENTAL Collapses all sites in the context by lazily sampling and @@ -263,33 +278,48 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def process_message(self, msg): - from funsor.terms import Funsor + if msg["type"] != "sample": + return - if msg["type"] == "sample": - if msg["value"] is None: - msg["value"] = msg["name"] - if isinstance(msg["fn"], ExpandedDistribution): - msg["fn"] = msg["fn"].base_dist + import funsor + + # Eagerly convert fn and value to Funsor. + dim_to_name = {f.dim: f.name for f in msg["cond_indep_stack"]} + dim_to_name.update(self.preserved_plates) + if isinstance(msg["fn"], (Independent, ExpandedDistribution)): + msg["fn"] = _eager_expand_fn(msg["fn"]) + msg["fn"] = funsor.to_funsor(msg["fn"], funsor.Real, dim_to_name) + domain = msg["fn"].inputs["value"] + if msg["value"] is None: + msg["value"] = funsor.Variable(msg["name"], domain) + else: + msg["value"] = funsor.to_funsor(msg["value"], domain, dim_to_name) - if isinstance(msg["fn"], Funsor) or isinstance(msg["value"], (str, Funsor)): - msg["stop"] = True + msg["stop"] = True def __enter__(self): - self.preserved_plates = frozenset(h.name for h in _PYRO_STACK - if isinstance(h, plate)) + self.preserved_plates = {h.dim: h.name for h in _PYRO_STACK + if isinstance(h, plate)} COERCIONS.append(self._coerce) return super().__enter__() def __exit__(self, exc_type, exc_value, traceback): - import funsor - _coerce = COERCIONS.pop() assert _coerce is self._coerce super().__exit__(exc_type, exc_value, traceback) if exc_type is not None: + self.trace.clear() + self.preserved_plates.clear() return + if any(site["type"] == "sample" for site in self.trace.values()): + name, log_prob, _, _ = self._get_log_prob() + numpyro.factor(name, log_prob.data) + + def _get_log_prob(self): + import funsor + # Convert delayed statements to pyro.factor() reduced_vars = [] log_prob_terms = [] @@ -299,24 +329,28 @@ def __exit__(self, exc_type, exc_value, traceback): continue if not site["is_observed"]: reduced_vars.append(name) - dim_to_name = {f.dim: f.name for f in site["cond_indep_stack"]} - fn = funsor.to_funsor(site["fn"], funsor.Real, dim_to_name) - value = site["value"] - if not isinstance(value, str): - value = funsor.to_funsor(site["value"], fn.inputs["value"], dim_to_name) - log_prob_terms.append(fn(value=value)) + log_prob_terms.append(site["fn"](value=site["value"])) plates |= frozenset(f.name for f in site["cond_indep_stack"]) - assert log_prob_terms, "nothing to collapse" - reduced_plates = plates - self.preserved_plates - log_prob = funsor.sum_product.sum_product( - funsor.ops.logaddexp, - funsor.ops.add, - log_prob_terms, - eliminate=frozenset(reduced_vars) | reduced_plates, - plates=plates, - ) name = reduced_vars[0] - numpyro.factor(name, log_prob.data) + reduced_vars = frozenset(reduced_vars) + assert log_prob_terms, "nothing to collapse" + reduced_plates = plates - frozenset(self.preserved_plates.values()) + self.trace.clear() + self.preserved_plates.clear() + if reduced_plates: + log_prob = funsor.sum_product.sum_product( + funsor.ops.logaddexp, + funsor.ops.add, + log_prob_terms, + eliminate=frozenset(reduced_vars) | reduced_plates, + plates=plates, + ) + log_joint = NotImplemented + else: + log_joint = reduce(funsor.ops.add, log_prob_terms) + log_prob = log_joint.reduce(funsor.ops.logaddexp, reduced_vars) + + return name, log_prob, log_joint, reduced_vars class condition(Messenger): diff --git a/test/test_handlers.py b/test/test_handlers.py index bb679283b..0ee6c58a8 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -656,15 +656,14 @@ def test_collapse_normal_mvn_mvn(): def model(): x = numpyro.sample("x", dist.Exponential(1)) with handlers.collapse(): - with numpyro.plate("d", d): - # TODO: verify that to_event works here - beta0 = numpyro.sample("beta0", dist.Normal(0, 1).expand([S]).to_event(1)) - # TODO: address beta0 is a str, which cannot do infer_param_domain - beta = numpyro.sample("beta", dist.MultivariateNormal(beta0, jnp.eye(S))) - # FIXME: beta is a string here, how to apply numeric operators + with numpyro.plate("d", d, dim=-1): + beta0 = numpyro.sample("beta0", dist.Normal(0., 1.).expand([d, S]).to_event(1)) + beta = numpyro.sample( + "beta", dist.MultivariateNormal(beta0, scale_tril=jnp.eye(S))) + mean = jnp.ones((T, d)) @ beta - with numpyro.plate("data", T, dim=-2): - numpyro.sample("obs", dist.MultivariateNormal(mean, jnp.eye(S)), obs=data) + with numpyro.plate("data", T, dim=-1): + numpyro.sample("obs", dist.MultivariateNormal(mean, scale_tril=jnp.eye(S)), obs=data) def guide(): rate = numpyro.param("rate", 1., constraint=constraints.positive) From da3dd01c6953b51048c0e777cbbb72a25eb023d9 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Wed, 25 Nov 2020 22:23:02 -0600 Subject: [PATCH 09/10] mark xfail for failing tests --- test/test_handlers.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/test/test_handlers.py b/test/test_handlers.py index 0ee6c58a8..5d5747754 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -649,25 +649,50 @@ def guide(): svi.update(svi_state) +@pytest.mark.xfail(reason="missing pattern in Funsor") +def test_collapse_diag_normal_plate_normal(): + d = 3 + data = np.ones((5, d)) + + def model(): + x = numpyro.sample("x", dist.Normal(0, 1)) + with handlers.collapse(): + with handlers.plate("data", len(data)): + y = numpyro.sample("y", dist.Normal(x, 1.).expand([d]).to_event(1)) + numpyro.sample("z", dist.Normal(y, 1.).to_event(1), obs=data) + + def guide(): + loc = numpyro.param("loc", 0.) + scale = numpyro.param("scale", 1., constraint=constraints.positive) + numpyro.sample("x", dist.Normal(loc, scale)) + + svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO()) + svi_state = svi.init(random.PRNGKey(0)) + svi.update(svi_state) + + +@pytest.mark.xfail(reason="missing pattern in Funsor") def test_collapse_normal_mvn_mvn(): T, d, S = 5, 2, 3 data = jnp.ones((T, S)) def model(): - x = numpyro.sample("x", dist.Exponential(1)) + x = numpyro.sample("x", dist.Normal(0, 1)) with handlers.collapse(): with numpyro.plate("d", d, dim=-1): - beta0 = numpyro.sample("beta0", dist.Normal(0., 1.).expand([d, S]).to_event(1)) + beta0 = numpyro.sample("beta0", dist.Normal(x, 1.).expand([d, S]).to_event(1)) beta = numpyro.sample( "beta", dist.MultivariateNormal(beta0, scale_tril=jnp.eye(S))) + # this fails because beta shape is (3,) while it should be (2, 3) mean = jnp.ones((T, d)) @ beta with numpyro.plate("data", T, dim=-1): numpyro.sample("obs", dist.MultivariateNormal(mean, scale_tril=jnp.eye(S)), obs=data) def guide(): - rate = numpyro.param("rate", 1., constraint=constraints.positive) - numpyro.sample("x", dist.Exponential(rate)) + loc = numpyro.param("loc", 0.) + scale = numpyro.param("scale", 1., constraint=constraints.positive) + numpyro.sample("x", dist.Normal(loc, scale)) svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO()) svi_state = svi.init(random.PRNGKey(0)) From c2f96efb438150bfd0bb52fd18d31cb8b7754849 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Wed, 25 Nov 2020 22:24:16 -0600 Subject: [PATCH 10/10] lint --- test/test_handlers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_handlers.py b/test/test_handlers.py index 5d5747754..5c8b221bb 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -684,7 +684,7 @@ def model(): beta = numpyro.sample( "beta", dist.MultivariateNormal(beta0, scale_tril=jnp.eye(S))) - # this fails because beta shape is (3,) while it should be (2, 3) + # this fails because beta shape is (3,) while it should be (2, 3) mean = jnp.ones((T, d)) @ beta with numpyro.plate("data", T, dim=-1): numpyro.sample("obs", dist.MultivariateNormal(mean, scale_tril=jnp.eye(S)), obs=data)