Skip to content

Commit

Permalink
Add observe operation and new condition handler (#175)
Browse files Browse the repository at this point in the history
* Separate observe and condition

* Split up files and create observational handlers folder

* imports

* lint

* rename test

* add test about commutativity of do and condition

* doc

* union

* fix particle test case

* fix bug

* chirho
  • Loading branch information
eb8680 authored Jul 10, 2023
1 parent c0178b3 commit e06de84
Show file tree
Hide file tree
Showing 8 changed files with 247 additions and 8 deletions.
9 changes: 5 additions & 4 deletions chirho/interventional/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,14 @@ def __init__(self, actions: Mapping[Hashable, AtomicIntervention[T]]):
super().__init__()

def _pyro_post_sample(self, msg):
try:
action = self.actions[msg["name"]]
except KeyError:
if msg["name"] not in self.actions or msg["infer"].get(
"_do_not_intervene", None
):
return

msg["value"] = intervene(
msg["value"],
action,
self.actions[msg["name"]],
event_dim=len(msg["fn"].event_shape),
name=msg["name"],
)
Expand Down
1 change: 1 addition & 0 deletions chirho/observational/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import chirho.observational.internals # noqa: F401
1 change: 1 addition & 0 deletions chirho/observational/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .condition import condition # noqa: F401
63 changes: 63 additions & 0 deletions chirho/observational/handlers/condition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Generic, Hashable, Mapping, TypeVar

import pyro

from chirho.observational.internals import ObserveNameMessenger
from chirho.observational.ops import AtomicObservation, observe

T = TypeVar("T")


class ConditionMessenger(Generic[T], ObserveNameMessenger):
"""
Condition on values in a probabilistic program.
Can be used as a drop-in replacement for :func:`pyro.condition` that supports
a richer set of observational data types and enables counterfactual inference.
"""

def __init__(self, data: Mapping[Hashable, AtomicObservation[T]]):
self.data = data
super().__init__()

def _pyro_sample(self, msg):
if pyro.poutine.util.site_is_subsample(msg) or pyro.poutine.util.site_is_factor(
msg
):
return

if msg["name"] not in self.data or msg["infer"].get("_do_not_observe", None):
if (
"_markov_scope" in msg["infer"]
and getattr(self, "_current_site", None) is not None
):
msg["infer"]["_markov_scope"].pop(self._current_site, None)
return

msg["stop"] = True
msg["done"] = True

# flags to guarantee commutativity of condition, intervene, trace
msg["mask"] = False
msg["is_observed"] = False
msg["infer"]["is_auxiliary"] = True
msg["infer"]["_do_not_trace"] = True
msg["infer"]["_do_not_intervene"] = True
msg["infer"]["_do_not_observe"] = True

with pyro.poutine.infer_config(
config_fn=lambda msg_: {
"_do_not_observe": msg["name"] == msg_["name"]
or msg_["infer"].get("_do_not_observe", False)
}
):
try:
self._current_site = msg["name"]
msg["value"] = observe(
msg["fn"], self.data[msg["name"]], name=msg["name"], **msg["kwargs"]
)
finally:
self._current_site = None


condition = pyro.poutine.handlers._make_handler(ConditionMessenger)[1]
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

T = TypeVar("T")


Kernel = Callable[[T, T], torch.Tensor]


Expand Down
47 changes: 47 additions & 0 deletions chirho/observational/internals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Optional, TypeVar

import pyro
import pyro.distributions
import torch

from chirho.observational.ops import AtomicObservation, observe

T = TypeVar("T")


@observe.register(int)
@observe.register(float)
@observe.register(bool)
@observe.register(torch.Tensor)
def _observe_deterministic(rv: T, obs: Optional[AtomicObservation[T]] = None, **kwargs):
"""
Observe a tensor in a probabilistic program.
"""
rv_dist = pyro.distributions.Delta(
torch.as_tensor(rv), event_dim=kwargs.pop("event_dim", 0)
)
return observe(rv_dist, obs, **kwargs)


@observe.register(pyro.distributions.Distribution)
@pyro.poutine.runtime.effectful(type="observe")
def _observe_distribution(
rv: pyro.distributions.Distribution,
obs: Optional[AtomicObservation[T]] = None,
*,
name: Optional[str] = None,
**kwargs,
) -> T:
if name is None:
raise ValueError("name must be specified when observing a distribution")

if callable(obs):
raise NotImplementedError("Dependent observations are not yet supported")

return pyro.sample(name, rv, obs=obs, **kwargs)


class ObserveNameMessenger(pyro.poutine.messenger.Messenger):
def _pyro_observe(self, msg):
if "name" not in msg["kwargs"]:
msg["kwargs"]["name"] = msg["name"]
18 changes: 18 additions & 0 deletions chirho/observational/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import functools
from typing import Callable, Hashable, Mapping, Optional, TypeVar, Union

T = TypeVar("T")

AtomicObservation = Union[T, Callable[..., T]] # TODO add support for more atomic types
CompoundObservation = Union[
Mapping[Hashable, AtomicObservation[T]], Callable[..., AtomicObservation[T]]
]
Observation = Union[AtomicObservation[T], CompoundObservation[T]]


@functools.singledispatch
def observe(rv, obs: Optional[Observation[T]] = None, **kwargs) -> T:
"""
Observe a random value in a probabilistic program.
"""
raise NotImplementedError(f"observe not implemented for type {type(rv)}")
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
TwinWorldCounterfactual,
)
from chirho.interventional.handlers import do
from chirho.observational.handlers import (
from chirho.observational.handlers import condition
from chirho.observational.handlers.soft_conditioning import (
AutoSoftConditioning,
KernelSoftConditionReparam,
RBFKernel,
Expand Down Expand Up @@ -69,7 +70,7 @@ def test_soft_conditioning_smoke_continuous_1(
}
with pyro.poutine.trace() as tr, pyro.poutine.reparam(
config=reparam_config
), pyro.condition(data=data):
), condition(data=data):
continuous_scm_1()

tr.trace.compute_log_prob()
Expand Down Expand Up @@ -110,7 +111,7 @@ def test_soft_conditioning_smoke_discrete_1(
}
with pyro.poutine.trace() as tr, pyro.poutine.reparam(
config=reparam_config
), pyro.condition(data=data):
), condition(data=data):
discrete_scm_1()

