diff --git a/chirho/explainable/handlers/__init__.py b/chirho/explainable/handlers/__init__.py index d1baf7ae0..be1e9bee5 100644 --- a/chirho/explainable/handlers/__init__.py +++ b/chirho/explainable/handlers/__init__.py @@ -1,4 +1,8 @@ -from .components import random_intervention # noqa: F401 -from .components import ExtractSupports, undo_split # noqa: F401 -from .explanation import SearchForExplanation, SplitSubsets # noqa: F401 +from .components import ( # noqa: F401 + ExtractSupports, + random_intervention, + sufficiency_intervention, + undo_split, +) +from .explanation import SearchForExplanation, SearchForNS, SplitSubsets # noqa: F401 from .preemptions import Preemptions # noqa: F401 diff --git a/chirho/explainable/handlers/components.py b/chirho/explainable/handlers/components.py index e0555532a..1e85b3e65 100644 --- a/chirho/explainable/handlers/components.py +++ b/chirho/explainable/handlers/components.py @@ -8,16 +8,61 @@ from chirho.counterfactual.handlers.selection import get_factual_indices from chirho.explainable.internals import uniform_proposal from chirho.indexed.ops import IndexSet, gather, indices_of, scatter_n -from chirho.observational.handlers import soft_neq + +# from chirho.interventional.ops import intervene +from chirho.observational.handlers import soft_eq, soft_neq S = TypeVar("S") T = TypeVar("T") +def sufficiency_intervention( + support: constraints.Constraint, + antecedents: Iterable[str] = [], +) -> Callable[[T], T]: + """ + Creates a sufficiency intervention for a single sample site, determined by + the site name, intervening to keep the value as in the factual world with + respect to the antecedents. + + :param support: The support constraint for the site. + :param name: The sample site name. + + :return: A function that takes a `torch.Tensor` as input + and returns the factual value at the named site as a tensor. + + Example:: + + >>> with MultiWorldCounterfactual() as mwc: + >>> value = pyro.sample("value", proposal_dist) + >>> intervention = sufficiency_intervention(support) + >>> value = intervene(value, intervention) + """ + + def _sufficiency_intervention(value: T) -> T: + + indices = IndexSet( + **{ + name: ind + for name, ind in get_factual_indices().items() + if name in antecedents + } + ) + + factual_value = gather( + value, + indices, + event_dim=support.event_dim, + ) + return factual_value + + return _sufficiency_intervention + + def random_intervention( support: constraints.Constraint, name: str, -) -> Callable[[torch.Tensor], torch.Tensor]: +) -> Callable[[T], T]: """ Creates a random-valued intervention for a single sample site, determined by by the distribution support, and site name. @@ -38,8 +83,10 @@ def random_intervention( >>> assert x != 2 """ - def _random_intervention(value: torch.Tensor) -> torch.Tensor: - event_shape = value.shape[len(value.shape) - support.event_dim :] + def _random_intervention(value: T) -> T: + + event_shape = value.shape[len(value.shape) - support.event_dim :] # type: ignore + proposal_dist = uniform_proposal( support, event_shape=event_shape, @@ -92,7 +139,43 @@ def _undo_split(value: T) -> T: return _undo_split -def consequent_differs( +def consequent_eq( + support: constraints.Constraint, + antecedents: Iterable[str] = [], + **kwargs, +) -> Callable[[T], torch.Tensor]: + """ + A helper function for assessing whether values at a site are close to their observed values, assigning + a small negative value close to zero if a value is close to its observed state and a large negative value otherwise. + + :param support: The support constraint for the consequent site. + :param antecedents: A list of names of upstream intervened sites to consider when assessing similarity. + + :return: A callable which applied to a site value object (``consequent``), returns a tensor where each + element indicates the extent to which the corresponding element of ``consequent`` + is close to its factual value. + """ + + def _consequent_eq(consequent: T) -> torch.Tensor: + indices = IndexSet( + **{ + name: ind + for name, ind in get_factual_indices().items() + if name in antecedents + } + ) + eq = soft_eq( + support, + consequent, + gather(consequent, indices, event_dim=support.event_dim), + **kwargs, + ) + return eq + + return _consequent_eq + + +def consequent_neq( support: constraints.Constraint, antecedents: Iterable[str] = [], **kwargs, @@ -109,7 +192,7 @@ def consequent_differs( element indicates whether the corresponding element of ``consequent`` differs from its factual value. """ - def _consequent_differs(consequent: T) -> torch.Tensor: + def _consequent_neq(consequent: T) -> torch.Tensor: indices = IndexSet( **{ name: ind @@ -125,7 +208,100 @@ def _consequent_differs(consequent: T) -> torch.Tensor: ) return diff - return _consequent_differs + return _consequent_neq + + +def consequent_eq_neq( + support: constraints.Constraint, + antecedents: Iterable[str] = [], + **kwargs, +) -> Callable[[T], torch.Tensor]: + """ + A helper function for obtaining joint log prob of necessity and sufficiency. Assumes that + the necessity intervention has been applied in counterfactual world 1 and sufficiency intervention in + counterfactual world 2 (these can be passed as kwargs). + + :param support: The support constraint for the consequent site. + :param antecedents: A list of names of upstream intervened sites to consider when composing the joint log prob. + + :return: A callable which applied to a site value object (``consequent``), returns a tensor with log prob sums + of values resulting from necessity and sufficiency interventions, in appropriate counterfactual worlds. + """ + + def _consequent_eq_neq(consequent: T) -> torch.Tensor: + + factual_indices = IndexSet( + **{ + name: ind + for name, ind in get_factual_indices().items() + if name in antecedents + } + ) + + necessity_world = kwargs.get("necessity_world", 1) + sufficiency_world = kwargs.get("sufficiency_world", 2) + + necessity_indices = IndexSet( + **{ + name: {necessity_world} + for name in indices_of(consequent, event_dim=support.event_dim).keys() + if name in antecedents + } + ) + + sufficiency_indices = IndexSet( + **{ + name: {sufficiency_world} + for name in indices_of(consequent, event_dim=support.event_dim).keys() + if name in antecedents + } + ) + + factual_value = gather(consequent, factual_indices, event_dim=support.event_dim) + necessity_value = gather( + consequent, necessity_indices, event_dim=support.event_dim + ) + sufficiency_value = gather( + consequent, sufficiency_indices, event_dim=support.event_dim + ) + + necessity_log_probs = soft_neq( + support, necessity_value, factual_value, **kwargs + ) + sufficiency_log_probs = soft_eq( + support, sufficiency_value, factual_value, **kwargs + ) + + # nec_suff_log_probs = torch.add(necessity_log_probs, sufficiency_log_probs) + + FACTUAL_NEC_SUFF = torch.zeros_like(sufficiency_log_probs) + # TODO reflect on this, do we want zeros? + + nec_suff_log_probs_partitioned = { + **{ + factual_indices: FACTUAL_NEC_SUFF, + }, + **{ + IndexSet(**{antecedent: {ind}}): log_prob + for antecedent in ( + set(antecedents) + & set(indices_of(consequent, event_dim=support.event_dim)) + ) + for ind, log_prob in zip( + [necessity_world, sufficiency_world], + [necessity_log_probs, sufficiency_log_probs], + ) + }, + } + + new_value = scatter_n( + nec_suff_log_probs_partitioned, + event_dim=0, + ) + + return new_value + + return _consequent_eq_neq class ExtractSupports(pyro.poutine.messenger.Messenger): diff --git a/chirho/explainable/handlers/explanation.py b/chirho/explainable/handlers/explanation.py index 4387499e2..907b0d513 100644 --- a/chirho/explainable/handlers/explanation.py +++ b/chirho/explainable/handlers/explanation.py @@ -4,9 +4,11 @@ import pyro.distributions.constraints as constraints import torch -from chirho.explainable.handlers.components import ( - consequent_differs, +from chirho.explainable.handlers.components import ( # sufficiency_intervention, + consequent_eq_neq, + consequent_neq, random_intervention, + sufficiency_intervention, undo_split, ) from chirho.explainable.handlers.preemptions import Preemptions @@ -21,7 +23,10 @@ @contextlib.contextmanager def SplitSubsets( supports: Mapping[str, constraints.Constraint], - actions: Mapping[str, Intervention[T]], + actions: Mapping[str, Intervention[T]], # , Union[, Tuple[Intervention[T], ...]]], + # TODO deal with type-related linting errors + # which have to do with random_intervention typed with tensors + # and sufficiency_intervention typed with T *, bias: float = 0.0, prefix: str = "__cause_split_", @@ -120,7 +125,7 @@ def SearchForExplanation( constraints.Constraint, ): consequents = { - c: consequent_differs( + c: consequent_neq( support=s, antecedents=list(antecedents.keys()), scale=consequent_scale, @@ -155,3 +160,138 @@ def SearchForExplanation( with antecedent_handler, witness_handler, consequent_handler: yield + + +@contextlib.contextmanager +def SearchForNS( + antecedents: Union[ + Mapping[str, Intervention[T]], + Mapping[str, constraints.Constraint], + ], + witnesses: Union[ + Mapping[str, Intervention[T]], Mapping[str, constraints.Constraint] + ], + consequents: Union[ + Mapping[str, Callable[[T], Union[float, torch.Tensor]]], + Mapping[str, constraints.Constraint], + ], + *, + antecedent_bias: float = 0.0, + witness_bias: float = 0.0, + consequent_scale: float = 1e-2, + antecedent_prefix: str = "__antecedent_", + witness_prefix: str = "__witness_", + consequent_prefix: str = "__consequent_", +): + """ + Effect handler used for causal explanation search. On each run: + + 1. The antecedent nodes are intervened on with the values in ``antecedents`` \ + using :func:`~chirho.counterfactual.ops.split` . \ + Unless alternative interventions are provided, \ + counterfactual values are uniformly sampled for each antecedent node \ + using :func:`~chirho.explainable.internals.uniform_proposal` \ + given its support as a :class:`~pyro.distributions.constraints.Constraint`. + In another counterfactual world, the antecedent nodes are intervened to be + at their factual values. The former will be used for probability-of-necessity-like + calculations, while the latter will be used for the + probability-of-sufficiency-like ones. + + 2. These interventions are randomly :func:`~chirho.explainable.ops.preempt`-ed \ + using :func:`~chirho.explainable.handlers.undo_split` \ + by a :func:`~chirho.explainable.handlers.SplitSubsets` handler. + + 3. The witness nodes are randomly :func:`~chirho.explainable.ops.preempt`-ed \ + to be kept at the values given in ``witnesses``. + + 4. A :func:`~pyro.factor` node is added tracking whether the consequent nodes differ \ + between the factual and counterfactual worlds. + + :param antecedents: A mapping from antecedent names to interventions or to constraints. + :param witnesses: A mapping from witness names to interventions or to constraints. + :param consequents: A mapping from consequent names to factor functions or to constraints. + """ + antecedents_list = list(antecedents.keys()) + + if antecedents and isinstance( + next(iter(antecedents.values())), + constraints.Constraint, + ): + + antecedents_supports = {a: s for a, s in antecedents.items()} + + _antecedents: Mapping[ + str, + Intervention[T], + ] = { + a: ( + random_intervention(s, name=f"{antecedent_prefix}_proposal_{a}"), + sufficiency_intervention(s, antecedents.keys()), + ) # type: ignore + for a, s in antecedents_supports.items() + } + + else: + + antecedents_supports = {a: constraints.boolean for a in antecedents.keys()} + # TODO generalize to non-scalar antecedents + # comment: how about extracting supports? + + _antecedents = { + a: (antecedents[a], sufficiency_intervention(s, antecedents.keys())) # type: ignore + # TODO fix type error + for a, s in antecedents_supports.items() + } + + if witnesses and isinstance( + next(iter(witnesses.values())), + constraints.Constraint, + ): + witnesses = { + w: undo_split(s, antecedents=list(antecedents.keys())) + for w, s in witnesses.items() + } + + if consequents and isinstance( + next(iter(consequents.values())), + constraints.Constraint, + ): + + consequents_eq_neq = { + c: consequent_eq_neq( + support=s, + antecedents=antecedents_list, + scale=consequent_scale, # TODO allow for different scales for neq and eq + ) + for c, s in consequents.items() + } + + if len(consequents) == 0: + raise ValueError("must have at least one consequent") + + if len(antecedents) == 0: + raise ValueError("must have at least one antecedent") + + if set(consequents.keys()) & set(antecedents.keys()): + raise ValueError("consequents and possible antecedents must be disjoint") + + if set(consequents.keys()) & set(witnesses.keys()): + raise ValueError("consequents and possible witnesses must be disjoint") + + antecedent_handler = SplitSubsets( + supports=antecedents_supports, + actions=_antecedents, + bias=antecedent_bias, + prefix=antecedent_prefix, + ) + + witness_handler: Preemptions = Preemptions( + actions=witnesses, bias=witness_bias, prefix=witness_prefix + ) + + consequent_eq_neq_handler: Factors = Factors( + factors=consequents_eq_neq, prefix=f"{consequent_prefix}_eq_neq_" + ) + + with antecedent_handler, witness_handler, consequent_eq_neq_handler: + yield diff --git a/tests/explainable/test_handlers_components.py b/tests/explainable/test_handlers_components.py index 34683b436..615bd5453 100644 --- a/tests/explainable/test_handlers_components.py +++ b/tests/explainable/test_handlers_components.py @@ -6,14 +6,18 @@ from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual from chirho.counterfactual.ops import split -from chirho.explainable.handlers import random_intervention -from chirho.explainable.handlers.components import ( +from chirho.explainable.handlers import random_intervention, sufficiency_intervention +from chirho.explainable.handlers.components import ( # consequent_eq_neq, ExtractSupports, - consequent_differs, + consequent_eq, + consequent_eq_neq, + consequent_neq, undo_split, ) +from chirho.explainable.internals import uniform_proposal from chirho.explainable.ops import preempt from chirho.indexed.ops import IndexSet, gather, indices_of +from chirho.interventional.handlers import do from chirho.interventional.ops import intervene from chirho.observational.handlers.condition import Factors @@ -28,6 +32,42 @@ ] +@pytest.mark.parametrize("support", SUPPORT_CASES) +@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)], ids=str) +def test_sufficiency_intervention(support, event_shape): + + with MultiWorldCounterfactual(): + + if event_shape: + support = pyro.distributions.constraints.independent( + support, len(event_shape) + ) + + proposal_dist = uniform_proposal( + support, + event_shape=event_shape, + ) + + value = pyro.sample("value", proposal_dist) + + intervention = sufficiency_intervention(support, indices_of(value).keys()) + + value = intervene(value, intervention, event_dim=0) + + indices = indices_of(value) + observed = gather( + value, + IndexSet(**{index: {0} for index in indices}), + event_dim=0, + ) + intervened = gather( + value, IndexSet(**{index: {1} for index in indices}), event_dim=0 + ) + + assert torch.allclose(observed, intervened) + assert torch.all(support.check(value)) + + @pytest.mark.parametrize("support", SUPPORT_CASES) @pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)], ids=str) def test_random_intervention(support, event_shape): @@ -225,15 +265,11 @@ def model(): assert (x_00, x_10, x_01, x_11) == (5.0, 5.0, 2.0, 2.0) -# @pytest.mark.parametrize("plate_size", [4]) -# @pytest.mark.parametrize("event_shape", [()], ids=str) - - @pytest.mark.parametrize("plate_size", [4, 50, 200]) @pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)], ids=str) -def test_consequent_differs(plate_size, event_shape): +def test_consequent_neq(plate_size, event_shape): factors = { - "consequent": consequent_differs( + "consequent": consequent_neq( antecedents=["split"], support=constraints.independent(constraints.real, len(event_shape)), ) @@ -253,7 +289,7 @@ def model_cd(): ) con_dif = pyro.deterministic( "con_dif", - consequent_differs( + consequent_neq( support=constraints.independent(constraints.real, len(event_shape)), antecedents=["split"], )(consequent), @@ -279,6 +315,111 @@ def model_cd(): assert nd["__factor_consequent"]["log_prob"].sum() < -1e2 +# potentially, the following test could be merged with the previous one +# as they share quite a bit of code +# but despite some repeated code left separate to test two functionalities +# in isolation + + +@pytest.mark.parametrize("plate_size", [4, 50, 200]) +@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)], ids=str) +def test_consequent_eq(plate_size, event_shape): + factors = { + "consequent": consequent_eq( + antecedents=["split"], + support=constraints.independent(constraints.real, len(event_shape)), + ) + } + + @Factors(factors=factors) + @pyro.plate("data", size=plate_size, dim=-1) + def model_ce(): + w = pyro.sample( + "w", dist.Normal(0, 0.1).expand(event_shape).to_event(len(event_shape)) + ) + new_w = w.clone() + new_w[1::2] = 10 + w = split(w, (new_w,), name="split", event_dim=len(event_shape)) + consequent = pyro.deterministic( + "consequent", w * 0.1, event_dim=len(event_shape) + ) + con_eq = pyro.deterministic( + "con_eq", + consequent_eq( + support=constraints.independent(constraints.real, len(event_shape)), + antecedents=["split"], + )(consequent), + event_dim=0, + ) + + return con_eq + + with MultiWorldCounterfactual() as mwc: + with pyro.poutine.trace() as tr: + model_ce() + + tr.trace.compute_log_prob() + nd = tr.trace.nodes + + with mwc: + int_con_eq = gather(nd["con_eq"]["value"], IndexSet(**{"split": {1}})) + + assert "split" not in indices_of(int_con_eq) + assert not indices_of(int_con_eq) + + assert int_con_eq.squeeze().shape == nd["w"]["fn"].batch_shape + assert nd["__factor_consequent"]["log_prob"].sum() < -10 + + +@pytest.mark.parametrize("plate_size", [4, 50, 200]) +@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)], ids=str) +def test_consequent_eq_neq(plate_size, event_shape): + factors = { + "consequent": consequent_eq_neq( + support=constraints.independent(constraints.real, len(event_shape)), + antecedents=["w"], + ) + } + + @Factors(factors=factors) + @pyro.plate("data", size=plate_size, dim=-1) + def model_ce(): + w = pyro.sample( + "w", dist.Normal(0, 0.1).expand(event_shape).to_event(len(event_shape)) + ) + consequent = pyro.deterministic( + "consequent", w * 0.1, event_dim=len(event_shape) + ) + + return consequent + + antecedents = { + "w": ( + torch.tensor(5.0).expand(event_shape), + sufficiency_intervention( + constraints.independent(constraints.real, len(event_shape)), ["w"] + ), + ) + } + + with MultiWorldCounterfactual() as mwc: + with do(actions=antecedents): + with pyro.poutine.trace() as tr: + model_ce() + with pyro.poutine.trace() as tr: + model_ce() + + tr.trace.compute_log_prob() + nd = tr.trace.nodes + + with mwc: + eq_neq_log_probs = gather( + nd["__factor_consequent"]["log_prob"], IndexSet(**{"w": {1}}) + ) + + assert eq_neq_log_probs.sum() == 0 + + options = [ None, [], diff --git a/tests/explainable/test_handlers_explanation.py b/tests/explainable/test_handlers_explanation.py index 8acb4039b..73b97740b 100644 --- a/tests/explainable/test_handlers_explanation.py +++ b/tests/explainable/test_handlers_explanation.py @@ -7,7 +7,11 @@ from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual from chirho.explainable.handlers.components import undo_split -from chirho.explainable.handlers.explanation import SearchForExplanation, SplitSubsets +from chirho.explainable.handlers.explanation import ( + SearchForExplanation, + SearchForNS, + SplitSubsets, +) from chirho.explainable.handlers.preemptions import Preemptions from chirho.indexed.ops import IndexSet, gather from chirho.observational.handlers.condition import condition @@ -165,7 +169,47 @@ def test_SearchForExplanation(): assert tr_empty.trace.nodes -test_SearchForExplanation() +def test_SearchForNS(): + observations = { + "prob_sally_throws": 1.0, + "prob_bill_throws": 1.0, + "prob_sally_hits": 1.0, + "prob_bill_hits": 1.0, + "prob_bottle_shatters_if_sally": 1.0, + "prob_bottle_shatters_if_bill": 1.0, + } + + observations_conditioning = condition( + data={k: torch.as_tensor(v) for k, v in observations.items()} + ) + + antecedents = {"sally_throws": 0.0} + witnesses = {"bill_throws": constraints.boolean, "bill_hits": constraints.boolean} + consequents = {"bottle_shatters": constraints.boolean} + + with MultiWorldCounterfactual() as mwc: + with SearchForNS( + antecedents=antecedents, + witnesses=witnesses, + consequents=consequents, + antecedent_bias=0.1, + consequent_scale=1e-8, + ): + with observations_conditioning: + with pyro.plate("sample", 200): + with pyro.poutine.trace() as tr: + stones_bayesian_model() + + tr.trace.compute_log_prob() + tr = tr.trace.nodes + + with mwc: + eq_neq_logs = gather( + tr["__consequent__eq_neq_bottle_shatters"]["log_prob"], + IndexSet(**{"sally_throws": {1}}), + ) + + assert eq_neq_logs.shape[-1] == 200 def test_SplitSubsets_single_layer():