Skip to content

Commit

Permalink
Add an effect handler for inserting factor statements (#238)
Browse files Browse the repository at this point in the history
* Add a Factors handler for inserting new factors

* docs and test

* lint
  • Loading branch information
eb8680 authored Sep 1, 2023
1 parent cbe8bdb commit 7c2f163
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 1 deletion.
48 changes: 47 additions & 1 deletion chirho/observational/handlers/condition.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,57 @@
from typing import Generic, Hashable, Mapping, TypeVar
from typing import Callable, Generic, Hashable, Mapping, TypeVar, Union

import pyro
import torch

from chirho.observational.internals import ObserveNameMessenger
from chirho.observational.ops import AtomicObservation, observe

T = TypeVar("T")
R = Union[float, torch.Tensor]


class Factors(Generic[T], pyro.poutine.messenger.Messenger):
"""
Effect handler that adds new log-factors to the unnormalized
joint log-density of a probabilistic program.
After a :func:`pyro.sample` site whose name appears in ``factors``,
this handler inserts a new :func:`pyro.factor` site
whose name is prefixed with the string ``prefix``
and whose log-weight is the result of applying the corresponding function
to the value of the sample site. ::
>>> with Factors(factors={"x": lambda x: -(x - 1) ** 2}, prefix="__factor_"):
... with pyro.poutine.trace() as tr:
... x = pyro.sample("x", dist.Normal(0, 1))
... tr.trace.compute_log_prob()
>>> assert {"x", "__factor_x"} <= set(tr.trace.nodes.keys())
>>> assert torch.all(tr.trace.nodes["x"]["log_prob"] == -(x - 1) ** 2)
:param factors: A mapping from sample site names to log-factor functions.
:param prefix: The prefix to use for the names of the factor sites.
"""

factors: Mapping[str, Callable[[T], R]]
prefix: str

def __init__(
self,
factors: Mapping[str, Callable[[T], R]],
*,
prefix: str = "__factor_",
):
self.factors = factors
self.prefix = prefix
super().__init__()

def _pyro_post_sample(self, msg: dict) -> None:
try:
factor = self.factors[msg["name"]]
except KeyError:
return

pyro.factor(f"{self.prefix}{msg['name']}", factor(msg["value"]))


class ConditionMessenger(Generic[T], ObserveNameMessenger):
Expand Down
34 changes: 34 additions & 0 deletions tests/observational/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from chirho.interventional.handlers import do
from chirho.observational.handlers import condition
from chirho.observational.handlers.condition import Factors
from chirho.observational.handlers.soft_conditioning import (
AutoSoftConditioning,
KernelSoftConditionReparam,
Expand Down Expand Up @@ -284,3 +285,36 @@ def model():
assert torch.allclose(node["value"], tr3.trace.nodes[name]["value"])
assert torch.allclose(node["log_prob"], tr2.trace.nodes[name]["log_prob"])
assert torch.allclose(node["log_prob"], tr3.trace.nodes[name]["log_prob"])


def test_factors_handler():
def model():
z = pyro.sample("z", dist.Normal(0, 1), obs=torch.tensor(0.1))
with pyro.plate("data", 2):
x = pyro.sample("x", dist.Normal(z, 1))
y = pyro.sample("y", dist.Normal(x + z, 1))
return z, x, y

prefix = "__factor_"
factors = {
"z": lambda z: -((z - 1.5) ** 2),
"x": lambda x: -((x - 1) ** 2),
}

with Factors[torch.Tensor](factors=factors, prefix=prefix):
with pyro.poutine.trace() as tr:
model()

tr.trace.compute_log_prob()

for name in factors:
assert name in tr.trace.nodes
assert f"{prefix}{name}" in tr.trace.nodes
assert (
tr.trace.nodes[name]["fn"].batch_shape
== tr.trace.nodes[f"{prefix}{name}"]["fn"].batch_shape
)
assert torch.allclose(
tr.trace.nodes[f"{prefix}{name}"]["log_prob"],
factors[name](tr.trace.nodes[name]["value"]),
)

0 comments on commit 7c2f163

Please sign in to comment.