Skip to content

Commit

Permalink
Refactor counterfactuals to use observe and condition (#176)
Browse files Browse the repository at this point in the history
* Separate observe and condition

* Split up files and create observational handlers folder

* imports

* lint

* rename test

* add test about commutativity of do and condition

* doc

* union

* Refactor counterfactuals to use observe

* appease mypy

* Vindex fixes particle errors

* update backdoor

* update slc

* fix particle test case

* add cf commutativity test

* fix bug

* revert slc handler order

* add predictive smoke test

* nit

* elbo

* reorder test

* Add a stronger infer_discrete test

* move notebooks to separate branch

* test

* chirho

* merge fail

* Update and re-run example notebooks with new condition (#178)

* Update and re-run backdoor and SLC notebooks

* deepscm

* cevae

* import

* mediation

* merge

* update notebooks

* merge

* merge 2

* toc

* populate autodoc

* tweak

* Restores (via cherry-pick) Notebook Link and Formatting Changes (#205)

* fixed outline rendering

* fixes outline links in mediation notebook.

* fixes outline and links for backdoor notebook.

* fixes outline links in cevae notebook.

* fixes slc notebook outline links.

* adds outline back into deep scm notebook.

* address remaining reference issues, building now with now warnings

---------

Co-authored-by: Sam Witty <[email protected]>

---------

Co-authored-by: Andy Zane <[email protected]>
Co-authored-by: Sam Witty <[email protected]>
  • Loading branch information
3 people authored Jul 10, 2023
1 parent 2af3f73 commit f659fa8
Show file tree
Hide file tree
Showing 15 changed files with 769 additions and 413 deletions.
256 changes: 63 additions & 193 deletions chirho/counterfactual/handlers/ambiguity.py
Original file line number Diff line number Diff line change
@@ -1,179 +1,89 @@
from typing import Any, Dict, Literal, Optional, TypedDict, TypeVar, Union
import functools
from typing import TypeVar

import pyro
import pyro.distributions as dist
import pyro.infer.reparam
import torch

from chirho.counterfactual.handlers.selection import (
SelectCounterfactual,
SelectFactual,
get_factual_indices,
)
from chirho.counterfactual.internals import expand_obs_value_inplace_
from chirho.indexed.ops import gather, indices_of, scatter, union
from chirho.counterfactual.internals import no_ambiguity, site_is_ambiguous
from chirho.indexed.ops import gather, get_index_plates, indices_of, scatter
from chirho.observational.ops import observe

T = TypeVar("T")


def site_is_ambiguous(msg: Dict[str, Any]) -> bool:
class FactualConditioningMessenger(pyro.poutine.messenger.Messenger):
"""
Helper function used with :func:`pyro.condition` to determine
whether a site is observed or ambiguous.
A sample site is ambiguous if it is marked observed, is downstream of an intervention,
and the observed value's index variables are a strict subset of the distribution's
indices and hence require clarification of which entries of the random variable
are fixed/observed (as opposed to random/unobserved).
"""
if not (
msg["type"] == "sample"
and msg["is_observed"]
and not pyro.poutine.util.site_is_subsample(msg)
):
return False

try:
return not msg["infer"]["_specified_conditioning"]
except KeyError:
value_indices = indices_of(msg["value"], event_dim=len(msg["fn"].event_shape))
dist_indices = indices_of(msg["fn"])
return (
bool(union(value_indices, dist_indices)) and value_indices != dist_indices
)


def no_ambiguity(msg: Dict[str, Any]) -> Dict[str, Any]:
"""
Helper function used with :func:`pyro.poutine.infer_config` to inform
:class:`AmbiguousConditioningReparam` that all ambiguity in the current
context has been resolved.
"""
return {"_specified_conditioning": True}


class AmbiguousConditioningReparam(pyro.infer.reparam.reparam.Reparam):
"""
Abstract base class for reparameterizers that handle ambiguous conditioning.
"""

pass


class AmbiguousConditioningStrategy(pyro.infer.reparam.strategies.Strategy):
"""
Abstract base class for strategies that handle ambiguous conditioning.
"""

pass


CondStrategy = Union[
Dict[str, AmbiguousConditioningReparam], AmbiguousConditioningStrategy
]


class AmbiguousConditioningReparamMessenger(
pyro.poutine.reparam_messenger.ReparamMessenger
):
config: CondStrategy

def _pyro_sample(self, msg: pyro.infer.reparam.reparam.ReparamMessage) -> None:
if site_is_ambiguous(msg):
expand_obs_value_inplace_(msg)
msg["infer"]["_specified_conditioning"] = False
super()._pyro_sample(msg)
msg["infer"]["_specified_conditioning"] = True


class ConditionReparamMsg(TypedDict):
fn: pyro.distributions.Distribution
value: torch.Tensor
is_observed: Literal[True]


class ConditionReparamArgMsg(ConditionReparamMsg):
name: str


class FactualConditioningReparam(AmbiguousConditioningReparam):
"""
Factual conditioning reparameterizer.
This :class:`pyro.infer.reparam.reparam.Reparam` is used to resolve inherent
semantic ambiguity in conditioning in the presence of interventions by
splitting the observed value into a factual and counterfactual component,
associating the observed value with the factual random variable,
and sampling the counterfactual random variable from its prior.
"""

@pyro.poutine.infer_config(config_fn=no_ambiguity)
def apply(self, msg: ConditionReparamArgMsg) -> ConditionReparamMsg:
with SelectFactual():
fv = pyro.sample(msg["name"] + "_factual", msg["fn"], obs=msg["value"])

with SelectCounterfactual():
cv = pyro.sample(msg["name"] + "_counterfactual", msg["fn"])

event_dim = len(msg["fn"].event_shape)
fw_indices = get_factual_indices()
new_value: torch.Tensor = scatter(
fv, fw_indices, result=cv.clone(), event_dim=event_dim
)
new_fn = dist.Delta(new_value, event_dim=event_dim).mask(False)
return {"fn": new_fn, "value": new_value, "is_observed": True}


class MinimalFactualConditioning(AmbiguousConditioningStrategy):
"""
Reparameterization strategy for handling ambiguity in conditioning, for use with
Effect handler for handling ambiguity in conditioning, for use with
counterfactual semantics handlers such as :class:`MultiWorldCounterfactual` .
:class:`MinimalFactualConditioning` applies :class:`FactualConditioningReparam`
instances to all ambiguous sample sites in a model.
.. note::
A sample site is ambiguous if it is marked observed, is downstream of an intervention,
and the observed value's index variables are a strict subset of the distribution's
indices and hence require clarification of which entries of the random variable
are fixed/observed (as opposed to random/unobserved).
"""

def configure(
self, msg: pyro.infer.reparam.reparam.ReparamMessage
) -> Optional[FactualConditioningReparam]:
if not site_is_ambiguous(msg):
return None

return FactualConditioningReparam()


class ConditionTransformReparamMsg(TypedDict):
fn: pyro.distributions.TransformedDistribution
value: torch.Tensor
is_observed: Literal[True]
def _pyro_post_sample(self, msg: dict) -> None:
# expand latent values to include all index plates
if not msg["is_observed"] and not pyro.poutine.util.site_is_subsample(msg):
rv, value, event_dim = msg["fn"], msg["value"], len(msg["fn"].event_shape)
index_plates = get_index_plates()

new_shape = list(value.shape)
for k in set(indices_of(rv)) - set(indices_of(value, event_dim=event_dim)):
dim = index_plates[k].dim
new_shape = [1] * ((event_dim - dim) - len(new_shape)) + new_shape
new_shape[dim - event_dim] = rv.batch_shape[dim]

class ConditionTransformReparamArgMsg(ConditionTransformReparamMsg):
name: str
msg["value"] = value.expand(tuple(new_shape))

def _pyro_observe(self, msg: dict) -> None:
if "name" not in msg["kwargs"]:
msg["kwargs"]["name"] = msg["name"]

class ConditionTransformReparam(AmbiguousConditioningReparam):
def apply(
self, msg: ConditionTransformReparamArgMsg
) -> ConditionTransformReparamMsg:
name, fn, value = msg["name"], msg["fn"], msg["value"]
if not site_is_ambiguous(msg):
return

msg["value"] = self._dispatched_observe(*msg["args"], name=msg["name"])
msg["done"] = True
msg["stop"] = True

@functools.singledispatchmethod
def _dispatched_observe(self, rv, obs: torch.Tensor, name: str) -> torch.Tensor:
raise NotImplementedError

@_dispatched_observe.register(dist.FoldedDistribution)
@_dispatched_observe.register(dist.Distribution)
def _observe_dist(
self, rv: dist.Distribution, obs: torch.Tensor, name: str
) -> torch.Tensor:
with pyro.poutine.infer_config(config_fn=no_ambiguity):
with SelectFactual():
fv = pyro.sample(name + "_factual", rv, obs=obs)

with SelectCounterfactual():
cv = pyro.sample(name + "_counterfactual", rv)

event_dim = len(rv.event_shape)
fw_indices = get_factual_indices()
new_value: torch.Tensor = scatter(
fv, fw_indices, result=cv.clone(), event_dim=event_dim
)
new_rv = dist.Delta(new_value, event_dim=event_dim).mask(False)
return pyro.sample(name, new_rv, obs=new_value)

@_dispatched_observe.register
def _observe_tfmdist(
self, rv: dist.TransformedDistribution, value: torch.Tensor, name: str
) -> torch.Tensor:
tfm = (
fn.transforms[-1]
if len(fn.transforms) == 1
else dist.transforms.ComposeTransformModule(fn.transforms)
rv.transforms[-1]
if len(rv.transforms) == 1
else dist.transforms.ComposeTransformModule(rv.transforms)
)
noise_dist = fn.base_dist
noise_dist = rv.base_dist
noise_event_dim = len(noise_dist.event_shape)
obs_event_dim = len(fn.event_shape)
obs_event_dim = len(rv.event_shape)

# factual world
with SelectFactual(), pyro.poutine.infer_config(config_fn=no_ambiguity):
Expand All @@ -188,7 +98,8 @@ def apply(
obs_noise = gather(obs_noise, fw, event_dim=noise_event_dim).expand(
obs_noise.shape
)
obs_noise = pyro.sample(name + "_noise_prior", noise_dist, obs=obs_noise)
# obs_noise = pyro.sample(name + "_noise_prior", noise_dist, obs=obs_noise)
obs_noise = observe(noise_dist, obs_noise, name=name + "_noise_prior")

# counterfactual world
with SelectCounterfactual(), pyro.poutine.infer_config(config_fn=no_ambiguity):
Expand All @@ -201,45 +112,4 @@ def apply(
value, fw, result=cf_obs_value.clone(), event_dim=obs_event_dim
)
new_fn = dist.Delta(new_value, event_dim=obs_event_dim).mask(False)
return {"fn": new_fn, "value": new_value, "is_observed": msg["is_observed"]}


class AutoFactualConditioning(MinimalFactualConditioning):
"""
Reparameterization strategy for handling ambiguity in conditioning, for use with
counterfactual semantics handlers such as :class:`MultiWorldCounterfactual` .
When the distribution is a :class:`pyro.distributions.TransformedDistribution`,
:class:`AutoFactualConditioning` automatically applies :class:`ConditionTransformReparam`
to the site. Otherwise, it behaves like :class:`MinimalFactualConditioning` .
.. note::
This strategy is applied by default via :class:`MultiWorldCounterfactual`
and :class:`TwinWorldCounterfactual` unless otherwise specified.
.. note::
A sample site is ambiguous if it is marked observed, is downstream of an intervention,
and the observed value's index variables are a strict subset of the distribution's
indices and hence require clarification of which entries of the random variable
are fixed/observed (as opposed to random/unobserved).
"""

def configure(
self, msg: pyro.infer.reparam.reparam.ReparamMessage
) -> Optional[FactualConditioningReparam]:
if not site_is_ambiguous(msg):
return None

fn = msg["fn"]
while hasattr(fn, "base_dist"):
if isinstance(fn, dist.FoldedDistribution):
return FactualConditioningReparam()
elif isinstance(fn, dist.TransformedDistribution):
return ConditionTransformReparam()
else:
fn = fn.base_dist

return super().configure(msg)
return pyro.sample(name, new_fn, obs=new_value)
17 changes: 4 additions & 13 deletions chirho/counterfactual/handlers/counterfactual.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
from typing import Any, Dict, Optional, TypeVar
from typing import Any, Dict, TypeVar

import pyro

from chirho.counterfactual.handlers.ambiguity import (
AmbiguousConditioningReparamMessenger,
AutoFactualConditioning,
CondStrategy,
)
from chirho.counterfactual.handlers.ambiguity import FactualConditioningMessenger
from chirho.counterfactual.ops import split
from chirho.indexed.handlers import IndexPlatesMessenger
from chirho.indexed.ops import get_index_plates
Expand All @@ -15,16 +11,11 @@
T = TypeVar("T")


class BaseCounterfactualMessenger(AmbiguousConditioningReparamMessenger):
class BaseCounterfactualMessenger(FactualConditioningMessenger):
"""
Base class for counterfactual handlers.
"""

def __init__(self, config: Optional[CondStrategy] = None):
if config is None:
config = AutoFactualConditioning()
super().__init__(config=config)

@staticmethod
def _pyro_intervene(msg: Dict[str, Any]) -> None:
msg["stop"] = True
Expand Down Expand Up @@ -70,7 +61,7 @@ def _pyro_split(cls, msg: Dict[str, Any]) -> None:
name = msg["name"] if msg["name"] is not None else cls.default_name
index_plates = get_index_plates()
if name in index_plates:
name = f"{name}_{len(index_plates)}"
name = f"{name}__dup_{len(index_plates)}"
msg["kwargs"]["name"] = msg["name"] = name


Expand Down
47 changes: 25 additions & 22 deletions chirho/counterfactual/internals.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
import pyro.infer.reparam
import torch
from typing import Any, Dict

from chirho.indexed.ops import indices_of, union

def expand_obs_value_inplace_(msg: pyro.infer.reparam.reparam.ReparamMessage) -> None:

def site_is_ambiguous(msg: Dict[str, Any]) -> bool:
"""
Helper function used with :func:`observe` to determine
whether a site is observed or ambiguous.
A sample site is ambiguous if it is marked observed, is downstream of an intervention,
and the observed value's index variables are a strict subset of the distribution's
indices and hence require clarification of which entries of the random variable
are fixed/observed (as opposed to random/unobserved).
"""
rv, obs = msg["args"][:2]
value_indices = indices_of(obs, event_dim=len(rv.event_shape))
dist_indices = indices_of(rv)
return (
bool(union(value_indices, dist_indices)) and value_indices != dist_indices
) or not msg["infer"].get("_specified_conditioning", True)


def no_ambiguity(msg: Dict[str, Any]) -> Dict[str, Any]:
"""
Slightly gross workaround that mutates the msg in place
to avoid triggering overzealous validation logic in
:class:~`pyro.poutine.reparam.ReparamMessenger`
that uses cheaper tensor shape and identity equality checks as
a conservative proxy for an expensive tensor value equality check.
(see https://github.com/pyro-ppl/pyro/blob/685c7adee65bbcdd6bd6c84c834a0a460f2224eb/pyro/poutine/reparam_messenger.py#L99) # noqa: E501
This workaround is correct because these reparameterizers do not change
the observed entries, it just packs counterfactual values around them;
the equality check being approximated by that logic would still pass.
Helper function used with :func:`pyro.poutine.infer_config` to inform
:class:`FactualConditioningMessenger` that all ambiguity in the current
context has been resolved.
"""
msg["value"] = torch.as_tensor(msg["value"])
msg["infer"]["orig_shape"] = msg["value"].shape
_custom_init = getattr(msg["value"], "_pyro_custom_init", False)
msg["value"] = msg["value"].expand(
torch.broadcast_shapes(
msg["fn"].batch_shape + msg["fn"].event_shape,
msg["value"].shape,
)
)
setattr(msg["value"], "_pyro_custom_init", _custom_init)
return {"_specified_conditioning": True}
Loading

0 comments on commit f659fa8

Please sign in to comment.