-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Cut Posterior Implementation + Examples (#229)
* initial handler for cut posterior during pair session with Eli * implementation of cut posterior * pair programming session with Eli for cut posterior implementation * removed unneeded plate class * fixed bug in factor constraint name * fixed bug with subsample error * uncomitted linear gaussian test before major changes * passes closed form discrete test case * cleaned up test file * recovery of exact posterior from saturated VI family * index module has failing test * small lint change * bayessdid basic functionality implemented * skeleton for svi cut inference added * cut vi method implemented * adding synthetic control weights plotting * cleaned up notebook and added plots * removes pyro.factor statement for stitching * merged in master and fixed causal_pyro import errors * testing for gaussian linear case * fixed analytical formula for linear gaussian case * fixed mask issue * added single stage cut estimator * removed discrete test that checks compares svi output * linting * removed unneeded imports * added failing test when composing trace and indexcut * added replay check * initial implementation * extended to a GLM and match data experiment in paper * cleaned up notebook cells; more documentation needed still * cleaned up tests * cleaned up notebook * forgot to save changes * added uncertainty figure * cleaned up notebook * fixed linting error for python 3.8 * addressing some of eli's refactoring suggestions * refactored to inherit from dependent mask messenger (Eli's suggestion) * added missing fn argument * commit just to document a test that will be removed testing single cut implementations * revert back to removing old SingleStageCut implementation * updated imports and reran notebooks * issues with running svi with local params turned on * removed two stage approach * consolidated tests --------- Co-authored-by: Raj Agrawal <[email protected]>
- Loading branch information
1 parent
f6f428d
commit 19a395c
Showing
6 changed files
with
1,741 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from typing import Any, Dict, Optional, Set, TypeVar | ||
|
||
import pyro | ||
import torch | ||
|
||
from chirho.indexed.handlers import DependentMaskMessenger, add_indices | ||
from chirho.indexed.ops import IndexSet, gather, indexset_as_mask, scatter | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
class CutModule(pyro.poutine.messenger.Messenger): | ||
""" | ||
Converts a Pyro model into a module using the "cut" operation | ||
""" | ||
|
||
vars: Set[str] | ||
|
||
def __init__(self, vars: Set[str]): | ||
self.vars = vars | ||
super().__init__() | ||
|
||
def _pyro_sample(self, msg: Dict[str, Any]) -> None: | ||
# There are 4 cases to consider for a sample site: | ||
# 1. The site appears in self.vars and is observed | ||
# 2. The site appears in self.vars and is not observed | ||
# 3. The site does not appear in self.vars and is observed | ||
# 4. The site does not appear in self.vars and is not observed | ||
if msg["name"] not in self.vars: | ||
if msg["is_observed"]: | ||
# use mask to remove the contribution of this observed site to the model log-joint | ||
msg["mask"] = ( | ||
msg["mask"] if msg["mask"] is not None else True | ||
) & torch.tensor(False, dtype=torch.bool).expand(msg["fn"].batch_shape) | ||
else: | ||
pass | ||
|
||
# For sites that do not appear in module, rename them to avoid naming conflict | ||
if msg["name"] not in self.vars: | ||
msg["name"] = f"{msg['name']}_nuisance" | ||
|
||
|
||
class CutComplementModule(pyro.poutine.messenger.Messenger): | ||
vars: Set[str] | ||
|
||
def __init__(self, vars: Set[str]): | ||
self.vars = vars | ||
super().__init__() | ||
|
||
def _pyro_sample(self, msg: Dict[str, Any]) -> None: | ||
# There are 4 cases to consider for a sample site: | ||
# 1. The site appears in self.vars and is observed | ||
# 2. The site appears in self.vars and is not observed | ||
# 3. The site does not appear in self.vars and is observed | ||
# 4. The site does not appear in self.vars and is not observed | ||
if msg["name"] in self.vars: | ||
# use mask to remove the contribution of this observed site to the model log-joint | ||
msg["mask"] = ( | ||
msg["mask"] if msg["mask"] is not None else True | ||
) & torch.tensor(False, dtype=torch.bool).expand(msg["fn"].batch_shape) | ||
|
||
|
||
class SingleStageCut(DependentMaskMessenger): | ||
""" | ||
Represent module and complement in a single Pyro model using plates | ||
""" | ||
|
||
vars: Set[str] | ||
name: str | ||
|
||
def __init__(self, vars: Set[str], *, name: str = "__cut_plate"): | ||
self.vars = vars | ||
self.name = name | ||
super().__init__() | ||
|
||
def __enter__(self): | ||
add_indices(IndexSet(**{self.name: {0, 1}})) | ||
return super().__enter__() | ||
|
||
def get_mask( | ||
self, | ||
dist: pyro.distributions.Distribution, | ||
value: Optional[torch.Tensor], | ||
device: torch.device = torch.device("cpu"), | ||
name: Optional[str] = None, | ||
) -> torch.Tensor: | ||
return indexset_as_mask( | ||
IndexSet(**{self.name: {0 if name in self.vars else 1}}) | ||
) | ||
|
||
def _pyro_post_sample(self, msg: Dict[str, Any]) -> None: | ||
if pyro.poutine.util.site_is_subsample(msg): | ||
return | ||
|
||
if (not msg["is_observed"]) and (msg["name"] in self.vars): | ||
# discard the second value | ||
value_one = gather( | ||
msg["value"], | ||
IndexSet(**{self.name: {0}}), | ||
event_dim=msg["fn"].event_dim, | ||
) | ||
|
||
msg["value"] = scatter( | ||
{ | ||
IndexSet(**{self.name: {0}}): value_one, | ||
IndexSet(**{self.name: {1}}): value_one.detach(), | ||
}, | ||
event_dim=msg["fn"].event_dim, | ||
) |
Oops, something went wrong.