Skip to content

Commit

Permalink
Add estimation of context-sensitive probability of sufficiency to the…
Browse files Browse the repository at this point in the history
… explainable reasoning module (#536)

* rename consequent_differs to consequent_neq

* add consequent_eq

* lint

* added sufficiency_intervention

* generalized hints for sufficiency_intervention

* add event_dim to indices_of within sufficiency intervention

* added antecedents in sufficiency_intervention

* test sufficiency_intervention with split

* format, lint

* expanded antecedents in scribbles.ipynb

* fixed sufficiency_intervention and test thereof

* consequents to consequents_nec in scribbles

* working eq and neq

* fixed the event dim bug

* added consequent_eq_neq

* format lint

* added naming to do within components WIP

* debugged consequent_eq_neq to handle upstream interventions

* generalizing across event shapes WIP

* debugging events WIP

* generalizing to SearchForNS WIP

* working forest fire example

* extract distros from the ff example

* state before reverting

* reverted on sufficiency_intervention

* reinstate counterfactual

* reverted on SplitSubsets

* simplified test consequent eq_neq, moving on

* SearchForNS into explainable

* test SearchForNS

* tests passing

* revert counterfactual

* pulled code changes from the tutorial branch

* removed outdated comment
  • Loading branch information
rfl-urbaniak authored May 8, 2024
1 parent 71b2e4f commit 8ad5b56
Show file tree
Hide file tree
Showing 5 changed files with 531 additions and 26 deletions.
10 changes: 7 additions & 3 deletions chirho/explainable/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from .components import random_intervention # noqa: F401
from .components import ExtractSupports, undo_split # noqa: F401
from .explanation import SearchForExplanation, SplitSubsets # noqa: F401
from .components import ( # noqa: F401
ExtractSupports,
random_intervention,
sufficiency_intervention,
undo_split,
)
from .explanation import SearchForExplanation, SearchForNS, SplitSubsets # noqa: F401
from .preemptions import Preemptions # noqa: F401
190 changes: 183 additions & 7 deletions chirho/explainable/handlers/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,61 @@
from chirho.counterfactual.handlers.selection import get_factual_indices
from chirho.explainable.internals import uniform_proposal
from chirho.indexed.ops import IndexSet, gather, indices_of, scatter_n
from chirho.observational.handlers import soft_neq

# from chirho.interventional.ops import intervene
from chirho.observational.handlers import soft_eq, soft_neq

S = TypeVar("S")
T = TypeVar("T")


def sufficiency_intervention(
support: constraints.Constraint,
antecedents: Iterable[str] = [],
) -> Callable[[T], T]:
"""
Creates a sufficiency intervention for a single sample site, determined by
the site name, intervening to keep the value as in the factual world with
respect to the antecedents.
:param support: The support constraint for the site.
:param name: The sample site name.
:return: A function that takes a `torch.Tensor` as input
and returns the factual value at the named site as a tensor.
Example::
>>> with MultiWorldCounterfactual() as mwc:
>>> value = pyro.sample("value", proposal_dist)
>>> intervention = sufficiency_intervention(support)
>>> value = intervene(value, intervention)
"""

def _sufficiency_intervention(value: T) -> T:

indices = IndexSet(
**{
name: ind
for name, ind in get_factual_indices().items()
if name in antecedents
}
)

factual_value = gather(
value,
indices,
event_dim=support.event_dim,
)
return factual_value

return _sufficiency_intervention


def random_intervention(
support: constraints.Constraint,
name: str,
) -> Callable[[torch.Tensor], torch.Tensor]:
) -> Callable[[T], T]:
"""
Creates a random-valued intervention for a single sample site, determined by
by the distribution support, and site name.
Expand All @@ -38,8 +83,10 @@ def random_intervention(
>>> assert x != 2
"""

def _random_intervention(value: torch.Tensor) -> torch.Tensor:
event_shape = value.shape[len(value.shape) - support.event_dim :]
def _random_intervention(value: T) -> T:

event_shape = value.shape[len(value.shape) - support.event_dim :] # type: ignore

