Skip to content

Commit

Permalink
Add a BiasedPreemptions handler (#239)
Browse files Browse the repository at this point in the history
* Add BiasedPreemptions handler

* test and tiny refactoring
  • Loading branch information
eb8680 authored Sep 4, 2023
1 parent 7c2f163 commit d372f96
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 10 deletions.
96 changes: 90 additions & 6 deletions chirho/counterfactual/handlers/counterfactual.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,14 @@ def _pyro_intervene(msg: Dict[str, Any]) -> None:
@staticmethod
def _pyro_preempt(msg: Dict[str, Any]) -> None:
obs, acts, case = msg["args"]
msg["kwargs"]["name"] = f"__split_{msg['name']}"
if msg["kwargs"].get("name", None) is None:
msg["kwargs"]["name"] = msg["name"]

if case is not None:
return

case_dist = pyro.distributions.Categorical(torch.ones(len(acts) + 1))
case = pyro.sample(msg["kwargs"]["name"], case_dist.mask(False), obs=case)
case = pyro.sample(msg["name"], case_dist.mask(False), obs=case)
msg["args"] = (obs, acts, case)


Expand Down Expand Up @@ -100,23 +105,102 @@ class Preemptions(Generic[T], pyro.poutine.messenger.Messenger):
or one of its subclasses, typically from an auxiliary discrete random variable.
:param actions: A mapping from sample site names to interventions.
:param prefix: Prefix usable for naming any auxiliary random variables.
"""

actions: Mapping[str, Intervention[T]]
prefix: str

def __init__(self, actions: Mapping[str, Intervention[T]]):
def __init__(
self, actions: Mapping[str, Intervention[T]], *, prefix: str = "__split_"
):
self.actions = actions
self.prefix = prefix
super().__init__()

def _pyro_post_sample(self, msg: Dict[str, Any]) -> None:
def _pyro_post_sample(self, msg):
try:
action = self.actions[msg["name"]]
except KeyError:
return
msg["value"] = preempt(
msg["value"],
(action,),
(action,) if not isinstance(action, tuple) else action,
None,
event_dim=len(msg["fn"].event_shape),
name=msg["name"],
name=f"{self.prefix}{msg['name']}",
)


class BiasedPreemptions(pyro.poutine.messenger.Messenger):
"""
Effect handler that applies the operation :func:`~chirho.counterfactual.ops.preempt`
to sample sites in a probabilistic program,
similar to the handler :func:`~chirho.observational.handlers.condition`
for :func:`~chirho.observational.ops.observe` .
or the handler :func:`~chirho.interventional.handlers.do`
for :func:`~chirho.interventional.ops.intervene` .
See the documentation for :func:`~chirho.counterfactual.ops.preempt` for more details.
This handler introduces an auxiliary discrete random variable at each preempted sample site
whose name is the name of the sample site prefixed by ``prefix``, and
whose value is used as the ``case`` argument to :func:`preempt`,
to determine whether the preemption returns the present value of the site
or the new value specified for the site in ``actions``
The distributions of the auxiliary discrete random variables are parameterized by ``bias``.
By default, ``bias == 0`` and the value returned by the sample site is equally likely
to be the factual case (i.e. the present value of the site) or one of the counterfactual cases
(i.e. the new value(s) specified for the site in ``actions``).
When ``0 < bias <= 0.5``, the preemption is less than equally likely to occur.
When ``-0.5 <= bias < 0``, the preemption is more than equally likely to occur.
More specifically, the probability of the factual case is ``0.5 - bias``,
and the probability of each counterfactual case is ``(0.5 + bias) / num_actions``,
where ``num_actions`` is the number of counterfactual actions for the sample site (usually 1).
:param actions: A mapping from sample site names to interventions.
:param bias: The scalar bias towards the factual case. Must be between -0.5 and 0.5.
:param prefix: The prefix for naming the auxiliary discrete random variables.
"""

actions: Mapping[str, Intervention[torch.Tensor]]
bias: float
prefix: str

def __init__(
self,
actions: Mapping[str, Intervention[torch.Tensor]],
*,
bias: float = 0.0,
prefix: str = "__witness_split_",
):
assert -0.5 <= bias <= 0.5, "bias must be between -0.5 and 0.5"
self.actions = actions
self.bias = bias
self.prefix = prefix
super().__init__()

def _pyro_post_sample(self, msg):
try:
action = self.actions[msg["name"]]
except KeyError:
return

action = (action,) if not isinstance(action, tuple) else action
num_actions = len(action) if isinstance(action, tuple) else 1
weights = torch.tensor(
[0.5 - self.bias] + ([(0.5 + self.bias) / num_actions] * num_actions),
device=msg["value"].device,
)
case_dist = pyro.distributions.Categorical(probs=weights)
case = pyro.sample(f"{self.prefix}{msg['name']}", case_dist)

msg["value"] = preempt(
msg["value"],
action,
case,
event_dim=len(msg["fn"].event_shape),
name=f"{self.prefix}{msg['name']}",
)
15 changes: 11 additions & 4 deletions tests/counterfactual/test_counterfactual_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
SingleWorldFactual,
TwinWorldCounterfactual,
)
from chirho.counterfactual.handlers.counterfactual import Preemptions
from chirho.counterfactual.handlers.counterfactual import BiasedPreemptions, Preemptions
from chirho.counterfactual.handlers.selection import SelectFactual
from chirho.counterfactual.ops import preempt, split
from chirho.indexed.ops import IndexSet, gather, indices_of, union
Expand Down Expand Up @@ -675,14 +675,14 @@ def model():

@pytest.mark.parametrize("cf_dim", [-2, -3, None])
@pytest.mark.parametrize("event_shape", [(), (4,), (4, 3)])
def test_cf_handler_preemptions(cf_dim, event_shape):
@pytest.mark.parametrize("use_biased_preemption", [False, True])
def test_cf_handler_preemptions(cf_dim, event_shape, use_biased_preemption):
event_dim = len(event_shape)

splits = {"x": torch.tensor(0.0)}
preemptions = {"y": torch.tensor(1.0)}

@do(actions=splits)
@Preemptions(actions=preemptions)
@pyro.plate("data", size=1000, dim=-1)
def model():
w = pyro.sample(
Expand All @@ -693,7 +693,14 @@ def model():
z = pyro.sample("z", dist.Normal(x + y, 1).to_event(len(event_shape)))
return dict(w=w, x=x, y=y, z=z)

with MultiWorldCounterfactual(cf_dim):
if use_biased_preemption:
preemption_handler = BiasedPreemptions(
actions=preemptions, bias=0.1, prefix="__split_"
)
else:
preemption_handler = Preemptions(actions=preemptions)

with MultiWorldCounterfactual(cf_dim), preemption_handler:
tr = pyro.poutine.trace(model).get_trace()
assert all(f"__split_{k}" in tr.nodes for k in preemptions.keys())
assert indices_of(tr.nodes["w"]["value"], event_dim=event_dim) == IndexSet()
Expand Down

0 comments on commit d372f96

Please sign in to comment.