@@ -1958,3 +1958,33 @@ def test_vectorize_over_posterior_matches_sample():
1958
1958
atol = 0.6 / np .sqrt (10000 ),
1959
1959
)
1960
1960
assert np .all (np .abs (vect_obs - x_posterior [..., None ]) < 1 )
1961
+
1962
+
1963
+ def test_vectorize_over_posterior_with_intermediate_rvs ():
1964
+ with pm .Model () as model :
1965
+ a = pm .Normal ("a" )
1966
+ b = pm .Normal .dist (a )
1967
+ c = b + 1
1968
+ d = pm .Normal .dist (c )
1969
+ idata = pm .sample_prior_predictive (100 , var_names = ["a" ])
1970
+ idata .add_groups ({"posterior" : idata .prior })
1971
+ _ , _ , vectorized_no_intermediate = vectorize_over_posterior (
1972
+ outputs = [b , c , d ],
1973
+ posterior = idata .posterior ,
1974
+ input_rvs = [a ],
1975
+ allow_rvs_in_graph = True ,
1976
+ )
1977
+ [vectorized_intermediate_rvs ] = vectorize_over_posterior (
1978
+ outputs = [d ],
1979
+ posterior = idata .posterior ,
1980
+ input_rvs = [a ],
1981
+ allow_rvs_in_graph = True ,
1982
+ )
1983
+ assert vectorized_no_intermediate .type .shape == (1 , 100 )
1984
+ assert vectorized_no_intermediate .type .shape == vectorized_intermediate_rvs .type .shape
1985
+ [a_ancestor1 ] = get_var_by_name ([vectorized_no_intermediate ], "a" )
1986
+ [a_ancestor2 ] = get_var_by_name ([vectorized_intermediate_rvs ], "a" )
1987
+ assert isinstance (a_ancestor1 , TensorConstant )
1988
+ assert np .array_equiv (a_ancestor1 .eval (), idata .posterior .a .data )
1989
+ assert isinstance (a_ancestor2 , TensorConstant )
1990
+ assert np .array_equiv (a_ancestor2 .eval (), idata .posterior .a .data )
0 commit comments