diff --git a/chirho/__init__.py b/chirho/__init__.py index 278e99674..00a671c5a 100644 --- a/chirho/__init__.py +++ b/chirho/__init__.py @@ -2,4 +2,5 @@ Project short description. """ -__version__ = "0.0.1" + +__version__ = "0.2.0" diff --git a/chirho/interventional/handlers.py b/chirho/interventional/handlers.py index 3bfbd694c..5e2d231bd 100644 --- a/chirho/interventional/handlers.py +++ b/chirho/interventional/handlers.py @@ -118,4 +118,9 @@ def _pyro_post_sample(self, msg): ) -do = pyro.poutine.handlers._make_handler(Interventions)[1] +if isinstance(pyro.poutine.handlers._make_handler(Interventions), tuple): + do = pyro.poutine.handlers._make_handler(Interventions)[1] +else: + + @pyro.poutine.handlers._make_handler(Interventions) + def do(fn: Callable, actions: Mapping[Hashable, AtomicIntervention[T]]): ... diff --git a/chirho/observational/handlers/condition.py b/chirho/observational/handlers/condition.py index 3bd4e6614..dc91415ca 100644 --- a/chirho/observational/handlers/condition.py +++ b/chirho/observational/handlers/condition.py @@ -110,4 +110,10 @@ def _pyro_sample(self, msg): self._current_site = None -condition = pyro.poutine.handlers._make_handler(Observations)[1] +if isinstance(pyro.poutine.handlers._make_handler(Observations), tuple): + # backwards compatibility + condition = pyro.poutine.handlers._make_handler(Observations)[1] +else: + + @pyro.poutine.handlers._make_handler(Observations) + def condition(fn: Callable, data: Mapping[str, Observation[T]]): ... diff --git a/chirho/robust/internals/nmc.py b/chirho/robust/internals/nmc.py index 342abdcc0..d12468432 100644 --- a/chirho/robust/internals/nmc.py +++ b/chirho/robust/internals/nmc.py @@ -124,6 +124,7 @@ class BatchedNMCLogMarginalLikelihood(Generic[P, T], torch.nn.Module): used to approximate marginal distribution, defaults to 1 :type num_samples: int, optional """ + model: Callable[P, Any] guide: Optional[Callable[P, Any]] num_samples: int diff --git a/chirho/robust/ops.py b/chirho/robust/ops.py index 86ed8f89d..95f80973a 100644 --- a/chirho/robust/ops.py +++ b/chirho/robust/ops.py @@ -16,8 +16,7 @@ class Functional(Protocol[P, S]): def __call__( self, __model: Callable[P, Any], *models: Callable[P, Any] - ) -> Callable[P, S]: - ... + ) -> Callable[P, S]: ... def influence_fn( diff --git a/tests/observational/test_cut_posterior_modules.py b/tests/observational/test_cut_posterior_modules.py index be1faf0da..899828cdf 100644 --- a/tests/observational/test_cut_posterior_modules.py +++ b/tests/observational/test_cut_posterior_modules.py @@ -212,10 +212,7 @@ def run_svi_inference(model, n_steps=1000, verbose=True, lr=0.03, **model_kwargs def analytical_linear_gaussian_cut_posterior(data): post_sd_mod_one = math.sqrt((1 + NUM_SAMPS_MODULE_ONE / SIGMA_ONE**2) ** (-1)) pr_eta_cut = dist.Normal( - 1 - / SIGMA_ONE**2 - * data["w"].sum() - / (1 + NUM_SAMPS_MODULE_ONE / SIGMA_ONE**2), + 1 / SIGMA_ONE**2 * data["w"].sum() / (1 + NUM_SAMPS_MODULE_ONE / SIGMA_ONE**2), scale=post_sd_mod_one, ) post_mean_mod_two = lambda eta: ( # noqa diff --git a/tests/robust/test_handlers.py b/tests/robust/test_handlers.py index e43015282..12bb81c8f 100644 --- a/tests/robust/test_handlers.py +++ b/tests/robust/test_handlers.py @@ -28,11 +28,14 @@ (SimpleModel, lambda _: SimpleGuide(), {"y"}, None), pytest.param( SimpleModel, - pyro.infer.autoguide.AutoNormal, + lambda m: pyro.infer.autoguide.AutoNormal(pyro.poutine.block(hide=["y"])(m)), {"y"}, 1, - marks=pytest.mark.xfail( - reason="torch.func autograd doesnt work with PyroParam" + marks=( + [pytest.mark.xfail(reason="torch.func autograd doesnt work with PyroParam")] + if tuple(map(int, pyro.__version__.split("+")[0].split(".")[:3])) + <= (1, 8, 6) + else [] ), ), ] diff --git a/tests/robust/test_internals_compositions.py b/tests/robust/test_internals_compositions.py index 6f9dfde8d..ade682726 100644 --- a/tests/robust/test_internals_compositions.py +++ b/tests/robust/test_internals_compositions.py @@ -49,9 +49,11 @@ def test_empirical_fisher_vp_nmclikelihood_cg_composition(): ) v = { - k: torch.ones_like(v).unsqueeze(0) - if k != "model.guide.loc_a" - else torch.zeros_like(v).unsqueeze(0) + k: ( + torch.ones_like(v).unsqueeze(0) + if k != "model.guide.loc_a" + else torch.zeros_like(v).unsqueeze(0) + ) for k, v in log_prob_params.items() } diff --git a/tests/robust/test_internals_linearize.py b/tests/robust/test_internals_linearize.py index 435632789..5a887efe0 100644 --- a/tests/robust/test_internals_linearize.py +++ b/tests/robust/test_internals_linearize.py @@ -69,11 +69,20 @@ def test_batch_cg_solve(ndim: int, dtype: torch.dtype, num_particles: int): (SimpleModel, lambda _: SimpleGuide(), {"y"}, None), pytest.param( SimpleModel, - pyro.infer.autoguide.AutoNormal, + lambda m: pyro.infer.autoguide.AutoNormal( + pyro.poutine.block( + hide=[ + "y", + ] + )(m) + ), {"y"}, 1, - marks=pytest.mark.xfail( - reason="torch.func autograd doesnt work with PyroParam" + marks=( + [pytest.mark.xfail(reason="torch.func autograd doesnt work with PyroParam")] + if tuple(map(int, pyro.__version__.split("+")[0].split(".")[:3])) + <= (1, 8, 6) + else [] ), ), ] @@ -117,7 +126,7 @@ def test_nmc_param_influence_smoke( for k, v in test_datum_eif.items(): assert not torch.isnan(v).any(), f"eif for {k} had nans" assert not torch.isinf(v).any(), f"eif for {k} had infs" - if not k.endswith("guide.loc_a"): + if not (k.endswith("guide.loc_a") or k.endswith("a_unconstrained")): assert not torch.isclose( v, torch.zeros_like(v) ).all(), f"eif for {k} was zero" @@ -162,7 +171,7 @@ def test_nmc_param_influence_vmap_smoke( for k, v in test_data_eif.items(): assert not torch.isnan(v).any(), f"eif for {k} had nans" assert not torch.isinf(v).any(), f"eif for {k} had infs" - if not k.endswith("guide.loc_a"): + if not (k.endswith("guide.loc_a") or k.endswith("a_unconstrained")): assert not torch.isclose( v, torch.zeros_like(v) ).all(), f"eif for {k} was zero" diff --git a/tests/robust/test_ops.py b/tests/robust/test_ops.py index e3d5e5290..8fa320dcf 100644 --- a/tests/robust/test_ops.py +++ b/tests/robust/test_ops.py @@ -28,11 +28,14 @@ (SimpleModel, lambda _: SimpleGuide(), {"y"}, None), pytest.param( SimpleModel, - pyro.infer.autoguide.AutoNormal, + lambda m: pyro.infer.autoguide.AutoNormal(pyro.poutine.block(hide=["y"])(m)), {"y"}, 1, - marks=pytest.mark.xfail( - reason="torch.func autograd doesnt work with PyroParam" + marks=( + [pytest.mark.xfail(reason="torch.func autograd doesnt work with PyroParam")] + if tuple(map(int, pyro.__version__.split("+")[0].split(".")[:3])) + <= (1, 8, 6) + else [] ), ), ]