tr.trace.compute_log_prob()
Expand Down Expand Up @@ -154,7 +155,7 @@ def test_soft_conditioning_counterfactual_continuous_1(

with pyro.poutine.trace() as tr, pyro.poutine.reparam(
config=reparam_config
), cf_class(cf_dim), do(actions=actions), pyro.condition(data=data):
), cf_class(cf_dim), do(actions=actions), condition(data=data):
continuous_scm_1()

tr.trace.compute_log_prob()
Expand All @@ -174,3 +175,109 @@ def test_soft_conditioning_counterfactual_continuous_1(
else:
assert AutoSoftConditioning.site_is_deterministic(tr.trace.nodes[name])
assert f"{name}_approx_log_prob" not in tr.trace.nodes


def hmm_model(data):
transition_probs = pyro.param(
"transition_probs",
torch.tensor([[0.75, 0.25], [0.25, 0.75]]),
constraint=dist.constraints.simplex,
)
emission_probs = pyro.sample(
"emission_probs",
dist.Dirichlet(torch.tensor([0.5, 0.5])).expand([2]).to_event(1),
)
x = pyro.sample("x", dist.Categorical(torch.tensor([0.5, 0.5])))
logger.debug(f"-1\t{tuple(x.shape)}")
for t, y in pyro.markov(enumerate(data)):
x = pyro.sample(
f"x_{t}",
dist.Categorical(pyro.ops.indexing.Vindex(transition_probs)[..., x, :]),
)

pyro.sample(
f"y_{t}",
dist.Categorical(pyro.ops.indexing.Vindex(emission_probs)[..., x, :]),
)
logger.debug(f"{t}\t{tuple(x.shape)}")


@pytest.mark.parametrize("num_particles", [1, 10])
@pytest.mark.parametrize("max_plate_nesting", [3, float("inf")])
@pytest.mark.parametrize("use_guide", [False, True])
@pytest.mark.parametrize("num_steps", [2, 3, 4, 5, 6])
@pytest.mark.parametrize("Elbo", [pyro.infer.TraceEnum_ELBO, pyro.infer.TraceTMC_ELBO])
def test_smoke_condition_enumerate_hmm_elbo(
num_steps, Elbo, use_guide, max_plate_nesting, num_particles
):
data = dist.Categorical(torch.tensor([0.5, 0.5])).sample((num_steps,))

assert issubclass(Elbo, pyro.infer.elbo.ELBO)
elbo = Elbo(
max_plate_nesting=max_plate_nesting,
num_particles=num_particles,
vectorize_particles=(num_particles > 1),
)

model = condition(data={f"y_{t}": y for t, y in enumerate(data)})(hmm_model)

if use_guide:
guide = pyro.infer.config_enumerate(default="parallel")(
pyro.infer.autoguide.AutoDiscreteParallel(
pyro.poutine.block(expose=["x"])(condition(data={})(model))
)
)
model = pyro.infer.config_enumerate(default="parallel")(model)
else:
model = pyro.infer.config_enumerate(default="parallel")(model)
model = condition(model, data={"x": torch.as_tensor(0)})

def guide(data):
pass

# smoke test
elbo.differentiable_loss(model, guide, data)


def test_condition_commutes():
def model():
z = pyro.sample("z", dist.Normal(0, 1), obs=torch.tensor(0.1))
with pyro.plate("data", 2):
x = pyro.sample("x", dist.Normal(z, 1))
y = pyro.sample("y", dist.Normal(x + z, 1))
return z, x, y

h_cond = condition(
data={"x": torch.tensor([0.0, 1.0]), "y": torch.tensor([1.0, 2.0])}
)
h_do = do(actions={"z": torch.tensor(0.0), "x": torch.tensor([0.3, 0.4])})

# case 1
with pyro.poutine.trace() as tr1:
with h_cond, h_do:
model()

# case 2
with pyro.poutine.trace() as tr2:
with h_do, h_cond:
model()

# case 3
with h_cond, pyro.poutine.trace() as tr3:
with h_do:
model()

tr1.trace.compute_log_prob()
tr2.trace.compute_log_prob()
tr3.trace.compute_log_prob()

assert set(tr1.trace.nodes) == set(tr2.trace.nodes) == set(tr3.trace.nodes)
assert (
tr1.trace.log_prob_sum() == tr2.trace.log_prob_sum() == tr3.trace.log_prob_sum()
)
for name, node in tr1.trace.nodes.items():
if node["type"] == "sample" and not pyro.poutine.util.site_is_subsample(node):
assert torch.allclose(node["value"], tr2.trace.nodes[name]["value"])
assert torch.allclose(node["value"], tr3.trace.nodes[name]["value"])
assert torch.allclose(node["log_prob"], tr2.trace.nodes[name]["log_prob"])
assert torch.allclose(node["log_prob"], tr3.trace.nodes[name]["log_prob"])

0 comments on commit e06de84

Please sign in to comment.