-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor counterfactuals to use observe and condition (#176)
* Separate observe and condition * Split up files and create observational handlers folder * imports * lint * rename test * add test about commutativity of do and condition * doc * union * Refactor counterfactuals to use observe * appease mypy * Vindex fixes particle errors * update backdoor * update slc * fix particle test case * add cf commutativity test * fix bug * revert slc handler order * add predictive smoke test * nit * elbo * reorder test * Add a stronger infer_discrete test * move notebooks to separate branch * test * chirho * merge fail * Update and re-run example notebooks with new condition (#178) * Update and re-run backdoor and SLC notebooks * deepscm * cevae * import * mediation * merge * update notebooks * merge * merge 2 * toc * populate autodoc * tweak * Restores (via cherry-pick) Notebook Link and Formatting Changes (#205) * fixed outline rendering * fixes outline links in mediation notebook. * fixes outline and links for backdoor notebook. * fixes outline links in cevae notebook. * fixes slc notebook outline links. * adds outline back into deep scm notebook. * address remaining reference issues, building now with now warnings --------- Co-authored-by: Sam Witty <[email protected]> --------- Co-authored-by: Andy Zane <[email protected]> Co-authored-by: Sam Witty <[email protected]>
- Loading branch information
1 parent
2af3f73
commit f659fa8
Showing
15 changed files
with
769 additions
and
413 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,29 @@ | ||
import pyro.infer.reparam | ||
import torch | ||
from typing import Any, Dict | ||
|
||
from chirho.indexed.ops import indices_of, union | ||
|
||
def expand_obs_value_inplace_(msg: pyro.infer.reparam.reparam.ReparamMessage) -> None: | ||
|
||
def site_is_ambiguous(msg: Dict[str, Any]) -> bool: | ||
""" | ||
Helper function used with :func:`observe` to determine | ||
whether a site is observed or ambiguous. | ||
A sample site is ambiguous if it is marked observed, is downstream of an intervention, | ||
and the observed value's index variables are a strict subset of the distribution's | ||
indices and hence require clarification of which entries of the random variable | ||
are fixed/observed (as opposed to random/unobserved). | ||
""" | ||
rv, obs = msg["args"][:2] | ||
value_indices = indices_of(obs, event_dim=len(rv.event_shape)) | ||
dist_indices = indices_of(rv) | ||
return ( | ||
bool(union(value_indices, dist_indices)) and value_indices != dist_indices | ||
) or not msg["infer"].get("_specified_conditioning", True) | ||
|
||
|
||
def no_ambiguity(msg: Dict[str, Any]) -> Dict[str, Any]: | ||
""" | ||
Slightly gross workaround that mutates the msg in place | ||
to avoid triggering overzealous validation logic in | ||
:class:~`pyro.poutine.reparam.ReparamMessenger` | ||
that uses cheaper tensor shape and identity equality checks as | ||
a conservative proxy for an expensive tensor value equality check. | ||
(see https://github.com/pyro-ppl/pyro/blob/685c7adee65bbcdd6bd6c84c834a0a460f2224eb/pyro/poutine/reparam_messenger.py#L99) # noqa: E501 | ||
This workaround is correct because these reparameterizers do not change | ||
the observed entries, it just packs counterfactual values around them; | ||
the equality check being approximated by that logic would still pass. | ||
Helper function used with :func:`pyro.poutine.infer_config` to inform | ||
:class:`FactualConditioningMessenger` that all ambiguity in the current | ||
context has been resolved. | ||
""" | ||
msg["value"] = torch.as_tensor(msg["value"]) | ||
msg["infer"]["orig_shape"] = msg["value"].shape | ||
_custom_init = getattr(msg["value"], "_pyro_custom_init", False) | ||
msg["value"] = msg["value"].expand( | ||
torch.broadcast_shapes( | ||
msg["fn"].batch_shape + msg["fn"].event_shape, | ||
msg["value"].shape, | ||
) | ||
) | ||
setattr(msg["value"], "_pyro_custom_init", _custom_init) | ||
return {"_specified_conditioning": True} |
Oops, something went wrong.