Skip to content

Commit

Permalink
staging-causality (explainable reasoning as a separate module) (#441)
Browse files Browse the repository at this point in the history
* removed explanation code (#411)

* add `preempt` and its tests (#414)

* added preempt and its tests

* refactored test to use new code

* lint

* added uniform proposal (#416)

* add consequent_differs and a test thereof (#418)

* added preempt and its tests

* refactored test to use new code

* lint

* added consequent_differs and a test thereof

---------

Co-authored-by: eb8680 <[email protected]>

* add Preemptions and a test thereof (#420)

* added preempt and its tests

* refactored test to use new code

* lint

* added Preemptions and a test thereof

* added `undo_split` and a test thereof (#422)

* import

---------

Co-authored-by: Eli <[email protected]>

* added SplitSubsets and a test thereof (#424)

* added preempt and its tests

* refactored test to use new code

* lint

* added Preemptions and a test thereof

* added `undo_split` and a test thereof

* added SplitSubsets and a test thereof

* added `undo_split` and a test thereof (#422)

* import

---------

Co-authored-by: Eli <[email protected]>
Co-authored-by: eb8680 <[email protected]>

* added `random intervention` and a test thereof (#442)

* add `SearchForExplanation` (#445)

* added `random intervention` and a test thereof

* added `SearchForExplanation` and a test thereof

---------

Co-authored-by: eb8680 <[email protected]>

* explainable documentation update (#457)

* docstrings update WIP

* small typo

* Remove explainable references from chirho.counterfactual (#458)

* implement `soft_eq` and `soft_neq`  (#472)

* add soft_neq and a few tests in test_ops

* sof eq WIP

* conversion to soft_eq

* defaulting soft_eq scale to .1

* Tavares conditions have landed

* soft_neq with tests

* added docstring for soft_neq

* refactor soft_eq and soft_neq

* tests

* remove failing tests

* comment

* sign

* move soft_neq to internals for now

* nit

* remove comment

---------

Co-authored-by: Eli <[email protected]>

* Add explainable module to Sphinx build (#481)

* Add explainable module to sphinx build

* Add explainable module to sphinx build

* Reorganize code in chirho.explainable (#485)

* Add explainable module to sphinx build

* Add explainable module to sphinx build

* reorganize codebase

* remove empty file

* rename alternatives

* reorganize test files

* sphinx

* remove test_defaults

* implement soft_neq and use downstream in the explainable module (#490)

* that's it

* lint

* lint

* preemptions type as required by runner lint

* persistent lint typing error, using Any

* reverting to Preemptions

* add revised actual causality notebook (#446)

* added `random intervention` and a test thereof

* added `SearchForExplanation` and a test thereof

* added ac notebook with a test

* small cleanup

* that's it

* rerun actual causality and added it to rst

* actual causality nb WIP

* lint

* proofreading

* lint

* preemptions type as required by runner lint

* persistent lint typing error, using Any

* reverting to Preemptions

* lint, add actual_causality to notebook tests again

* removed redundant alternatives.py

* remove redundant split_subset.py

* fixing docstrings

* format & lint

* pulled explanation from ru-propagate-...

* docstrings in explanation.py

* typo in ac nb

---------

Co-authored-by: eb8680 <[email protected]>

* allow for empty witness in SearchForCause (#491)

* allow for empty witness in SearchForCause

* removed mwc_empty from a test

* allow empty ant and con

* Implement `InferConstraints` (#493)

* experiment done

* cleanup, docstring

* eliminate args from ExtractSupport, move to components

* lint

---------

Co-authored-by: eb8680 <[email protected]>
Co-authored-by: Eli <[email protected]>
  • Loading branch information
3 people authored Jan 12, 2024
1 parent 4cfa291 commit 804b846
Show file tree
Hide file tree
Showing 21 changed files with 3,287 additions and 998 deletions.
85 changes: 3 additions & 82 deletions chirho/counterfactual/handlers/counterfactual.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from __future__ import annotations

from typing import Any, Dict, Generic, Mapping, TypeVar
from typing import Any, Dict, TypeVar

import pyro
import torch

from chirho.counterfactual.handlers.ambiguity import FactualConditioningMessenger
from chirho.counterfactual.ops import preempt, split
from chirho.counterfactual.ops import split
from chirho.indexed.handlers import IndexPlatesMessenger
from chirho.indexed.ops import get_index_plates
from chirho.interventional.ops import Intervention, intervene
from chirho.interventional.ops import intervene

T = TypeVar("T")

Expand All @@ -25,10 +24,6 @@ class BaseCounterfactualMessenger(FactualConditioningMessenger):
:func:`~chirho.interventional.ops.intervene` by instantiating the primitive operation
:func:`~chirho.counterfactual.ops.split`, which is then subsequently handled by subclasses
such as :class:`~chirho.counterfactual.handlers.counterfactual.MultiWorldCounterfactual`.
In addition, :class:`~chirho.counterfactual.handlers.counterfactual.BaseCounterfactualMessenger`
handles :func:`~chirho.counterfactual.ops.preempt` operations by introducing an auxiliary categorical
variable at each of the preempted addresses.
"""

@staticmethod
Expand Down Expand Up @@ -196,77 +191,3 @@ class TwinWorldCounterfactual(IndexPlatesMessenger, BaseCounterfactualMessenger)
@classmethod
def _pyro_split(cls, msg: Dict[str, Any]) -> None:
msg["kwargs"]["name"] = msg["name"] = cls.default_name


class Preemptions(Generic[T], pyro.poutine.messenger.Messenger):
"""
Effect handler that applies the operation :func:`~chirho.counterfactual.ops.preempt`
to sample sites in a probabilistic program,
similar to the handler :func:`~chirho.observational.handlers.condition`
for :func:`~chirho.observational.ops.observe` .
or the handler :func:`~chirho.interventional.handlers.do`
for :func:`~chirho.interventional.ops.intervene` .
See the documentation for :func:`~chirho.counterfactual.ops.preempt` for more details.
This handler introduces an auxiliary discrete random variable at each preempted sample site
whose name is the name of the sample site prefixed by ``prefix``, and
whose value is used as the ``case`` argument to :func:`preempt`,
to determine whether the preemption returns the present value of the site
or the new value specified for the site in ``actions``
The distributions of the auxiliary discrete random variables are parameterized by ``bias``.
By default, ``bias == 0`` and the value returned by the sample site is equally likely
to be the factual case (i.e. the present value of the site) or one of the counterfactual cases
(i.e. the new value(s) specified for the site in ``actions``).
When ``0 < bias <= 0.5``, the preemption is less than equally likely to occur.
When ``-0.5 <= bias < 0``, the preemption is more than equally likely to occur.
More specifically, the probability of the factual case is ``0.5 - bias``,
and the probability of each counterfactual case is ``(0.5 + bias) / num_actions``,
where ``num_actions`` is the number of counterfactual actions for the sample site (usually 1).
:param actions: A mapping from sample site names to interventions.
:param bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5.
:param prefix: The prefix for naming the auxiliary discrete random variables.
"""

actions: Mapping[str, Intervention[T]]
prefix: str
bias: float

def __init__(
self,
actions: Mapping[str, Intervention[T]],
*,
prefix: str = "__witness_split_",
bias: float = 0.0,
):
assert -0.5 <= bias <= 0.5, "bias must be between -0.5 and 0.5"
self.actions = actions
self.bias = bias
self.prefix = prefix
super().__init__()

def _pyro_post_sample(self, msg):
try:
action = self.actions[msg["name"]]
except KeyError:
return

action = (action,) if not isinstance(action, tuple) else action
num_actions = len(action) if isinstance(action, tuple) else 1
weights = torch.tensor(
[0.5 - self.bias] + ([(0.5 + self.bias) / num_actions] * num_actions),
device=msg["value"].device,
)
case_dist = pyro.distributions.Categorical(probs=weights)
case = pyro.sample(f"{self.prefix}{msg['name']}", case_dist)

msg["value"] = preempt(
msg["value"],
action,
case,
event_dim=len(msg["fn"].event_shape),
name=f"{self.prefix}{msg['name']}",
)
Loading

0 comments on commit 804b846

Please sign in to comment.