diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f523706d68c..10fd36fd944 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,16 +37,16 @@ repos: hooks: - id: sphinx-lint args: ["."] -- repo: https://github.com/lucianopaz/head_of_apache - rev: "0.0.3" - hooks: - - id: head_of_apache - args: - - --author=The PyMC Developers - - --exclude=docs/ - - --exclude=scripts/ - - --exclude=binder/ - - --exclude=versioneer.py +#- repo: https://github.com/lucianopaz/head_of_apache +# rev: "0.0.3" +# hooks: +# - id: head_of_apache +# args: +# - --author=The PyMC Developers +# - --exclude=docs/ +# - --exclude=scripts/ +# - --exclude=binder/ +# - --exclude=versioneer.py - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.8.4 hooks: diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index bc3e3475d10..fce64e3b38a 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1068,7 +1068,9 @@ def _sample_many( step: function Step function """ + initial_step_state = step.sampling_state for i in range(chains): + step.sampling_state = initial_step_state _sample( draws=draws, chain=i, diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 4edc80433de..67417e0d8f1 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -33,7 +33,13 @@ from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError -from pymc.util import CustomProgress, default_progress_theme +from pymc.util import ( + CustomProgress, + RandomGeneratorState, + default_progress_theme, + get_state_from_generator, + random_generator_from_state, +) logger = logging.getLogger(__name__) @@ -96,13 +102,12 @@ def __init__( shared_point, draws: int, tune: int, - rng: np.random.Generator, - seed_seq: np.random.SeedSequence, + rng_state: RandomGeneratorState, blas_cores, ): # For some strange reason, spawn multiprocessing doesn't copy the rng # seed sequence, so we have to rebuild it from scratch - rng = np.random.Generator(type(rng.bit_generator)(seed_seq)) + rng = random_generator_from_state(rng_state) self._msg_pipe = msg_pipe self._step_method = step_method self._step_method_is_pickled = step_method_is_pickled @@ -263,8 +268,7 @@ def __init__( self._shared_point, draws, tune, - rng, - rng.bit_generator.seed_seq, + get_state_from_generator(rng), blas_cores, ), ) diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index 253e0bd0447..d0393afd570 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -31,7 +31,12 @@ from pymc.blocking import PointType, StatDtype, StatsDict, StatShape, StatsType from pymc.model import modelcontext -from pymc.step_methods.state import DataClassState, WithSamplingState, dataclass_state +from pymc.step_methods.state import ( + DataClassState, + RandomGeneratorState, + WithSamplingState, + dataclass_state, +) from pymc.util import RandomGenerator, get_random_generator __all__ = ("Competence", "CompoundStep") @@ -91,7 +96,7 @@ def infer_warn_stats_info( @dataclass_state class StepMethodState(DataClassState): - rng: np.random.Generator + rng: RandomGeneratorState class BlockedStep(ABC, WithSamplingState): diff --git a/pymc/step_methods/hmc/quadpotential.py b/pymc/step_methods/hmc/quadpotential.py index 33f3571eda5..2c1b500cc63 100644 --- a/pymc/step_methods/hmc/quadpotential.py +++ b/pymc/step_methods/hmc/quadpotential.py @@ -26,7 +26,12 @@ from scipy.sparse import issparse from pymc.pytensorf import floatX -from pymc.step_methods.state import DataClassState, WithSamplingState, dataclass_state +from pymc.step_methods.state import ( + DataClassState, + RandomGeneratorState, + WithSamplingState, + dataclass_state, +) from pymc.util import RandomGenerator, get_random_generator __all__ = [ @@ -105,7 +110,7 @@ def __str__(self): @dataclass_state class PotentialState(DataClassState): - rng: np.random.Generator + rng: RandomGeneratorState class QuadPotential(WithSamplingState): @@ -476,9 +481,8 @@ def current_mean(self, out=None): class QuadPotentialDiagAdaptExpState(QuadPotentialDiagAdaptState): _alpha: float _stop_adaptation: float - _variance_estimator: ExpWeightedVarianceState - - _variance_estimator_grad: ExpWeightedVarianceState | None = None + _variance_estimator: ExpWeightedVarianceState | None + _variance_estimator_grad: ExpWeightedVarianceState | None class QuadPotentialDiagAdaptExp(QuadPotentialDiagAdapt): @@ -524,6 +528,8 @@ def __init__(self, *args, alpha, use_grads=False, stop_adaptation=None, rng=None if stop_adaptation is None: stop_adaptation = np.inf self._stop_adaptation = stop_adaptation + self._variance_estimator = None + self._variance_estimator_grad = None def update(self, sample, grad, tune): if tune and self._n_samples < self._stop_adaptation: diff --git a/pymc/step_methods/state.py b/pymc/step_methods/state.py index 9b85d7784bb..e24276cf143 100644 --- a/pymc/step_methods/state.py +++ b/pymc/step_methods/state.py @@ -17,6 +17,8 @@ import numpy as np +from pymc.util import RandomGeneratorState, get_state_from_generator, random_generator_from_state + dataclass_state = dataclass(kw_only=True) @@ -66,8 +68,11 @@ def sampling_state(self) -> DataClassState: kwargs = {} for field in fields(state_class): val = getattr(self, field.name) + _val: Any if isinstance(val, WithSamplingState): _val = val.sampling_state + elif isinstance(val, np.random.Generator): + _val = get_state_from_generator(val) else: _val = val kwargs[field.name] = deepcopy(_val) @@ -81,6 +86,8 @@ def sampling_state(self, state: DataClassState): ), f"Encountered invalid state class '{state.__class__}'. State must be '{state_class}'" for field in fields(state_class): state_val = deepcopy(getattr(state, field.name)) + if isinstance(state_val, RandomGeneratorState): + state_val = random_generator_from_state(state_val) self_val = getattr(self, field.name) is_frozen = field.metadata.get("frozen", False) if is_frozen: diff --git a/pymc/util.py b/pymc/util.py index 8ec8aa84dea..8a059d7e0d6 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -15,6 +15,7 @@ import functools import warnings +from collections import namedtuple from collections.abc import Sequence from copy import deepcopy from typing import NewType, cast @@ -601,6 +602,31 @@ def update( return None +RandomGeneratorState = namedtuple("RandomGeneratorState", ["bit_generator_state", "seed_seq_state"]) + + +def get_state_from_generator( + rng: np.random.Generator | np.random.BitGenerator, +) -> RandomGeneratorState: + assert isinstance(rng, (np.random.Generator | np.random.BitGenerator)) + bit_gen: np.random.BitGenerator = ( + rng.bit_generator if isinstance(rng, np.random.Generator) else rng + ) + + return RandomGeneratorState( + bit_generator_state=bit_gen.state, + seed_seq_state=bit_gen.seed_seq.state, # type: ignore[attr-defined] + ) + + +def random_generator_from_state(state: RandomGeneratorState) -> np.random.Generator: + seed_seq = np.random.SeedSequence(**state.seed_seq_state) + bit_generator_class = getattr(np.random, state.bit_generator_state["bit_generator"]) + bit_generator = bit_generator_class(seed_seq) + bit_generator.state = state.bit_generator_state + return np.random.Generator(bit_generator) + + def get_random_generator( seed: RandomGenerator | np.random.RandomState = None, copy: bool = True ) -> np.random.Generator: @@ -645,6 +671,10 @@ def get_random_generator( # In the former case, it will return seed, in the latter it will return # a new Generator object that has the same BitGenerator. This would potentially # make the new generator be shared across many users. To avoid this, we - # deepcopy by default. + # copy by default. + # Also, because of https://github.com/numpy/numpy/issues/27727, we can't use + # deepcopy. We must rebuild a Generator without losing the SeedSequence information + if isinstance(seed, np.random.Generator | np.random.BitGenerator): + return random_generator_from_state(get_state_from_generator(seed)) seed = deepcopy(seed) return np.random.default_rng(seed) diff --git a/tests/logprob/test_transform_value.py b/tests/logprob/test_transform_value.py index c2529ddb964..491a38086cc 100644 --- a/tests/logprob/test_transform_value.py +++ b/tests/logprob/test_transform_value.py @@ -19,7 +19,7 @@ import pytest import scipy as sp -from numdifftools import Jacobian +from numdifftools import Derivative, Jacobian from pytensor import scan from pytensor import tensor as pt from pytensor.compile.builders import OpFromGraph @@ -279,7 +279,7 @@ def a_backward_fn_(x): exp_log_jac_val = jacobian_estimate(a_trans_value) else: - jacobian_val = np.atleast_2d(sp.misc.derivative(a_backward_fn, a_trans_value, dx=1e-6)) + jacobian_val = np.atleast_2d(Derivative(a_backward_fn, step=1e-6)(a_trans_value)) exp_log_jac_val = np.linalg.slogdet(jacobian_val)[-1] log_jac_val = log_jac_fn(a_trans_value) diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 3219d45b76b..41b068e0427 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -148,32 +148,35 @@ def test_sample_does_not_rely_on_external_global_seeding(self): assert np.all(idata12["x"] != idata22["x"]) assert np.all(idata13["x"] != idata23["x"]) - def test_sample_init(self): + @pytest.mark.parametrize( + "init", + ( + "advi", + "advi_map", + "map", + "adapt_diag", + "jitter+adapt_diag", + "jitter+adapt_diag_grad", + "adapt_full", + "jitter+adapt_full", + ), + ) + def test_sample_init(self, init): with self.model: - for init in ( - "advi", - "advi_map", - "map", - "adapt_diag", - "jitter+adapt_diag", - "jitter+adapt_diag_grad", - "adapt_full", - "jitter+adapt_full", - ): - kwargs = { - "init": init, - "tune": 120, - "n_init": 1000, - "draws": 50, - "random_seed": 20160911, - } - with warnings.catch_warnings(record=True) as rec: - warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) - if init.endswith("adapt_full"): - with pytest.warns(UserWarning, match="experimental feature"): - pm.sample(**kwargs) - else: - pm.sample(**kwargs) + kwargs = { + "init": init, + "tune": 120, + "n_init": 1000, + "draws": 50, + "random_seed": 20160911, + } + with warnings.catch_warnings(record=True) as rec: + warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) + if init.endswith("adapt_full"): + with pytest.warns(UserWarning, match="experimental feature"): + pm.sample(**kwargs, cores=1) + else: + pm.sample(**kwargs, cores=1) def test_sample_args(self): with self.model: diff --git a/tests/sampling/test_parallel.py b/tests/sampling/test_parallel.py index 8c71bcac001..c16489610fd 100644 --- a/tests/sampling/test_parallel.py +++ b/tests/sampling/test_parallel.py @@ -228,3 +228,23 @@ def logp(x, mu): with warnings.catch_warnings(): warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn") + + +@pytest.mark.parametrize("cores", (1, 2)) +def test_sampling_with_random_generator_matches(cores): + # Regression test for https://github.com/pymc-devs/pymc/issues/7612 + kwargs = { + "chains": 2, + "cores": cores, + "tune": 10, + "draws": 10, + "compute_convergence_checks": False, + "progress_bar": False, + } + with pm.Model() as m: + x = pm.Normal("x") + + post1 = pm.sample(random_seed=np.random.default_rng(42), **kwargs).posterior + post2 = pm.sample(random_seed=np.random.default_rng(42), **kwargs).posterior + + assert post1.equals(post2), (post1["x"].mean().item(), post2["x"].mean().item()) diff --git a/tests/step_methods/test_metropolis.py b/tests/step_methods/test_metropolis.py index a01e75506b7..0a81797b3c6 100644 --- a/tests/step_methods/test_metropolis.py +++ b/tests/step_methods/test_metropolis.py @@ -14,8 +14,6 @@ import warnings -from copy import deepcopy - import arviz as az import numpy as np import numpy.testing as npt @@ -406,8 +404,7 @@ def test_sampling_state(step_method, model_fn): sampler = step_method(model.value_vars) if hasattr(sampler, "link_population"): sampler.link_population([initial_point] * 100, 0) - sampler_orig = deepcopy(sampler) - state_orig = sampler_orig.sampling_state + state_orig = sampler.sampling_state sample1, stat1 = sampler.step(initial_point) sampler.tune = False