From 8ba2f2c64f626b9ad737d9e2915c7e61d97392df Mon Sep 17 00:00:00 2001 From: eb8680 Date: Tue, 16 Jul 2024 10:48:00 -0400 Subject: [PATCH] Move cut posterior handlers to robust module (#549) * Move cut posterior handlers to robust module * Update test_handlers_cut.py --- chirho/{observational => robust}/handlers/cut.py | 0 docs/source/dr_learner.ipynb | 2 +- docs/source/sdid.ipynb | 2 +- .../test_handlers_cut.py} | 6 +----- 4 files changed, 3 insertions(+), 7 deletions(-) rename chirho/{observational => robust}/handlers/cut.py (100%) rename tests/{observational/test_cut_posterior_modules.py => robust/test_handlers_cut.py} (98%) diff --git a/chirho/observational/handlers/cut.py b/chirho/robust/handlers/cut.py similarity index 100% rename from chirho/observational/handlers/cut.py rename to chirho/robust/handlers/cut.py diff --git a/docs/source/dr_learner.ipynb b/docs/source/dr_learner.ipynb index 4cdc7800a..3b1fa5ce8 100644 --- a/docs/source/dr_learner.ipynb +++ b/docs/source/dr_learner.ipynb @@ -45,7 +45,7 @@ "import pyro.distributions as dist\n", "from pyro.infer.autoguide import AutoNormal\n", "from chirho.indexed.handlers import IndexPlatesMessenger\n", - "from chirho.observational.handlers.cut import SingleStageCut\n", + "from chirho.robust.handlers.cut import SingleStageCut\n", "from pyro.infer import Predictive\n", "\n", "pyro.settings.set(module_local_params=True)\n", diff --git a/docs/source/sdid.ipynb b/docs/source/sdid.ipynb index c91c37b21..6c5aed456 100644 --- a/docs/source/sdid.ipynb +++ b/docs/source/sdid.ipynb @@ -52,7 +52,7 @@ "import pyro.distributions as dist\n", "from pyro.infer.autoguide import AutoNormal\n", "from chirho.indexed.handlers import IndexPlatesMessenger\n", - "from chirho.observational.handlers.cut import SingleStageCut\n", + "from chirho.robust.handlers.cut import SingleStageCut\n", "from pyro.infer import Predictive\n", "\n", "pyro.settings.set(module_local_params=True)\n", diff --git a/tests/observational/test_cut_posterior_modules.py b/tests/robust/test_handlers_cut.py similarity index 98% rename from tests/observational/test_cut_posterior_modules.py rename to tests/robust/test_handlers_cut.py index 899828cdf..c5bec637e 100644 --- a/tests/observational/test_cut_posterior_modules.py +++ b/tests/robust/test_handlers_cut.py @@ -7,11 +7,7 @@ from pyro.infer.autoguide import AutoMultivariateNormal from chirho.indexed.handlers import IndexPlatesMessenger -from chirho.observational.handlers.cut import ( - CutComplementModule, - CutModule, - SingleStageCut, -) +from chirho.robust.handlers.cut import CutComplementModule, CutModule, SingleStageCut pyro.settings.set(module_local_params=True)