Skip to content

Commit

Permalink
Cut Posterior Implementation + Examples (#229)
Browse files Browse the repository at this point in the history
* initial handler for cut posterior during pair session with Eli

* implementation of cut posterior

* pair programming session with Eli for cut posterior implementation

* removed unneeded plate class

* fixed bug in factor constraint name

* fixed bug with subsample error

* uncomitted linear gaussian test before major changes

* passes closed form discrete test case

* cleaned up test file

* recovery of exact posterior from saturated VI family

* index module has failing test

* small lint change

* bayessdid basic functionality implemented

* skeleton for svi cut inference added

* cut vi method implemented

* adding synthetic control weights plotting

* cleaned up notebook and added plots

* removes pyro.factor statement for stitching

* merged in master and fixed causal_pyro import errors

* testing for gaussian linear case

* fixed analytical formula for linear gaussian case

* fixed mask issue

* added single stage cut estimator

* removed discrete test that checks compares svi output

* linting

* removed unneeded imports

* added failing test when composing trace and indexcut

* added replay check

* initial implementation

* extended to a GLM and match data experiment in paper

* cleaned up notebook cells; more documentation needed still

* cleaned up tests

* cleaned up notebook

* forgot to save changes

* added uncertainty figure

* cleaned up notebook

* fixed linting error for python 3.8

* addressing some of eli's refactoring suggestions

* refactored to inherit from dependent mask messenger (Eli's suggestion)

* added missing fn argument

* commit just to document a test that will be removed testing single cut implementations

* revert back to removing old SingleStageCut implementation

* updated imports and reran notebooks

* issues with running svi with local params turned on

* removed two stage approach

* consolidated tests

---------

Co-authored-by: Raj Agrawal <[email protected]>
  • Loading branch information
agrawalraj and agrawalraj authored Aug 7, 2023
1 parent f6f428d commit 19a395c
Show file tree
Hide file tree
Showing 6 changed files with 1,741 additions and 1 deletion.
2 changes: 2 additions & 0 deletions chirho/counterfactual/handlers/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def get_mask(
dist: pyro.distributions.Distribution,
value: Optional[torch.Tensor],
device: torch.device = torch.device("cpu"),
name: Optional[str] = None,
) -> torch.Tensor:
indices = get_factual_indices()
return ~indexset_as_mask(indices, device=device) # negate == complement
Expand All @@ -62,6 +63,7 @@ def get_mask(
dist: pyro.distributions.Distribution,
value: Optional[torch.Tensor] = None,
device: torch.device = torch.device("cpu"),
name: Optional[str] = None,
) -> torch.Tensor:
indices = get_factual_indices()
return indexset_as_mask(indices, device=device)
12 changes: 11 additions & 1 deletion chirho/indexed/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,20 @@ def get_mask(
dist: pyro.distributions.Distribution,
value: Optional[torch.Tensor],
device: torch.device = torch.device("cpu"),
name: Optional[str] = None,
) -> torch.Tensor:
raise NotImplementedError

def _pyro_sample(self, msg: Dict[str, Any]) -> None:
if pyro.poutine.util.site_is_subsample(msg):
return

device = get_sample_msg_device(msg["fn"], msg["value"])
mask = self.get_mask(msg["fn"], msg["value"], device=device)
name = msg["name"] if "name" in msg else None
mask = self.get_mask(msg["fn"], msg["value"], device=device, name=name)
msg["mask"] = mask if msg["mask"] is None else msg["mask"] & mask

# expand distribution to make sure two copies of a variable are sampled
msg["fn"] = msg["fn"].expand(
torch.broadcast_shapes(msg["fn"].batch_shape, mask.shape)
)
109 changes: 109 additions & 0 deletions chirho/observational/handlers/cut.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import Any, Dict, Optional, Set, TypeVar

import pyro
import torch

from chirho.indexed.handlers import DependentMaskMessenger, add_indices
from chirho.indexed.ops import IndexSet, gather, indexset_as_mask, scatter

T = TypeVar("T")


class CutModule(pyro.poutine.messenger.Messenger):
"""
Converts a Pyro model into a module using the "cut" operation
"""

vars: Set[str]

def __init__(self, vars: Set[str]):
self.vars = vars
super().__init__()

def _pyro_sample(self, msg: Dict[str, Any]) -> None:
# There are 4 cases to consider for a sample site:
# 1. The site appears in self.vars and is observed
# 2. The site appears in self.vars and is not observed
# 3. The site does not appear in self.vars and is observed
# 4. The site does not appear in self.vars and is not observed
if msg["name"] not in self.vars:
if msg["is_observed"]:
# use mask to remove the contribution of this observed site to the model log-joint
msg["mask"] = (
msg["mask"] if msg["mask"] is not None else True
) & torch.tensor(False, dtype=torch.bool).expand(msg["fn"].batch_shape)
else:
pass

# For sites that do not appear in module, rename them to avoid naming conflict
if msg["name"] not in self.vars:
msg["name"] = f"{msg['name']}_nuisance"


class CutComplementModule(pyro.poutine.messenger.Messenger):
vars: Set[str]

def __init__(self, vars: Set[str]):
self.vars = vars
super().__init__()

def _pyro_sample(self, msg: Dict[str, Any]) -> None:
# There are 4 cases to consider for a sample site:
# 1. The site appears in self.vars and is observed
# 2. The site appears in self.vars and is not observed
# 3. The site does not appear in self.vars and is observed
# 4. The site does not appear in self.vars and is not observed
if msg["name"] in self.vars:
# use mask to remove the contribution of this observed site to the model log-joint
msg["mask"] = (
msg["mask"] if msg["mask"] is not None else True
) & torch.tensor(False, dtype=torch.bool).expand(msg["fn"].batch_shape)


class SingleStageCut(DependentMaskMessenger):
"""
Represent module and complement in a single Pyro model using plates
"""

vars: Set[str]
name: str

def __init__(self, vars: Set[str], *, name: str = "__cut_plate"):
self.vars = vars
self.name = name
super().__init__()

def __enter__(self):
add_indices(IndexSet(**{self.name: {0, 1}}))
return super().__enter__()

def get_mask(
self,
dist: pyro.distributions.Distribution,
value: Optional[torch.Tensor],
device: torch.device = torch.device("cpu"),
name: Optional[str] = None,
) -> torch.Tensor:
return indexset_as_mask(
IndexSet(**{self.name: {0 if name in self.vars else 1}})
)

def _pyro_post_sample(self, msg: Dict[str, Any]) -> None:
if pyro.poutine.util.site_is_subsample(msg):
return

if (not msg["is_observed"]) and (msg["name"] in self.vars):
# discard the second value
value_one = gather(
msg["value"],
IndexSet(**{self.name: {0}}),
event_dim=msg["fn"].event_dim,
)

msg["value"] = scatter(
{
IndexSet(**{self.name: {0}}): value_one,
IndexSet(**{self.name: {1}}): value_one.detach(),
},
event_dim=msg["fn"].event_dim,
)
Loading

0 comments on commit 19a395c

Please sign in to comment.