diff --git a/chirho/counterfactual/handlers/counterfactual.py b/chirho/counterfactual/handlers/counterfactual.py index cbed31726..03b67eb72 100644 --- a/chirho/counterfactual/handlers/counterfactual.py +++ b/chirho/counterfactual/handlers/counterfactual.py @@ -1,15 +1,14 @@ from __future__ import annotations -from typing import Any, Dict, Generic, Mapping, TypeVar +from typing import Any, Dict, TypeVar import pyro -import torch from chirho.counterfactual.handlers.ambiguity import FactualConditioningMessenger -from chirho.counterfactual.ops import preempt, split +from chirho.counterfactual.ops import split from chirho.indexed.handlers import IndexPlatesMessenger from chirho.indexed.ops import get_index_plates -from chirho.interventional.ops import Intervention, intervene +from chirho.interventional.ops import intervene T = TypeVar("T") @@ -25,10 +24,6 @@ class BaseCounterfactualMessenger(FactualConditioningMessenger): :func:`~chirho.interventional.ops.intervene` by instantiating the primitive operation :func:`~chirho.counterfactual.ops.split`, which is then subsequently handled by subclasses such as :class:`~chirho.counterfactual.handlers.counterfactual.MultiWorldCounterfactual`. - - In addition, :class:`~chirho.counterfactual.handlers.counterfactual.BaseCounterfactualMessenger` - handles :func:`~chirho.counterfactual.ops.preempt` operations by introducing an auxiliary categorical - variable at each of the preempted addresses. """ @staticmethod @@ -196,77 +191,3 @@ class TwinWorldCounterfactual(IndexPlatesMessenger, BaseCounterfactualMessenger) @classmethod def _pyro_split(cls, msg: Dict[str, Any]) -> None: msg["kwargs"]["name"] = msg["name"] = cls.default_name - - -class Preemptions(Generic[T], pyro.poutine.messenger.Messenger): - """ - Effect handler that applies the operation :func:`~chirho.counterfactual.ops.preempt` - to sample sites in a probabilistic program, - similar to the handler :func:`~chirho.observational.handlers.condition` - for :func:`~chirho.observational.ops.observe` . - or the handler :func:`~chirho.interventional.handlers.do` - for :func:`~chirho.interventional.ops.intervene` . - - See the documentation for :func:`~chirho.counterfactual.ops.preempt` for more details. - - This handler introduces an auxiliary discrete random variable at each preempted sample site - whose name is the name of the sample site prefixed by ``prefix``, and - whose value is used as the ``case`` argument to :func:`preempt`, - to determine whether the preemption returns the present value of the site - or the new value specified for the site in ``actions`` - - The distributions of the auxiliary discrete random variables are parameterized by ``bias``. - By default, ``bias == 0`` and the value returned by the sample site is equally likely - to be the factual case (i.e. the present value of the site) or one of the counterfactual cases - (i.e. the new value(s) specified for the site in ``actions``). - When ``0 < bias <= 0.5``, the preemption is less than equally likely to occur. - When ``-0.5 <= bias < 0``, the preemption is more than equally likely to occur. - - More specifically, the probability of the factual case is ``0.5 - bias``, - and the probability of each counterfactual case is ``(0.5 + bias) / num_actions``, - where ``num_actions`` is the number of counterfactual actions for the sample site (usually 1). - - :param actions: A mapping from sample site names to interventions. - :param bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5. - :param prefix: The prefix for naming the auxiliary discrete random variables. - """ - - actions: Mapping[str, Intervention[T]] - prefix: str - bias: float - - def __init__( - self, - actions: Mapping[str, Intervention[T]], - *, - prefix: str = "__witness_split_", - bias: float = 0.0, - ): - assert -0.5 <= bias <= 0.5, "bias must be between -0.5 and 0.5" - self.actions = actions - self.bias = bias - self.prefix = prefix - super().__init__() - - def _pyro_post_sample(self, msg): - try: - action = self.actions[msg["name"]] - except KeyError: - return - - action = (action,) if not isinstance(action, tuple) else action - num_actions = len(action) if isinstance(action, tuple) else 1 - weights = torch.tensor( - [0.5 - self.bias] + ([(0.5 + self.bias) / num_actions] * num_actions), - device=msg["value"].device, - ) - case_dist = pyro.distributions.Categorical(probs=weights) - case = pyro.sample(f"{self.prefix}{msg['name']}", case_dist) - - msg["value"] = preempt( - msg["value"], - action, - case, - event_dim=len(msg["fn"].event_shape), - name=f"{self.prefix}{msg['name']}", - ) diff --git a/chirho/counterfactual/handlers/explanation.py b/chirho/counterfactual/handlers/explanation.py deleted file mode 100644 index 47afb350d..000000000 --- a/chirho/counterfactual/handlers/explanation.py +++ /dev/null @@ -1,300 +0,0 @@ -from __future__ import annotations - -import collections.abc -import contextlib -import functools -import itertools -from typing import Callable, Iterable, Mapping, TypeVar, Union - -import pyro -import torch - -from chirho.counterfactual.handlers.counterfactual import Preemptions -from chirho.counterfactual.handlers.selection import get_factual_indices -from chirho.indexed.ops import IndexSet, cond, gather, indices_of, scatter_n -from chirho.interventional.handlers import do -from chirho.interventional.ops import Intervention -from chirho.observational.handlers.condition import Factors - -S = TypeVar("S") -T = TypeVar("T") - - -def undo_split(antecedents: Iterable[str] = [], event_dim: int = 0) -> Callable[[T], T]: - """ - A helper function that undoes an upstream :func:`~chirho.counterfactual.ops.split` operation, - meant to be used to create arguments to pass to :func:`~chirho.interventional.ops.intervene` , - :func:`~chirho.counterfactual.ops.split` or :func:`~chirho.counterfactual.ops.preempt`. - Works by gathering the factual value and scattering it back into two alternative cases. - - :param antecedents: A list of upstream intervened sites which induced the :func:`split` to be reversed. - :param event_dim: The event dimension of the value to be preempted. - :return: A callable that applied to a site value object returns a site value object in which - the factual value has been scattered back into two alternative cases. - """ - - def _undo_split(value: T) -> T: - antecedents_ = [ - a for a in antecedents if a in indices_of(value, event_dim=event_dim) - ] - - factual_value = gather( - value, - IndexSet(**{antecedent: {0} for antecedent in antecedents_}), - event_dim=event_dim, - ) - - # TODO exponential in len(antecedents) - add an indexed.ops.expand to do this cheaply - return scatter_n( - { - IndexSet( - **{antecedent: {ind} for antecedent, ind in zip(antecedents_, inds)} - ): factual_value - for inds in itertools.product(*[[0, 1]] * len(antecedents_)) - }, - event_dim=event_dim, - ) - - return _undo_split - - -def consequent_differs( - antecedents: Iterable[str] = [], eps: float = -1e8, event_dim: int = 0 -) -> Callable[[T], torch.Tensor]: - """ - A helper function for assessing whether values at a site differ from their observed values, assigning - `eps` if a value differs from its observed state and `0.0` otherwise. - - :param antecedents: A list of names of upstream intervened sites to consider when assessing differences. - :param eps: A numerical value assigned if the values differ, defaults to -1e8. - :param event_dim: The event dimension of the value object. - - :return: A callable which applied to a site value object (`consequent`), returns a tensor where each - element indicates whether the corresponding element of `consequent` differs from its factual value - (`eps` if there is a difference, `0.0` otherwise). - """ - - def _consequent_differs(consequent: T) -> torch.Tensor: - indices = IndexSet( - **{ - name: ind - for name, ind in get_factual_indices().items() - if name in antecedents - } - ) - not_eq: torch.Tensor = consequent != gather( - consequent, indices, event_dim=event_dim - ) - for _ in range(event_dim): - not_eq = torch.all(not_eq, dim=-1, keepdim=False) - return cond(eps, 0.0, not_eq, event_dim=event_dim) - - return _consequent_differs - - -@functools.singledispatch -def uniform_proposal( - support: pyro.distributions.constraints.Constraint, - **kwargs, -) -> pyro.distributions.Distribution: - """ - This function heuristically constructs a probability distribution over a specified - support. The choice of distribution depends on the type of support provided. - - - If the support is `real`, it creates a wide Normal distribution - and standard deviation, defaulting to (0,100). - - If the support is `boolean`, it creates a Bernoulli distribution with a fixed logit of 0, - corresponding to success probability .5. - - If the support is an `interval`, the transformed distribution is centered around the - midpoint of the interval. - - :param support: The support used to create the probability distribution. - :param kwargs: Additional keyword arguments. - :return: A uniform probability distribution over the specified support. - """ - if support is pyro.distributions.constraints.real: - return pyro.distributions.Normal(0, 10).mask(False) - elif support is pyro.distributions.constraints.boolean: - return pyro.distributions.Bernoulli(logits=torch.zeros(())) - else: - tfm = pyro.distributions.transforms.biject_to(support) - base = uniform_proposal(pyro.distributions.constraints.real, **kwargs) - return pyro.distributions.TransformedDistribution(base, tfm) - - -@uniform_proposal.register -def _uniform_proposal_indep( - support: pyro.distributions.constraints.independent, - *, - event_shape: torch.Size = torch.Size([]), - **kwargs, -) -> pyro.distributions.Distribution: - d = uniform_proposal(support.base_constraint, event_shape=event_shape, **kwargs) - return d.expand(event_shape).to_event(support.reinterpreted_batch_ndims) - - -@uniform_proposal.register -def _uniform_proposal_integer( - support: pyro.distributions.constraints.integer_interval, - **kwargs, -) -> pyro.distributions.Distribution: - if support.lower_bound != 0: - raise NotImplementedError( - "integer_interval with lower_bound > 0 not yet supported" - ) - n = support.upper_bound - support.lower_bound + 1 - return pyro.distributions.Categorical(probs=torch.ones((n,))) - - -def random_intervention( - support: pyro.distributions.constraints.Constraint, - name: str, -) -> Callable[[torch.Tensor], torch.Tensor]: - """ - Creates a random-valued intervention for a single sample site, determined by - by the distribution support, and site name. - - :param support: The support constraint for the sample site. - :param name: The name of the auxiliary sample site. - - :return: A function that takes a ``torch.Tensor`` as input - and returns a random sample over the pre-specified support of the same - event shape as the input tensor. - - Example:: - - >>> support = pyro.distributions.constraints.real - >>> intervention_fn = random_intervention(support, name="random_value") - >>> with chirho.interventional.handlers.do(actions={"x": intervention_fn}): - ... x = pyro.deterministic("x", torch.tensor(2.)) - >>> assert x != 2 - """ - - def _random_intervention(value: torch.Tensor) -> torch.Tensor: - event_shape = value.shape[len(value.shape) - support.event_dim :] - proposal_dist = uniform_proposal( - support, - event_shape=event_shape, - ) - return pyro.sample(name, proposal_dist) - - return _random_intervention - - -@contextlib.contextmanager -def SearchForCause( - actions: Mapping[str, Intervention[T]], - *, - bias: float = 0.0, - prefix: str = "__cause_split_", -): - """ - A context manager used for a stochastic search of minimal but-for causes among potential interventions. - On each run, nodes listed in `actions` are randomly selected and intervened on with probability `.5 + bias` - (that is, preempted with probability `.5-bias`). The sampling is achieved by adding stochastic binary preemption - nodes associated with intervention candidates. If a given preemption node has value `0`, the corresponding - intervention is executed. See tests in `tests/counterfactual/test_handlers_explanation.py` for examples. - - :param actions: A mapping of sites to interventions. - :param bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5, defaults to 0.0. - :param prefix: A prefix used for naming additional preemption nodes. Defaults to "__cause_split_". - """ - # TODO support event_dim != 0 propagation in factual_preemption - preemptions = { - antecedent: undo_split(antecedents=[antecedent]) - for antecedent in actions.keys() - } - - with do(actions=actions): - with Preemptions(actions=preemptions, bias=bias, prefix=prefix): - yield - - -@contextlib.contextmanager -def ExplainCauses( - antecedents: Union[ - Mapping[str, Intervention[T]], - Mapping[str, pyro.distributions.constraints.Constraint], - ], - witnesses: Union[Mapping[str, Intervention[T]], Iterable[str]], - consequents: Union[ - Mapping[str, Callable[[T], Union[float, torch.Tensor]]], Iterable[str] - ], - *, - antecedent_bias: float = 0.0, - witness_bias: float = 0.0, - consequent_eps: float = -1e8, - antecedent_prefix: str = "__antecedent_", - witness_prefix: str = "__witness_", - consequent_prefix: str = "__consequent_", -): - """ - Effect handler used for causal explanation search. On each run: - - 1. The antecedent nodes are intervened on with the values in ``antecedents`` \ - using :func:`~chirho.counterfactual.ops.split` . \ - Unless alternative interventions are provided, \ - counterfactual values are uniformly sampled for each antecedent node \ - using :func:`~chirho.counterfactual.handlers.explanation.uniform_proposal` \ - given its support as a :class:`~pyro.distributions.constraints.Constraint` . - - 2. These interventions are randomly :func:`~chirho.counterfactual.ops.preempt`-ed \ - using :func:`~chirho.counterfactual.handlers.explanation.undo_split` \ - by a :func:`~chirho.counterfactual.handlers.explanation.SearchForCause` handler. - - 3. The witness nodes are randomly :func:`~chirho.counterfactual.ops.preempt`-ed \ - to be kept at the values given in ``witnesses`` . - - 4. A :func:`~pyro.factor` node is added tracking whether the consequent nodes differ \ - between the factual and counterfactual worlds. - - :param antecedents: A mapping from antecedent names to interventions. - :param witnesses: A mapping from witness names to interventions. - :param consequents: A mapping from consequent names to factor functions. - """ - if isinstance( - next(iter(antecedents.values())), - pyro.distributions.constraints.Constraint, - ): - antecedents = { - a: random_intervention(s, name=f"{antecedent_prefix}_proposal_{a}") - for a, s in antecedents.items() - } - - if not isinstance(witnesses, collections.abc.Mapping): - witnesses = { - w: undo_split(antecedents=list(antecedents.keys())) for w in witnesses - } - - if not isinstance(consequents, collections.abc.Mapping): - consequents = { - c: consequent_differs( - antecedents=list(antecedents.keys()), eps=consequent_eps - ) - for c in consequents - } - - if len(consequents) == 0: - raise ValueError("must have at least one consequent") - - if len(antecedents) == 0: - raise ValueError("must have at least one antecedent") - - if set(consequents.keys()) & set(antecedents.keys()): - raise ValueError("consequents and possible antecedents must be disjoint") - - if set(consequents.keys()) & set(witnesses.keys()): - raise ValueError("consequents and possible witnesses must be disjoint") - - antecedent_handler = SearchForCause( - actions=antecedents, bias=antecedent_bias, prefix=antecedent_prefix - ) - - witness_handler = Preemptions( - actions=witnesses, bias=witness_bias, prefix=witness_prefix - ) - - consequent_handler = Factors(factors=consequents, prefix=consequent_prefix) - - with antecedent_handler, witness_handler, consequent_handler: - yield diff --git a/chirho/counterfactual/ops.py b/chirho/counterfactual/ops.py index 8817d6042..394c51995 100644 --- a/chirho/counterfactual/ops.py +++ b/chirho/counterfactual/ops.py @@ -1,11 +1,11 @@ from __future__ import annotations import functools -from typing import Optional, Tuple, TypeVar +from typing import Tuple, TypeVar import pyro -from chirho.indexed.ops import IndexSet, cond_n, scatter_n +from chirho.indexed.ops import IndexSet, scatter_n from chirho.interventional.ops import Intervention, intervene S = TypeVar("S") @@ -40,35 +40,3 @@ def split(obs: T, acts: Tuple[Intervention[T], ...], **kwargs) -> T: act_values[IndexSet(**{name: {i + 1}})] = intervene(obs, act, **kwargs) return scatter_n(act_values, event_dim=kwargs.get("event_dim", 0)) - - -@pyro.poutine.runtime.effectful(type="preempt") -@functools.partial(pyro.poutine.block, hide_types=["intervene"]) -def preempt( - obs: T, acts: Tuple[Intervention[T], ...], case: Optional[S] = None, **kwargs -) -> T: - """ - Effectful primitive operation for "preempting" values in a probabilistic program. - - Unlike the counterfactual operation :func:`~chirho.counterfactual.ops.split`, - which returns multiple values concatenated along a new axis - via the operation :func:`~chirho.indexed.ops.scatter`, - :func:`preempt` returns a single value determined by the argument ``case`` - via :func:`~chirho.indexed.ops.cond` . - - In a probabilistic program, a :func:`preempt` call induces a mixture distribution - over downstream values, whereas :func:`split` would induce a joint distribution. - - :param obs: The observed value. - :param acts: The interventions to apply. - :param case: The case to select. - """ - if case is None: - return obs - - name = kwargs.get("name", None) - act_values = {IndexSet(**{name: {0}}): obs} - for i, act in enumerate(acts): - act_values[IndexSet(**{name: {i + 1}})] = intervene(obs, act, **kwargs) - - return cond_n(act_values, case, event_dim=kwargs.get("event_dim", 0)) diff --git a/chirho/explainable/handlers/__init__.py b/chirho/explainable/handlers/__init__.py new file mode 100644 index 000000000..d1baf7ae0 --- /dev/null +++ b/chirho/explainable/handlers/__init__.py @@ -0,0 +1,4 @@ +from .components import random_intervention # noqa: F401 +from .components import ExtractSupports, undo_split # noqa: F401 +from .explanation import SearchForExplanation, SplitSubsets # noqa: F401 +from .preemptions import Preemptions # noqa: F401 diff --git a/chirho/explainable/handlers/components.py b/chirho/explainable/handlers/components.py new file mode 100644 index 000000000..689e2ddc1 --- /dev/null +++ b/chirho/explainable/handlers/components.py @@ -0,0 +1,156 @@ +import itertools +from typing import Callable, Iterable, MutableMapping, TypeVar + +import pyro +import pyro.distributions.constraints as constraints +import torch + +from chirho.counterfactual.handlers.selection import get_factual_indices +from chirho.explainable.internals import soft_neq, uniform_proposal +from chirho.indexed.ops import IndexSet, gather, indices_of, scatter_n + +S = TypeVar("S") +T = TypeVar("T") + + +def random_intervention( + support: constraints.Constraint, + name: str, +) -> Callable[[torch.Tensor], torch.Tensor]: + """ + Creates a random-valued intervention for a single sample site, determined by + by the distribution support, and site name. + + :param support: The support constraint for the sample site. + :param name: The name of the auxiliary sample site. + + :return: A function that takes a `torch.Tensor` as input + and returns a random sample over the pre-specified support of the same + event shape as the input tensor. + + Example:: + + >>> support = pyro.distributions.constraints.real + >>> intervention_fn = random_intervention(support, name="random_value") + >>> with chirho.interventional.handlers.do(actions={"x": intervention_fn}): + ... x = pyro.deterministic("x", torch.tensor(2.)) + >>> assert x != 2 + """ + + def _random_intervention(value: torch.Tensor) -> torch.Tensor: + event_shape = value.shape[len(value.shape) - support.event_dim :] + proposal_dist = uniform_proposal( + support, + event_shape=event_shape, + ) + return pyro.sample(name, proposal_dist) + + return _random_intervention + + +def undo_split( + support: constraints.Constraint, antecedents: Iterable[str] = [] +) -> Callable[[T], T]: + """ + A helper function that undoes an upstream :func:`~chirho.counterfactual.ops.split` operation, + meant to be used to create arguments to pass to :func:`~chirho.interventional.ops.intervene` , + :func:`~chirho.counterfactual.ops.split` or :func:`~chirho.explainable.ops.preempt`. + Works by gathering the factual value and scattering it back into two alternative cases. + + :param support: The support constraint for the site at which :func:`split` is being undone. + :param antecedents: A list of upstream intervened sites which induced the :func:`split` + to be reversed. + :return: A callable that applied to a site value object returns a site value object in which + the factual value has been scattered back into two alternative cases. + """ + + def _undo_split(value: T) -> T: + antecedents_ = [ + a + for a in antecedents + if a in indices_of(value, event_dim=support.event_dim) + ] + + factual_value = gather( + value, + IndexSet(**{antecedent: {0} for antecedent in antecedents_}), + event_dim=support.event_dim, + ) + + # TODO exponential in len(antecedents) - add an indexed.ops.expand to do this cheaply + return scatter_n( + { + IndexSet( + **{antecedent: {ind} for antecedent, ind in zip(antecedents_, inds)} + ): factual_value + for inds in itertools.product(*[[0, 1]] * len(antecedents_)) + }, + event_dim=support.event_dim, + ) + + return _undo_split + + +def consequent_differs( + support: constraints.Constraint, + antecedents: Iterable[str] = [], + **kwargs, +) -> Callable[[T], torch.Tensor]: + """ + A helper function for assessing whether values at a site differ from their observed values, assigning + a small negative value close to zero if a value differs from its observed state + and a large negative value otherwise. + + :param support: The support constraint for the consequent site. + :param antecedents: A list of names of upstream intervened sites to consider when assessing differences. + + :return: A callable which applied to a site value object (`consequent`), returns a tensor where each + element indicates whether the corresponding element of `consequent` differs from its factual value. + """ + + def _consequent_differs(consequent: T) -> torch.Tensor: + indices = IndexSet( + **{ + name: ind + for name, ind in get_factual_indices().items() + if name in antecedents + } + ) + diff = soft_neq( + support, + consequent, + gather(consequent, indices, event_dim=support.event_dim), + **kwargs, + ) + return diff + + return _consequent_differs + + +class ExtractSupports(pyro.poutine.messenger.Messenger): + """ + A Pyro Messenger for inferring distribution constraints. + + :return: An instance of ``ExtractSupports`` with a new attribute: ``supports``, + a dictionary mapping variable names to constraints for all variables in the model. + + Example: + + >>> def mixed_supports_model(): + >>> uniform_var = pyro.sample("uniform_var", dist.Uniform(1, 10)) + >>> normal_var = pyro.sample("normal_var", dist.Normal(3, 15)) + >>> with ExtractSupports() as s: + ... mixed_supports_model() + >>> print(s.supports) + """ + + supports: MutableMapping[str, pyro.distributions.constraints.Constraint] + + def __init__(self): + super(ExtractSupports, self).__init__() + + self.supports = {} + + def _pyro_post_sample(self, msg: dict) -> None: + if not pyro.poutine.util.site_is_subsample(msg): + self.supports[msg["name"]] = msg["fn"].support diff --git a/chirho/explainable/handlers/explanation.py b/chirho/explainable/handlers/explanation.py new file mode 100644 index 000000000..008de3b4d --- /dev/null +++ b/chirho/explainable/handlers/explanation.py @@ -0,0 +1,157 @@ +import contextlib +from typing import Callable, Mapping, TypeVar, Union + +import pyro.distributions.constraints as constraints +import torch + +from chirho.explainable.handlers.components import ( + consequent_differs, + random_intervention, + undo_split, +) +from chirho.explainable.handlers.preemptions import Preemptions +from chirho.interventional.handlers import do +from chirho.interventional.ops import Intervention +from chirho.observational.handlers.condition import Factors + +S = TypeVar("S") +T = TypeVar("T") + + +@contextlib.contextmanager +def SplitSubsets( + supports: Mapping[str, constraints.Constraint], + actions: Mapping[str, Intervention[T]], + *, + bias: float = 0.0, + prefix: str = "__cause_split_", +): + """ + A context manager used for a stochastic search of minimal but-for causes among potential interventions. + On each run, nodes listed in `actions` are randomly selected and intervened on with probability `.5 + bias` + (that is, preempted with probability `.5-bias`). The sampling is achieved by adding stochastic binary preemption + nodes associated with intervention candidates. If a given preemption node has value `0`, the corresponding + intervention is executed. See tests in `tests/explainable/test_handlers_explanation.py` for examples. + + :param supports: A mapping of sites to their support constraints. + :param actions: A mapping of sites to interventions. + :param bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5, defaults to 0.0. + :param prefix: A prefix used for naming additional preemption nodes. Defaults to `__cause_split_`. + """ + preemptions = { + antecedent: undo_split(supports[antecedent], antecedents=[antecedent]) + for antecedent in actions.keys() + } + + with do(actions=actions): + with Preemptions(actions=preemptions, bias=bias, prefix=prefix): + yield + + +@contextlib.contextmanager +def SearchForExplanation( + antecedents: Union[ + Mapping[str, Intervention[T]], + Mapping[str, constraints.Constraint], + ], + witnesses: Union[ + Mapping[str, Intervention[T]], Mapping[str, constraints.Constraint] + ], + consequents: Union[ + Mapping[str, Callable[[T], Union[float, torch.Tensor]]], + Mapping[str, constraints.Constraint], + ], + *, + antecedent_bias: float = 0.0, + witness_bias: float = 0.0, + consequent_scale: float = 1e-2, + antecedent_prefix: str = "__antecedent_", + witness_prefix: str = "__witness_", + consequent_prefix: str = "__consequent_", +): + """ + Effect handler used for causal explanation search. On each run: + + 1. The antecedent nodes are intervened on with the values in ``antecedents`` \ + using :func:`~chirho.counterfactual.ops.split` . \ + Unless alternative interventions are provided, \ + counterfactual values are uniformly sampled for each antecedent node \ + using :func:`~chirho.explainable.internals.uniform_proposal` \ + given its support as a :class:`~pyro.distributions.constraints.Constraint`. + + 2. These interventions are randomly :func:`~chirho.explainable.ops.preempt`-ed \ + using :func:`~chirho.explainable.handlers.undo_split` \ + by a :func:`~chirho.explainable.handlers.SplitSubsets` handler. + + 3. The witness nodes are randomly :func:`~chirho.explainable.ops.preempt`-ed \ + to be kept at the values given in ``witnesses``. + + 4. A :func:`~pyro.factor` node is added tracking whether the consequent nodes differ \ + between the factual and counterfactual worlds. + + :param antecedents: A mapping from antecedent names to interventions or to constraints. + :param witnesses: A mapping from witness names to interventions or to constraints. + :param consequents: A mapping from consequent names to factor functions or to constraints. + """ + if antecedents and isinstance( + next(iter(antecedents.values())), + constraints.Constraint, + ): + antecedents_supports = {a: s for a, s in antecedents.items()} + antecedents = { + a: random_intervention(s, name=f"{antecedent_prefix}_proposal_{a}") + for a, s in antecedents.items() + } + else: + antecedents_supports = {a: constraints.boolean for a in antecedents.keys()} + # TODO generalize to non-scalar antecedents + + if witnesses and isinstance( + next(iter(witnesses.values())), + constraints.Constraint, + ): + witnesses = { + w: undo_split(s, antecedents=list(antecedents.keys())) + for w, s in witnesses.items() + } + + if consequents and isinstance( + next(iter(consequents.values())), + constraints.Constraint, + ): + consequents = { + c: consequent_differs( + support=s, + antecedents=list(antecedents.keys()), + scale=consequent_scale, + ) + for c, s in consequents.items() + } + + if len(consequents) == 0: + raise ValueError("must have at least one consequent") + + if len(antecedents) == 0: + raise ValueError("must have at least one antecedent") + + if set(consequents.keys()) & set(antecedents.keys()): + raise ValueError("consequents and possible antecedents must be disjoint") + + if set(consequents.keys()) & set(witnesses.keys()): + raise ValueError("consequents and possible witnesses must be disjoint") + + antecedent_handler = SplitSubsets( + supports=antecedents_supports, + actions=antecedents, + bias=antecedent_bias, + prefix=antecedent_prefix, + ) + + witness_handler: Preemptions = Preemptions( + actions=witnesses, bias=witness_bias, prefix=witness_prefix + ) + + consequent_handler = Factors(factors=consequents, prefix=consequent_prefix) + + with antecedent_handler, witness_handler, consequent_handler: + yield diff --git a/chirho/explainable/handlers/preemptions.py b/chirho/explainable/handlers/preemptions.py new file mode 100644 index 000000000..4268d730f --- /dev/null +++ b/chirho/explainable/handlers/preemptions.py @@ -0,0 +1,84 @@ +from typing import Generic, Mapping, TypeVar + +import pyro +import torch + +from chirho.explainable.ops import preempt +from chirho.interventional.ops import Intervention + +S = TypeVar("S") +T = TypeVar("T") + + +class Preemptions(Generic[T], pyro.poutine.messenger.Messenger): + """ + Effect handler that applies the operation :func:`~chirho.explainable.ops.preempt` + to sample sites in a probabilistic program, + similar to the handler :func:`~chirho.observational.handlers.condition` + for :func:`~chirho.observational.ops.observe` . + or the handler :func:`~chirho.interventional.handlers.do` + for :func:`~chirho.interventional.ops.intervene` . + + See the documentation for :func:`~chirho.explainable.ops.preempt` for more details. + + This handler introduces an auxiliary discrete random variable at each preempted sample site + whose name is the name of the sample site prefixed by ``prefix``, and + whose value is used as the ``case`` argument to :func:`preempt`, + to determine whether the preemption returns the present value of the site + or the new value specified for the site in ``actions`` + + The distributions of the auxiliary discrete random variables are parameterized by ``bias``. + By default, ``bias == 0`` and the value returned by the sample site is equally likely + to be the factual case (i.e. the present value of the site) or one of the counterfactual cases + (i.e. the new value(s) specified for the site in ``actions``). + When ``0 < bias <= 0.5``, the preemption is less than equally likely to occur. + When ``-0.5 <= bias < 0``, the preemption is more than equally likely to occur. + + More specifically, the probability of the factual case is ``0.5 - bias``, + and the probability of each counterfactual case is ``(0.5 + bias) / num_actions``, + where ``num_actions`` is the number of counterfactual actions for the sample site (usually 1). + + :param actions: A mapping from sample site names to interventions. + :param bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5. + :param prefix: The prefix for naming the auxiliary discrete random variables. + """ + + actions: Mapping[str, Intervention[T]] + prefix: str + bias: float + + def __init__( + self, + actions: Mapping[str, Intervention[T]], + *, + prefix: str = "__witness_split_", + bias: float = 0.0, + ): + assert -0.5 <= bias <= 0.5, "bias must be between -0.5 and 0.5" + self.actions = actions + self.bias = bias + self.prefix = prefix + super().__init__() + + def _pyro_post_sample(self, msg): + try: + action = self.actions[msg["name"]] + except KeyError: + return + + action = (action,) if not isinstance(action, tuple) else action + num_actions = len(action) if isinstance(action, tuple) else 1 + weights = torch.tensor( + [0.5 - self.bias] + ([(0.5 + self.bias) / num_actions] * num_actions), + device=msg["value"].device, + ) + case_dist = pyro.distributions.Categorical(probs=weights) + case = pyro.sample(f"{self.prefix}{msg['name']}", case_dist) + + msg["value"] = preempt( + msg["value"], + action, + case, + event_dim=len(msg["fn"].event_shape), + name=f"{self.prefix}{msg['name']}", + ) diff --git a/chirho/explainable/internals/__init__.py b/chirho/explainable/internals/__init__.py new file mode 100644 index 000000000..b1fe9440e --- /dev/null +++ b/chirho/explainable/internals/__init__.py @@ -0,0 +1 @@ +from .defaults import soft_eq, soft_neq, uniform_proposal # noqa: F401 diff --git a/chirho/explainable/internals/defaults.py b/chirho/explainable/internals/defaults.py new file mode 100644 index 000000000..9a7af9d53 --- /dev/null +++ b/chirho/explainable/internals/defaults.py @@ -0,0 +1,175 @@ +import functools +from typing import TypeVar + +import pyro +import pyro.distributions as dist +import pyro.distributions.constraints as constraints +import torch +from torch.distributions import biject_to + +from chirho.indexed.ops import cond + +S = TypeVar("S") +T = TypeVar("T") + + +@functools.singledispatch +def uniform_proposal( + support: pyro.distributions.constraints.Constraint, + **kwargs, +) -> pyro.distributions.Distribution: + """ + This function heuristically constructs a probability distribution over a specified + support. The choice of distribution depends on the type of support provided. + + - If the support is ``real``, it creates a wide Normal distribution + and standard deviation, defaulting to ``(0,10)``. + - If the support is ``boolean``, it creates a Bernoulli distribution with a fixed logit of ``0``, + corresponding to success probability ``.5``. + - If the support is an ``interval``, the transformed distribution is centered around the + midpoint of the interval. + + :param support: The support used to create the probability distribution. + :param kwargs: Additional keyword arguments. + :return: A uniform probability distribution over the specified support. + """ + if support is pyro.distributions.constraints.real: + return pyro.distributions.Normal(0, 10).mask(False) + elif support is pyro.distributions.constraints.boolean: + return pyro.distributions.Bernoulli(logits=torch.zeros(())) + else: + tfm = pyro.distributions.transforms.biject_to(support) + base = uniform_proposal(pyro.distributions.constraints.real, **kwargs) + return pyro.distributions.TransformedDistribution(base, tfm) + + +@uniform_proposal.register +def _uniform_proposal_indep( + support: pyro.distributions.constraints.independent, + *, + event_shape: torch.Size = torch.Size([]), + **kwargs, +) -> pyro.distributions.Distribution: + d = uniform_proposal(support.base_constraint, event_shape=event_shape, **kwargs) + return d.expand(event_shape).to_event(support.reinterpreted_batch_ndims) + + +@uniform_proposal.register +def _uniform_proposal_integer( + support: pyro.distributions.constraints.integer_interval, + **kwargs, +) -> pyro.distributions.Distribution: + if support.lower_bound != 0: + raise NotImplementedError( + "integer_interval with lower_bound > 0 not yet supported" + ) + n = support.upper_bound - support.lower_bound + 1 + return pyro.distributions.Categorical(probs=torch.ones((n,))) + + +@functools.singledispatch +def soft_eq(support: constraints.Constraint, v1: T, v2: T, **kwargs) -> torch.Tensor: + """ + Computes soft equality between two values ``v1`` and ``v2`` given a distribution constraint ``support``. + Returns a negative value if there is a difference (the larger the difference, the lower the value) + and tends to a low value as ``v1`` and ``v2`` tend to each other. + + :param support: A distribution constraint. + :params v1, v2: the values to be compared. + :param kwargs: Additional keywords arguments passed further; `scale` adjusts the softness of the inequality. + :return: A tensor of log probabilities capturing the soft equality between ``v1`` and ``v2``, + depends on the support and scale. + :raises TypeError: If boolean tensors have different data types. + + Comment: if the support is boolean, setting ``scale = 1e-8`` results in a value close to ``0.0`` if the values + are equal and a large negative number ``<=1e-8`` otherwise. + """ + if not isinstance(v1, torch.Tensor) or not isinstance(v2, torch.Tensor): + raise NotImplementedError("Soft equality is only implemented for tensors.") + elif support.is_discrete: + raise NotImplementedError( + "Soft equality is not implemented for arbitrary discrete distributions." + ) + elif support is constraints.real: # base case + scale = kwargs.get("scale", 0.1) + return dist.Normal(0.0, scale).log_prob(v1 - v2) + else: + tfm = biject_to(support) + v1_inv = tfm.inv(v1) + ldj = tfm.log_abs_det_jacobian(v1_inv, v1) + v2_inv = tfm.inv(v2) + ldj = ldj + tfm.log_abs_det_jacobian(v2_inv, v2) + for _ in range(tfm.codomain.event_dim - tfm.domain.event_dim): + ldj = torch.sum(ldj, dim=-1) + return soft_eq(tfm.domain, v1_inv, v2_inv, **kwargs) + ldj + + +@soft_eq.register +def _soft_eq_independent(support: constraints.independent, v1: T, v2: T, **kwargs): + result = soft_eq(support.base_constraint, v1, v2, **kwargs) + for _ in range(support.reinterpreted_batch_ndims): + result = torch.sum(result, dim=-1) + return result + + +@soft_eq.register(type(constraints.boolean)) +def _soft_eq_boolean(support, v1: torch.Tensor, v2: torch.Tensor, **kwargs): + assert support is constraints.boolean + scale = kwargs.get("scale", 0.1) + return torch.log(cond(scale, 1 - scale, v1 == v2, event_dim=0)) + + +@soft_eq.register +def _soft_eq_integer_interval( + support: constraints.integer_interval, v1: torch.Tensor, v2: torch.Tensor, **kwargs +): + scale = kwargs.get("scale", 0.1) + width = support.upper_bound - support.lower_bound + 1 + return dist.Binomial(total_count=width, probs=scale).log_prob(torch.abs(v1 - v2)) + + +@soft_eq.register(type(constraints.integer)) +def _soft_eq_integer(support, v1: torch.Tensor, v2: torch.Tensor, **kwargs): + scale = kwargs.get("scale", 0.1) + return dist.Poisson(rate=scale).log_prob(torch.abs(v1 - v2)) + + +@soft_eq.register(type(constraints.positive_integer)) +@soft_eq.register(type(constraints.nonnegative_integer)) +def _soft_eq_positive_integer(support, v1: T, v2: T, **kwargs): + return soft_eq(constraints.integer, v1, v2, **kwargs) + + +@functools.singledispatch +def soft_neq(support: constraints.Constraint, v1: T, v2: T, **kwargs) -> torch.Tensor: + """ + Computes soft inequality between two values `v1` and `v2` given a distribution constraint `support`. + Tends to `1-log(scale)` as the difference between the value increases, and tends to + `log(scale)` as `v1` and `v2` tend to each other, summing elementwise over tensors. + + :param support: A distribution constraint. + :params v1, v2: the values to be compared. + :param kwargs: Additional keywords arguments: + `scale` to adjust the softness of the inequality. + :return: A tensor of log probabilities capturing the soft inequality between `v1` and `v2`. + :raises TypeError: If boolean tensors have different data types. + :raises NotImplementedError: If arguments are not tensors. + """ + if not isinstance(v1, torch.Tensor) or not isinstance(v2, torch.Tensor): + raise NotImplementedError("Soft equality is only implemented for tensors.") + elif support.is_discrete: # for discrete pmf, soft_neq = 1 - soft_eq (in log space) + return torch.log(-torch.expm1(soft_eq(support, v1, v2, **kwargs))) + elif support is constraints.real: # base case + scale = kwargs.get("scale", 0.1) + return torch.log(2 * dist.Normal(0.0, scale).cdf(torch.abs(v1 - v2)) - 1) + else: + tfm = biject_to(support) + return soft_neq(tfm.domain, tfm.inv(v1), tfm.inv(v2), **kwargs) + + +@soft_neq.register +def _soft_neq_independent(support: constraints.independent, v1: T, v2: T, **kwargs): + result = soft_neq(support.base_constraint, v1, v2, **kwargs) + for _ in range(support.reinterpreted_batch_ndims): + result = torch.sum(result, dim=-1) + return result diff --git a/chirho/explainable/ops.py b/chirho/explainable/ops.py new file mode 100644 index 000000000..54e07a684 --- /dev/null +++ b/chirho/explainable/ops.py @@ -0,0 +1,42 @@ +import functools +from typing import Optional, Tuple, TypeVar + +import pyro + +from chirho.indexed.ops import IndexSet, cond_n +from chirho.interventional.ops import Intervention, intervene + +S = TypeVar("S") +T = TypeVar("T") + + +@pyro.poutine.runtime.effectful(type="preempt") +@functools.partial(pyro.poutine.block, hide_types=["intervene"]) +def preempt( + obs: T, acts: Tuple[Intervention[T], ...], case: Optional[S] = None, **kwargs +) -> T: + """ + Effectful primitive operation for "preempting" values in a probabilistic program. + + Unlike the counterfactual operation :func:`~chirho.counterfactual.ops.split`, + which returns multiple values concatenated along a new axis + via the operation :func:`~chirho.indexed.ops.scatter`, + :func:`preempt` returns a single value determined by the argument ``case`` + via :func:`~chirho.indexed.ops.cond`. + + In a probabilistic program, a :func:`preempt` call induces a mixture distribution + over downstream values, whereas :func:`split` would induce a joint distribution. + + :param obs: The observed value. + :param acts: The interventions to apply. + :param case: The case to select. + """ + if case is None: + return obs + + name = kwargs.get("name", None) + act_values = {IndexSet(**{name: {0}}): obs} + for i, act in enumerate(acts): + act_values[IndexSet(**{name: {i + 1}})] = intervene(obs, act, **kwargs) + + return cond_n(act_values, case, event_dim=kwargs.get("event_dim", 0)) diff --git a/docs/source/actual_causality.ipynb b/docs/source/actual_causality.ipynb new file mode 100644 index 000000000..da663e2a1 --- /dev/null +++ b/docs/source/actual_causality.ipynb @@ -0,0 +1,1747 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Actual Causality and the modified Halpern-Pearl definition" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Summary**\n", + "\n", + "The **Explainable Reasoning with ChiRho** package aims to provide a systematic, unified approach to actual causality and causal explanation computations in terms of different probabilistic queries over expanded causal models that are constructed from a single generic program transformation applied to an arbitrary causal model represented as a ChiRho program. The approach of reducing causal queries to probabilistic computations on transformed causal models is the foundational idea behind all of ChiRho. Where “actual causality” or \"causal explanation\" queries differ is their use of auxiliary variables representing uncertainty over which interventions or preemptions to apply, implicitly inducing a search space over counterfactuals.\n", + "\n", + "The goal of this notebook is to illustrate how the package can be used to provide approximate method of answering actual causality queries in line with the so-called Halpern-Pearl modified definition of actual causality [(J. Halpern, MIT Press, 2016)](https://mitpress.mit.edu/9780262537131/actual-causality/).\n", + "\n", + "\n", + "In another notebook, we illustrate how the package can be used to answer analogous queries related to causal explanation, as defined in the same book." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Outline**\n", + "\n", + "[Intuitions and formalization](#intuitions-and-formalization)\n", + " \n", + "- [Structural causal models](#structural-causal-models)\n", + "\n", + "- [Halpern-Pearl modified definition of actual causality](#halpern-pearl-modified-definition-of-actual-causality)\n", + "\n", + "[Implementation](#implementation)\n", + "\n", + "[Examples](#examples)\n", + "\n", + "- [Comments on example selection](#comments-on-example-selection)\n", + " \n", + "- [Stone-throwing](#stone-throwing)\n", + "\n", + "- [Forest fire](#forest-fire)\n", + "\n", + "- [Doctors](#doctors)\n", + "\n", + "- [Friendly fire](#friendly-fire)\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Intuitions and formalization" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Actual causality (sometimes called **token causality** or **specific causality**) is usually contrasted with type causality (sometimes called **general causality**). While the latter is concerned with general statements (such as \"smoking causes cancer\"), actual causality focuses on particular events. For illustration, consider the following causality-related questions:\n", + "\n", + "- **Friendly Fire**: On March 24, 2002, A B-52 bomber fired a Joint Direct Attack Munition at a US battalion command post, killing three and injuring twenty special forces soldiers. Out of multiple potential contributing factors, which were **actually** responsible for the incident?\n", + " \n", + "- **Schizophrenia** : The disease arises from the interaction between multiple genetic and environmental factors. Given a particular patient and what we know about them, which of these factors **actually** caused her state?\n", + " \n", + "- **Explainable AI**: Your loan application has been refused. The bank representative informs you the decision was made using predictive modeling to estimate the probability of default. They give you a list of various factors considered in the prediction. But which of these factors **actually** resulted in the rejection, and what were their contributions?\n", + " \n", + "These are questions about **actual causality**. While having answers to such questions is not directly useful for prediction tasks, they are useful for understanding how we can prevent undesirable outcomes similar to ones that we have observed or promote the occurrence of desirable outcomes in contexts similar to the ones in which they had been observed. These context-sensitive causality questions are also an essential element of blame and responsibility assignments, and of at least one prominent account of the notion of explanation.\n", + "\n", + "The general intuition behind the notion of actual causality that we will focus on is that a certain state of antecedent nodes is the cause of a given state of the consequent nodes if there is a part of the actual reality such that if it is kept fixed at what it actually is, and we intervened on the antecedent nodes to be in a different state, the consequent nodes would no longer be in the observed states. A proper explication of this notion requires the context of structural causal models - we first explain what these are, and then move on to the definition." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Structural causal models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "While statistical information might help address questions of actual causality, is not sufficient. One requires causal theories that explain how the relevant aspects of the world function, as well as information about the actual facts pertaining to the specific case. For this reason, the notion on which we focus in this notebook is formulated within the framework of structural causal models, which can represent such information.\n", + "\n", + "The notion is defined in the context of a deterministic structural causal model (SCMs). One major component thereof is a selection of **variables**. For instance, in a very simple model for a forest-fire problem, we might consider a model with three endogenous binary variables: $FF$ (forest fire), $L$ (lightning), and $MD$ (match dropped) whose values are determined by the values of other variables, and two exogenous noise variables $U_{MD}$ and $U_L$ that determine the values of $MD$ and $L$. Moreover, some of those variables/nodes are connected by means of directed **edges**. For instance, in the example at hand, the model contains two edges that go from $U_MD$ to $MD$ and from $U_L$ to $L$ respectively, and two edges that go from $L$ to $FF$ and from $MD$ to $FF$. Each influence is associated with a **structural equation** - for instance, $FF = max(L, MD)$ indicates that a forest fire occurs if either of the two factors occurs. SCMs come also with a **context**, which is the values of **exogenous variables** whose values are not determined by the structural equations, but rather by factors outside the model. In our example, one context might be that both a match has been dropped and a lightning occurred.\n", + "\n", + "More formally, a causal model $M$ is a tuple $\\langle S, F\\rangle$, where:\n", + "\n", + "- $S$ is a **signature**, that is a tuple $\\langle U, V, R\\rangle$, where $U$ is a set of exogenous variables, $V$ is a set of endogenous variables and $R: U \\cup V \\mapsto R(Y)$, where $R(Y)\\neq \\emptyset$, that is $R$ assigns non-empty ranges to exogenous and endogenous variables.\n", + "\n", + "- To each endogenous $X\\in V$, $F$ assigns a function $F_X$, which maps the cross-product of ranges of all variables other than $X$ to $R(X)$. In other words, $F_X$ determines the value of $X$ given the values of other variables in the model (some of them might be redundant in a given equation). The intuition is that these functions correspond to structural equations of the form $X = F_X(U, V)$ which are to be read from right to left: if the values of $U\\cup V$ are fixed to be such-and-such, say $\\vec{u}$ and $\\vec{v}$, this causes $X$ to take the value $F_X(\\vec{u}, \\vec{v})$.\n", + "\n", + "A **deterministic causal model** (also called **causal setting**), $\\langle M, \\vec{u}\\rangle$ is a causal model $M$ together with fixed settings $\\vec{u}$ of its exogenous variables $U$. To intervene, say, to make $Y$ have value $y$, is to replace the structural equation for $Y$ of the form $Y = F_Y(U, V)$ with $Y = y$. $\\langle M, \\vec{u}\\rangle \\models [Y \\leftarrow y](X = x)$ means: in the deterministic model obtained from $\\langle M, \\vec{u}\\rangle$ by intervening on $Y$ to have value $y$ $X$ has value $x$. Sometimes, instead of $X = x$, one might be interested in a more general claim $\\varphi$ involving potentially multiple variables, in which case the notation is $\\langle M, \\vec{u}\\rangle \\models [Y \\leftarrow y](\\varphi)$. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Halpern-Pearl modified definition of actual causality" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It is important to recognize that the straightforward counterfactual strategy, which asks whether the event would have occurred if the antecedent had not taken place, is inadequate as a definition of actual causality. A simple example can help illustrate this point. Suppose I throw a stone, which hits and shatters a bottle. However, just a second later, Bill also throws a stone at the bottle but misses, solely because the bottle was already shattered by my stone. In this scenario, the intuition is that my throw is the cause of the bottle shattering, even though the bottle would still have shattered if I hadn't thrown the stone. \n", + "This highlights the need for a more elaborate account that considers the actual state, taking into consideration the fact that Bill's stone did not, in fact, hit the bottle. One such account involves the following definition of actual causality:\n", + "\n", + "Given an SCM $M$ and a vector of its exogenous variable settings $\\vec{u}$ we'll write $(M, \\vec{u})\\models [ \\vec{Y} \\leftarrow \\vec{y}]\\psi$ just in case $\\psi$ holds in $(M',\\vec{u})$, where $M'$ is the intervened model obtained by replacing the structural equation(s) for $\\vec{Y}$ in $M$ with $\\vec{Y_i} = \\vec{y_i}$. \n", + "\n", + "We say that $\\vec{X}=\\vec{x}$ is an actual cause of $\\varphi$ in $(M,\\vec{u})$ just in case:\n", + "\n", + "AC1. Factivity: $(M, \\vec{u}) \\models [\\vec{X} = \\vec{x} \\wedge \\varphi]$\n", + "\n", + "AC2. Necessity:\n", + "\n", + "$\\exists \\vec{W}, \\vec{x}'(M, \\vec{u})\\models [\\vec{X} \\leftarrow \\vec{x}', \\vec{W} = \\vec{w}^{\\star}] \\neg \\varphi$,\n", + "where $\\vec{w}^\\star$ are the actual values of $\\vec{W}$, i.e. $(M, \\vec{u}) \\models \\vec{W} = \\vec{w}^\\star$\n", + "\n", + "AC3. Minimality: $\\vec{X}$ is a subset-minimal set of potential causes satisfying AC2\n", + "\n", + "AC1 requires that both the antecedent and the consequent hold. The intuition behind AC2 is that for $\\vec{X}=\\vec{x}$ to be the actual cause of $\\varphi$, there needs to be a vector of witness nodes $\\vec{W}$ and a vector $\\vec{x'}$ of *alternative* settings of $\\vec{X}$ such that if $\\vec{W}$ are intervened to have their actual values $\\vec{w^\\star}$, and $\\vec{X}$ are intervened to have values $\\vec{x'}$, $\\varphi$ no longer holds in the resulting model. AC3 requires that the antecedent should be a minimal one satisfying AC2." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Implementation\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import pandas as pd\n", + "import pyro\n", + "import pyro.distributions as dist\n", + "import pyro.distributions.constraints as constraints\n", + "import pyro.infer\n", + "import torch\n", + "from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual\n", + "from chirho.explainable.handlers import Preemptions, SplitSubsets, SearchForExplanation\n", + "from chirho.indexed.ops import IndexSet, gather, indices_of\n", + "from chirho.observational.handlers.condition import condition\n", + "\n", + "\n", + "smoke_test = ('CI' in os.environ)\n", + "num_samples = 10 if smoke_test else 200" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Instead of full enumeration, we will be approximating the answers with sampling. In particular, answering an actual causality query requires investigating the consequences of intervening on all possible witness candidate nodes in all possible combinations thereof to have the values they actually have in a given model. While complete enumeration would work for smaller models, we implement a more general approximate method, which draws random sets of witness nodes multiple times and intervenes on those sampled sets. For smaller models (as the one used in our examples), complete coverage of all possible combinations is easily obtained. For larger models complete enumeration becomes less feasible.\n", + "\n", + "\n", + "An SCM in this context is represented by a ChiRho model, where the exogenous variables are stochastic and introduced using `pyro.sample`, and all the endogenous variables are determined by these, and introduced by `pyro.deterministic` (read on for examples). For simplicity we often assume most of the nodes are binary (this assumption can be weakened, read on for details), and that the nodes are discrete. \n", + "\n", + "The key role in this implementation is played by (1) the `SearchForExplanation` handler. It takes `antecedents`, `witnesses`, `consequents`, `antecedent_bias` and `witness_bias` and, roughly speaking, makes three steps:\n", + "\n", + "(A) It randomly intervenes on some of the antecedents (each antecedent node having probability `0.5 - bias` of being intervened on, with non-null bias to prefer smaller antedecedent sets) to have an alternative value (either pre-specified, or randomly selected, depending on whether we pass a list of concrete interventions, or distribution constraints).\n", + "\n", + "(B) randomly preempts some of the witnesses intervening on them to have the observed value in all counterfactual worlds (the probability of witness preemption is `0.5 + witness_bias`). The intuition here is that the witness-preempted nodes are the part of the actual context that is assumed to be kept fixed in a given interventional scenario (a sample covers multiple such scenarios).\n", + "\n", + "(C) adds sites with `log_probs` tracking whether the counterfactual value of any of the consequents is different from its observed value, marking cases where it doesn't with an extremely low `log_prob` (and a value negligibly close to 0 otherwise). \n", + "\n", + "Since those steps are achieved by adding new sites to the model, the model trace can now be inspected to test for actual causality. In particular, if the `log_prob` of the site added in (C) is very low, then the antecedent is definitely not an actual cause of the consequent, as a given interventional setting does not result in a change to the consequent(s). If it is zero, minimality claims are evaluated by investigating the `log_prob_sum` corresponding to the antecedent preemption sites - by default, bias is set to `.1` to prefer smaller causal sets. All in all, an antecedent set is an actual cause if all its nodes and only its nodes are intervened on in the MAP (wrt. to log probs at play) counterfactual world." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# somewhat boiler-plate sample trace processing, can be skipped by a reader\n", + "\n", + "def gather_observed(value, antecedents, witnesses):\n", + " _indices = [\n", + " i for i in list(antecedents.keys()) + witnesses if i in indices_of(value, event_dim=0)\n", + " ]\n", + " _int_can = gather(\n", + " value, IndexSet(**{i: {0} for i in _indices}), event_dim=0,)\n", + " return _int_can\n", + "\n", + "def gather_intervened(value, antecedents, witnesses):\n", + " _indices = [\n", + " i for i in list(antecedents.keys()) + witnesses if i in indices_of(value, event_dim=0)\n", + " ]\n", + " _int_can = gather(\n", + " value, IndexSet(**{i: {1} for i in _indices}), event_dim=0,)\n", + " return _int_can\n", + "\n", + "\n", + "def get_table(trace, mwc, antecedents, witnesses, consequents):\n", + "\n", + " values_table = {}\n", + " nodes = trace.trace.nodes\n", + " witnesses = [key for key, _ in witnesses.items()]\n", + "\n", + " with mwc:\n", + "\n", + " for antecedent_str in antecedents.keys():\n", + " \n", + " obs_ant = gather_observed(nodes[antecedent_str][\"value\"], antecedents, witnesses)\n", + " int_ant = gather_intervened(nodes[antecedent_str][\"value\"], antecedents, witnesses)\n", + "\n", + " values_table[f\"{antecedent_str}_obs\"] = obs_ant.squeeze().tolist()\n", + " values_table[f\"{antecedent_str}_int\"] = int_ant.squeeze().tolist()\n", + " \n", + " apr_ant = nodes[f\"__antecedent_{antecedent_str}\"][\"value\"]\n", + " values_table[f\"apr_{antecedent_str}\"] = apr_ant.squeeze().tolist()\n", + " \n", + " values_table[f\"apr_{antecedent_str}_lp\"] = nodes[f\"__antecedent_{antecedent_str}\"][\"fn\"].log_prob(apr_ant)\n", + "\n", + " if witnesses:\n", + " for candidate in witnesses:\n", + " obs_candidate = gather_observed(nodes[candidate][\"value\"], antecedents, witnesses)\n", + " int_candidate = gather_intervened(nodes[candidate][\"value\"], antecedents, witnesses)\n", + " values_table[f\"{candidate}_obs\"] = obs_candidate.squeeze().tolist()\n", + " values_table[f\"{candidate}_int\"] = int_candidate.squeeze().tolist()\n", + "\n", + " wpr_con = nodes[f\"__witness_{candidate}\"][\"value\"]\n", + " values_table[f\"wpr_{candidate}\"] = wpr_con.squeeze().tolist()\n", + " \n", + "\n", + " for consequent in consequents:\n", + " obs_consequent = gather_observed(nodes[consequent][\"value\"], antecedents, witnesses)\n", + " int_consequent = gather_intervened(nodes[consequent][\"value\"], antecedents, witnesses)\n", + " con_lp = nodes[f\"__consequent_{consequent}\"]['fn'].log_prob(torch.tensor(1)) #TODO: this feels like a hack\n", + " _indices_lp = [\n", + " i for i in list(antecedents.keys()) + witnesses if i in indices_of(con_lp)]\n", + " int_con_lp = gather(con_lp, IndexSet(**{i: {1} for i in _indices_lp}), event_dim=0,) \n", + "\n", + "\n", + " values_table[f\"{consequent}_obs\"] = obs_consequent.squeeze().tolist() \n", + " values_table[f\"{consequent}_int\"] = int_consequent.squeeze().tolist()\n", + " values_table[f\"{consequent}_lp\"] = int_con_lp.squeeze().tolist() \n", + "\n", + " values_df = pd.DataFrame(values_table)\n", + "\n", + " values_df.drop_duplicates(inplace=True)\n", + "\n", + " summands = [col for col in values_df.columns if col.endswith('lp')]\n", + " values_df[\"sum_log_prob\"] = values_df[summands].sum(axis = 1)\n", + " values_df.sort_values(by = \"sum_log_prob\", ascending = False, inplace = True)\n", + "\n", + " return values_df" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# this reduces the actual causality check to checking a property of the \n", + "# resulting sums of log probabilities\n", + "# for the antecedent preemption and the consequent differs nodes\n", + "\n", + "def ac_check(trace, mwc, antecedents, witnesses, consequents):\n", + "\n", + " table = get_table(trace, mwc, antecedents, witnesses, consequents)\n", + " \n", + " if (list(table['sum_log_prob'])[0]<= -1e8):\n", + " print(\"No resulting difference to the consequent in the sample.\")\n", + " return\n", + " \n", + " winners = table[table['sum_log_prob'] == table['sum_log_prob'].max()]\n", + " \n", + "\n", + " ac_flags = []\n", + " for _, row in winners.iterrows():\n", + " active_antecedents = []\n", + " for antecedent in antecedents:\n", + " if row[f\"apr_{antecedent}\"] == 0:\n", + " active_antecedents.append(antecedent)\n", + "\n", + " ac_flags.append(set(active_antecedents) == set(antecedents))\n", + "\n", + " if not any(ac_flags):\n", + " print(\"The antecedent set is not minimal.\")\n", + " else:\n", + " print(\"The antecedent set is an actual cause.\")\n", + "\n", + " return any(ac_flags)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Examples" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Comments on example selection\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For the sake of illustration, we reconstruct a few examples, which - with one exception (friendly fire incident) - come from Halpern's book. The selection is as follows:\n", + "\n", + "- **Stone throwing:** this is a classic, simple structure in which the but-for clause fails due to over-determination, but an actual causality claim holds (p. 3 of the book).\n", + "\n", + "- **Forest fire:** one of the simplest structures illustrating conjunctions being actual causes, and how an event can be part of an actual cause without being an actual cause itself (example 2.3.1, p. 28).\n", + "\n", + "- **Doctors:** a simple example illustrating the intransitivity of actual causality (example 2.3.5, p. 37).\n", + "\n", + "- **Friendly fire incident:** a real-life example, to illustrate how the tools can be applied outside of a narrow selection of thought experiments. (a causal model developed in a real-life incident investigation, as discussed in the [Incident Reporting using SERAS® Reporter and SERAS® Analyst](https://www.causalis.com/90-publications/99-downloads/) paper)\n", + "\n", + "- **Voting:** this illustrates how on this approach a voter is only an actual cause if they can make a difference, but only part of an actual cause otherwise, which motivates reflection on responsibility and blame (example 2.3.2)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Stone-throwing" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Sally and Billy pick up stones and throw them at a bottle. Sally's stone gets there first, shattering the bottle. Both throws are perfectly accurate, so Billy's stone would have shattered the bottle had it not been preempted by Sally’s throw. (see *Actual Causality*, p. 3 and multiple further points at which the example is discussed in the book)." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "@pyro.infer.config_enumerate\n", + "def stones_model(): \n", + " prob_sally_throws = pyro.sample(\"prob_sally_throws\", dist.Beta(1, 1))\n", + " prob_bill_throws = pyro.sample(\"prob_bill_throws\", dist.Beta(1, 1))\n", + " prob_sally_hits = pyro.sample(\"prob_sally_hits\", dist.Beta(1, 1))\n", + " prob_bill_hits = pyro.sample(\"prob_bill_hits\", dist.Beta(1, 1))\n", + " prob_bottle_shatters_if_sally = pyro.sample(\"prob_bottle_shatters_if_sally\", dist.Beta(1, 1))\n", + " prob_bottle_shatters_if_bill = pyro.sample(\"prob_bottle_shatters_if_bill\", dist.Beta(1, 1))\n", + "\n", + " sally_throws = pyro.sample(\"sally_throws\", dist.Bernoulli(prob_sally_throws))\n", + " bill_throws = pyro.sample(\"bill_throws\", dist.Bernoulli(prob_bill_throws))\n", + "\n", + "\n", + " new_shp = torch.where(sally_throws == 1,prob_sally_hits, 0.0)\n", + "\n", + " sally_hits = pyro.sample(\"sally_hits\",dist.Bernoulli(new_shp))\n", + "\n", + " new_bhp = torch.where(\n", + " bill_throws.bool() & (~sally_hits.bool()),\n", + " prob_bill_hits,\n", + " torch.tensor(0.0),\n", + " )\n", + "\n", + " bill_hits = pyro.sample(\"bill_hits\", dist.Bernoulli(new_bhp))\n", + "\n", + " new_bsp = torch.where(bill_hits.bool(), prob_bottle_shatters_if_bill,\n", + " torch.where(sally_hits.bool(),prob_bottle_shatters_if_sally,torch.tensor(0.0),),)\n", + "\n", + " bottle_shatters = pyro.sample(\"bottle_shatters\", dist.Bernoulli(new_bsp))\n", + "\n", + " return {\"sally_throws\": sally_throws, \"bill_throws\": bill_throws, \"sally_hits\": sally_hits,\n", + " \"bill_hits\": bill_hits, \"bottle_shatters\": bottle_shatters,}\n", + "\n", + "stones_model.nodes = [\"sally_throws\",\"bill_throws\", \"sally_hits\", \"bill_hits\",\"bottle_shatters\",]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "def tensorize_observations(observations):\n", + " return {k: torch.as_tensor(v) for k, v in observations.items()}\n", + "\n", + "observations = {\"prob_sally_throws\": 1.0, \n", + " \"prob_bill_throws\": 1.0,\n", + " \"prob_sally_hits\": 1.0,\n", + " \"prob_bill_hits\": 1.0,\n", + " \"prob_bottle_shatters_if_sally\": 1.0,\n", + " \"prob_bottle_shatters_if_bill\": 1.0,\n", + " \"sally_throws\": 1.0, \"bill_throws\": 1.0}\n", + "\n", + "observations_tensorized = tensorize_observations(observations)\n", + "\n", + "# One way to go is to manually specify a single alternative value\n", + "# which helps if you explicitly want to use a contrastive notion of \n", + "# actual causality\n", + "antecedents = {\"sally_throws\": 0.0}\n", + "antencedent_bias = 0.1\n", + "witnesses = {\"bill_throws\": constraints.boolean, \"bill_hits\": constraints.boolean}\n", + "consequents = {\"bottle_shatters\": constraints.boolean}" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "with MultiWorldCounterfactual() as mwc:\n", + " with SearchForExplanation(antecedents = antecedents, \n", + " witnesses = witnesses, consequents = consequents,\n", + " consequent_scale= 1e-7):\n", + " with condition(data = observations_tensorized):\n", + " with pyro.plate(\"sample\", num_samples):\n", + " with pyro.poutine.trace() as tr:\n", + " stones_model()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once we process the sample trace, the table contains all the information we need to evaluate an actual causality claim. \n", + "\n", + "For any node it contains:\n", + "- `_obs` and `_int`, the observed and intervened values of that node.\n", + "\n", + "If a node is an antecedent candidate:\n", + "\n", + "- `apr_`, which marks if the node has been preempted as an antecedent; if the value is `0`, the antecedent intervention has been applied.\n", + "- `apr__lp` the log probability corresponding to the auxiliary antecedent preemption variable. Its tracking is needed to minimize the cause set.\n", + "\n", + "Moreover, for witness candidates, the table contains:\n", + "\n", + "- `wpr_`, which marks whether a node has been preempted (intervened to have the same counterfactual value as the observed value). Since for the actual causality queries, log probabilities for either value are the same, and can be safely ignored.\n", + "\n", + "For any consequent node:\n", + "\n", + "- `_lp`, which tracks in terms of log probabilities whether the counterfactual value of a consequent node differs from its observed value. If it doesn't, the value is extremely low, `-1e8` by default, and it is 0 otherwise. \n", + "\n", + "The table then sums up the relevant log probabilities in `sum_log_prob`, which effectively ranks interventional settings by whether a change to the consequent resulted and by how small the antecedent set is." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'sally_throws': 0.0}\n", + "{'bill_throws': Boolean(), 'bill_hits': Boolean()}\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sally_throws_obssally_throws_intapr_sally_throwsapr_sally_throws_lpbill_throws_obsbill_throws_intwpr_bill_throwsbill_hits_obsbill_hits_intwpr_bill_hitsbottle_shatters_obsbottle_shatters_intbottle_shatters_lpsum_log_prob
01.00.00-0.6931471.01.010.00.011.00.0-1.192093e-07-0.693147
21.00.00-0.6931471.01.000.00.011.00.0-1.192093e-07-0.693147
11.01.01-0.6931471.01.000.00.001.01.0-1.594238e+01-16.635532
31.01.01-0.6931471.01.000.00.011.01.0-1.594238e+01-16.635532
51.01.01-0.6931471.01.010.00.011.01.0-1.594238e+01-16.635532
61.01.01-0.6931471.01.010.00.001.01.0-1.594238e+01-16.635532
\n", + "
" + ], + "text/plain": [ + " sally_throws_obs sally_throws_int apr_sally_throws apr_sally_throws_lp \\\n", + "0 1.0 0.0 0 -0.693147 \n", + "2 1.0 0.0 0 -0.693147 \n", + "1 1.0 1.0 1 -0.693147 \n", + "3 1.0 1.0 1 -0.693147 \n", + "5 1.0 1.0 1 -0.693147 \n", + "6 1.0 1.0 1 -0.693147 \n", + "\n", + " bill_throws_obs bill_throws_int wpr_bill_throws bill_hits_obs \\\n", + "0 1.0 1.0 1 0.0 \n", + "2 1.0 1.0 0 0.0 \n", + "1 1.0 1.0 0 0.0 \n", + "3 1.0 1.0 0 0.0 \n", + "5 1.0 1.0 1 0.0 \n", + "6 1.0 1.0 1 0.0 \n", + "\n", + " bill_hits_int wpr_bill_hits bottle_shatters_obs bottle_shatters_int \\\n", + "0 0.0 1 1.0 0.0 \n", + "2 0.0 1 1.0 0.0 \n", + "1 0.0 0 1.0 1.0 \n", + "3 0.0 1 1.0 1.0 \n", + "5 0.0 1 1.0 1.0 \n", + "6 0.0 0 1.0 1.0 \n", + "\n", + " bottle_shatters_lp sum_log_prob \n", + "0 -1.192093e-07 -0.693147 \n", + "2 -1.192093e-07 -0.693147 \n", + "1 -1.594238e+01 -16.635532 \n", + "3 -1.594238e+01 -16.635532 \n", + "5 -1.594238e+01 -16.635532 \n", + "6 -1.594238e+01 -16.635532 " + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "print(antecedents)\n", + "print(witnesses)\n", + "stones_table = get_table(tr, mwc, antecedents, witnesses, consequents)\n", + "display(stones_table)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The antecedent set is an actual cause.\n" + ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ac_check(tr, mwc, antecedents, witnesses, consequents)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# If, more in the spirit of the original definition\n", + "# we want to search through all possible values of the antecedent,\n", + "# we can use a constraint instead of specifying the counterfactual value\n", + "# manually\n", + "\n", + "antecedents = {\"sally_hits\": pyro.distributions.constraints.boolean}\n", + "\n", + "with MultiWorldCounterfactual() as mwc:\n", + " with SearchForExplanation(antecedents = antecedents, \n", + " witnesses = witnesses, consequents = consequents,\n", + " consequent_scale= 1e-8):\n", + " with condition(data = observations_tensorized):\n", + " with pyro.plate(\"sample\", num_samples):\n", + " with pyro.poutine.trace() as tr:\n", + " stones_model()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sally_hits_obssally_hits_intapr_sally_hitsapr_sally_hits_lpbill_throws_obsbill_throws_intwpr_bill_throwsbill_hits_obsbill_hits_intwpr_bill_hitsbottle_shatters_obsbottle_shatters_intbottle_shatters_lpsum_log_prob
51.00.00-0.6931471.01.000.00.011.00.00.0-0.693147
01.01.01-0.6931471.01.010.00.011.01.0-inf-inf
41.01.00-0.6931471.01.010.00.011.01.0-inf-inf
71.01.01-0.6931471.01.010.00.001.01.0-inf-inf
91.01.01-0.6931471.01.000.00.001.01.0-inf-inf
\n", + "
" + ], + "text/plain": [ + " sally_hits_obs sally_hits_int apr_sally_hits apr_sally_hits_lp \\\n", + "5 1.0 0.0 0 -0.693147 \n", + "0 1.0 1.0 1 -0.693147 \n", + "4 1.0 1.0 0 -0.693147 \n", + "7 1.0 1.0 1 -0.693147 \n", + "9 1.0 1.0 1 -0.693147 \n", + "\n", + " bill_throws_obs bill_throws_int wpr_bill_throws bill_hits_obs \\\n", + "5 1.0 1.0 0 0.0 \n", + "0 1.0 1.0 1 0.0 \n", + "4 1.0 1.0 1 0.0 \n", + "7 1.0 1.0 1 0.0 \n", + "9 1.0 1.0 0 0.0 \n", + "\n", + " bill_hits_int wpr_bill_hits bottle_shatters_obs bottle_shatters_int \\\n", + "5 0.0 1 1.0 0.0 \n", + "0 0.0 1 1.0 1.0 \n", + "4 0.0 1 1.0 1.0 \n", + "7 0.0 0 1.0 1.0 \n", + "9 0.0 0 1.0 1.0 \n", + "\n", + " bottle_shatters_lp sum_log_prob \n", + "5 0.0 -0.693147 \n", + "0 -inf -inf \n", + "4 -inf -inf \n", + "7 -inf -inf \n", + "9 -inf -inf " + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# now our samples include some cases where the antecedent intervention\n", + "# was the same as the observed value; this does not change the result,\n", + "# as the __consequent_ log prob is practically -inf in these cases\n", + "\n", + "stones_table = get_table(tr, mwc, antecedents, witnesses, consequents)\n", + "display(stones_table)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The antecedent set is an actual cause.\n" + ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# the result of the actual causality check is the same as before\n", + "\n", + "ac_check(tr, mwc, antecedents, witnesses, consequents)\n", + "\n", + "# since we're dealing with binary antecedents in this notebook,\n", + "# we'll keep using the contrastive notion in what follows" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "# in contrast, this antecedent set is not minimal\n", + "antecedents2 = {\"sally_throws\": 0.0, \"bill_throws\": 0.0}\n", + "witnesses2 = {\"bill_hits\": constraints.boolean} \n", + "\n", + "with MultiWorldCounterfactual() as mwc2:\n", + " with SearchForExplanation(antecedents = antecedents2, antecedent_bias= antencedent_bias,\n", + " witnesses = witnesses2,\n", + " consequents = consequents,\n", + " consequent_scale= 1e-8):\n", + " with condition(data = observations_tensorized):\n", + " with pyro.plate(\"sample\", num_samples):\n", + " with pyro.poutine.trace() as tr2:\n", + " stones_model()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The antecedent set is an actual cause.\n" + ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "stones_table2 = get_table(tr2, mwc2, antecedents2, witnesses2, consequents)\n", + "ac_check(tr2, mwc2, antecedents2, witnesses2, consequents)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Forest fire" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this simplified model, a forest fire was caused by lightning or an arsonist, so we use three endogenous variables, and two exogenous variables corresponding to the two factors. In the conjunctive model,\n", + "both of the factors have to be present for the fire to start. In the disjunctive model, each of them alone is sufficient." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "def ff_conjunctive():\n", + " u_match_dropped = pyro.sample(\"u_match_dropped\", dist.Bernoulli(0.5))\n", + " u_lightning = pyro.sample(\"u_lightning\", dist.Bernoulli(0.5))\n", + "\n", + " match_dropped = pyro.deterministic(\"match_dropped\",\n", + " u_match_dropped, event_dim=0)\n", + " lightning = pyro.deterministic(\"lightning\", u_lightning, event_dim=0)\n", + " forest_fire = pyro.deterministic(\"forest_fire\", match_dropped.bool() & lightning.bool(),\n", + " event_dim=0).float()\n", + "\n", + " return {\"match_dropped\": match_dropped, \"lightning\": lightning,\n", + " \"forest_fire\": forest_fire}" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "def ff_disjunctive():\n", + " u_match_dropped = pyro.sample(\"u_match_dropped\", dist.Bernoulli(0.5))\n", + " u_lightning = pyro.sample(\"u_lightning\", dist.Bernoulli(0.5))\n", + "\n", + " match_dropped = pyro.deterministic(\"match_dropped\",\n", + " u_match_dropped, event_dim=0)\n", + " lightning = pyro.deterministic(\"lightning\", u_lightning, event_dim=0)\n", + " forest_fire = pyro.deterministic(\"forest_fire\", match_dropped.bool() | lightning.bool(), event_dim=0).float()\n", + "\n", + " return {\"match_dropped\": match_dropped, \"lightning\": lightning,\n", + " \"forest_fire\": forest_fire}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "antecedents_ff = {\"match_dropped\": 0.0}\n", + "witnesses_ff = {\"lightning\": constraints.boolean}\n", + "consequents_ff = {\"forest_fire\": constraints.boolean}\n", + "observations_ff = tensorize_observations({\"match_dropped\": 1.0, \"lightning\": 1.0})" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "with MultiWorldCounterfactual() as mwc_ff:\n", + " with SearchForExplanation(antecedents = antecedents_ff, antecedent_bias= antencedent_bias,\n", + " witnesses = witnesses_ff,\n", + " consequents = consequents_ff,\n", + " consequent_scale= 1e-8):\n", + " with condition(data = observations_ff):\n", + " with pyro.plate(\"sample\", num_samples):\n", + " with pyro.poutine.trace() as tr_ff:\n", + " ff_conjunctive()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The antecedent set is an actual cause.\n" + ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# In the conjunctive model \n", + "# Each of the two factors is a but-for cause\n", + "\n", + "ac_check(tr_ff, mwc_ff, antecedents_ff, witnesses_ff, consequents_ff)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "# In the disjunctive model \n", + "# there still would be fire if there was no lightning\n", + "\n", + "with MultiWorldCounterfactual() as mwc_ffd:\n", + " with SearchForExplanation(antecedents = antecedents_ff,\n", + " antecedent_bias= antencedent_bias,\n", + " witnesses = witnesses_ff,\n", + " consequents = consequents_ff,\n", + " consequent_scale = 1e-8):\n", + " with condition(data = observations_ff):\n", + " with pyro.plate(\"sample\", num_samples):\n", + " with pyro.poutine.trace() as tr_ffd:\n", + " ff_disjunctive()" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No resulting difference to the consequent in the sample.\n" + ] + } + ], + "source": [ + "ac_check(tr_ffd, mwc_ffd, antecedents_ff, witnesses_ff, consequents_ff)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "# in the disjunctive model\n", + "# the actual cause is the composition of the two factors\n", + "\n", + "antecedents_ffd2 = {\"match_dropped\": 0.0, \"lightning\":0.0}\n", + "witnesses_ffd2 = {}\n", + "\n", + "with MultiWorldCounterfactual() as mwc_ffd2:\n", + " with SearchForExplanation(antecedents = antecedents_ffd2, antecedent_bias= antencedent_bias,\n", + " witnesses = witnesses_ffd2,\n", + " consequents = consequents_ff,\n", + " consequent_scale= 1e-8):\n", + " with condition(data = observations_ff):\n", + " with pyro.plate(\"sample\", num_samples):\n", + " with pyro.poutine.trace() as tr_ffd2:\n", + " ff_disjunctive()" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No resulting difference to the consequent in the sample.\n" + ] + } + ], + "source": [ + "ac_check(tr_ffd2, mwc_ffd2, antecedents_ffd2, witnesses_ffd2, consequents_ff)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Doctors" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This example illustrates that actual causality is not, in general, transitive. One doctor is responsible for administering the medicine on Monday, and if she does, Bill recovers on Tuesday.\n", + "Another doctor is reliable and treats Bill on Tuesday if the first doctor failed to do so on Monday. If both doctors treat Bill, he is in `condition1`, dead on Wednesday. Otherwise, he is either healthy on Tuesday (`condition2`) or healthy on Wednesday (`condition3`), or did not receive any treatment and feels worse but is alive on Wednesday (`condition4`).\n", + "\n", + "Now suppose Bill did receive treatment on Monday. This is an actual cause of his not receiving treatment on Tuesday, and the latter is an actual cause of his being alive on Wednesday. However, there is nothing that the first doctor could do to cause Bill to be dead on Wednesday." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "def bc_function(mt, tt):\n", + " condition1 = (mt == 1) & (tt == 1)\n", + " condition2 = (mt == 1) & (tt == 0)\n", + " condition3 = (mt == 0) & (tt == 1)\n", + " condition4 = ~(condition1 | condition2 | condition3)\n", + "\n", + " output = torch.where(condition1, torch.tensor(3.0), torch.tensor(0.0))\n", + " output = torch.where(condition2, torch.tensor(0.0), output)\n", + " output = torch.where(condition3, torch.tensor(1.0), output)\n", + " output = torch.where(condition4, torch.tensor(2.0), output)\n", + "\n", + " return output\n", + "\n", + "\n", + "def model_doctors():\n", + " u_monday_treatment = pyro.sample(\"u_monday_treatment\", dist.Bernoulli(0.5))\n", + "\n", + " monday_treatment = pyro.deterministic(\n", + " \"monday_treatment\", u_monday_treatment, event_dim=0\n", + " )\n", + "\n", + " tuesday_treatment = pyro.deterministic(\n", + " \"tuesday_treatment\",\n", + " torch.logical_not(monday_treatment).float(),\n", + " event_dim=0,\n", + " )\n", + "\n", + " bills_condition = pyro.deterministic(\n", + " \"bills_condition\",\n", + " bc_function(monday_treatment, tuesday_treatment),\n", + " event_dim=0,\n", + " )\n", + "\n", + " bill_alive = pyro.deterministic(\n", + " \"bill_alive\", bills_condition.not_equal(3.0).float(), event_dim=0\n", + " )\n", + "\n", + " return {\n", + " \"monday_treatment\": monday_treatment,\n", + " \"tuesday_treatment\": tuesday_treatment,\n", + " \"bills_condition\": bills_condition,\n", + " \"bill_alive\": bill_alive,\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "antecedents_doc1 = {\"monday_treatment\": 0.0}\n", + "witnesses_doc = {}\n", + "consequents_doc1 = {\"tuesday_treatment\": constraints.boolean}\n", + "observations_doc = tensorize_observations({\"u_monday_treatment\": 1.0})" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The antecedent set is an actual cause.\n" + ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# The first actual causal link holds\n", + "\n", + "with MultiWorldCounterfactual() as mwc_doc1:\n", + " with SearchForExplanation(antecedents = antecedents_doc1, antecedent_bias= antencedent_bias,\n", + " witnesses = witnesses_doc,\n", + " consequents = consequents_doc1,\n", + " consequent_scale= 1e-8):\n", + " with condition(data = observations_doc):\n", + " with pyro.plate(\"sample\", num_samples):\n", + " with pyro.poutine.trace() as tr_doc1:\n", + " model_doctors()\n", + " \n", + "ac_check(tr_doc1, mwc_doc1, antecedents_doc1, witnesses_doc, consequents_doc1)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The antecedent set is an actual cause.\n" + ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# So does the second\n", + "\n", + "antecedents_doc2 = {\"tuesday_treatment\": 1.0}\n", + "consequents_doc2 = {\"bill_alive\": constraints.boolean}\n", + "\n", + "\n", + "with MultiWorldCounterfactual() as mwc_doc2:\n", + " with SearchForExplanation(antecedents = antecedents_doc2, antecedent_bias= antencedent_bias,\n", + " witnesses = witnesses_doc,\n", + " consequents = consequents_doc2,\n", + " consequent_scale= 1e-8):\n", + " with condition(data = observations_doc):\n", + " with pyro.plate(\"sample\", num_samples):\n", + " with pyro.poutine.trace() as tr_doc2:\n", + " model_doctors()\n", + "\n", + "ac_check(tr_doc2, mwc_doc2, antecedents_doc2, witnesses_doc, consequents_doc2)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No resulting difference to the consequent in the sample.\n" + ] + } + ], + "source": [ + "# The third does not, so transitivity fails!\n", + "\n", + "with MultiWorldCounterfactual() as mwc_doc3:\n", + " with SearchForExplanation(antecedents = antecedents_doc1, antecedent_bias= antencedent_bias,\n", + " witnesses = witnesses_doc,\n", + " consequents = consequents_doc2,\n", + " consequent_scale= 1e-8):\n", + " with condition(data = observations_doc):\n", + " with pyro.plate(\"sample\", num_samples):\n", + " with pyro.poutine.trace() as tr_doc3:\n", + " model_doctors()\n", + "\n", + "ac_check(tr_doc3, mwc_doc3, antecedents_doc1, witnesses_doc, consequents_doc2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Friendly fire\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "This comes from a causal model developed in a real-life incident investigation, as discussed in the [Incident Reporting using SERAS® Reporter and SERAS® Analyst](https://www.causalis.com/90-publications/99-downloads/) paper.\n", + "\n", + "A U.S. Special Forces air controller changing the battery on a Global Positioning System device he was using to target a Taliban outpost north of Kandahar. Three special forces soldiers were killed and 20 were injured when a 2,000-pound, satellite-guided bomb landed, not on the Taliban outpost, but on a battalion command post occupied by American forces and a group of Afghan allies, including Hamid Karzai, now the interim prime minister. The Air Force combat controller was using a Precision Lightweight GPS Receiver to calculate the Taliban's coordinates for the attack. The controller did not realize that after he changed the device's battery, the machine was programmed to automatically come back on displaying coordinates for its own location, the official said.\n", + "\n", + "Minutes before the B-52 strike, the controller had used the GPS receiver to\n", + "calculate the latitude and longitude of the Taliban position in minutes and seconds for an airstrike by a Navy F/A-18. Then, with the B-52 approaching the target, the air controller did a second calculation in “degree decimals” required by the bomber crew. The controller had performed the calculation and recorded the position, when the receiver battery died. Without realizing the machine was programmed to come back on showing the coordinates of its\n", + "own location, the controller mistakenly called in the American position to the B-52.\n", + "\n", + "Factors included in the model (will be connected in the model as specified in the original report):\n", + "\n", + "1. The air controller changed the battery on the PLGR\n", + "2. Three special forces soldiers were killed and 20 were injured\n", + "3. B-52 fired a JDAM bomb at the Allied position\n", + "4. The air controller was using the PLGR to calculate the Taliban's coordinates\n", + "5. The controller did not realize that the PLGR was programmed to automatically come back on displaying coordinates for its own location\n", + "6. The controller had used the PLGR to calculate the latitude and longitude of the Taliban position in minutes and seconds for an airstrike by a Navy F/A-18\n", + "7. The air controller did a second calculation in “degree decimals” required by the bomber crew\n", + "8. The controller had performed the calculation and recorded the position\n", + "9. The controller mistakenly called in the American position to the B-52\n", + "10. The B-52 fired a JDAM bomb at the Allied position\n", + "11. The U.S. Air Force and Army had a training problem\n", + "12. The PLRG resumed displaying the coordinates of its own location after the battery was changed\n", + "13. The battery died at the crucial time\n", + "14. The controller thought he was calling in the Taliban position\n", + "\n", + "We will encode the model and show that in such somewhat more complicated cases, answers to `ac_check` queries are also intuitive." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def model_friendly_fire():\n", + " u_f4_PLGR_now = pyro.sample(\"u_f4_PLGR_now\", dist.Bernoulli(0.5))\n", + " u_f11_training = pyro.sample(\"u_f11_training\", dist.Bernoulli(0.5))\n", + "\n", + " f4_PLGR_now = pyro.deterministic(\"f4_PLGR_now\", u_f4_PLGR_now, event_dim=0)\n", + " f11_training = pyro.deterministic(\n", + " \"f11_training\", u_f11_training, event_dim=0\n", + " )\n", + "\n", + " f6_PLGR_before = pyro.deterministic(\n", + " \"f6_PLGR_before\", f4_PLGR_now, event_dim=0\n", + " )\n", + " f7_second_calculation = pyro.deterministic(\n", + " \"f7_second_calculation\", f4_PLGR_now, event_dim=0\n", + " )\n", + " f13_battery_died = pyro.deterministic(\n", + " \"f13_battery_died\",\n", + " f6_PLGR_before.bool() & f7_second_calculation.bool(),\n", + " event_dim=0,\n", + " )\n", + "\n", + " f1_battery_change = pyro.deterministic(\n", + " \"f1_battery_change\", f13_battery_died, event_dim=0\n", + " )\n", + "\n", + " f12_PLGR_after = pyro.deterministic(\n", + " \"f12_PLGR_after\", f1_battery_change, event_dim=0\n", + " )\n", + "\n", + " f5_unaware = pyro.deterministic(\"f5_unaware\", f11_training, event_dim=0)\n", + "\n", + " f14_wrong_position = pyro.deterministic(\n", + " \"f14_wrong_position\", f5_unaware, event_dim=0\n", + " )\n", + "\n", + " f9_mistake_call = pyro.deterministic(\n", + " \"f9_mistake_call\",\n", + " f12_PLGR_after.bool() & \n", + " f14_wrong_position.bool(),\n", + " event_dim=0,\n", + " )\n", + "\n", + " f3_fired = pyro.deterministic(\"f3_fired\", f9_mistake_call, event_dim=0)\n", + "\n", + " f10_landed = pyro.deterministic(\n", + " \"f10_landed\", f3_fired.bool() & f9_mistake_call.bool(), event_dim=0\n", + " )\n", + "\n", + " f2_killed = pyro.deterministic(\"f2_killed\", f10_landed, event_dim=0)\n", + "\n", + " return {\n", + " \"f1_battery_change\": f1_battery_change,\n", + " \"f2_killed\": f2_killed,\n", + " \"f3_fired\": f3_fired,\n", + " \"f4_PLGR_now\": f4_PLGR_now,\n", + " \"f5_unaware\": f5_unaware,\n", + " \"f6_PLGR_before\": f6_PLGR_before,\n", + " \"f7_second_calculation\": f7_second_calculation,\n", + " \"f9_mistake_call\": f9_mistake_call,\n", + " \"f10_landed\": f10_landed,\n", + " \"f11_training\": f11_training,\n", + " \"f12_PLGR_after\": f12_PLGR_after,\n", + " \"f13_battery_died\": f13_battery_died,\n", + " \"f14_wrong_position\": f14_wrong_position,\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "def convert_to_bool(list):\n", + " return {key: constraints.boolean for key in list}\n", + "\n", + "antecedents_fi1 = {\"f6_PLGR_before\": 0.0, \"f7_second_calculation\": 0.0}\n", + "consequents_fi = convert_to_bool([\"f2_killed\"])\n", + "witnesses_fi = convert_to_bool([\"f4_PLGR_now\",\"f5_unaware\", \"f11_training\", \"f14_wrong_position\"])\n", + "observations_fi = tensorize_observations({\"u_f4_PLGR_now\": 1.0, \"u_f11_training\": 1.0})" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "with MultiWorldCounterfactual() as mwc_fi1:\n", + " with SearchForExplanation(antecedents = antecedents_fi1, antecedent_bias= antencedent_bias,\n", + " witnesses = witnesses_fi,\n", + " consequents = consequents_fi,\n", + " consequent_scale= 1e-8):\n", + " with condition(data = observations_fi):\n", + " with pyro.plate(\"sample\", num_samples):\n", + " with pyro.poutine.trace() as tr_fi1:\n", + " model_friendly_fire() " + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The antecedent set is not minimal.\n" + ] + }, + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ac_check(tr_fi1, mwc_fi1, antecedents_fi1, witnesses_fi, consequents_fi)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The antecedent set is an actual cause.\n" + ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "antecedents_fi2 = {\"f6_PLGR_before\": 0.0}\n", + "\n", + "with MultiWorldCounterfactual() as mwc_fi2:\n", + " with SearchForExplanation(antecedents = antecedents_fi2, antecedent_bias= antencedent_bias,\n", + " witnesses = witnesses_fi,\n", + " consequents = consequents_fi,\n", + " consequent_scale= 1e-8):\n", + " with condition(data = observations_fi):\n", + " with pyro.plate(\"sample\", num_samples):\n", + " with pyro.poutine.trace() as tr_fi2:\n", + " model_friendly_fire() \n", + " \n", + "ac_check(tr_fi2, mwc_fi2, antecedents_fi2, witnesses_fi, consequents_fi)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Voting\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The main reason why the voting models are interesting in this context is that we are interested in the role of particular voters in the coming about of the result. The intuition is that a voter might play are role or be blamed for not voting even if her vote is not decisive. This should be handled by the notion of responsibility. For now, we just notice that the notion of actual causality at play is not enough to capture these intuitions. Say you give one vote in a binary majority vote, `vote0`, you vote \"for\", and there are six other voters. " + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "def voting_model():\n", + " u_vote0 = pyro.sample(\"u_vote0\", dist.Bernoulli(0.6))\n", + " u_vote1 = pyro.sample(\"u_vote1\", dist.Bernoulli(0.6))\n", + " u_vote2 = pyro.sample(\"u_vote2\", dist.Bernoulli(0.6))\n", + " u_vote3 = pyro.sample(\"u_vote3\", dist.Bernoulli(0.6))\n", + " u_vote4 = pyro.sample(\"u_vote4\", dist.Bernoulli(0.6))\n", + " u_vote5 = pyro.sample(\"u_vote5\", dist.Bernoulli(0.6))\n", + "\n", + " vote0 = pyro.deterministic(\"vote0\", u_vote0, event_dim=0)\n", + " vote1 = pyro.deterministic(\"vote1\", u_vote1, event_dim=0)\n", + " vote2 = pyro.deterministic(\"vote2\", u_vote2, event_dim=0)\n", + " vote3 = pyro.deterministic(\"vote3\", u_vote3, event_dim=0)\n", + " vote4 = pyro.deterministic(\"vote4\", u_vote4, event_dim=0)\n", + " vote5 = pyro.deterministic(\"vote5\", u_vote5, event_dim=0)\n", + " \n", + " outcome = pyro.deterministic(\"outcome\", vote0 + vote1 + vote2 + vote3 + vote4 + vote5 > 3).float()\n", + " return {\"outcome\": outcome}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "antecedents_v = {\"vote0\":0.0}\n", + "outcome_v = convert_to_bool([\"outcome\"])\n", + "witnesses_v = convert_to_bool([f\"vote{i}\" for i in range(1,6)])\n", + "observations_v1 = tensorize_observations(dict(u_vote0=1., u_vote1=1., u_vote2=1.,\n", + " u_vote3=1., u_vote4=0., u_vote5=0.))" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The antecedent set is an actual cause.\n" + ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "with MultiWorldCounterfactual() as mwc_v1:\n", + " with SearchForExplanation(antecedents = antecedents_v, antecedent_bias= antencedent_bias,\n", + " witnesses = witnesses_v,\n", + " consequents = outcome_v,\n", + " consequent_scale= 1e-8):\n", + " with condition(data = observations_v1):\n", + " with pyro.plate(\"sample\", num_samples):\n", + " with pyro.poutine.trace() as tr_v1:\n", + " voting_model()\n", + "\n", + "# if you're one of four voters who voted for, you are an actual cause\n", + "# of the outcome\n", + "\n", + "ac_check(tr_v1, mwc_v1, antecedents_v, witnesses_v, outcome_v)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No resulting difference to the consequent in the sample.\n" + ] + } + ], + "source": [ + "# if you're one of five voters who voted for, you are not an actual cause\n", + "# of the outcome\n", + "\n", + "observations_v2 = tensorize_observations(dict(u_vote0=1., u_vote1=1., u_vote2=1.,\n", + " u_vote3=1., u_vote4=1., u_vote5=0.))\n", + "\n", + "with MultiWorldCounterfactual() as mwc_v2:\n", + " with SearchForExplanation(antecedents = antecedents_v, antecedent_bias= antencedent_bias,\n", + " witnesses = witnesses_v,\n", + " consequents = outcome_v,\n", + " consequent_scale= 1e-8):\n", + " with condition(data = observations_v2):\n", + " with pyro.plate(\"sample\", num_samples):\n", + " with pyro.poutine.trace() as tr_v2:\n", + " voting_model()\n", + " \n", + "ac_check(tr_v2, mwc_v2, antecedents_v, witnesses_v, outcome_v)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The antecedent set is an actual cause.\n" + ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "antecedents_v3 = {\"vote0\":0.0, \"vote1\": 0.0}\n", + "witnesses_v3 = convert_to_bool([f\"vote{i}\" for i in range(2,6)])\n", + "\n", + "with MultiWorldCounterfactual() as mwc_v3:\n", + " with SearchForExplanation(antecedents = antecedents_v3, antecedent_bias= antencedent_bias,\n", + " witnesses = witnesses_v3,\n", + " consequents = outcome_v,\n", + " consequent_scale= 1e-8):\n", + " with condition(data = observations_v2):\n", + " with pyro.plate(\"sample\", num_samples):\n", + " with pyro.poutine.trace() as tr_v3:\n", + " voting_model()\n", + "\n", + "ac_check(tr_v3, mwc_v3, antecedents_v3, witnesses_v3, outcome_v)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "chirho", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/source/counterfactual.rst b/docs/source/counterfactual.rst index b0ae0f334..0ffc0d692 100644 --- a/docs/source/counterfactual.rst +++ b/docs/source/counterfactual.rst @@ -31,10 +31,6 @@ Handlers :members: :undoc-members: -.. automodule:: chirho.counterfactual.handlers.explanation - :members: - :undoc-members: - Internals --------- diff --git a/docs/source/explainable.rst b/docs/source/explainable.rst new file mode 100644 index 000000000..5a41aa2b7 --- /dev/null +++ b/docs/source/explainable.rst @@ -0,0 +1,43 @@ +Explainable +=========== + +.. automodule:: chirho.explainable + :members: + :undoc-members: + +Operations +---------- + +.. automodule:: chirho.explainable.ops + :members: + :undoc-members: + +Handlers +-------- + +.. automodule:: chirho.explainable.handlers + :members: + :undoc-members: + +.. automodule:: chirho.explainable.handlers.components + :members: + :undoc-members: + +.. automodule:: chirho.explainable.handlers.explanation + :members: + :undoc-members: + +.. automodule:: chirho.explainable.handlers.preemptions + :members: + :undoc-members: + +Internals +--------- + +.. automodule:: chirho.explainable.internals + :members: + :undoc-members: + +.. automodule:: chirho.explainable.internals.defaults + :members: + :undoc-members: diff --git a/docs/source/index.rst b/docs/source/index.rst index aaed0b468..dd14293bd 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -31,6 +31,7 @@ Table of Contents sdid dr_learner dynamical_intro + actual_causality .. toctree:: :maxdepth: 2 @@ -41,6 +42,7 @@ Table of Contents observational indexed dynamical + explainable .. toctree:: :maxdepth: 2 diff --git a/scripts/test_notebooks.sh b/scripts/test_notebooks.sh index 72a1d86a3..6832b12e0 100755 --- a/scripts/test_notebooks.sh +++ b/scripts/test_notebooks.sh @@ -1,5 +1,5 @@ #!/bin/bash -INCLUDED_NOTEBOOKS="docs/source/tutorial_i.ipynb docs/source/backdoor.ipynb docs/source/dr_learner.ipynb docs/source/mediation.ipynb docs/source/sdid.ipynb docs/source/slc.ipynb docs/source/dynamical_intro.ipynb" +INCLUDED_NOTEBOOKS="docs/source/tutorial_i.ipynb docs/source/backdoor.ipynb docs/source/dr_learner.ipynb docs/source/mediation.ipynb docs/source/sdid.ipynb docs/source/slc.ipynb docs/source/dynamical_intro.ipynb docs/source/actual_causality.ipynb" CI=1 pytest --nbval-lax --dist loadscope -n auto $INCLUDED_NOTEBOOKS diff --git a/tests/counterfactual/test_counterfactual_handler.py b/tests/counterfactual/test_counterfactual_handler.py index de9960796..2fd3e5bee 100644 --- a/tests/counterfactual/test_counterfactual_handler.py +++ b/tests/counterfactual/test_counterfactual_handler.py @@ -15,9 +15,7 @@ SingleWorldFactual, TwinWorldCounterfactual, ) -from chirho.counterfactual.handlers.counterfactual import Preemptions from chirho.counterfactual.handlers.selection import SelectFactual -from chirho.counterfactual.ops import preempt, split from chirho.indexed.ops import IndexSet, gather, indices_of, union from chirho.interventional.handlers import do from chirho.interventional.ops import intervene @@ -650,72 +648,6 @@ def model(): assert torch.any(tr.nodes["z"]["value"] < 1) -def test_preempt_op_singleworld(): - @SingleWorldCounterfactual() - @pyro.plate("data", size=1000, dim=-1) - def model(): - x = pyro.sample("x", dist.Bernoulli(0.67)) - x = pyro.deterministic( - "x_", split(x, (torch.tensor(0.0),), name="x", event_dim=0), event_dim=0 - ) - y = pyro.sample("y", dist.Bernoulli(0.67)) - y_case = torch.tensor(1) - y = pyro.deterministic( - "y_", - preempt(y, (torch.tensor(1.0),), y_case, name="__y", event_dim=0), - event_dim=0, - ) - z = pyro.sample("z", dist.Bernoulli(0.67)) - return dict(x=x, y=y, z=z) - - tr = pyro.poutine.trace(model).get_trace() - assert torch.all(tr.nodes["x_"]["value"] == 0.0) - assert torch.all(tr.nodes["y_"]["value"] == 1.0) - - -@pytest.mark.parametrize("cf_dim", [-2, -3, None]) -@pytest.mark.parametrize("event_shape", [(), (4,), (4, 3)]) -def test_cf_handler_preemptions(cf_dim, event_shape): - event_dim = len(event_shape) - - splits = {"x": torch.tensor(0.0)} - preemptions = {"y": torch.tensor(1.0)} - - @do(actions=splits) - @pyro.plate("data", size=1000, dim=-1) - def model(): - w = pyro.sample( - "w", dist.Normal(0, 1).expand(event_shape).to_event(len(event_shape)) - ) - x = pyro.sample("x", dist.Normal(w, 1).to_event(len(event_shape))) - y = pyro.sample("y", dist.Normal(w + x, 1).to_event(len(event_shape))) - z = pyro.sample("z", dist.Normal(x + y, 1).to_event(len(event_shape))) - return dict(w=w, x=x, y=y, z=z) - - preemption_handler = Preemptions(actions=preemptions, bias=0.1, prefix="__split_") - - with MultiWorldCounterfactual(cf_dim), preemption_handler: - tr = pyro.poutine.trace(model).get_trace() - assert all(f"__split_{k}" in tr.nodes for k in preemptions.keys()) - assert indices_of(tr.nodes["w"]["value"], event_dim=event_dim) == IndexSet() - assert indices_of(tr.nodes["y"]["value"], event_dim=event_dim) == IndexSet( - x={0, 1} - ) - assert indices_of(tr.nodes["z"]["value"], event_dim=event_dim) == IndexSet( - x={0, 1} - ) - - for k in preemptions.keys(): - tst = tr.nodes[f"__split_{k}"]["value"] - assert torch.allclose( - tr.nodes[f"__split_{k}"]["fn"].log_prob(torch.tensor(0)).exp(), - torch.tensor(0.5 - 0.1), - ) - tst_0 = (tst == 0).expand(tr.nodes[k]["fn"].batch_shape) - assert torch.all(tr.nodes[k]["value"][~tst_0] == preemptions[k]) - assert torch.all(tr.nodes[k]["value"][tst_0] != preemptions[k]) - - # Define a helper function to run SVI. (Generally, Pyro users like to have more control over the training process!) def run_svi_inference(model, n_steps=1000, verbose=True, lr=0.03, **model_kwargs): guide = AutoMultivariateNormal(model) diff --git a/tests/counterfactual/test_handlers_explanation.py b/tests/counterfactual/test_handlers_explanation.py deleted file mode 100644 index 392a7aeb2..000000000 --- a/tests/counterfactual/test_handlers_explanation.py +++ /dev/null @@ -1,509 +0,0 @@ -import pyro -import pyro.distributions as dist -import pyro.infer -import pytest -import torch - -from chirho.counterfactual.handlers.counterfactual import ( - MultiWorldCounterfactual, - Preemptions, -) -from chirho.counterfactual.handlers.explanation import ( - ExplainCauses, - SearchForCause, - consequent_differs, - random_intervention, - undo_split, - uniform_proposal, -) -from chirho.counterfactual.ops import preempt, split -from chirho.indexed.ops import IndexSet, gather, indices_of -from chirho.interventional.ops import intervene -from chirho.observational.handlers.condition import Factors, condition - - -def test_undo_split(): - with MultiWorldCounterfactual(): - x_obs = torch.zeros(10) - x_cf_1 = torch.ones(10) - x_cf_2 = 2 * x_cf_1 - x_split = split(x_obs, (x_cf_1,), name="split1") - x_split = split(x_split, (x_cf_2,), name="split2") - - undo_split2 = undo_split(antecedents=["split2"]) - x_undone = undo_split2(x_split) - - assert indices_of(x_split) == indices_of(x_undone) - assert torch.all(gather(x_split, IndexSet(split2={0})) == x_undone) - - -@pytest.mark.parametrize("plate_size", [4, 50, 200]) -@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)]) -def test_undo_split_parametrized(event_shape, plate_size): - joint_dims = torch.Size([plate_size, *event_shape]) - - replace1 = torch.ones(joint_dims) - preemption_tensor = replace1 * 5 - case = torch.randint(0, 2, size=joint_dims) - - @pyro.plate("data", size=plate_size, dim=-1) - def model(): - w = pyro.sample( - "w", dist.Normal(0, 1).expand(event_shape).to_event(len(event_shape)) - ) - w = split(w, (replace1,), name="split1") - - w = pyro.deterministic( - "w_preempted", preempt(w, preemption_tensor, case, name="w_preempted") - ) - - w = pyro.deterministic("w_undone", undo_split(antecedents=["split1"])(w)) - - with MultiWorldCounterfactual() as mwc: - with pyro.poutine.trace() as tr: - model() - - nd = tr.trace.nodes - - with mwc: - assert indices_of(nd["w_undone"]["value"]) == IndexSet(split1={0, 1}) - - w_undone_shape = list(nd["w_undone"]["value"].shape) - desired_shape = list( - (2,) - + (1,) * (len(w_undone_shape) - len(event_shape) - 2) - + (plate_size,) - + event_shape - ) - assert w_undone_shape == desired_shape - - cf_values = gather(nd["w_undone"]["value"], IndexSet(split1={1})).squeeze() - observed_values = gather( - nd["w_undone"]["value"], IndexSet(split1={0}) - ).squeeze() - - preempted_values = cf_values[case == 1.0] - reverted_values = cf_values[case == 0.0] - picked_values = observed_values[case == 0.0] - - assert torch.all(preempted_values == 5.0) - assert torch.all(reverted_values == picked_values) - - -def test_undo_split_with_interaction(): - def model(): - x = pyro.sample("x", dist.Delta(torch.tensor(1.0))) - - x_split = pyro.deterministic( - "x_split", - split(x, (torch.tensor(0.5),), name="x_split", event_dim=0), - event_dim=0, - ) - - x_undone = pyro.deterministic( - "x_undone", undo_split(antecedents=["x_split"])(x_split), event_dim=0 - ) - - x_case = torch.tensor(1) - x_preempted = pyro.deterministic( - "x_preempted", - preempt( - x_undone, (torch.tensor(5.0),), x_case, name="x_preempted", event_dim=0 - ), - event_dim=0, - ) - - x_undone_2 = pyro.deterministic( - "x_undone_2", undo_split(antecedents=["x"])(x_preempted), event_dim=0 - ) - - x_split2 = pyro.deterministic( - "x_split2", - split(x_undone_2, (torch.tensor(2.0),), name="x_split2", event_dim=0), - event_dim=0, - ) - - x_undone_3 = pyro.deterministic( - "x_undone_3", - undo_split(antecedents=["x_split", "x_split2"])(x_split2), - event_dim=0, - ) - - return x_undone_3 - - with MultiWorldCounterfactual() as mwc: - with pyro.poutine.trace() as tr: - model() - - nd = tr.trace.nodes - - with mwc: - x_split_2 = nd["x_split2"]["value"] - x_00 = gather( - x_split_2, IndexSet(x_split={0}, x_split2={0}), event_dim=0 - ) # 5.0 - x_10 = gather( - x_split_2, IndexSet(x_split={1}, x_split2={0}), event_dim=0 - ) # 5.0 - x_01 = gather( - x_split_2, IndexSet(x_split={0}, x_split2={1}), event_dim=0 - ) # 2.0 - x_11 = gather( - x_split_2, IndexSet(x_split={1}, x_split2={1}), event_dim=0 - ) # 2.0 - - assert ( - nd["x_split"]["value"][0].item() == 1.0 - and nd["x_split"]["value"][1].item() == 0.5 - ) - - assert ( - nd["x_undone"]["value"][0].item() == 1.0 - and nd["x_undone"]["value"][1].item() == 1.0 - ) - - assert ( - nd["x_preempted"]["value"][0].item() == 5.0 - and nd["x_preempted"]["value"][1].item() == 5.0 - ) - - assert ( - nd["x_undone_2"]["value"][0].item() == 5.0 - and nd["x_undone_2"]["value"][1].item() == 5.0 - ) - - assert torch.all(nd["x_undone_3"]["value"] == 5.0) - - assert (x_00, x_10, x_01, x_11) == (5.0, 5.0, 2.0, 2.0) - - -@pytest.mark.parametrize("plate_size", [4, 50, 200]) -@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)], ids=str) -def test_consequent_differs(plate_size, event_shape): - factors = { - "consequent": consequent_differs( - antecedents=["split"], event_dim=len(event_shape) - ) - } - - @Factors(factors=factors) - @pyro.plate("data", size=plate_size, dim=-1) - def model_cd(): - w = pyro.sample( - "w", dist.Normal(0, 0.1).expand(event_shape).to_event(len(event_shape)) - ) - new_w = w.clone() - new_w[1::2] = 10 - w = split(w, (new_w,), name="split") - consequent = pyro.deterministic( - "consequent", w * 0.1, event_dim=len(event_shape) - ) - con_dif = pyro.deterministic( - "con_dif", consequent_differs(antecedents=["split"])(consequent) - ) - return con_dif - - with MultiWorldCounterfactual() as mwc: - with pyro.poutine.trace() as tr: - model_cd() - - tr.trace.compute_log_prob() - nd = tr.trace.nodes - - with mwc: - int_con_dif = gather( - nd["con_dif"]["value"], IndexSet(**{"split": {1}}) - ).squeeze() - - assert torch.all(int_con_dif[1::2] == 0.0) - assert torch.all(int_con_dif[0::2] == -1e8) - - assert nd["__factor_consequent"]["log_prob"].sum() < -1e2 - - -SUPPORT_CASES = [ - pyro.distributions.constraints.real, - pyro.distributions.constraints.boolean, - pyro.distributions.constraints.positive, - pyro.distributions.constraints.interval(0, 10), - pyro.distributions.constraints.interval(-5, 5), - pyro.distributions.constraints.integer_interval(0, 2), - pyro.distributions.constraints.integer_interval(0, 100), -] - - -@pytest.mark.parametrize("support", SUPPORT_CASES) -@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)], ids=str) -def test_uniform_proposal(support, event_shape): - if event_shape: - support = pyro.distributions.constraints.independent(support, len(event_shape)) - - uniform = uniform_proposal(support, event_shape=event_shape) - samples = uniform.sample((10,)) - assert torch.all(support.check(samples)) - - -@pytest.mark.parametrize("support", SUPPORT_CASES) -@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)], ids=str) -def test_random_intervention(support, event_shape): - if event_shape: - support = pyro.distributions.constraints.independent(support, len(event_shape)) - - obs_value = torch.randn(event_shape) - intervention = random_intervention(support, "samples") - - with pyro.plate("draws", 10): - samples = intervene(obs_value, intervention) - - assert torch.all(support.check(samples)) - - -def stones_bayesian_model(): - with pyro.poutine.mask(mask=False): - prob_sally_throws = pyro.sample("prob_sally_throws", dist.Beta(1, 1)) - prob_bill_throws = pyro.sample("prob_bill_throws", dist.Beta(1, 1)) - prob_sally_hits = pyro.sample("prob_sally_hits", dist.Beta(1, 1)) - prob_bill_hits = pyro.sample("prob_bill_hits", dist.Beta(1, 1)) - prob_bottle_shatters_if_sally = pyro.sample( - "prob_bottle_shatters_if_sally", dist.Beta(1, 1) - ) - prob_bottle_shatters_if_bill = pyro.sample( - "prob_bottle_shatters_if_bill", dist.Beta(1, 1) - ) - - sally_throws = pyro.sample("sally_throws", dist.Bernoulli(prob_sally_throws)) - bill_throws = pyro.sample("bill_throws", dist.Bernoulli(prob_bill_throws)) - - new_shp = torch.where(sally_throws == 1, prob_sally_hits, 0.0) - - sally_hits = pyro.sample("sally_hits", dist.Bernoulli(new_shp)) - - new_bhp = torch.where( - (bill_throws.bool() & (~sally_hits.bool())) == 1, - prob_bill_hits, - torch.tensor(0.0), - ) - - bill_hits = pyro.sample("bill_hits", dist.Bernoulli(new_bhp)) - - new_bsp = torch.where( - bill_hits.bool() == 1, - prob_bottle_shatters_if_bill, - torch.where( - sally_hits.bool() == 1, - prob_bottle_shatters_if_sally, - torch.tensor(0.0), - ), - ) - - bottle_shatters = pyro.sample("bottle_shatters", dist.Bernoulli(new_bsp)) - - return { - "sally_throws": sally_throws, - "bill_throws": bill_throws, - "sally_hits": sally_hits, - "bill_hits": bill_hits, - "bottle_shatters": bottle_shatters, - } - - -def test_SearchForCause_single_layer(): - observations = { - "prob_sally_throws": 1.0, - "prob_bill_throws": 1.0, - "prob_sally_hits": 1.0, - "prob_bill_hits": 1.0, - "prob_bottle_shatters_if_sally": 1.0, - "prob_bottle_shatters_if_bill": 1.0, - } - - observations_conditioning = condition( - data={k: torch.as_tensor(v) for k, v in observations.items()} - ) - - with MultiWorldCounterfactual() as mwc: - with SearchForCause({"sally_throws": 0.0}, bias=0.0): - with observations_conditioning: - with pyro.poutine.trace() as tr: - stones_bayesian_model() - - tr = tr.trace.nodes - - with mwc: - preempt_sally_throws = gather( - tr["__cause_split_sally_throws"]["value"], - IndexSet(**{"sally_throws": {0}}), - event_dim=0, - ) - - int_sally_hits = gather( - tr["sally_hits"]["value"], IndexSet(**{"sally_throws": {1}}), event_dim=0 - ) - - obs_bill_hits = gather( - tr["bill_hits"]["value"], IndexSet(**{"sally_throws": {0}}), event_dim=0 - ) - - int_bill_hits = gather( - tr["bill_hits"]["value"], IndexSet(**{"sally_throws": {1}}), event_dim=0 - ) - - int_bottle_shatters = gather( - tr["bottle_shatters"]["value"], - IndexSet(**{"sally_throws": {1}}), - event_dim=0, - ) - - outcome = { - "preempt_sally_throws": preempt_sally_throws.item(), - "int_sally_hits": int_sally_hits.item(), - "obs_bill_hits": obs_bill_hits.item(), - "int_bill_hits": int_bill_hits.item(), - "intervened_bottle_shatters": int_bottle_shatters.item(), - } - - assert list(outcome.values()) == [0, 0.0, 0.0, 1.0, 1.0] or list( - outcome.values() - ) == [1, 1.0, 0.0, 0.0, 1.0] - - -def test_SearchForCause_two_layers(): - observations = { - "prob_sally_throws": 1.0, - "prob_bill_throws": 1.0, - "prob_sally_hits": 1.0, - "prob_bill_hits": 1.0, - "prob_bottle_shatters_if_sally": 1.0, - "prob_bottle_shatters_if_bill": 1.0, - } - - observations_conditioning = condition( - data={k: torch.as_tensor(v) for k, v in observations.items()} - ) - - actions = {"sally_throws": 0.0} - - pinned_preemption_variables = { - "preempt_sally_throws": torch.tensor(0), - "witness_preempt_bill_hits": torch.tensor(1), - } - preemption_conditioning = condition(data=pinned_preemption_variables) - - witness_preemptions = {"bill_hits": undo_split(antecedents=actions.keys())} - witness_preemptions_handler: Preemptions = Preemptions( - actions=witness_preemptions, prefix="witness_preempt_" - ) - - with MultiWorldCounterfactual() as mwc: - with SearchForCause(actions=actions, bias=0.1, prefix="preempt_"): - with preemption_conditioning, witness_preemptions_handler: - with observations_conditioning: - with pyro.poutine.trace() as tr: - stones_bayesian_model() - - tr = tr.trace.nodes - - with mwc: - obs_bill_hits = gather( - tr["bill_hits"]["value"], - IndexSet(**{"sally_throws": {0}}), - event_dim=0, - ).item() - int_bill_hits = gather( - tr["bill_hits"]["value"], - IndexSet(**{"sally_throws": {1}}), - event_dim=0, - ).item() - int_bottle_shatters = gather( - tr["bottle_shatters"]["value"], - IndexSet(**{"sally_throws": {1}}), - event_dim=0, - ).item() - - assert obs_bill_hits == 0.0 and int_bill_hits == 0.0 and int_bottle_shatters == 0.0 - - -def test_ExplainCauses(): - observations = { - "prob_sally_throws": 1.0, - "prob_bill_throws": 1.0, - "prob_sally_hits": 1.0, - "prob_bill_hits": 1.0, - "prob_bottle_shatters_if_sally": 1.0, - "prob_bottle_shatters_if_bill": 1.0, - } - - observations_conditioning = condition( - data={k: torch.as_tensor(v) for k, v in observations.items()} - ) - - antecedents = {"sally_throws": 0.0} - witnesses = ["bill_throws", "bill_hits"] - consequents = ["bottle_shatters"] - - with MultiWorldCounterfactual() as mwc: - with ExplainCauses( - antecedents=antecedents, - witnesses=witnesses, - consequents=consequents, - antecedent_bias=0.1, - ): - with observations_conditioning: - with pyro.plate("sample", 200): - with pyro.poutine.trace() as tr: - stones_bayesian_model() - - tr.trace.compute_log_prob() - tr = tr.trace.nodes - - with mwc: - log_probs = ( - gather( - tr["__consequent_bottle_shatters"]["log_prob"], - IndexSet(**{i: {1} for i in antecedents.keys()}), - event_dim=0, - ) - .squeeze() - .tolist() - ) - - st_obs = ( - gather( - tr["sally_throws"]["value"], - IndexSet(**{i: {0} for i in antecedents.keys()}), - event_dim=0, - ) - .squeeze() - .tolist() - ) - - st_int = ( - gather( - tr["sally_throws"]["value"], - IndexSet(**{i: {1} for i in antecedents.keys()}), - event_dim=0, - ) - .squeeze() - .tolist() - ) - - bh_int = ( - gather( - tr["bill_hits"]["value"], - IndexSet(**{i: {1} for i in antecedents.keys()}), - event_dim=0, - ) - .squeeze() - .tolist() - ) - - st_ant = tr["__antecedent_sally_throws"]["value"].squeeze().tolist() - - assert all(lp == -1e8 or lp == 0.0 for lp in log_probs) - - for step in range(200): - bottle_will_shatter = ( - st_obs[step] != st_int[step] and st_ant == 0.0 - ) or bh_int[step] == 1.0 - if bottle_will_shatter: - assert log_probs[step] == -1e8 diff --git a/tests/explainable/test_handlers_components.py b/tests/explainable/test_handlers_components.py new file mode 100644 index 000000000..34683b436 --- /dev/null +++ b/tests/explainable/test_handlers_components.py @@ -0,0 +1,325 @@ +import pyro +import pyro.distributions as dist +import pyro.distributions.constraints as constraints +import pytest +import torch + +from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual +from chirho.counterfactual.ops import split +from chirho.explainable.handlers import random_intervention +from chirho.explainable.handlers.components import ( + ExtractSupports, + consequent_differs, + undo_split, +) +from chirho.explainable.ops import preempt +from chirho.indexed.ops import IndexSet, gather, indices_of +from chirho.interventional.ops import intervene +from chirho.observational.handlers.condition import Factors + +SUPPORT_CASES = [ + pyro.distributions.constraints.real, + pyro.distributions.constraints.boolean, + pyro.distributions.constraints.positive, + pyro.distributions.constraints.interval(0, 10), + pyro.distributions.constraints.interval(-5, 5), + pyro.distributions.constraints.integer_interval(0, 2), + pyro.distributions.constraints.integer_interval(0, 100), +] + + +@pytest.mark.parametrize("support", SUPPORT_CASES) +@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)], ids=str) +def test_random_intervention(support, event_shape): + if event_shape: + support = pyro.distributions.constraints.independent(support, len(event_shape)) + + obs_value = torch.randn(event_shape) + intervention = random_intervention(support, "samples") + + with pyro.plate("draws", 10): + samples = intervene(obs_value, intervention) + + assert torch.all(support.check(samples)) + + +def test_undo_split(): + with MultiWorldCounterfactual(): + x_obs = torch.zeros(10) + x_cf_1 = torch.ones(10) + x_cf_2 = 2 * x_cf_1 + x_split = split(x_obs, (x_cf_1,), name="split1", event_dim=1) + x_split = split(x_split, (x_cf_2,), name="split2", event_dim=1) + + undo_split2 = undo_split( + support=constraints.independent(constraints.real, 1), antecedents=["split2"] + ) + x_undone = undo_split2(x_split) + + assert indices_of(x_split, event_dim=1) == indices_of(x_undone, event_dim=1) + assert torch.all(gather(x_split, IndexSet(split2={0}), event_dim=1) == x_undone) + + +@pytest.mark.parametrize("plate_size", [4, 50, 200]) +@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)]) +def test_undo_split_parametrized(event_shape, plate_size): + joint_dims = torch.Size([plate_size, *event_shape]) + + replace1 = torch.ones(joint_dims) + preemption_tensor = replace1 * 5 + case = torch.randint(0, 2, size=(plate_size,)) + + @pyro.plate("data", size=plate_size, dim=-1) + def model(): + w = pyro.sample( + "w", dist.Normal(0, 1).expand(event_shape).to_event(len(event_shape)) + ) + w = split(w, (replace1,), name="split1", event_dim=len(event_shape)) + + w = pyro.deterministic( + "w_preempted", + preempt( + w, + preemption_tensor, + case, + name="w_preempted", + event_dim=len(event_shape), + ), + event_dim=len(event_shape), + ) + + w = pyro.deterministic( + "w_undone", + undo_split( + support=constraints.independent(constraints.real, len(event_shape)), + antecedents=["split1"], + )(w), + event_dim=len(event_shape), + ) + + with MultiWorldCounterfactual() as mwc: + with pyro.poutine.trace() as tr: + model() + + nd = tr.trace.nodes + + with mwc: + assert indices_of( + nd["w_undone"]["value"], event_dim=len(event_shape) + ) == IndexSet(split1={0, 1}) + + w_undone_shape = list(nd["w_undone"]["value"].shape) + desired_shape = list( + (2,) + + (1,) * (len(w_undone_shape) - len(event_shape) - 2) + + (plate_size,) + + event_shape + ) + assert w_undone_shape == desired_shape + + cf_values = gather( + nd["w_undone"]["value"], IndexSet(split1={1}), event_dim=len(event_shape) + ).squeeze() + observed_values = gather( + nd["w_undone"]["value"], IndexSet(split1={0}), event_dim=len(event_shape) + ).squeeze() + + preempted_values = cf_values[case == 1.0] + reverted_values = cf_values[case == 0.0] + picked_values = observed_values[case == 0.0] + + assert torch.all(preempted_values == 5.0) + assert torch.all(reverted_values == picked_values) + + +def test_undo_split_with_interaction(): + def model(): + x = pyro.sample("x", dist.Delta(torch.tensor(1.0))) + + x_split = pyro.deterministic( + "x_split", + split(x, (torch.tensor(0.5),), name="x_split", event_dim=0), + event_dim=0, + ) + + x_undone = pyro.deterministic( + "x_undone", + undo_split(support=constraints.real, antecedents=["x_split"])(x_split), + event_dim=0, + ) + + x_case = torch.tensor(1) + x_preempted = pyro.deterministic( + "x_preempted", + preempt( + x_undone, (torch.tensor(5.0),), x_case, name="x_preempted", event_dim=0 + ), + event_dim=0, + ) + + x_undone_2 = pyro.deterministic( + "x_undone_2", + undo_split(support=constraints.real, antecedents=["x"])(x_preempted), + event_dim=0, + ) + + x_split2 = pyro.deterministic( + "x_split2", + split(x_undone_2, (torch.tensor(2.0),), name="x_split2", event_dim=0), + event_dim=0, + ) + + x_undone_3 = pyro.deterministic( + "x_undone_3", + undo_split(support=constraints.real, antecedents=["x_split", "x_split2"])( + x_split2 + ), + event_dim=0, + ) + + return x_undone_3 + + with MultiWorldCounterfactual() as mwc: + with pyro.poutine.trace() as tr: + model() + + nd = tr.trace.nodes + + with mwc: + x_split_2 = nd["x_split2"]["value"] + x_00 = gather( + x_split_2, IndexSet(x_split={0}, x_split2={0}), event_dim=0 + ) # 5.0 + x_10 = gather( + x_split_2, IndexSet(x_split={1}, x_split2={0}), event_dim=0 + ) # 5.0 + x_01 = gather( + x_split_2, IndexSet(x_split={0}, x_split2={1}), event_dim=0 + ) # 2.0 + x_11 = gather( + x_split_2, IndexSet(x_split={1}, x_split2={1}), event_dim=0 + ) # 2.0 + + assert ( + nd["x_split"]["value"][0].item() == 1.0 + and nd["x_split"]["value"][1].item() == 0.5 + ) + + assert ( + nd["x_undone"]["value"][0].item() == 1.0 + and nd["x_undone"]["value"][1].item() == 1.0 + ) + + assert ( + nd["x_preempted"]["value"][0].item() == 5.0 + and nd["x_preempted"]["value"][1].item() == 5.0 + ) + + assert ( + nd["x_undone_2"]["value"][0].item() == 5.0 + and nd["x_undone_2"]["value"][1].item() == 5.0 + ) + + assert torch.all(nd["x_undone_3"]["value"] == 5.0) + + assert (x_00, x_10, x_01, x_11) == (5.0, 5.0, 2.0, 2.0) + + +# @pytest.mark.parametrize("plate_size", [4]) +# @pytest.mark.parametrize("event_shape", [()], ids=str) + + +@pytest.mark.parametrize("plate_size", [4, 50, 200]) +@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)], ids=str) +def test_consequent_differs(plate_size, event_shape): + factors = { + "consequent": consequent_differs( + antecedents=["split"], + support=constraints.independent(constraints.real, len(event_shape)), + ) + } + + @Factors(factors=factors) + @pyro.plate("data", size=plate_size, dim=-1) + def model_cd(): + w = pyro.sample( + "w", dist.Normal(0, 0.1).expand(event_shape).to_event(len(event_shape)) + ) + new_w = w.clone() + new_w[1::2] = 10 + w = split(w, (new_w,), name="split", event_dim=len(event_shape)) + consequent = pyro.deterministic( + "consequent", w * 0.1, event_dim=len(event_shape) + ) + con_dif = pyro.deterministic( + "con_dif", + consequent_differs( + support=constraints.independent(constraints.real, len(event_shape)), + antecedents=["split"], + )(consequent), + event_dim=0, + ) + + return con_dif + + with MultiWorldCounterfactual() as mwc: + with pyro.poutine.trace() as tr: + model_cd() + + tr.trace.compute_log_prob() + nd = tr.trace.nodes + + with mwc: + int_con_dif = gather(nd["con_dif"]["value"], IndexSet(**{"split": {1}})) + + assert "split" not in indices_of(int_con_dif) + assert not indices_of(int_con_dif) + + assert int_con_dif.squeeze().shape == nd["w"]["fn"].batch_shape + assert nd["__factor_consequent"]["log_prob"].sum() < -1e2 + + +options = [ + None, + [], + ["uniform_var"], + ["uniform_var", "normal_var", "bernoulli_var"], + {}, + {"uniform_var": 5.0, "bernoulli_var": 5.0}, + { + "uniform_var": constraints.interval(1, 10), + "bernoulli_var": constraints.interval(0, 1), + }, # misspecified on purpose, should make no damage +] + + +@pytest.mark.parametrize("event_shape", [(), (3, 2)], ids=str) +@pytest.mark.parametrize("plate_size", [4, 50]) +def test_ExtractSupports(event_shape, plate_size): + @pyro.plate("data", size=plate_size, dim=-1) + def mixed_supports_model(): + uniform_var = pyro.sample( + "uniform_var", + dist.Uniform(1, 10).expand(event_shape).to_event(len(event_shape)), + ) + normal_var = pyro.sample( + "normal_var", + dist.Normal(3, 15).expand(event_shape).to_event(len(event_shape)), + ) + bernoulli_var = pyro.sample("bernoulli_var", dist.Bernoulli(0.5)) + positive_var = pyro.sample( + "positive_var", + dist.LogNormal(0, 1).expand(event_shape).to_event(len(event_shape)), + ) + + return uniform_var, normal_var, bernoulli_var, positive_var + + with ExtractSupports() as s: + mixed_supports_model() + + assert list(s.supports.keys()) == [ + "uniform_var", + "normal_var", + "bernoulli_var", + "positive_var", + ] diff --git a/tests/explainable/test_handlers_explanation.py b/tests/explainable/test_handlers_explanation.py new file mode 100644 index 000000000..8acb4039b --- /dev/null +++ b/tests/explainable/test_handlers_explanation.py @@ -0,0 +1,295 @@ +import math + +import pyro +import pyro.distributions as dist +import pyro.distributions.constraints as constraints +import torch + +from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual +from chirho.explainable.handlers.components import undo_split +from chirho.explainable.handlers.explanation import SearchForExplanation, SplitSubsets +from chirho.explainable.handlers.preemptions import Preemptions +from chirho.indexed.ops import IndexSet, gather +from chirho.observational.handlers.condition import condition + + +def stones_bayesian_model(): + with pyro.poutine.mask(mask=False): + prob_sally_throws = pyro.sample("prob_sally_throws", dist.Beta(1, 1)) + prob_bill_throws = pyro.sample("prob_bill_throws", dist.Beta(1, 1)) + prob_sally_hits = pyro.sample("prob_sally_hits", dist.Beta(1, 1)) + prob_bill_hits = pyro.sample("prob_bill_hits", dist.Beta(1, 1)) + prob_bottle_shatters_if_sally = pyro.sample( + "prob_bottle_shatters_if_sally", dist.Beta(1, 1) + ) + prob_bottle_shatters_if_bill = pyro.sample( + "prob_bottle_shatters_if_bill", dist.Beta(1, 1) + ) + + sally_throws = pyro.sample("sally_throws", dist.Bernoulli(prob_sally_throws)) + bill_throws = pyro.sample("bill_throws", dist.Bernoulli(prob_bill_throws)) + + new_shp = torch.where(sally_throws == 1, prob_sally_hits, 0.0) + + sally_hits = pyro.sample("sally_hits", dist.Bernoulli(new_shp)) + + new_bhp = torch.where( + (bill_throws.bool() & (~sally_hits.bool())) == 1, + prob_bill_hits, + torch.tensor(0.0), + ) + + bill_hits = pyro.sample("bill_hits", dist.Bernoulli(new_bhp)) + + new_bsp = torch.where( + bill_hits.bool() == 1, + prob_bottle_shatters_if_bill, + torch.where( + sally_hits.bool() == 1, + prob_bottle_shatters_if_sally, + torch.tensor(0.0), + ), + ) + + bottle_shatters = pyro.sample("bottle_shatters", dist.Bernoulli(new_bsp)) + + return { + "sally_throws": sally_throws, + "bill_throws": bill_throws, + "sally_hits": sally_hits, + "bill_hits": bill_hits, + "bottle_shatters": bottle_shatters, + } + + +def test_SearchForExplanation(): + observations = { + "prob_sally_throws": 1.0, + "prob_bill_throws": 1.0, + "prob_sally_hits": 1.0, + "prob_bill_hits": 1.0, + "prob_bottle_shatters_if_sally": 1.0, + "prob_bottle_shatters_if_bill": 1.0, + } + + observations_conditioning = condition( + data={k: torch.as_tensor(v) for k, v in observations.items()} + ) + + antecedents = {"sally_throws": 0.0} + witnesses = {"bill_throws": constraints.boolean, "bill_hits": constraints.boolean} + consequents = {"bottle_shatters": constraints.boolean} + + with MultiWorldCounterfactual() as mwc: + with SearchForExplanation( + antecedents=antecedents, + witnesses=witnesses, + consequents=consequents, + antecedent_bias=0.1, + consequent_scale=1e-8, + ): + with observations_conditioning: + with pyro.plate("sample", 200): + with pyro.poutine.trace() as tr: + stones_bayesian_model() + + tr.trace.compute_log_prob() + tr = tr.trace.nodes + + with mwc: + log_probs = ( + gather( + tr["__consequent_bottle_shatters"]["log_prob"], + IndexSet(**{i: {1} for i in antecedents.keys()}), + event_dim=0, + ) + .squeeze() + .tolist() + ) + + st_obs = ( + gather( + tr["sally_throws"]["value"], + IndexSet(**{i: {0} for i in antecedents.keys()}), + event_dim=0, + ) + .squeeze() + .tolist() + ) + + st_int = ( + gather( + tr["sally_throws"]["value"], + IndexSet(**{i: {1} for i in antecedents.keys()}), + event_dim=0, + ) + .squeeze() + .tolist() + ) + + bh_int = ( + gather( + tr["bill_hits"]["value"], + IndexSet(**{i: {1} for i in antecedents.keys()}), + event_dim=0, + ) + .squeeze() + .tolist() + ) + + st_ant = tr["__antecedent_sally_throws"]["value"].squeeze().tolist() + + assert all(lp <= -1e5 or lp > math.log(0.5) for lp in log_probs) + + for step in range(200): + bottle_will_shatter = ( + st_obs[step] != st_int[step] and st_ant == 0.0 + ) or bh_int[step] == 1.0 + if bottle_will_shatter: + assert log_probs[step] <= -1e5 + + witnesses = {} + with MultiWorldCounterfactual(): + with SearchForExplanation( + antecedents=antecedents, + witnesses=witnesses, + consequents=consequents, + antecedent_bias=0.1, + consequent_scale=1e-8, + ): + with observations_conditioning: + with pyro.plate("sample", 200): + with pyro.poutine.trace() as tr_empty: + stones_bayesian_model() + + assert tr_empty.trace.nodes + + +test_SearchForExplanation() + + +def test_SplitSubsets_single_layer(): + observations = { + "prob_sally_throws": 1.0, + "prob_bill_throws": 1.0, + "prob_sally_hits": 1.0, + "prob_bill_hits": 1.0, + "prob_bottle_shatters_if_sally": 1.0, + "prob_bottle_shatters_if_bill": 1.0, + } + + observations_conditioning = condition( + data={k: torch.as_tensor(v) for k, v in observations.items()} + ) + + with MultiWorldCounterfactual() as mwc: + with SplitSubsets( + supports={"sally_throws": constraints.boolean}, + actions={"sally_throws": 0.0}, + bias=0.0, + ): + with observations_conditioning: + with pyro.poutine.trace() as tr: + stones_bayesian_model() + + tr = tr.trace.nodes + + with mwc: + preempt_sally_throws = gather( + tr["__cause_split_sally_throws"]["value"], + IndexSet(**{"sally_throws": {0}}), + event_dim=0, + ) + + int_sally_hits = gather( + tr["sally_hits"]["value"], IndexSet(**{"sally_throws": {1}}), event_dim=0 + ) + + obs_bill_hits = gather( + tr["bill_hits"]["value"], IndexSet(**{"sally_throws": {0}}), event_dim=0 + ) + + int_bill_hits = gather( + tr["bill_hits"]["value"], IndexSet(**{"sally_throws": {1}}), event_dim=0 + ) + + int_bottle_shatters = gather( + tr["bottle_shatters"]["value"], + IndexSet(**{"sally_throws": {1}}), + event_dim=0, + ) + + outcome = { + "preempt_sally_throws": preempt_sally_throws.item(), + "int_sally_hits": int_sally_hits.item(), + "obs_bill_hits": obs_bill_hits.item(), + "int_bill_hits": int_bill_hits.item(), + "intervened_bottle_shatters": int_bottle_shatters.item(), + } + + assert list(outcome.values()) == [0, 0.0, 0.0, 1.0, 1.0] or list( + outcome.values() + ) == [1, 1.0, 0.0, 0.0, 1.0] + + +def test_SplitSubsets_two_layers(): + observations = { + "prob_sally_throws": 1.0, + "prob_bill_throws": 1.0, + "prob_sally_hits": 1.0, + "prob_bill_hits": 1.0, + "prob_bottle_shatters_if_sally": 1.0, + "prob_bottle_shatters_if_bill": 1.0, + } + + observations_conditioning = condition( + data={k: torch.as_tensor(v) for k, v in observations.items()} + ) + + actions = {"sally_throws": 0.0} + + pinned_preemption_variables = { + "preempt_sally_throws": torch.tensor(0), + "witness_preempt_bill_hits": torch.tensor(1), + } + preemption_conditioning = condition(data=pinned_preemption_variables) + + witness_preemptions = { + "bill_hits": undo_split(constraints.boolean, antecedents=actions.keys()) + } + witness_preemptions_handler: Preemptions = Preemptions( + actions=witness_preemptions, prefix="witness_preempt_" + ) + + with MultiWorldCounterfactual() as mwc: + with SplitSubsets( + supports={"sally_throws": constraints.boolean}, + actions=actions, + bias=0.1, + prefix="preempt_", + ): + with preemption_conditioning, witness_preemptions_handler: + with observations_conditioning: + with pyro.poutine.trace() as tr: + stones_bayesian_model() + + tr = tr.trace.nodes + + with mwc: + obs_bill_hits = gather( + tr["bill_hits"]["value"], + IndexSet(**{"sally_throws": {0}}), + event_dim=0, + ).item() + int_bill_hits = gather( + tr["bill_hits"]["value"], + IndexSet(**{"sally_throws": {1}}), + event_dim=0, + ).item() + int_bottle_shatters = gather( + tr["bottle_shatters"]["value"], + IndexSet(**{"sally_throws": {1}}), + event_dim=0, + ).item() + + assert obs_bill_hits == 0.0 and int_bill_hits == 0.0 and int_bottle_shatters == 0.0 diff --git a/tests/explainable/test_internals.py b/tests/explainable/test_internals.py new file mode 100644 index 000000000..00c427f9c --- /dev/null +++ b/tests/explainable/test_internals.py @@ -0,0 +1,169 @@ +import math + +import pyro.distributions.constraints as constraints +import pytest +import torch + +from chirho.explainable.internals.defaults import soft_eq, soft_neq, uniform_proposal + +SUPPORT_CASES = [ + constraints.real, + constraints.boolean, + constraints.positive, + constraints.interval(0, 10), + constraints.interval(-5, 5), + constraints.integer_interval(0, 2), + constraints.integer_interval(0, 100), +] + + +@pytest.mark.parametrize("support", SUPPORT_CASES) +@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)], ids=str) +def test_uniform_proposal(support, event_shape): + if event_shape: + support = constraints.independent(support, len(event_shape)) + + uniform = uniform_proposal(support, event_shape=event_shape) + samples = uniform.sample((10,)) + assert torch.all(support.check(samples)) + + +def test_soft_boolean(): + support = constraints.boolean + scale = 1e-1 + + boolean_tensor_1 = torch.tensor([True, False, True, False]) + boolean_tensor_2 = torch.tensor([True, True, False, False]) + + log_boolean_eq = soft_eq(support, boolean_tensor_1, boolean_tensor_2, scale=scale) + log_boolean_neq = soft_neq(support, boolean_tensor_1, boolean_tensor_2, scale=scale) + + real_tensor_1 = torch.tensor([1.0, 0.0, 1.0, 0.0]) + real_tensor_2 = torch.tensor([1.0, 1.0, 0.0, 0.0]) + + real_boolean_eq = soft_eq(support, real_tensor_1, real_tensor_2, scale=scale) + real_boolean_neq = soft_neq(support, real_tensor_1, real_tensor_2, scale=scale) + + logp, log1mp = math.log(scale), math.log(1 - scale) + assert torch.equal(log_boolean_eq, real_boolean_eq) and torch.allclose( + real_boolean_eq, torch.tensor([log1mp, logp, logp, log1mp]) + ) + + assert torch.equal(log_boolean_neq, real_boolean_neq) and torch.allclose( + real_boolean_neq, torch.tensor([logp, log1mp, log1mp, logp]) + ) + + +def test_soft_interval(): + scale = 1.0 + t1 = torch.arange(0.5, 7.5, 0.1) + t2 = t1 + 1 + t2b = t1 + 2 + + inter_eq = soft_eq(constraints.interval(0, 10), t1, t2, scale=scale) + inter_eq_b = soft_eq(constraints.interval(0, 10), t1, t2b, scale=scale) + + inter_neq = soft_neq(constraints.interval(0, 10), t1, t2, scale=scale) + inter_neq_b = soft_neq(constraints.interval(0, 10), t1, t2b, scale=scale) + + assert torch.all( + inter_eq_b < inter_eq + ), "soft_eq is not monotonic in the absolute distance between the two original values" + + assert torch.all( + inter_neq_b > inter_neq + ), "soft_neq is not monotonic in the absolute distance between the two original values" + assert ( + soft_neq( + constraints.interval(0, 10), + torch.tensor(0.0), + torch.tensor(10.0), + scale=scale, + ) + == 0 + ), "soft_neq is not zero at maximal difference" + + +def test_soft_eq_tavares_relaxation(): + # these test cases are for our counterpart + # of conditions (i)-(iii) of predicate relaxation + # from "Predicate exchange..." by Tavares et al. + + # condition i: when a tends to zero, soft_eq tends to the true only for + # true identity and to negative infinity otherwise + support = constraints.real + assert ( + soft_eq(support, torch.tensor(1.0), torch.tensor(1.001), scale=1e-10) < 1e-10 + ), "soft_eq does not tend to negative infinity for false identities as a tends to zero" + + # condition ii: approaching true answer as scale goes to infty + scales = [1e6, 1e10] + for scale in scales: + score_diff = soft_eq(support, torch.tensor(1.0), torch.tensor(2.0), scale=scale) + score_id = soft_eq(support, torch.tensor(1.0), torch.tensor(1.0), scale=scale) + assert ( + torch.abs(score_diff - score_id) < 1e-10 + ), "soft_eq does not approach true answer as scale approaches infinity" + + # condition iii: 0 just in case true identity + true_identity = soft_eq(support, torch.tensor(1.0), torch.tensor(1.0)) + false_identity = soft_eq(support, torch.tensor(1.0), torch.tensor(1.001)) + + # assert true_identity == 0, "soft_eq does not yield zero on identity" + assert true_identity > false_identity, "soft_eq does not penalize difference" + + +def test_soft_neq_tavares_relaxation(): + support = constraints.real + + min_scale = 1 / math.sqrt(2 * math.pi) + + # condition i: when a tends to allowed minimum (1 / math.sqrt(2 * math.pi)), + # the difference in outcomes between identity and non-identity tends to negative infinity + diff = soft_neq( + support, torch.tensor(1.0), torch.tensor(1.0), scale=min_scale + 0.0001 + ) - soft_neq( + support, torch.tensor(1.0), torch.tensor(1.001), scale=min_scale + 0.0001 + ) + + assert diff < -1e8, "condition i failed" + + # condition ii: as scale goes to infinity + # the score tends to that of identity + x = torch.tensor(0.0) + y = torch.arange(-100, 100, 0.1) + indentity_score = soft_neq( + support, torch.tensor(1.0), torch.tensor(1.0), scale=1e10 + ) + scaled = soft_neq(support, x, y, scale=1e10) + + assert torch.allclose(indentity_score, scaled), "condition ii failed" + + # condition iii: for any scale, the score tends to zero + # as difference tends to infinity + # and to its minimum as it tends to zero + # and doesn't equal to minimum for non-zero difference + scales = [0.4, 1, 5, 50, 500] + x = torch.tensor(0.0) + y = torch.arange(-100, 100, 0.1) + + for scale in scales: + z = torch.tensor([-1e10 * scale, 1e10 * scale]) + + identity_score = soft_neq( + support, torch.tensor(1.0), torch.tensor(1.0), scale=scale + ) + scaled_y = soft_neq(support, x, y, scale=scale) + scaled_z = soft_neq(support, x, z, scale=scale) + + assert torch.allclose( + identity_score, torch.min(scaled_y) + ), "condition iii failed" + lower = 1 + scale * 1e-3 + assert torch.all( + soft_neq( + support, torch.tensor(1.0), torch.arange(lower, 2, 0.001), scale=scale + ) + > identity_score + ) + assert torch.allclose(scaled_z, torch.tensor(0.0)), "condition iii failed" diff --git a/tests/explainable/test_ops.py b/tests/explainable/test_ops.py new file mode 100644 index 000000000..aeb7bae4e --- /dev/null +++ b/tests/explainable/test_ops.py @@ -0,0 +1,81 @@ +import pyro +import pyro.distributions as dist +import pyro.infer +import pytest +import torch + +from chirho.counterfactual.handlers import ( + MultiWorldCounterfactual, + SingleWorldCounterfactual, +) +from chirho.counterfactual.ops import split +from chirho.explainable.handlers import Preemptions +from chirho.explainable.ops import preempt +from chirho.indexed.ops import IndexSet, indices_of +from chirho.interventional.handlers import do + + +def test_preempt_op_singleworld(): + @SingleWorldCounterfactual() + @pyro.plate("data", size=1000, dim=-1) + def model(): + x = pyro.sample("x", dist.Bernoulli(0.67)) + x = pyro.deterministic( + "x_", split(x, (torch.tensor(0.0),), name="x", event_dim=0), event_dim=0 + ) + y = pyro.sample("y", dist.Bernoulli(0.67)) + y_case = torch.tensor(1) + y = pyro.deterministic( + "y_", + preempt(y, (torch.tensor(1.0),), y_case, name="__y", event_dim=0), + event_dim=0, + ) + z = pyro.sample("z", dist.Bernoulli(0.67)) + return dict(x=x, y=y, z=z) + + tr = pyro.poutine.trace(model).get_trace() + assert torch.all(tr.nodes["x_"]["value"] == 0.0) + assert torch.all(tr.nodes["y_"]["value"] == 1.0) + + +@pytest.mark.parametrize("cf_dim", [-2, -3, None]) +@pytest.mark.parametrize("event_shape", [(), (4,), (4, 3)]) +def test_cf_handler_preemptions(cf_dim, event_shape): + event_dim = len(event_shape) + + splits = {"x": torch.tensor(0.0)} + preemptions = {"y": torch.tensor(1.0)} + + @do(actions=splits) + @pyro.plate("data", size=1000, dim=-1) + def model(): + w = pyro.sample( + "w", dist.Normal(0, 1).expand(event_shape).to_event(len(event_shape)) + ) + x = pyro.sample("x", dist.Normal(w, 1).to_event(len(event_shape))) + y = pyro.sample("y", dist.Normal(w + x, 1).to_event(len(event_shape))) + z = pyro.sample("z", dist.Normal(x + y, 1).to_event(len(event_shape))) + return dict(w=w, x=x, y=y, z=z) + + preemption_handler = Preemptions(actions=preemptions, bias=0.1, prefix="__split_") + + with MultiWorldCounterfactual(cf_dim), preemption_handler: + tr = pyro.poutine.trace(model).get_trace() + assert all(f"__split_{k}" in tr.nodes for k in preemptions.keys()) + assert indices_of(tr.nodes["w"]["value"], event_dim=event_dim) == IndexSet() + assert indices_of(tr.nodes["y"]["value"], event_dim=event_dim) == IndexSet( + x={0, 1} + ) + assert indices_of(tr.nodes["z"]["value"], event_dim=event_dim) == IndexSet( + x={0, 1} + ) + + for k in preemptions.keys(): + tst = tr.nodes[f"__split_{k}"]["value"] + assert torch.allclose( + tr.nodes[f"__split_{k}"]["fn"].log_prob(torch.tensor(0)).exp(), + torch.tensor(0.5 - 0.1), + ) + tst_0 = (tst == 0).expand(tr.nodes[k]["fn"].batch_shape) + assert torch.all(tr.nodes[k]["value"][~tst_0] == preemptions[k]) + assert torch.all(tr.nodes[k]["value"][tst_0] != preemptions[k])