Skip to content

Commit

Permalink
Add regression test mcmc seeding with Generators
Browse files Browse the repository at this point in the history
Co-authored-by: ricardoV94 <[email protected]>
  • Loading branch information
lucianopaz and ricardoV94 committed Jan 7, 2025
1 parent fe4d304 commit 9dcd50a
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tests/sampling/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,23 @@ def logp(x, mu):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn")


@pytest.mark.parametrize("cores", (1, 2))
def test_sampling_with_random_generator_matches(cores):
# Regression test for https://github.com/pymc-devs/pymc/issues/7612
kwargs = {
"chains": 2,
"cores": cores,
"tune": 10,
"draws": 10,
"compute_convergence_checks": False,
"progress_bar": False,
}
with pm.Model() as m:
x = pm.Normal("x")

post1 = pm.sample(random_seed=np.random.default_rng(42), **kwargs).posterior
post2 = pm.sample(random_seed=np.random.default_rng(42), **kwargs).posterior

assert post1.equals(post2), (post1["x"].mean().item(), post2["x"].mean().item())

0 comments on commit 9dcd50a

Please sign in to comment.