From 7ba67b611243fa4e15c1f3ac80e98fe9b61dbfab Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Sat, 21 Dec 2024 12:27:09 +0100 Subject: [PATCH] Add regression test mcmc seeding with Generators Co-authored-by: ricardoV94 <28983449+ricardov94@users.noreply.github.com> --- tests/sampling/test_parallel.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/sampling/test_parallel.py b/tests/sampling/test_parallel.py index 8c71bcac001..fb0ed528234 100644 --- a/tests/sampling/test_parallel.py +++ b/tests/sampling/test_parallel.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2025 The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -228,3 +228,23 @@ def logp(x, mu): with warnings.catch_warnings(): warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn") + + +@pytest.mark.parametrize("cores", (1, 2)) +def test_sampling_with_random_generator_matches(cores): + # Regression test for https://github.com/pymc-devs/pymc/issues/7612 + kwargs = { + "chains": 2, + "cores": cores, + "tune": 10, + "draws": 10, + "compute_convergence_checks": False, + "progress_bar": False, + } + with pm.Model() as m: + x = pm.Normal("x") + + post1 = pm.sample(random_seed=np.random.default_rng(42), **kwargs).posterior + post2 = pm.sample(random_seed=np.random.default_rng(42), **kwargs).posterior + + assert post1.equals(post2), (post1["x"].mean().item(), post2["x"].mean().item())