Skip to content

Commit 4f57c50

Browse files
committed
Fix independent_rvs determination in vectorize_over_posterior
1 parent 0960323 commit 4f57c50

File tree

2 files changed

+34
-7
lines changed

2 files changed

+34
-7
lines changed

pymc/sampling/forward.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,14 +1059,11 @@ def vectorize_over_posterior(
10591059
for rv in general_toposort( # type: ignore[call-overload]
10601060
all_rvs, lambda x: x.owner.inputs if x.owner is not None else None
10611061
)
1062-
if rv in all_rvs
1062+
if rv in all_rvs and rv not in needed_rvs
10631063
]:
1064-
rv_ancestors = ancestors([rv], blockers=[*needed_rvs, *independent_rvs, *outputs])
1065-
if (
1066-
rv not in needed_rvs
1067-
and not ({*outputs, *independent_rvs} & set(rv_ancestors))
1068-
and {var for var in rv_ancestors if var in all_rvs} <= {rv, *needed_rvs}
1069-
):
1064+
blockers = [*needed_rvs, *independent_rvs, *outputs]
1065+
rv_ancestors = ancestors([rv], blockers=blockers)
1066+
if not (set(blockers) & set(rv_ancestors)):
10701067
independent_rvs.append(rv)
10711068
for rv in independent_rvs:
10721069
replace_dict[rv] = change_dist_size(rv, new_size=batch_shape, expand=True)

tests/sampling/test_forward.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1958,3 +1958,33 @@ def test_vectorize_over_posterior_matches_sample():
19581958
atol=0.6 / np.sqrt(10000),
19591959
)
19601960
assert np.all(np.abs(vect_obs - x_posterior[..., None]) < 1)
1961+
1962+
1963+
def test_vectorize_over_posterior_with_intermediate_rvs():
1964+
with pm.Model() as model:
1965+
a = pm.Normal("a")
1966+
b = pm.Normal.dist(a)
1967+
c = b + 1
1968+
d = pm.Normal.dist(c)
1969+
idata = pm.sample_prior_predictive(100, var_names=["a"])
1970+
idata.add_groups({"posterior": idata.prior})
1971+
_, _, vectorized_no_intermediate = vectorize_over_posterior(
1972+
outputs=[b, c, d],
1973+
posterior=idata.posterior,
1974+
input_rvs=[a],
1975+
allow_rvs_in_graph=True,
1976+
)
1977+
[vectorized_intermediate_rvs] = vectorize_over_posterior(
1978+
outputs=[d],
1979+
posterior=idata.posterior,
1980+
input_rvs=[a],
1981+
allow_rvs_in_graph=True,
1982+
)
1983+
assert vectorized_no_intermediate.type.shape == (1, 100)
1984+
assert vectorized_no_intermediate.type.shape == vectorized_intermediate_rvs.type.shape
1985+
[a_ancestor1] = get_var_by_name([vectorized_no_intermediate], "a")
1986+
[a_ancestor2] = get_var_by_name([vectorized_intermediate_rvs], "a")
1987+
assert isinstance(a_ancestor1, TensorConstant)
1988+
assert np.array_equiv(a_ancestor1.eval(), idata.posterior.a.data)
1989+
assert isinstance(a_ancestor2, TensorConstant)
1990+
assert np.array_equiv(a_ancestor2.eval(), idata.posterior.a.data)

0 commit comments

Comments
 (0)