Skip to content

Commit

Permalink
Add preempt operation (#142)
Browse files Browse the repository at this point in the history
* Refactor and simplify counterfactual handlers

* lint

* Add cond operation

* Add preempt operation

* Add broadcasting test cases

* add broadcasting test cases

* docstring for cond

* Add handler for inserting preempt calls (#143)

* Add handler for inserting preempt calls

* fix

* add a unit test

* add a unit test

* event shape in test

* start docstring

* docstring and lint
  • Loading branch information
eb8680 authored Aug 7, 2023
1 parent 7508a0a commit f6f428d
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 8 deletions.
55 changes: 52 additions & 3 deletions chirho/counterfactual/handlers/counterfactual.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Any, Dict, TypeVar
from typing import Any, Dict, Generic, Mapping, TypeVar

import pyro
import torch

from chirho.counterfactual.handlers.ambiguity import FactualConditioningMessenger
from chirho.counterfactual.ops import split
from chirho.counterfactual.ops import preempt, split
from chirho.indexed.handlers import IndexPlatesMessenger
from chirho.indexed.ops import get_index_plates
from chirho.interventional.ops import intervene
from chirho.interventional.ops import Intervention, intervene

T = TypeVar("T")

Expand All @@ -26,6 +27,14 @@ def _pyro_intervene(msg: Dict[str, Any]) -> None:
msg["value"] = split(obs, acts, name=msg["name"], **msg["kwargs"])
msg["done"] = True

@staticmethod
def _pyro_preempt(msg: Dict[str, Any]) -> None:
obs, acts, case = msg["args"]
msg["kwargs"]["name"] = f"__split_{msg['name']}"
case_dist = pyro.distributions.Categorical(torch.ones(len(acts) + 1))
case = pyro.sample(msg["kwargs"]["name"], case_dist.mask(False), obs=case)
msg["args"] = (obs, acts, case)


class SingleWorldCounterfactual(BaseCounterfactualMessenger):
"""
Expand Down Expand Up @@ -71,3 +80,43 @@ class TwinWorldCounterfactual(IndexPlatesMessenger, BaseCounterfactualMessenger)
@classmethod
def _pyro_split(cls, msg: Dict[str, Any]) -> None:
msg["kwargs"]["name"] = msg["name"] = cls.default_name


class Preemptions(Generic[T], pyro.poutine.messenger.Messenger):
"""
Effect handler that applies the operation :func:`~chirho.counterfactual.ops.preempt`
to sample sites in a probabilistic program,
similar to the handler :func:`~chirho.observational.handlers.condition`
for :func:`~chirho.observational.ops.observe` .
or the handler :func:`~chirho.interventional.handlers.do`
for :func:`~chirho.interventional.ops.intervene` .
See the documentation for :func:`~chirho.counterfactual.ops.preempt` for more details.
.. note:: This handler does not allow the direct specification of the ``case`` argument
to :func:`~chirho.counterfactual.ops.preempt` and therefore cannot be used alone.
Instead, the ``case`` argument to :func:`preempt` is assumed to be set separately
by :class:`~chirho.counterfactual.handlers.counterfactual.BaseCounterfactualMessenger`
or one of its subclasses, typically from an auxiliary discrete random variable.
:param actions: A mapping from sample site names to interventions.
"""

actions: Mapping[str, Intervention[T]]

def __init__(self, actions: Mapping[str, Intervention[T]]):
self.actions = actions
super().__init__()

def _pyro_post_sample(self, msg: Dict[str, Any]) -> None:
try:
action = self.actions[msg["name"]]
except KeyError:
return
msg["value"] = preempt(
msg["value"],
(action,),
None,
event_dim=len(msg["fn"].event_shape),
name=msg["name"],
)
36 changes: 34 additions & 2 deletions chirho/counterfactual/ops.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Tuple, TypeVar
from typing import Optional, Tuple, TypeVar

import pyro

from chirho.indexed.ops import IndexSet, scatter
from chirho.indexed.ops import IndexSet, cond, scatter
from chirho.interventional.ops import Intervention, intervene

S = TypeVar("S")
Expand All @@ -21,3 +21,35 @@ def split(obs: T, acts: Tuple[Intervention[T], ...], **kwargs) -> T:
act_values[IndexSet(**{name: {i + 1}})] = intervene(obs, act, **kwargs)

return scatter(act_values, event_dim=kwargs.get("event_dim", 0))


@pyro.poutine.runtime.effectful(type="preempt")
@pyro.poutine.block(hide_types=["intervene"])
def preempt(
obs: T, acts: Tuple[Intervention[T], ...], case: Optional[S] = None, **kwargs
) -> T:
"""
Effectful primitive operation for "preempting" values in a probabilistic program.
Unlike the counterfactual operation :func:`~chirho.counterfactual.ops.split`,
which returns multiple values concatenated along a new axis
via the operation :func:`~chirho.indexed.ops.scatter`,
:func:`preempt` returns a single value determined by the argument ``case``
via :func:`~chirho.indexed.ops.cond` .
In a probabilistic program, a :func:`preempt` call induces a mixture distribution
over downstream values, whereas :func:`split` would induce a joint distribution.
:param obs: The observed value.
:param acts: The interventions to apply.
:param case: The case to select.
"""
if case is None:
return obs

name = kwargs.get("name", None)
act_values = {IndexSet(**{name: {0}}): obs}
for i, act in enumerate(acts):
act_values[IndexSet(**{name: {i + 1}})] = intervene(obs, act, **kwargs)

return cond(act_values, case, event_dim=kwargs.get("event_dim", 0))
9 changes: 6 additions & 3 deletions chirho/indexed/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _scatter_n(values: Dict[IndexSet, T], *, result: Optional[T] = None, **kwarg


@functools.singledispatch
def cond(fst, snd: T, case, **kwargs):
def cond(fst, snd, case: Optional[T] = None, **kwargs):
"""
Selection operation that is the sum-type analogue of :func:`scatter`
in the sense that where :func:`scatter` propagates both of its arguments,
Expand Down Expand Up @@ -303,8 +303,11 @@ def _cond_n(values: Dict[IndexSet, T], case: Union[bool, torch.Tensor], **kwargs
assert all(isinstance(k, IndexSet) for k in values.keys())
result: Optional[T] = None
for indices, value in values.items():
tst = functools.reduce(
operator.or_, [case == index for index in next(iter(indices.values()))]
tst = torch.as_tensor(
functools.reduce(
operator.or_, [case == index for index in next(iter(indices.values()))]
),
dtype=torch.bool,
)
result = cond(result if result is not None else value, value, tst, **kwargs)
return result
Expand Down
57 changes: 57 additions & 0 deletions tests/counterfactual/test_counterfactual_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
SingleWorldFactual,
TwinWorldCounterfactual,
)
from chirho.counterfactual.handlers.counterfactual import Preemptions
from chirho.counterfactual.handlers.selection import SelectFactual
from chirho.counterfactual.ops import preempt, split
from chirho.indexed.ops import IndexSet, gather, indices_of, union
from chirho.interventional.handlers import do
from chirho.interventional.ops import intervene
Expand Down Expand Up @@ -648,6 +650,61 @@ def model():
assert torch.any(tr.nodes["z"]["value"] < 1)


def test_preempt_op_singleworld():
@SingleWorldCounterfactual()
@pyro.plate("data", size=1000, dim=-1)
def model():
x = pyro.sample("x", dist.Bernoulli(0.67))
x = pyro.deterministic(
"x_", split(x, (torch.tensor(0.0),), name="x", event_dim=0), event_dim=0
)
y = pyro.sample("y", dist.Bernoulli(0.67))
y_case = torch.tensor(1)
y = pyro.deterministic(
"y_",
preempt(y, (torch.tensor(1.0),), y_case, name="__y", event_dim=0),
event_dim=0,
)
z = pyro.sample("z", dist.Bernoulli(0.67))
return dict(x=x, y=y, z=z)

tr = pyro.poutine.trace(model).get_trace()
assert torch.all(tr.nodes["x_"]["value"] == 0.0)
assert torch.all(tr.nodes["y_"]["value"] == 1.0)


@pytest.mark.parametrize("cf_dim", [-2, -3, None])
@pytest.mark.parametrize("event_shape", [(), (4,), (4, 3)])
def test_cf_handler_preemptions(cf_dim, event_shape):
event_dim = len(event_shape)

splits = {"x": torch.tensor(0.0)}
preemptions = {"y": torch.tensor(1.0)}

@do(actions=splits)
@Preemptions(actions=preemptions)
@pyro.plate("data", size=1000, dim=-1)
def model():
w = pyro.sample(
"w", dist.Normal(0, 1).expand(event_shape).to_event(len(event_shape))
)
x = pyro.sample("x", dist.Normal(w, 1).to_event(len(event_shape)))
y = pyro.sample("y", dist.Normal(w + x, 1).to_event(len(event_shape)))
z = pyro.sample("z", dist.Normal(x + y, 1).to_event(len(event_shape)))
return dict(w=w, x=x, y=y, z=z)

with MultiWorldCounterfactual(cf_dim):
tr = pyro.poutine.trace(model).get_trace()
assert all(f"__split_{k}" in tr.nodes for k in preemptions.keys())
assert indices_of(tr.nodes["w"]["value"], event_dim=event_dim) == IndexSet()
assert indices_of(tr.nodes["y"]["value"], event_dim=event_dim) == IndexSet(
x={0, 1}
)
assert indices_of(tr.nodes["z"]["value"], event_dim=event_dim) == IndexSet(
x={0, 1}
)


# Define a helper function to run SVI. (Generally, Pyro users like to have more control over the training process!)
def run_svi_inference(model, n_steps=1000, verbose=True, lr=0.03, **model_kwargs):
guide = AutoMultivariateNormal(model)
Expand Down

0 comments on commit f6f428d

Please sign in to comment.