Skip to content

Commit 29fd732

Browse files
authored
Add test to compare draws when var_names is used in pm.sample() (#7287)
1 parent a74c03f commit 29fd732

File tree

1 file changed

+33
-6
lines changed

1 file changed

+33
-6
lines changed

tests/sampling/test_mcmc.py

+33-6
Original file line numberDiff line numberDiff line change
@@ -695,12 +695,39 @@ def test_no_init_nuts_compound(caplog):
695695

696696

697697
def test_sample_var_names():
698-
with pm.Model() as model:
699-
a = pm.Normal("a")
700-
b = pm.Deterministic("b", a**2)
701-
idata = pm.sample(10, tune=10, var_names=["a"])
702-
assert "a" in idata.posterior
703-
assert "b" not in idata.posterior
698+
# Generate data
699+
seed = 1234
700+
rng = np.random.default_rng(seed)
701+
702+
group = rng.choice(list("ABCD"), size=100)
703+
x = rng.normal(size=100)
704+
y = rng.normal(size=100)
705+
706+
group_values, group_idx = np.unique(group, return_inverse=True)
707+
708+
coords = {"group": group_values}
709+
710+
# Create model
711+
with pm.Model(coords=coords) as model:
712+
b_group = pm.Normal("b_group", dims="group")
713+
b_x = pm.Normal("b_x")
714+
mu = pm.Deterministic("mu", b_group[group_idx] + b_x * x)
715+
sigma = pm.HalfNormal("sigma")
716+
pm.Normal("y", mu=mu, sigma=sigma, observed=y)
717+
718+
# Sample with and without var_names, but always with the same seed
719+
with model:
720+
idata_1 = pm.sample(tune=100, draws=100, random_seed=seed)
721+
idata_2 = pm.sample(
722+
tune=100, draws=100, var_names=["b_group", "b_x", "sigma"], random_seed=seed
723+
)
724+
725+
assert "mu" in idata_1.posterior
726+
assert "mu" not in idata_2.posterior
727+
728+
assert np.all(idata_1.posterior["b_group"] == idata_2.posterior["b_group"]).item()
729+
assert np.all(idata_1.posterior["b_x"] == idata_2.posterior["b_x"]).item()
730+
assert np.all(idata_1.posterior["sigma"] == idata_2.posterior["sigma"]).item()
704731

705732

706733
class TestAssignStepMethods:

0 commit comments

Comments
 (0)