Skip to content

Commit

Permalink
Switch to local PyroParams in tests (#225)
Browse files Browse the repository at this point in the history
  • Loading branch information
eb8680 authored Jul 14, 2023
1 parent 035a84b commit 0a68fa2
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 48 deletions.
61 changes: 34 additions & 27 deletions tests/counterfactual/test_counterfactual_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from chirho.observational.handlers import condition
from chirho.observational.ops import observe

pyro.settings.set(module_local_params=True)

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -309,35 +311,35 @@ def model():
}


def hmm_model(data: Iterable, use_condition: bool):
transition_probs = pyro.param(
"transition_probs",
torch.tensor([[0.75, 0.25], [0.25, 0.75]]),
constraint=dist.constraints.simplex,
)
emission_probs = pyro.sample(
"emission_probs",
dist.Dirichlet(torch.tensor([0.5, 0.5])).expand([2]).to_event(1),
)
x = pyro.sample("x", dist.Categorical(torch.tensor([0.5, 0.5])))
logger.debug(f"-1\t{tuple(x.shape)}")
for t, y in pyro.markov(enumerate(data)):
x = pyro.sample(
f"x_{t}",
dist.Categorical(pyro.ops.indexing.Vindex(transition_probs)[..., x, :]),
)
class HMM(pyro.nn.PyroModule):
@pyro.nn.PyroParam(constraint=dist.constraints.simplex)
def trans_probs(self):
return torch.tensor([[0.75, 0.25], [0.25, 0.75]])

