diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 4edc80433d..67417e0d8f 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -33,7 +33,13 @@ from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError -from pymc.util import CustomProgress, default_progress_theme +from pymc.util import ( + CustomProgress, + RandomGeneratorState, + default_progress_theme, + get_state_from_generator, + random_generator_from_state, +) logger = logging.getLogger(__name__) @@ -96,13 +102,12 @@ def __init__( shared_point, draws: int, tune: int, - rng: np.random.Generator, - seed_seq: np.random.SeedSequence, + rng_state: RandomGeneratorState, blas_cores, ): # For some strange reason, spawn multiprocessing doesn't copy the rng # seed sequence, so we have to rebuild it from scratch - rng = np.random.Generator(type(rng.bit_generator)(seed_seq)) + rng = random_generator_from_state(rng_state) self._msg_pipe = msg_pipe self._step_method = step_method self._step_method_is_pickled = step_method_is_pickled @@ -263,8 +268,7 @@ def __init__( self._shared_point, draws, tune, - rng, - rng.bit_generator.seed_seq, + get_state_from_generator(rng), blas_cores, ), )