From 31bf864933c541b50bf3c10d7e009c16e3caf023 Mon Sep 17 00:00:00 2001 From: Goose Date: Wed, 11 Dec 2024 12:41:13 +1000 Subject: [PATCH 01/11] use jaxified logp for initial point evaluation when sampling via Jax --- pymc/initial_point.py | 13 ++++++-- pymc/model/core.py | 4 +-- pymc/sampling/jax.py | 63 ++++++++++++++++++++++++++------------ pymc/sampling/mcmc.py | 18 ++++++++--- tests/sampling/test_jax.py | 12 ++++++-- 5 files changed, 79 insertions(+), 31 deletions(-) diff --git a/pymc/initial_point.py b/pymc/initial_point.py index 241409f6834..675b98319b8 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -67,7 +67,7 @@ def make_initial_point_fns_per_chain( overrides: StartDict | Sequence[StartDict | None] | None, jitter_rvs: set[TensorVariable] | None = None, chains: int, -) -> list[Callable]: +) -> list[Callable[[int], PointType]]: """Create an initial point function for each chain, as defined by initvals. If a single initval dictionary is passed, the function is replicated for each @@ -82,6 +82,11 @@ def make_initial_point_fns_per_chain( Random variable tensors for which U(-1, 1) jitter shall be applied. (To the transformed space if applicable.) + Returns + ------- + ipfns : list[Callable[[int], dict[str, np.ndarray]]] + list of functions that return initial points for each chain. + Raises ------ ValueError @@ -124,7 +129,7 @@ def make_initial_point_fn( jitter_rvs: set[TensorVariable] | None = None, default_strategy: str = "support_point", return_transformed: bool = True, -) -> Callable: +) -> Callable[[int], PointType]: """Create seeded function that computes initial values for all free model variables. Parameters @@ -138,6 +143,10 @@ def make_initial_point_fn( Initial value (strategies) to use instead of what's specified in `Model.initial_values`. return_transformed : bool If `True` the returned variables will correspond to transformed initial values. + + Returns + ------- + initial_point_fn : Callable[[int], dict[str, np.ndarray]] """ sdict_overrides = convert_str_to_rv_dict(model, overrides or {}) initval_strats = { diff --git a/pymc/model/core.py b/pymc/model/core.py index 99711e566ed..9b6c48506d1 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -19,7 +19,7 @@ import types import warnings -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence from typing import ( Literal, cast, @@ -585,7 +585,7 @@ def compile_logp( jacobian: bool = True, sum: bool = True, **compile_kwargs, - ) -> PointFunc: + ) -> Callable[[PointType], np.ndarray]: """Compiled log probability density function. Parameters diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 43e1baa87fa..704c0832594 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -144,13 +144,15 @@ def get_jaxified_graph( return jax_funcify(fgraph) -def get_jaxified_logp(model: Model, negative_logp=True) -> Callable: +def get_jaxified_logp( + model: Model, negative_logp=True +) -> Callable[[Sequence[np.ndarray]], np.ndarray]: model_logp = model.logp() if not negative_logp: model_logp = -model_logp logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp]) - def logp_fn_wrap(x): + def logp_fn_wrap(x: Sequence[np.ndarray]) -> np.ndarray: return logp_fn(*x)[0] return logp_fn_wrap @@ -211,23 +213,39 @@ def _get_batched_jittered_initial_points( chains: int, initvals: StartDict | Sequence[StartDict | None] | None, random_seed: RandomSeed, + logp_fn: Callable[[Sequence[np.ndarray]], np.ndarray], jitter: bool = True, jitter_max_retries: int = 10, ) -> np.ndarray | list[np.ndarray]: - """Get jittered initial point in format expected by NumPyro MCMC kernel. + """Get jittered initial point in format expected by Jax MCMC kernel. + + Parameters + ---------- + logp_fn : Callable[Sequence[np.ndarray]], np.ndarray] + Jaxified logp function Returns ------- - out: list of ndarrays + out: list[np.ndarray] list with one item per variable and number of chains as batch dimension. Each item has shape `(chains, *var.shape)` """ + + def eval_logp_initial_point(point: dict[str, np.ndarray]) -> np.ndarray: + """Wrap logp_fn to conform to _init_jitter logic. + + Wraps jaxified logp function to accept a dict of + {model_variable: np.array} key:value pairs. + """ + return logp_fn(point.values()) + initial_points = _init_jitter( model, initvals, seeds=_get_seeds_per_chain(random_seed, chains), jitter=jitter, jitter_max_retries=jitter_max_retries, + logp_fn=eval_logp_initial_point, ) initial_points_values = [list(initial_point.values()) for initial_point in initial_points] if chains == 1: @@ -236,7 +254,7 @@ def _get_batched_jittered_initial_points( def _blackjax_inference_loop( - seed, init_position, logprob_fn, draws, tune, target_accept, **adaptation_kwargs + seed, init_position, logp_fn, draws, tune, target_accept, **adaptation_kwargs ): import blackjax @@ -252,13 +270,13 @@ def _blackjax_inference_loop( adapt = blackjax.window_adaptation( algorithm=algorithm, - logdensity_fn=logprob_fn, + logdensity_fn=logp_fn, target_acceptance_rate=target_accept, adaptation_info_fn=get_filter_adapt_info_fn(), **adaptation_kwargs, ) (last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune) - kernel = algorithm(logprob_fn, **tuned_params).step + kernel = algorithm(logp_fn, **tuned_params).step def _one_step(state, xs): _, rng_key = xs @@ -292,8 +310,9 @@ def _sample_blackjax_nuts( chain_method: str | None, progressbar: bool, random_seed: int, - initial_points, + initial_points: np.ndarray | list[np.ndarray], nuts_kwargs, + logp_fn: Callable[[Sequence[np.ndarray]], np.ndarray] | None = None, ) -> az.InferenceData: """ Draw samples from the posterior using the NUTS method from the ``blackjax`` library. @@ -366,7 +385,8 @@ def _sample_blackjax_nuts( if chains == 1: initial_points = [np.stack(init_state) for init_state in zip(initial_points)] - logprob_fn = get_jaxified_logp(model) + if logp_fn is None: + logp_fn = get_jaxified_logp(model) seed = jax.random.PRNGKey(random_seed) keys = jax.random.split(seed, chains) @@ -374,7 +394,7 @@ def _sample_blackjax_nuts( nuts_kwargs["progress_bar"] = progressbar get_posterior_samples = partial( _blackjax_inference_loop, - logprob_fn=logprob_fn, + logp_fn=logp_fn, tune=tune, draws=draws, target_accept=target_accept, @@ -415,14 +435,16 @@ def _sample_numpyro_nuts( chain_method: str | None, progressbar: bool, random_seed: int, - initial_points, + initial_points: np.ndarray | list[np.ndarray], nuts_kwargs: dict[str, Any], + logp_fn: Callable | None = None, ): import numpyro from numpyro.infer import MCMC, NUTS - logp_fn = get_jaxified_logp(model, negative_logp=False) + if logp_fn is None: + logp_fn = get_jaxified_logp(model, negative_logp=False) nuts_kwargs.setdefault("adapt_step_size", True) nuts_kwargs.setdefault("adapt_mass_matrix", True) @@ -590,6 +612,15 @@ def sample_jax_nuts( get_default_varnames(filtered_var_names, include_transformed=keep_untransformed) ) + if nuts_sampler == "numpyro": + sampler_fn = _sample_numpyro_nuts + logp_fn = get_jaxified_logp(model, negative_logp=False) + elif nuts_sampler == "blackjax": + sampler_fn = _sample_blackjax_nuts + logp_fn = get_jaxified_logp(model) + else: + raise ValueError(f"{nuts_sampler=} not recognized") + (random_seed,) = _get_seeds_per_chain(random_seed, 1) initial_points = _get_batched_jittered_initial_points( @@ -598,15 +629,9 @@ def sample_jax_nuts( initvals=initvals, random_seed=random_seed, jitter=jitter, + logp_fn=logp_fn, ) - if nuts_sampler == "numpyro": - sampler_fn = _sample_numpyro_nuts - elif nuts_sampler == "blackjax": - sampler_fn = _sample_blackjax_nuts - else: - raise ValueError(f"{nuts_sampler=} not recognized") - tic1 = datetime.now() raw_mcmc_samples, sample_stats, library = sampler_fn( model=model, diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index bc3e3475d10..25615122e66 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1339,6 +1339,7 @@ def _init_jitter( jitter: bool, jitter_max_retries: int, logp_dlogp_func=None, + logp_fn: Callable[[PointType], np.ndarray] | None = None, ) -> list[PointType]: """Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain. @@ -1353,11 +1354,13 @@ def _init_jitter( Whether to apply jitter or not. jitter_max_retries : int Maximum number of repeated attempts at initializing values (per chain). + logp_fn: Callable[[dict[str, np.ndarray]], np.ndarray] + Jaxified logp function that takes the output of the initial point functions as input. Returns ------- - start : ``pymc.model.Point`` - Starting point for sampler + initial_points : list[dict[str, np.ndarray]] + List of starting points for the sampler """ ipfns = make_initial_point_fns_per_chain( model=model, @@ -1369,12 +1372,17 @@ def _init_jitter( if not jitter: return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)] - model_logp_fn: Callable + model_logp_fn: Callable[[PointType], np.ndarray] if logp_dlogp_func is None: - model_logp_fn = model.compile_logp() + if logp_fn is None: + # pymc NUTS path + model_logp_fn = model.compile_logp() + else: + # Jax path + model_logp_fn = logp_fn else: - def model_logp_fn(ip): + def model_logp_fn(ip: PointType) -> np.ndarray: q, _ = DictToArrayBijection.map(ip) return logp_dlogp_func([q], extra_vars={})[0] diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py index d6a8d1021b7..e17a924243b 100644 --- a/tests/sampling/test_jax.py +++ b/tests/sampling/test_jax.py @@ -333,20 +333,26 @@ def test_get_batched_jittered_initial_points(): with pm.Model() as model: x = pm.MvNormal("x", mu=np.zeros(3), cov=np.eye(3), shape=(2, 3), initval=np.zeros((2, 3))) + logp_fn = get_jaxified_logp(model) + # No jitter ips = _get_batched_jittered_initial_points( - model=model, chains=1, random_seed=1, initvals=None, jitter=False + model=model, chains=1, random_seed=1, initvals=None, jitter=False, logp_fn=logp_fn ) assert np.all(ips[0] == 0) # Single chain - ips = _get_batched_jittered_initial_points(model=model, chains=1, random_seed=1, initvals=None) + ips = _get_batched_jittered_initial_points( + model=model, chains=1, random_seed=1, initvals=None, logp_fn=logp_fn + ) assert ips[0].shape == (2, 3) assert np.all(ips[0] != 0) # Multiple chains - ips = _get_batched_jittered_initial_points(model=model, chains=2, random_seed=1, initvals=None) + ips = _get_batched_jittered_initial_points( + model=model, chains=2, random_seed=1, initvals=None, logp_fn=logp_fn + ) assert ips[0].shape == (2, 2, 3) assert np.all(ips[0][0] != ips[0][1]) From 3996a067a6e63a124f5024a170e520184dc114fb Mon Sep 17 00:00:00 2001 From: Goose Date: Wed, 11 Dec 2024 21:02:09 +1000 Subject: [PATCH 02/11] correcting initial point type hinting --- pymc/initial_point.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pymc/initial_point.py b/pymc/initial_point.py index 675b98319b8..da55fc38105 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -26,6 +26,7 @@ from pymc.logprob.transforms import Transform from pymc.pytensorf import ( + SeedSequenceSeed, compile, find_rng_nodes, replace_rng_nodes, @@ -67,7 +68,7 @@ def make_initial_point_fns_per_chain( overrides: StartDict | Sequence[StartDict | None] | None, jitter_rvs: set[TensorVariable] | None = None, chains: int, -) -> list[Callable[[int], PointType]]: +) -> list[Callable[[SeedSequenceSeed], PointType]]: """Create an initial point function for each chain, as defined by initvals. If a single initval dictionary is passed, the function is replicated for each @@ -84,7 +85,7 @@ def make_initial_point_fns_per_chain( Returns ------- - ipfns : list[Callable[[int], dict[str, np.ndarray]]] + ipfns : list[Callable[[SeedSequenceSeed], dict[str, np.ndarray]]] list of functions that return initial points for each chain. Raises @@ -129,7 +130,7 @@ def make_initial_point_fn( jitter_rvs: set[TensorVariable] | None = None, default_strategy: str = "support_point", return_transformed: bool = True, -) -> Callable[[int], PointType]: +) -> Callable[[SeedSequenceSeed], PointType]: """Create seeded function that computes initial values for all free model variables. Parameters @@ -146,7 +147,7 @@ def make_initial_point_fn( Returns ------- - initial_point_fn : Callable[[int], dict[str, np.ndarray]] + initial_point_fn : Callable[[SeedSequenceSeed], dict[str, np.ndarray]] """ sdict_overrides = convert_str_to_rv_dict(model, overrides or {}) initval_strats = { From 1fb9df15d1d72327d935af1f011addab50561a23 Mon Sep 17 00:00:00 2001 From: Goose Date: Wed, 11 Dec 2024 21:26:18 +1000 Subject: [PATCH 03/11] refactor init_jitter inputs --- pymc/sampling/jax.py | 18 +++++++++++------- pymc/sampling/mcmc.py | 26 +++++++++++--------------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 704c0832594..2df2656261c 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -213,7 +213,7 @@ def _get_batched_jittered_initial_points( chains: int, initvals: StartDict | Sequence[StartDict | None] | None, random_seed: RandomSeed, - logp_fn: Callable[[Sequence[np.ndarray]], np.ndarray], + logp_fn: Callable[[Sequence[np.ndarray]], np.ndarray] | None = None, jitter: bool = True, jitter_max_retries: int = 10, ) -> np.ndarray | list[np.ndarray]: @@ -230,14 +230,18 @@ def _get_batched_jittered_initial_points( list with one item per variable and number of chains as batch dimension. Each item has shape `(chains, *var.shape)` """ + if logp_fn is None: + eval_logp_initial_point = None + + else: - def eval_logp_initial_point(point: dict[str, np.ndarray]) -> np.ndarray: - """Wrap logp_fn to conform to _init_jitter logic. + def eval_logp_initial_point(point: dict[str, np.ndarray]) -> np.ndarray: + """Wrap logp_fn to conform to _init_jitter logic. - Wraps jaxified logp function to accept a dict of - {model_variable: np.array} key:value pairs. - """ - return logp_fn(point.values()) + Wraps jaxified logp function to accept a dict of + {model_variable: np.array} key:value pairs. + """ + return logp_fn(point.values()) initial_points = _init_jitter( model, diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 25615122e66..3dc64ab3448 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1338,7 +1338,6 @@ def _init_jitter( seeds: Sequence[int] | np.ndarray, jitter: bool, jitter_max_retries: int, - logp_dlogp_func=None, logp_fn: Callable[[PointType], np.ndarray] | None = None, ) -> list[PointType]: """Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain. @@ -1354,8 +1353,9 @@ def _init_jitter( Whether to apply jitter or not. jitter_max_retries : int Maximum number of repeated attempts at initializing values (per chain). - logp_fn: Callable[[dict[str, np.ndarray]], np.ndarray] + logp_fn: Callable[[dict[str, np.ndarray]], np.ndarray] | None Jaxified logp function that takes the output of the initial point functions as input. + If None, will use the results of model.compile_logp(). Returns ------- @@ -1372,19 +1372,10 @@ def _init_jitter( if not jitter: return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)] - model_logp_fn: Callable[[PointType], np.ndarray] - if logp_dlogp_func is None: - if logp_fn is None: - # pymc NUTS path - model_logp_fn = model.compile_logp() - else: - # Jax path - model_logp_fn = logp_fn + if logp_fn is None: + model_logp_fn = model.compile_logp() else: - - def model_logp_fn(ip: PointType) -> np.ndarray: - q, _ = DictToArrayBijection.map(ip) - return logp_dlogp_func([q], extra_vars={})[0] + model_logp_fn = logp_fn initial_points = [] for ipfn, seed in zip(ipfns, seeds): @@ -1509,13 +1500,18 @@ def init_nuts( logp_dlogp_func = model.logp_dlogp_function(ravel_inputs=True, **compile_kwargs) logp_dlogp_func.trust_input = True + + def model_logp_fn(ip: PointType) -> np.ndarray: + q, _ = DictToArrayBijection.map(ip) + return logp_dlogp_func([q], extra_vars={})[0] + initial_points = _init_jitter( model, initvals, seeds=random_seed_list, jitter="jitter" in init, jitter_max_retries=jitter_max_retries, - logp_dlogp_func=logp_dlogp_func, + logp_fn=model_logp_fn, ) apoints = [DictToArrayBijection.map(point) for point in initial_points] From f71aedc38b76a83dc64c94ed9b303d0fcaea714f Mon Sep 17 00:00:00 2001 From: Goose Date: Wed, 11 Dec 2024 21:28:55 +1000 Subject: [PATCH 04/11] revert changes to tests --- tests/sampling/test_jax.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py index e17a924243b..a8db78d5988 100644 --- a/tests/sampling/test_jax.py +++ b/tests/sampling/test_jax.py @@ -337,22 +337,18 @@ def test_get_batched_jittered_initial_points(): # No jitter ips = _get_batched_jittered_initial_points( - model=model, chains=1, random_seed=1, initvals=None, jitter=False, logp_fn=logp_fn + model=model, chains=1, random_seed=1, initvals=None, jitter=False ) assert np.all(ips[0] == 0) # Single chain - ips = _get_batched_jittered_initial_points( - model=model, chains=1, random_seed=1, initvals=None, logp_fn=logp_fn - ) + ips = _get_batched_jittered_initial_points(model=model, chains=1, random_seed=1, initvals=None) assert ips[0].shape == (2, 3) assert np.all(ips[0] != 0) # Multiple chains - ips = _get_batched_jittered_initial_points( - model=model, chains=2, random_seed=1, initvals=None, logp_fn=logp_fn - ) + ips = _get_batched_jittered_initial_points(model=model, chains=2, random_seed=1, initvals=None) assert ips[0].shape == (2, 2, 3) assert np.all(ips[0][0] != ips[0][1]) From 2855587faf5d378f587746d2336ec5cc00d8a6d8 Mon Sep 17 00:00:00 2001 From: Goose Date: Wed, 11 Dec 2024 21:31:08 +1000 Subject: [PATCH 05/11] removed redundant line from test --- tests/sampling/test_jax.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py index a8db78d5988..d6a8d1021b7 100644 --- a/tests/sampling/test_jax.py +++ b/tests/sampling/test_jax.py @@ -333,8 +333,6 @@ def test_get_batched_jittered_initial_points(): with pm.Model() as model: x = pm.MvNormal("x", mu=np.zeros(3), cov=np.eye(3), shape=(2, 3), initval=np.zeros((2, 3))) - logp_fn = get_jaxified_logp(model) - # No jitter ips = _get_batched_jittered_initial_points( model=model, chains=1, random_seed=1, initvals=None, jitter=False From 85996a1c838cf23e427b0386c35ce7dc665daeb5 Mon Sep 17 00:00:00 2001 From: Goose Date: Thu, 12 Dec 2024 12:55:39 +1000 Subject: [PATCH 06/11] correct type annotations related to jaxified logp func --- pymc/sampling/jax.py | 13 ++++++------- pymc/sampling/mcmc.py | 4 ++-- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 2df2656261c..6427f3bae9a 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -28,6 +28,7 @@ from arviz.data.base import make_attrs from jax.lax import scan +from numpy.typing import ArrayLike from pytensor.compile import SharedVariable, Supervisor, mode from pytensor.graph.basic import graph_inputs from pytensor.graph.fg import FunctionGraph @@ -121,7 +122,7 @@ def _replace_shared_variables(graph: list[TensorVariable]) -> list[TensorVariabl def get_jaxified_graph( inputs: list[TensorVariable] | None = None, outputs: list[TensorVariable] | None = None, -) -> list[TensorVariable]: +) -> Callable[[list[TensorVariable]], list[TensorVariable]]: """Compile a PyTensor graph into an optimized JAX function.""" graph = _replace_shared_variables(outputs) if outputs is not None else None @@ -144,15 +145,13 @@ def get_jaxified_graph( return jax_funcify(fgraph) -def get_jaxified_logp( - model: Model, negative_logp=True -) -> Callable[[Sequence[np.ndarray]], np.ndarray]: +def get_jaxified_logp(model: Model, negative_logp: bool = True) -> Callable[[ArrayLike], jax.Array]: model_logp = model.logp() if not negative_logp: model_logp = -model_logp logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp]) - def logp_fn_wrap(x: Sequence[np.ndarray]) -> np.ndarray: + def logp_fn_wrap(x: ArrayLike) -> jax.Array: return logp_fn(*x)[0] return logp_fn_wrap @@ -213,7 +212,7 @@ def _get_batched_jittered_initial_points( chains: int, initvals: StartDict | Sequence[StartDict | None] | None, random_seed: RandomSeed, - logp_fn: Callable[[Sequence[np.ndarray]], np.ndarray] | None = None, + logp_fn: Callable[[ArrayLike], jax.Array] | None = None, jitter: bool = True, jitter_max_retries: int = 10, ) -> np.ndarray | list[np.ndarray]: @@ -235,7 +234,7 @@ def _get_batched_jittered_initial_points( else: - def eval_logp_initial_point(point: dict[str, np.ndarray]) -> np.ndarray: + def eval_logp_initial_point(point: dict[str, np.ndarray]) -> jax.Array: """Wrap logp_fn to conform to _init_jitter logic. Wraps jaxified logp function to accept a dict of diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 3dc64ab3448..c2294699928 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1353,8 +1353,8 @@ def _init_jitter( Whether to apply jitter or not. jitter_max_retries : int Maximum number of repeated attempts at initializing values (per chain). - logp_fn: Callable[[dict[str, np.ndarray]], np.ndarray] | None - Jaxified logp function that takes the output of the initial point functions as input. + logp_fn: Callable[[dict[str, np.ndarray]], np.ndarray | jax.Array] | None + logp function that takes the output of initial point functions as input. If None, will use the results of model.compile_logp(). Returns From 2e9d7dba320e8fddedc3b28e6f6d22ab30b16e6a Mon Sep 17 00:00:00 2001 From: Goose Date: Thu, 12 Dec 2024 13:47:55 +1000 Subject: [PATCH 07/11] improved docstrings & type annotations --- pymc/sampling/jax.py | 112 ++++++++++++++++++++++++++++++------------- 1 file changed, 80 insertions(+), 32 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 6427f3bae9a..41b11cddc9b 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -18,7 +18,8 @@ from collections.abc import Callable, Sequence from datetime import datetime from functools import partial -from typing import Any, Literal +from types import ModuleType +from typing import TYPE_CHECKING, Any, Literal import arviz as az import jax @@ -69,6 +70,9 @@ "sample_numpyro_nuts", ) +if TYPE_CHECKING: + from numpyro.infer import MCMC + @jax_funcify.register(Assert) @jax_funcify.register(CheckParameterValue) @@ -310,50 +314,48 @@ def _sample_blackjax_nuts( tune: int, draws: int, chains: int, - chain_method: str | None, + chain_method: Literal["parallel", "vectorized"], progressbar: bool, random_seed: int, initial_points: np.ndarray | list[np.ndarray], nuts_kwargs, - logp_fn: Callable[[Sequence[np.ndarray]], np.ndarray] | None = None, -) -> az.InferenceData: + logp_fn: Callable[[ArrayLike], jax.Array] | None = None, +) -> tuple[Any, dict[str, Any], ModuleType]: """ Draw samples from the posterior using the NUTS method from the ``blackjax`` library. Parameters ---------- - draws : int, default 1000 - The number of samples to draw. The number of tuned samples are discarded by - default. + model : Model, optional + Model to sample from. The model needs to have free random variables. When inside + a ``with`` model context, it defaults to that model, otherwise the model must be + passed explicitly. + target_accept : float in [0, 1]. + The step size is tuned such that we approximate this acceptance rate. Higher + values like 0.9 or 0.95 often work better for problematic posteriors. tune : int, default 1000 Number of iterations to tune. Samplers adjust the step sizes, scalings or similar during tuning. Tuning samples will be drawn in addition to the number specified in the ``draws`` argument. + draws : int, default 1000 + The number of samples to draw. The number of tuned samples are discarded by + default. chains : int, default 4 The number of chains to sample. - target_accept : float in [0, 1]. - The step size is tuned such that we approximate this acceptance rate. Higher - values like 0.9 or 0.95 often work better for problematic posteriors. + chain_method : str, default "parallel" + Specify how samples should be drawn. The choices include "parallel", and + "vectorized". + progressbar : bool + Whether to show progressbar or not during sampling. random_seed : int, RandomState or Generator, optional Random seed used by the sampling steps. - initvals: StartDict or Sequence[Optional[StartDict]], optional - Initial values for random variables provided as a dictionary (or sequence of - dictionaries) mapping the random variable (by name or reference) to desired - starting values. - jitter: bool, default True - If True, add jitter to initial points. - model : Model, optional - Model to sample from. The model needs to have free random variables. When inside - a ``with`` model context, it defaults to that model, otherwise the model must be - passed explicitly. + initial_points : np.ndarray | list[np.ndarray] + Initial point(s) for sampler to begin sampling from. var_names : sequence of str, optional Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior. keep_untransformed : bool, default False Include untransformed variables in the posterior samples. Defaults to False. - chain_method : str, default "parallel" - Specify how samples should be drawn. The choices include "parallel", and - "vectorized". postprocessing_backend: Optional[Literal["cpu", "gpu"]], default None, Specify how postprocessing should be computed. gpu or cpu postprocessing_vectorize: Literal["vmap", "scan"], default "scan" @@ -365,13 +367,17 @@ def _sample_blackjax_nuts( ``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from the ``model`` argument if not provided in ``idata_kwargs``. If ``coords`` and ``dims`` are provided, they are used to update the inferred dictionaries. + logp_fn : Callable[[ArrayLike], jax.Array] | None: + jaxified logp function. If not passed in it will compute it here. Returns ------- - InferenceData - ArviZ ``InferenceData`` object that contains the posterior samples, together - with their respective sample stats and pointwise log likeihood values (unless - skipped with ``idata_kwargs``). + Tuple containing: + raw_mcmc_samples + Datastructure containing raw mcmc samples + sample_stats : dict[str, Any] + Dictionary containing sample stats + Module("blackjax") """ import blackjax @@ -409,7 +415,7 @@ def _sample_blackjax_nuts( # Adopted from arviz numpyro extractor -def _numpyro_stats_to_dict(posterior): +def _numpyro_stats_to_dict(posterior: MCMC) -> dict[str, Any]: """Extract sample_stats from NumPyro posterior.""" rename_key = { "potential_energy": "lp", @@ -440,8 +446,50 @@ def _sample_numpyro_nuts( random_seed: int, initial_points: np.ndarray | list[np.ndarray], nuts_kwargs: dict[str, Any], - logp_fn: Callable | None = None, -): + logp_fn: Callable[[ArrayLike], jax.Array] | None = None, +) -> tuple[Any, dict[str, Any], ModuleType]: + """ + Draw samples from the posterior using the NUTS method from the ``numpyro`` library. + + Parameters + ---------- + model : Model, optional + Model to sample from. The model needs to have free random variables. When inside + a ``with`` model context, it defaults to that model, otherwise the model must be + passed explicitly. + target_accept : float in [0, 1]. + The step size is tuned such that we approximate this acceptance rate. Higher + values like 0.9 or 0.95 often work better for problematic posteriors. + tune : int, default 1000 + Number of iterations to tune. Samplers adjust the step sizes, scalings or + similar during tuning. Tuning samples will be drawn in addition to the number + specified in the ``draws`` argument. + draws : int, default 1000 + The number of samples to draw. The number of tuned samples are discarded by + default. + chains : int, default 4 + The number of chains to sample. + chain_method : str, default "parallel" + Specify how samples should be drawn. The choices include "parallel", and + "vectorized". + progressbar : bool + Whether to show progressbar or not during sampling. + random_seed : int, RandomState or Generator, optional + Random seed used by the sampling steps. + initial_points : np.ndarray | list[np.ndarray] + Initial point(s) for sampler to begin sampling from. + logp_fn : Callable[[ArrayLike], jax.Array] | None: + jaxified logp function. If not passed in it will compute it here. + + Returns + ------- + Tuple containing: + raw_mcmc_samples + Datastructure containing raw mcmc samples + sample_stats : dict[str, Any] + Dictionary containing sample stats + Module("numpyro") + """ import numpyro from numpyro.infer import MCMC, NUTS @@ -505,7 +553,7 @@ def sample_jax_nuts( nuts_kwargs: dict | None = None, progressbar: bool = True, keep_untransformed: bool = False, - chain_method: str = "parallel", + chain_method: Literal["parallel", "vectorized"] = "parallel", postprocessing_backend: Literal["cpu", "gpu"] | None = None, postprocessing_vectorize: Literal["vmap", "scan"] | None = None, postprocessing_chunks=None, @@ -551,7 +599,7 @@ def sample_jax_nuts( If True, display a progressbar while sampling keep_untransformed : bool, default False Include untransformed variables in the posterior samples. - chain_method : str, default "parallel" + chain_method : Literal["parallel", "vectorized"], default "parallel" Specify how samples should be drawn. The choices include "parallel", and "vectorized". postprocessing_backend : Optional[Literal["cpu", "gpu"]], default None, From e50b5b429f40fb39c816678dbc923404324f3af0 Mon Sep 17 00:00:00 2001 From: Goose Date: Thu, 12 Dec 2024 19:29:09 +1000 Subject: [PATCH 08/11] fix docstring generation error --- pymc/sampling/jax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 41b11cddc9b..5de155d81f4 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -377,7 +377,7 @@ def _sample_blackjax_nuts( Datastructure containing raw mcmc samples sample_stats : dict[str, Any] Dictionary containing sample stats - Module("blackjax") + blackjax : ModuleType["blackjax"] """ import blackjax @@ -488,7 +488,7 @@ def _sample_numpyro_nuts( Datastructure containing raw mcmc samples sample_stats : dict[str, Any] Dictionary containing sample stats - Module("numpyro") + numpyro : ModuleType["numpyro"] """ import numpyro From f6fbb0a4aba558f4dae64328305bddc813af3f73 Mon Sep 17 00:00:00 2001 From: Goose Date: Thu, 12 Dec 2024 23:15:48 +1000 Subject: [PATCH 09/11] removed type_checking to fix sphinx-autodoc --- pymc/sampling/jax.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 5de155d81f4..5226742dbcc 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -19,7 +19,7 @@ from datetime import datetime from functools import partial from types import ModuleType -from typing import TYPE_CHECKING, Any, Literal +from typing import Any, Literal import arviz as az import jax @@ -70,9 +70,6 @@ "sample_numpyro_nuts", ) -if TYPE_CHECKING: - from numpyro.infer import MCMC - @jax_funcify.register(Assert) @jax_funcify.register(CheckParameterValue) @@ -415,7 +412,7 @@ def _sample_blackjax_nuts( # Adopted from arviz numpyro extractor -def _numpyro_stats_to_dict(posterior: MCMC) -> dict[str, Any]: +def _numpyro_stats_to_dict(posterior) -> dict[str, Any]: """Extract sample_stats from NumPyro posterior.""" rename_key = { "potential_energy": "lp", From deea64c6c66e82619aad87593b5de133e7c1e7bb Mon Sep 17 00:00:00 2001 From: Goose Date: Fri, 17 Jan 2025 11:29:24 +1000 Subject: [PATCH 10/11] implement docstring and type hinting feedback --- pymc/model/core.py | 4 ++-- pymc/sampling/jax.py | 12 +++++++----- pymc/sampling/mcmc.py | 2 +- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/pymc/model/core.py b/pymc/model/core.py index 9b6c48506d1..99711e566ed 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -19,7 +19,7 @@ import types import warnings -from collections.abc import Callable, Iterable, Sequence +from collections.abc import Iterable, Sequence from typing import ( Literal, cast, @@ -585,7 +585,7 @@ def compile_logp( jacobian: bool = True, sum: bool = True, **compile_kwargs, - ) -> Callable[[PointType], np.ndarray]: + ) -> PointFunc: """Compiled log probability density function. Parameters diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 5226742dbcc..5ded61a50da 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -226,7 +226,7 @@ def _get_batched_jittered_initial_points( Returns ------- - out: list[np.ndarray] + out: list of ndarrays list with one item per variable and number of chains as batch dimension. Each item has shape `(chains, *var.shape)` """ @@ -321,6 +321,8 @@ def _sample_blackjax_nuts( """ Draw samples from the posterior using the NUTS method from the ``blackjax`` library. + Note the default parameter values listed below are provided by the calling function `sample_jax_nuts`. + Parameters ---------- model : Model, optional @@ -365,11 +367,10 @@ def _sample_blackjax_nuts( the ``model`` argument if not provided in ``idata_kwargs``. If ``coords`` and ``dims`` are provided, they are used to update the inferred dictionaries. logp_fn : Callable[[ArrayLike], jax.Array] | None: - jaxified logp function. If not passed in it will compute it here. + jaxified logp function. If not passed in it will be created anew. Returns ------- - Tuple containing: raw_mcmc_samples Datastructure containing raw mcmc samples sample_stats : dict[str, Any] @@ -448,6 +449,8 @@ def _sample_numpyro_nuts( """ Draw samples from the posterior using the NUTS method from the ``numpyro`` library. + Note the default parameter values listed below are provided by the calling function `sample_jax_nuts`. + Parameters ---------- model : Model, optional @@ -476,11 +479,10 @@ def _sample_numpyro_nuts( initial_points : np.ndarray | list[np.ndarray] Initial point(s) for sampler to begin sampling from. logp_fn : Callable[[ArrayLike], jax.Array] | None: - jaxified logp function. If not passed in it will compute it here. + jaxified logp function. If not passed in it will be created anew. Returns ------- - Tuple containing: raw_mcmc_samples Datastructure containing raw mcmc samples sample_stats : dict[str, Any] diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index c2294699928..89ebd44972d 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1373,7 +1373,7 @@ def _init_jitter( return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)] if logp_fn is None: - model_logp_fn = model.compile_logp() + model_logp_fn: Callable[[PointType], np.ndarray] = model.compile_logp() else: model_logp_fn = logp_fn From ae0fb9638e33bb570f1e897ef4c9118b71385870 Mon Sep 17 00:00:00 2001 From: Goose Date: Fri, 17 Jan 2025 22:58:14 +1000 Subject: [PATCH 11/11] update docstring to more accurately mirror function singature --- pymc/sampling/jax.py | 80 ++++++++++++++++---------------------------- 1 file changed, 28 insertions(+), 52 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 5ded61a50da..89adbf233ba 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -315,58 +315,38 @@ def _sample_blackjax_nuts( progressbar: bool, random_seed: int, initial_points: np.ndarray | list[np.ndarray], - nuts_kwargs, + nuts_kwargs: dict[str, Any], logp_fn: Callable[[ArrayLike], jax.Array] | None = None, ) -> tuple[Any, dict[str, Any], ModuleType]: """ Draw samples from the posterior using the NUTS method from the ``blackjax`` library. - Note the default parameter values listed below are provided by the calling function `sample_jax_nuts`. - Parameters ---------- - model : Model, optional - Model to sample from. The model needs to have free random variables. When inside - a ``with`` model context, it defaults to that model, otherwise the model must be - passed explicitly. + model : Model + Model to sample from. The model needs to have free random variables. target_accept : float in [0, 1]. The step size is tuned such that we approximate this acceptance rate. Higher values like 0.9 or 0.95 often work better for problematic posteriors. - tune : int, default 1000 + tune : int Number of iterations to tune. Samplers adjust the step sizes, scalings or similar during tuning. Tuning samples will be drawn in addition to the number specified in the ``draws`` argument. - draws : int, default 1000 - The number of samples to draw. The number of tuned samples are discarded by - default. - chains : int, default 4 + draws : int + The number of samples to draw. The number of tuned samples are discarded by default. + chains : int The number of chains to sample. - chain_method : str, default "parallel" - Specify how samples should be drawn. The choices include "parallel", and - "vectorized". + chain_method : "parallel" or "vectorized" + Specify how samples should be drawn. progressbar : bool Whether to show progressbar or not during sampling. - random_seed : int, RandomState or Generator, optional + random_seed : int, RandomState or Generator Random seed used by the sampling steps. - initial_points : np.ndarray | list[np.ndarray] + initial_points : np.ndarray or list[np.ndarray] Initial point(s) for sampler to begin sampling from. - var_names : sequence of str, optional - Names of variables for which to compute the posterior samples. Defaults to all - variables in the posterior. - keep_untransformed : bool, default False - Include untransformed variables in the posterior samples. Defaults to False. - postprocessing_backend: Optional[Literal["cpu", "gpu"]], default None, - Specify how postprocessing should be computed. gpu or cpu - postprocessing_vectorize: Literal["vmap", "scan"], default "scan" - How to vectorize the postprocessing: vmap or sequential scan - idata_kwargs : dict, optional - Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as - value for the ``log_likelihood`` key to indicate that the pointwise log - likelihood should not be included in the returned object. Values for - ``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from - the ``model`` argument if not provided in ``idata_kwargs``. If ``coords`` and - ``dims`` are provided, they are used to update the inferred dictionaries. - logp_fn : Callable[[ArrayLike], jax.Array] | None: + nuts_kwargs : dict + Keyword arguments for the blackjax nuts sampler + logp_fn : Callable[[ArrayLike], jax.Array], optional, default None jaxified logp function. If not passed in it will be created anew. Returns @@ -439,7 +419,7 @@ def _sample_numpyro_nuts( tune: int, draws: int, chains: int, - chain_method: str | None, + chain_method: Literal["parallel", "vectorized"], progressbar: bool, random_seed: int, initial_points: np.ndarray | list[np.ndarray], @@ -449,36 +429,32 @@ def _sample_numpyro_nuts( """ Draw samples from the posterior using the NUTS method from the ``numpyro`` library. - Note the default parameter values listed below are provided by the calling function `sample_jax_nuts`. - Parameters ---------- - model : Model, optional - Model to sample from. The model needs to have free random variables. When inside - a ``with`` model context, it defaults to that model, otherwise the model must be - passed explicitly. + model : Model + Model to sample from. The model needs to have free random variables. target_accept : float in [0, 1]. The step size is tuned such that we approximate this acceptance rate. Higher values like 0.9 or 0.95 often work better for problematic posteriors. - tune : int, default 1000 + tune : int Number of iterations to tune. Samplers adjust the step sizes, scalings or similar during tuning. Tuning samples will be drawn in addition to the number specified in the ``draws`` argument. - draws : int, default 1000 - The number of samples to draw. The number of tuned samples are discarded by - default. - chains : int, default 4 + draws : int + The number of samples to draw. The number of tuned samples are discarded by default. + chains : int The number of chains to sample. - chain_method : str, default "parallel" - Specify how samples should be drawn. The choices include "parallel", and - "vectorized". + chain_method : "parallel" or "vectorized" + Specify how samples should be drawn. progressbar : bool Whether to show progressbar or not during sampling. - random_seed : int, RandomState or Generator, optional + random_seed : int, RandomState or Generator Random seed used by the sampling steps. - initial_points : np.ndarray | list[np.ndarray] + initial_points : np.ndarray or list[np.ndarray] Initial point(s) for sampler to begin sampling from. - logp_fn : Callable[[ArrayLike], jax.Array] | None: + nuts_kwargs : dict + Keyword arguments for the underlying numpyro nuts sampler + logp_fn : Callable[[ArrayLike], jax.Array], optional, default None jaxified logp function. If not passed in it will be created anew. Returns