Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MCMC non-deterministic seeding with Generators #7637

Merged
merged 8 commits into from
Jan 7, 2025
20 changes: 10 additions & 10 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ repos:
hooks:
- id: sphinx-lint
args: ["."]
- repo: https://github.com/lucianopaz/head_of_apache
rev: "0.0.3"
hooks:
- id: head_of_apache
args:
- --author=The PyMC Developers
- --exclude=docs/
- --exclude=scripts/
- --exclude=binder/
- --exclude=versioneer.py
#- repo: https://github.com/lucianopaz/head_of_apache
# rev: "0.0.3"
# hooks:
# - id: head_of_apache
# args:
# - --author=The PyMC Developers
# - --exclude=docs/
# - --exclude=scripts/
# - --exclude=binder/
# - --exclude=versioneer.py
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.4
hooks:
Expand Down
2 changes: 2 additions & 0 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,9 @@ def _sample_many(
step: function
Step function
"""
initial_step_state = step.sampling_state
for i in range(chains):
step.sampling_state = initial_step_state
_sample(
draws=draws,
chain=i,
Expand Down
16 changes: 10 additions & 6 deletions pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@

from pymc.blocking import DictToArrayBijection
from pymc.exceptions import SamplingError
from pymc.util import CustomProgress, default_progress_theme
from pymc.util import (
CustomProgress,
RandomGeneratorState,
default_progress_theme,
get_state_from_generator,
random_generator_from_state,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -96,13 +102,12 @@ def __init__(
shared_point,
draws: int,
tune: int,
rng: np.random.Generator,
seed_seq: np.random.SeedSequence,
rng_state: RandomGeneratorState,
blas_cores,
):
# For some strange reason, spawn multiprocessing doesn't copy the rng
# seed sequence, so we have to rebuild it from scratch
rng = np.random.Generator(type(rng.bit_generator)(seed_seq))
rng = random_generator_from_state(rng_state)
self._msg_pipe = msg_pipe
self._step_method = step_method
self._step_method_is_pickled = step_method_is_pickled
Expand Down Expand Up @@ -263,8 +268,7 @@ def __init__(
self._shared_point,
draws,
tune,
rng,
rng.bit_generator.seed_seq,
get_state_from_generator(rng),
blas_cores,
),
)
Expand Down
9 changes: 7 additions & 2 deletions pymc/step_methods/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@

from pymc.blocking import PointType, StatDtype, StatsDict, StatShape, StatsType
from pymc.model import modelcontext
from pymc.step_methods.state import DataClassState, WithSamplingState, dataclass_state
from pymc.step_methods.state import (
DataClassState,
RandomGeneratorState,
WithSamplingState,
dataclass_state,
)
from pymc.util import RandomGenerator, get_random_generator

__all__ = ("Competence", "CompoundStep")
Expand Down Expand Up @@ -91,7 +96,7 @@ def infer_warn_stats_info(

@dataclass_state
class StepMethodState(DataClassState):
rng: np.random.Generator
rng: RandomGeneratorState


class BlockedStep(ABC, WithSamplingState):
Expand Down
16 changes: 11 additions & 5 deletions pymc/step_methods/hmc/quadpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
from scipy.sparse import issparse

from pymc.pytensorf import floatX
from pymc.step_methods.state import DataClassState, WithSamplingState, dataclass_state
from pymc.step_methods.state import (
DataClassState,
RandomGeneratorState,
WithSamplingState,
dataclass_state,
)
from pymc.util import RandomGenerator, get_random_generator

__all__ = [
Expand Down Expand Up @@ -105,7 +110,7 @@ def __str__(self):

@dataclass_state
class PotentialState(DataClassState):
rng: np.random.Generator
rng: RandomGeneratorState


class QuadPotential(WithSamplingState):
Expand Down Expand Up @@ -476,9 +481,8 @@ def current_mean(self, out=None):
class QuadPotentialDiagAdaptExpState(QuadPotentialDiagAdaptState):
_alpha: float
_stop_adaptation: float
_variance_estimator: ExpWeightedVarianceState

_variance_estimator_grad: ExpWeightedVarianceState | None = None
_variance_estimator: ExpWeightedVarianceState | None
_variance_estimator_grad: ExpWeightedVarianceState | None


class QuadPotentialDiagAdaptExp(QuadPotentialDiagAdapt):
Expand Down Expand Up @@ -524,6 +528,8 @@ def __init__(self, *args, alpha, use_grads=False, stop_adaptation=None, rng=None
if stop_adaptation is None:
stop_adaptation = np.inf
self._stop_adaptation = stop_adaptation
self._variance_estimator = None
self._variance_estimator_grad = None

def update(self, sample, grad, tune):
if tune and self._n_samples < self._stop_adaptation:
Expand Down
7 changes: 7 additions & 0 deletions pymc/step_methods/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import numpy as np

from pymc.util import RandomGeneratorState, get_state_from_generator, random_generator_from_state

dataclass_state = dataclass(kw_only=True)


Expand Down Expand Up @@ -66,8 +68,11 @@ def sampling_state(self) -> DataClassState:
kwargs = {}
for field in fields(state_class):
val = getattr(self, field.name)
_val: Any
if isinstance(val, WithSamplingState):
_val = val.sampling_state
elif isinstance(val, np.random.Generator):
_val = get_state_from_generator(val)
else:
_val = val
kwargs[field.name] = deepcopy(_val)
Expand All @@ -81,6 +86,8 @@ def sampling_state(self, state: DataClassState):
), f"Encountered invalid state class '{state.__class__}'. State must be '{state_class}'"
for field in fields(state_class):
state_val = deepcopy(getattr(state, field.name))
if isinstance(state_val, RandomGeneratorState):
state_val = random_generator_from_state(state_val)
self_val = getattr(self, field.name)
is_frozen = field.metadata.get("frozen", False)
if is_frozen:
Expand Down
32 changes: 31 additions & 1 deletion pymc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import functools
import warnings

from collections import namedtuple
from collections.abc import Sequence
from copy import deepcopy
from typing import NewType, cast
Expand Down Expand Up @@ -601,6 +602,31 @@ def update(
return None


RandomGeneratorState = namedtuple("RandomGeneratorState", ["bit_generator_state", "seed_seq_state"])


def get_state_from_generator(
rng: np.random.Generator | np.random.BitGenerator,
) -> RandomGeneratorState:
assert isinstance(rng, (np.random.Generator | np.random.BitGenerator))
bit_gen: np.random.BitGenerator = (
rng.bit_generator if isinstance(rng, np.random.Generator) else rng
)

return RandomGeneratorState(
bit_generator_state=bit_gen.state,
seed_seq_state=bit_gen.seed_seq.state, # type: ignore[attr-defined]
)


def random_generator_from_state(state: RandomGeneratorState) -> np.random.Generator:
seed_seq = np.random.SeedSequence(**state.seed_seq_state)
bit_generator_class = getattr(np.random, state.bit_generator_state["bit_generator"])
bit_generator = bit_generator_class(seed_seq)
bit_generator.state = state.bit_generator_state
return np.random.Generator(bit_generator)


def get_random_generator(
seed: RandomGenerator | np.random.RandomState = None, copy: bool = True
) -> np.random.Generator:
Expand Down Expand Up @@ -645,6 +671,10 @@ def get_random_generator(
# In the former case, it will return seed, in the latter it will return
# a new Generator object that has the same BitGenerator. This would potentially
# make the new generator be shared across many users. To avoid this, we
# deepcopy by default.
# copy by default.
# Also, because of https://github.com/numpy/numpy/issues/27727, we can't use
# deepcopy. We must rebuild a Generator without losing the SeedSequence information
if isinstance(seed, np.random.Generator | np.random.BitGenerator):
return random_generator_from_state(get_state_from_generator(seed))
seed = deepcopy(seed)
return np.random.default_rng(seed)
4 changes: 2 additions & 2 deletions tests/logprob/test_transform_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pytest
import scipy as sp

from numdifftools import Jacobian
from numdifftools import Derivative, Jacobian
from pytensor import scan
from pytensor import tensor as pt
from pytensor.compile.builders import OpFromGraph
Expand Down Expand Up @@ -279,7 +279,7 @@ def a_backward_fn_(x):

exp_log_jac_val = jacobian_estimate(a_trans_value)
else:
jacobian_val = np.atleast_2d(sp.misc.derivative(a_backward_fn, a_trans_value, dx=1e-6))
jacobian_val = np.atleast_2d(Derivative(a_backward_fn, step=1e-6)(a_trans_value))
exp_log_jac_val = np.linalg.slogdet(jacobian_val)[-1]

log_jac_val = log_jac_fn(a_trans_value)
Expand Down
53 changes: 28 additions & 25 deletions tests/sampling/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,32 +148,35 @@ def test_sample_does_not_rely_on_external_global_seeding(self):
assert np.all(idata12["x"] != idata22["x"])
assert np.all(idata13["x"] != idata23["x"])

def test_sample_init(self):
@pytest.mark.parametrize(
"init",
(
"advi",
"advi_map",
"map",
"adapt_diag",
"jitter+adapt_diag",
"jitter+adapt_diag_grad",
"adapt_full",
"jitter+adapt_full",
),
)
def test_sample_init(self, init):
with self.model:
for init in (
"advi",
"advi_map",
"map",
"adapt_diag",
"jitter+adapt_diag",
"jitter+adapt_diag_grad",
"adapt_full",
"jitter+adapt_full",
):
kwargs = {
"init": init,
"tune": 120,
"n_init": 1000,
"draws": 50,
"random_seed": 20160911,
}
with warnings.catch_warnings(record=True) as rec:
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
if init.endswith("adapt_full"):
with pytest.warns(UserWarning, match="experimental feature"):
pm.sample(**kwargs)
else:
pm.sample(**kwargs)
kwargs = {
"init": init,
"tune": 120,
"n_init": 1000,
"draws": 50,
"random_seed": 20160911,
}
with warnings.catch_warnings(record=True) as rec:
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
if init.endswith("adapt_full"):
with pytest.warns(UserWarning, match="experimental feature"):
pm.sample(**kwargs, cores=1)
else:
pm.sample(**kwargs, cores=1)

def test_sample_args(self):
with self.model:
Expand Down
20 changes: 20 additions & 0 deletions tests/sampling/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,23 @@ def logp(x, mu):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn")


@pytest.mark.parametrize("cores", (1, 2))
def test_sampling_with_random_generator_matches(cores):
# Regression test for https://github.com/pymc-devs/pymc/issues/7612
kwargs = {
"chains": 2,
"cores": cores,
"tune": 10,
"draws": 10,
"compute_convergence_checks": False,
"progress_bar": False,
}
with pm.Model() as m:
x = pm.Normal("x")

post1 = pm.sample(random_seed=np.random.default_rng(42), **kwargs).posterior
post2 = pm.sample(random_seed=np.random.default_rng(42), **kwargs).posterior

assert post1.equals(post2), (post1["x"].mean().item(), post2["x"].mean().item())
5 changes: 1 addition & 4 deletions tests/step_methods/test_metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

import warnings

from copy import deepcopy

import arviz as az
import numpy as np
import numpy.testing as npt
Expand Down Expand Up @@ -406,8 +404,7 @@ def test_sampling_state(step_method, model_fn):
sampler = step_method(model.value_vars)
if hasattr(sampler, "link_population"):
sampler.link_population([initial_point] * 100, 0)
sampler_orig = deepcopy(sampler)
state_orig = sampler_orig.sampling_state
state_orig = sampler.sampling_state

sample1, stat1 = sampler.step(initial_point)
sampler.tune = False
Expand Down
Loading