Skip to content

Commit

Permalink
Ensure parallel sampling does not lose BitGenerator state
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz authored and ricardoV94 committed Jan 7, 2025
1 parent 0a99547 commit e8c22f1
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@

from pymc.blocking import DictToArrayBijection
from pymc.exceptions import SamplingError
from pymc.util import CustomProgress, default_progress_theme
from pymc.util import (
CustomProgress,
RandomGeneratorState,
default_progress_theme,
get_state_from_generator,
random_generator_from_state,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -96,13 +102,12 @@ def __init__(
shared_point,
draws: int,
tune: int,
rng: np.random.Generator,
seed_seq: np.random.SeedSequence,
rng_state: RandomGeneratorState,
blas_cores,
):
# For some strange reason, spawn multiprocessing doesn't copy the rng
# seed sequence, so we have to rebuild it from scratch
rng = np.random.Generator(type(rng.bit_generator)(seed_seq))
rng = random_generator_from_state(rng_state)
self._msg_pipe = msg_pipe
self._step_method = step_method
self._step_method_is_pickled = step_method_is_pickled
Expand Down Expand Up @@ -263,8 +268,7 @@ def __init__(
self._shared_point,
draws,
tune,
rng,
rng.bit_generator.seed_seq,
get_state_from_generator(rng),
blas_cores,
),
)
Expand Down

0 comments on commit e8c22f1

Please sign in to comment.