From 33e2d0c7d1662923586e464e10fd2cf0829b1ac9 Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Thu, 7 Nov 2024 13:23:32 +0100 Subject: [PATCH] Ensure parallel sampling does not lose BitGenerator state --- pymc/sampling/parallel.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 4edc80433de..67417e0d8f1 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -33,7 +33,13 @@ from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError -from pymc.util import CustomProgress, default_progress_theme +from pymc.util import ( + CustomProgress, + RandomGeneratorState, + default_progress_theme, + get_state_from_generator, + random_generator_from_state, +) logger = logging.getLogger(__name__) @@ -96,13 +102,12 @@ def __init__( shared_point, draws: int, tune: int, - rng: np.random.Generator, - seed_seq: np.random.SeedSequence, + rng_state: RandomGeneratorState, blas_cores, ): # For some strange reason, spawn multiprocessing doesn't copy the rng # seed sequence, so we have to rebuild it from scratch - rng = np.random.Generator(type(rng.bit_generator)(seed_seq)) + rng = random_generator_from_state(rng_state) self._msg_pipe = msg_pipe self._step_method = step_method self._step_method_is_pickled = step_method_is_pickled @@ -263,8 +268,7 @@ def __init__( self._shared_point, draws, tune, - rng, - rng.bit_generator.seed_seq, + get_state_from_generator(rng), blas_cores, ), )