if use_condition:
pyro.sample(
f"y_{t}",
dist.Categorical(pyro.ops.indexing.Vindex(emission_probs)[..., x, :]),
)
else:
observe(
dist.Categorical(pyro.ops.indexing.Vindex(emission_probs)[..., x, :]),
y,
name=f"y_{t}",
def forward(self, data: Iterable, use_condition: bool):
emit_probs = pyro.sample(
"emission_probs",
dist.Dirichlet(torch.tensor([0.5, 0.5])).expand([2]).to_event(1),
)
x = pyro.sample("x", dist.Categorical(torch.tensor([0.5, 0.5])))
logger.debug(f"-1\t{tuple(x.shape)}")
for t, y in pyro.markov(enumerate(data)):
x = pyro.sample(
f"x_{t}",
dist.Categorical(pyro.ops.indexing.Vindex(self.trans_probs)[..., x, :]),
)

if use_condition:
pyro.sample(
f"y_{t}",
dist.Categorical(pyro.ops.indexing.Vindex(emit_probs)[..., x, :]),
)
else:
observe(
dist.Categorical(pyro.ops.indexing.Vindex(emit_probs)[..., x, :]),
y,
name=f"y_{t}",
)
logger.debug(f"{t}\t{tuple(x.shape)}")


Expand All @@ -352,6 +354,7 @@ def test_smoke_cf_enumerate_hmm_elbo(
num_steps, use_condition, Elbo, use_guide, max_plate_nesting, cf_dim, num_particles
):
data = dist.Categorical(torch.tensor([0.5, 0.5])).sample((num_steps,))
hmm_model = HMM()

@do(actions={"x_0": torch.tensor(0), "x_1": torch.tensor(0)})
def model(data):
Expand Down Expand Up @@ -397,6 +400,7 @@ def test_smoke_cf_enumerate_hmm_compute_marginals(
num_steps, use_condition, max_plate_nesting, cf_dim
):
data = dist.Categorical(torch.tensor([0.5, 0.5])).sample((num_steps,))
hmm_model = HMM()

@do(actions={"x_0": torch.tensor(0), "x_1": torch.tensor(0)})
@condition(data={"x": torch.as_tensor(0)})
Expand Down Expand Up @@ -437,6 +441,7 @@ def test_smoke_cf_enumerate_hmm_infer_discrete(
num_steps, use_condition, max_plate_nesting, cf_dim, num_particles
):
data = dist.Categorical(torch.tensor([0.5, 0.5])).sample((num_steps,))
hmm_model = HMM()

@do(actions={"x_0": torch.tensor(0), "x_1": torch.tensor(0)})
@condition(data={"x": torch.as_tensor(0)})
Expand Down Expand Up @@ -473,6 +478,7 @@ def test_smoke_cf_enumerate_hmm_mcmc(
num_steps, use_condition, max_plate_nesting, Kernel, cf_dim
):
data = dist.Categorical(torch.tensor([0.5, 0.5])).sample((num_steps,))
hmm_model = HMM()

@do(actions={"x_0": torch.tensor(0), "x_1": torch.tensor(0)})
@condition(data={"x": torch.as_tensor(0)})
Expand Down Expand Up @@ -560,6 +566,7 @@ def model():
@pytest.mark.parametrize("num_steps", [3, 4, 5, 10])
def test_mode_cf_enumerate_hmm_infer_discrete(num_steps, cf_dim):
data = dist.Categorical(torch.tensor([0.5, 0.5])).sample((num_steps,))
hmm_model = HMM()

pin_tr = pyro.poutine.trace(hmm_model).get_trace(data, True)
pinned = {
Expand Down
45 changes: 24 additions & 21 deletions tests/observational/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
SoftEqKernel,
)

pyro.settings.set(module_local_params=True)

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -177,29 +179,29 @@ def test_soft_conditioning_counterfactual_continuous_1(
assert f"{name}_approx_log_prob" not in tr.trace.nodes


def hmm_model(data):
transition_probs = pyro.param(
"transition_probs",
torch.tensor([[0.75, 0.25], [0.25, 0.75]]),
constraint=dist.constraints.simplex,
)
emission_probs = pyro.sample(
"emission_probs",
dist.Dirichlet(torch.tensor([0.5, 0.5])).expand([2]).to_event(1),
)
x = pyro.sample("x", dist.Categorical(torch.tensor([0.5, 0.5])))
logger.debug(f"-1\t{tuple(x.shape)}")
for t, y in pyro.markov(enumerate(data)):
x = pyro.sample(
f"x_{t}",
dist.Categorical(pyro.ops.indexing.Vindex(transition_probs)[..., x, :]),
)
class HMM(pyro.nn.PyroModule):
@pyro.nn.PyroParam(constraint=dist.constraints.simplex)
def trans_probs(self):
return torch.tensor([[0.75, 0.25], [0.25, 0.75]])

pyro.sample(
f"y_{t}",
dist.Categorical(pyro.ops.indexing.Vindex(emission_probs)[..., x, :]),
def forward(self, data):
emission_probs = pyro.sample(
"emission_probs",
dist.Dirichlet(torch.tensor([0.5, 0.5])).expand([2]).to_event(1),
)
logger.debug(f"{t}\t{tuple(x.shape)}")
x = pyro.sample("x", dist.Categorical(torch.tensor([0.5, 0.5])))
logger.debug(f"-1\t{tuple(x.shape)}")
for t, y in pyro.markov(enumerate(data)):
x = pyro.sample(
f"x_{t}",
dist.Categorical(pyro.ops.indexing.Vindex(self.trans_probs)[..., x, :]),
)

pyro.sample(
f"y_{t}",
dist.Categorical(pyro.ops.indexing.Vindex(emission_probs)[..., x, :]),
)
logger.debug(f"{t}\t{tuple(x.shape)}")


@pytest.mark.parametrize("num_particles", [1, 10])
Expand All @@ -211,6 +213,7 @@ def test_smoke_condition_enumerate_hmm_elbo(
num_steps, Elbo, use_guide, max_plate_nesting, num_particles
):
data = dist.Categorical(torch.tensor([0.5, 0.5])).sample((num_steps,))
hmm_model = HMM()

assert issubclass(Elbo, pyro.infer.elbo.ELBO)
elbo = Elbo(
Expand Down

0 comments on commit 0a68fa2

Please sign in to comment.