@@ -695,12 +695,39 @@ def test_no_init_nuts_compound(caplog):
695
695
696
696
697
697
def test_sample_var_names ():
698
- with pm .Model () as model :
699
- a = pm .Normal ("a" )
700
- b = pm .Deterministic ("b" , a ** 2 )
701
- idata = pm .sample (10 , tune = 10 , var_names = ["a" ])
702
- assert "a" in idata .posterior
703
- assert "b" not in idata .posterior
698
+ # Generate data
699
+ seed = 1234
700
+ rng = np .random .default_rng (seed )
701
+
702
+ group = rng .choice (list ("ABCD" ), size = 100 )
703
+ x = rng .normal (size = 100 )
704
+ y = rng .normal (size = 100 )
705
+
706
+ group_values , group_idx = np .unique (group , return_inverse = True )
707
+
708
+ coords = {"group" : group_values }
709
+
710
+ # Create model
711
+ with pm .Model (coords = coords ) as model :
712
+ b_group = pm .Normal ("b_group" , dims = "group" )
713
+ b_x = pm .Normal ("b_x" )
714
+ mu = pm .Deterministic ("mu" , b_group [group_idx ] + b_x * x )
715
+ sigma = pm .HalfNormal ("sigma" )
716
+ pm .Normal ("y" , mu = mu , sigma = sigma , observed = y )
717
+
718
+ # Sample with and without var_names, but always with the same seed
719
+ with model :
720
+ idata_1 = pm .sample (tune = 100 , draws = 100 , random_seed = seed )
721
+ idata_2 = pm .sample (
722
+ tune = 100 , draws = 100 , var_names = ["b_group" , "b_x" , "sigma" ], random_seed = seed
723
+ )
724
+
725
+ assert "mu" in idata_1 .posterior
726
+ assert "mu" not in idata_2 .posterior
727
+
728
+ assert np .all (idata_1 .posterior ["b_group" ] == idata_2 .posterior ["b_group" ]).item ()
729
+ assert np .all (idata_1 .posterior ["b_x" ] == idata_2 .posterior ["b_x" ]).item ()
730
+ assert np .all (idata_1 .posterior ["sigma" ] == idata_2 .posterior ["sigma" ]).item ()
704
731
705
732
706
733
class TestAssignStepMethods :
0 commit comments