Skip to content

Commit

Permalink
Explainability edge cases (#557)
Browse files Browse the repository at this point in the history
* tests for indepepdent and correct

* added print

* extra case

* debugged reverse

* debug consequen_eq_neq

* fixed test_consequent_eq_neq

* fixed the test with dimensions

* consequent_eq_neq

* three variable model

* testing three dependent

* debugging

* minimal example for three independent variables

* more three variable models

* diverge

* debugged

* notebook tested three variable models

* three variable test cases aded

* clean up

* test for factual log probs

* more clean up

* fixed a lint error

* lint clean

* reverted metadata

* gather and compute_log_prob fixed

* scratch_notebook for components_test

* consequent_eq_neq event_shape logic fixed

* segregated two variable tests

* added event_shapes to tests

* clean up and lint

---------

Co-authored-by: rfl-urbaniak <[email protected]>
  • Loading branch information
PoorvaGarg and rfl-urbaniak authored Aug 14, 2024
1 parent b4b96ec commit 43c766c
Show file tree
Hide file tree
Showing 3 changed files with 302 additions and 40 deletions.
28 changes: 5 additions & 23 deletions chirho/explainable/handlers/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,6 @@ def consequent_eq_neq(
"""

def _consequent_eq_neq(consequent: T) -> torch.Tensor:

factual_indices = IndexSet(
**{
name: ind
for name, ind in get_factual_indices().items()
if name in antecedents
}
)

necessity_world = kwargs.get("necessity_world", 1)
sufficiency_world = kwargs.get("sufficiency_world", 2)

Expand All @@ -249,7 +240,6 @@ def _consequent_eq_neq(consequent: T) -> torch.Tensor:
if name in antecedents
}
)

sufficiency_indices = IndexSet(
**{
name: {sufficiency_world}
Expand All @@ -265,9 +255,6 @@ def _consequent_eq_neq(consequent: T) -> torch.Tensor:
consequent, sufficiency_indices, event_dim=support.event_dim
)

# compare to proposed consequent if provided
# as then the sufficiency value can be different
# due to witness preemption
necessity_log_probs = (
soft_neq(
support,
Expand All @@ -283,7 +270,6 @@ def _consequent_eq_neq(consequent: T) -> torch.Tensor:
**kwargs,
)
)

sufficiency_log_probs = (
soft_eq(support, sufficiency_value, proposed_consequent, **kwargs)
if proposed_consequent is not None
Expand All @@ -292,16 +278,13 @@ def _consequent_eq_neq(consequent: T) -> torch.Tensor:

FACTUAL_NEC_SUFF = torch.zeros_like(sufficiency_log_probs)

index_keys = set(antecedents)
null_index = IndexSet(**{name: {0} for name in index_keys})

nec_suff_log_probs_partitioned = {
**{null_index: FACTUAL_NEC_SUFF},
**{
factual_indices: FACTUAL_NEC_SUFF,
},
**{
IndexSet(**{antecedent: {ind}}): log_prob
for antecedent in (
set(antecedents)
& set(indices_of(consequent, event_dim=support.event_dim))
)
IndexSet(**{antecedent: {ind} for antecedent in index_keys}): log_prob
for ind, log_prob in zip(
[necessity_world, sufficiency_world],
[necessity_log_probs, sufficiency_log_probs],
Expand All @@ -313,7 +296,6 @@ def _consequent_eq_neq(consequent: T) -> torch.Tensor:
nec_suff_log_probs_partitioned,
event_dim=0,
)

return new_value

return _consequent_eq_neq
Expand Down
53 changes: 37 additions & 16 deletions tests/explainable/test_handlers_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,48 +377,69 @@ def test_consequent_eq_neq(plate_size, event_shape):
factors = {
"consequent": consequent_eq_neq(
support=constraints.independent(constraints.real, len(event_shape)),
proposed_consequent=torch.Tensor([0.1]), # added this
proposed_consequent=torch.tensor(0.01).expand(event_shape),
antecedents=["w"],
)
}

@Factors(factors=factors)
@pyro.plate("data", size=plate_size, dim=-1)
@pyro.plate("data", size=plate_size, dim=-4)
def model_ce():
w = pyro.sample(
"w", dist.Normal(0, 0.1).expand(event_shape).to_event(len(event_shape))
)
consequent = pyro.deterministic(
"consequent", w * 0.1, event_dim=len(event_shape)
"consequent", w * torch.tensor(0.1), event_dim=len(event_shape)
)

return consequent
assert w.shape == consequent.shape

antecedents = {
"w": (
torch.tensor(5.0).expand(event_shape),
torch.tensor(0.1).expand(event_shape),
sufficiency_intervention(
constraints.independent(constraints.real, len(event_shape)), ["w"]
),
)
}

with MultiWorldCounterfactual() as mwc:
with MultiWorldCounterfactual() as mwc_ce:
with do(actions=antecedents):
with pyro.poutine.trace() as tr:
with pyro.poutine.trace() as trace_ce:
model_ce()
with pyro.poutine.trace() as tr:
model_ce()

tr.trace.compute_log_prob()
nd = tr.trace.nodes
trace_ce.trace.compute_log_prob()
nd = trace_ce.trace.nodes
with mwc_ce:
eq_neq_log_probs_fact = gather(
nd["__factor_consequent"]["fn"].log_factor, IndexSet(**{"w": {0}})
)

with mwc:
eq_neq_log_probs = gather(
nd["__factor_consequent"]["log_prob"], IndexSet(**{"w": {1}})
eq_neq_log_probs_nec = gather(
nd["__factor_consequent"]["fn"].log_factor, IndexSet(**{"w": {1}})
)

consequent_suff = gather(
nd["consequent"]["value"],
IndexSet(**{"w": {2}}),
event_dim=len(event_shape),
)
eq_neq_log_probs_suff = gather(
nd["__factor_consequent"]["fn"].log_factor, IndexSet(**{"w": {2}})
)

assert eq_neq_log_probs.sum() == 0
assert torch.equal(
eq_neq_log_probs_fact, torch.zeros(eq_neq_log_probs_fact.shape)
)

result = dist.Normal(0.0, 0.1).log_prob(consequent_suff - torch.tensor(0.01))
for _ in range(len(event_shape)):
result = torch.sum(result, dim=-1)

assert torch.allclose(
eq_neq_log_probs_suff.squeeze(),
result.squeeze(),
)
assert eq_neq_log_probs_nec.sum().exp().item() == 0


options = [
Expand Down
Loading

0 comments on commit 43c766c

Please sign in to comment.