Skip to content

Commit

Permalink
Add regression test mcmc seeding with Generators
Browse files Browse the repository at this point in the history
Co-authored-by: ricardoV94 <[email protected]>
  • Loading branch information
lucianopaz and ricardoV94 committed Jan 7, 2025
1 parent 4e5dc28 commit 7ba67b6
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion tests/sampling/test_parallel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 The PyMC Developers
# Copyright 2025 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -228,3 +228,23 @@ def logp(x, mu):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn")


@pytest.mark.parametrize("cores", (1, 2))
def test_sampling_with_random_generator_matches(cores):
# Regression test for https://github.com/pymc-devs/pymc/issues/7612
kwargs = {
"chains": 2,
"cores": cores,
"tune": 10,
"draws": 10,
"compute_convergence_checks": False,
"progress_bar": False,
}
with pm.Model() as m:
x = pm.Normal("x")

post1 = pm.sample(random_seed=np.random.default_rng(42), **kwargs).posterior
post2 = pm.sample(random_seed=np.random.default_rng(42), **kwargs).posterior

assert post1.equals(post2), (post1["x"].mean().item(), post2["x"].mean().item())

0 comments on commit 7ba67b6

Please sign in to comment.