Skip to content

Commit

Permalink
Added test to check output of estimating counterfactual with AutoSoft…
Browse files Browse the repository at this point in the history
…Conditioning (#224)

* added test to check output of cf inference

* removed local module setting configuration

---------

Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: Andy Zane <[email protected]>
  • Loading branch information
3 people authored Jul 31, 2023
1 parent a654357 commit 7508a0a
Showing 1 changed file with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions tests/counterfactual/test_counterfactual_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pyro.infer
import pytest
import torch
from pyro.infer.autoguide import AutoMultivariateNormal

import chirho.interventional.handlers # noqa: F401
from chirho.counterfactual.handlers import ( # TwinWorldCounterfactual,
Expand All @@ -19,6 +20,7 @@
from chirho.interventional.handlers import do
from chirho.interventional.ops import intervene
from chirho.observational.handlers import condition
from chirho.observational.handlers.soft_conditioning import AutoSoftConditioning
from chirho.observational.ops import observe

pyro.settings.set(module_local_params=True)
Expand Down Expand Up @@ -644,3 +646,66 @@ def model():

assert torch.any(tr.nodes["z"]["value"] > 0)
assert torch.any(tr.nodes["z"]["value"] < 1)


# Define a helper function to run SVI. (Generally, Pyro users like to have more control over the training process!)
def run_svi_inference(model, n_steps=1000, verbose=True, lr=0.03, **model_kwargs):
guide = AutoMultivariateNormal(model)
elbo = pyro.infer.Trace_ELBO()(model, guide)
# initialize parameters
elbo(**model_kwargs)
adam = torch.optim.Adam(elbo.parameters(), lr=lr)
# Do gradient steps
for step in range(1, n_steps + 1):
adam.zero_grad()
loss = elbo(**model_kwargs)
loss.backward()
adam.step()
if (step % 100 == 0) or (step == 1) & verbose:
print("[iteration %04d] loss: %.4f" % (step, loss))
return guide


def test_cf_inference_with_soft_conditioner():
def model():
z = pyro.sample("z", dist.Normal(0, 1), obs=torch.tensor(0.1))
u_x = pyro.sample("u_x", dist.Normal(0, 1))
x = pyro.deterministic("x", z + u_x, event_dim=0)
u_y = pyro.sample("u_y", dist.Normal(0, 1))
y = pyro.deterministic("y", x + z + u_y, event_dim=0)
return dict(x=x, y=y, z=z)

h_cond = condition(data={"x": torch.tensor(0.0), "y": torch.tensor(1.0)})
h_do = do(actions={"z": torch.tensor(0.0)})
scale = 0.01
reparam_config = AutoSoftConditioning(scale=scale, alpha=0.5)

def model_cf():
with pyro.poutine.reparam(config=reparam_config):
with TwinWorldCounterfactual(), h_do, h_cond:
model()

def model_conditioned():
with pyro.poutine.reparam(config=reparam_config):
with h_cond:
model()

# Run SVI inference
guide = run_svi_inference(model_conditioned, n_steps=2500, verbose=False)
est_u_x = guide.median()["u_x"]
est_u_y = guide.median()["u_y"]

assert torch.allclose(
est_u_x, torch.tensor(-0.1), atol=5 * scale
) # p(u_x | z=.1, x=0, y=1) is a point mass at -0.1
assert torch.allclose(
est_u_y, torch.tensor(0.9), atol=5 * scale
) # p(u_y | z=.1, x=0, y=1) is a point mass at 0.9

# Compute counterfactuals
cf_samps = pyro.infer.Predictive(model_cf, guide=guide, num_samples=100)()
avg_x_cf = cf_samps["x"].squeeze()[:, 1].mean()
avg_y_cf = cf_samps["y"].squeeze()[:, 1].mean()

assert torch.allclose(avg_x_cf, torch.tensor(-0.1), atol=5 * scale)
assert torch.allclose(avg_y_cf, torch.tensor(0.8), atol=5 * scale)

0 comments on commit 7508a0a

Please sign in to comment.