proposal_dist = uniform_proposal(
support,
event_shape=event_shape,
Expand Down Expand Up @@ -92,7 +139,43 @@ def _undo_split(value: T) -> T:
return _undo_split


def consequent_differs(
def consequent_eq(
support: constraints.Constraint,
antecedents: Iterable[str] = [],
**kwargs,
) -> Callable[[T], torch.Tensor]:
"""
A helper function for assessing whether values at a site are close to their observed values, assigning
a small negative value close to zero if a value is close to its observed state and a large negative value otherwise.
:param support: The support constraint for the consequent site.
:param antecedents: A list of names of upstream intervened sites to consider when assessing similarity.
:return: A callable which applied to a site value object (``consequent``), returns a tensor where each
element indicates the extent to which the corresponding element of ``consequent``
is close to its factual value.
"""

def _consequent_eq(consequent: T) -> torch.Tensor:
indices = IndexSet(
**{
name: ind
for name, ind in get_factual_indices().items()
if name in antecedents
}
)
eq = soft_eq(
support,
consequent,
gather(consequent, indices, event_dim=support.event_dim),
**kwargs,
)
return eq

return _consequent_eq


def consequent_neq(
support: constraints.Constraint,
antecedents: Iterable[str] = [],
**kwargs,
Expand All @@ -109,7 +192,7 @@ def consequent_differs(
element indicates whether the corresponding element of ``consequent`` differs from its factual value.
"""

def _consequent_differs(consequent: T) -> torch.Tensor:
def _consequent_neq(consequent: T) -> torch.Tensor:
indices = IndexSet(
**{
name: ind
Expand All @@ -125,7 +208,100 @@ def _consequent_differs(consequent: T) -> torch.Tensor:
)
return diff

return _consequent_differs
return _consequent_neq


def consequent_eq_neq(
support: constraints.Constraint,
antecedents: Iterable[str] = [],
**kwargs,
) -> Callable[[T], torch.Tensor]:
"""
A helper function for obtaining joint log prob of necessity and sufficiency. Assumes that
the necessity intervention has been applied in counterfactual world 1 and sufficiency intervention in
counterfactual world 2 (these can be passed as kwargs).
:param support: The support constraint for the consequent site.
:param antecedents: A list of names of upstream intervened sites to consider when composing the joint log prob.
:return: A callable which applied to a site value object (``consequent``), returns a tensor with log prob sums
of values resulting from necessity and sufficiency interventions, in appropriate counterfactual worlds.
"""

def _consequent_eq_neq(consequent: T) -> torch.Tensor:

factual_indices = IndexSet(
**{
name: ind
for name, ind in get_factual_indices().items()
if name in antecedents
}
)

necessity_world = kwargs.get("necessity_world", 1)
sufficiency_world = kwargs.get("sufficiency_world", 2)

necessity_indices = IndexSet(
**{
name: {necessity_world}
for name in indices_of(consequent, event_dim=support.event_dim).keys()
if name in antecedents
}
)

sufficiency_indices = IndexSet(
**{
name: {sufficiency_world}
for name in indices_of(consequent, event_dim=support.event_dim).keys()
if name in antecedents
}
)

factual_value = gather(consequent, factual_indices, event_dim=support.event_dim)
necessity_value = gather(
consequent, necessity_indices, event_dim=support.event_dim
)
sufficiency_value = gather(
consequent, sufficiency_indices, event_dim=support.event_dim
)

necessity_log_probs = soft_neq(
support, necessity_value, factual_value, **kwargs
)
sufficiency_log_probs = soft_eq(
support, sufficiency_value, factual_value, **kwargs
)

# nec_suff_log_probs = torch.add(necessity_log_probs, sufficiency_log_probs)

FACTUAL_NEC_SUFF = torch.zeros_like(sufficiency_log_probs)
# TODO reflect on this, do we want zeros?

nec_suff_log_probs_partitioned = {
**{
factual_indices: FACTUAL_NEC_SUFF,
},
**{
IndexSet(**{antecedent: {ind}}): log_prob
for antecedent in (
set(antecedents)
& set(indices_of(consequent, event_dim=support.event_dim))
)
for ind, log_prob in zip(
[necessity_world, sufficiency_world],
[necessity_log_probs, sufficiency_log_probs],
)
},
}

new_value = scatter_n(
nec_suff_log_probs_partitioned,
event_dim=0,
)

return new_value

return _consequent_eq_neq


class ExtractSupports(pyro.poutine.messenger.Messenger):
Expand Down
Loading

0 comments on commit 8ad5b56

Please sign in to comment.