Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use jaxified logp for initial point evaluation when sampling via Jax #7610

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
13 changes: 11 additions & 2 deletions pymc/initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def make_initial_point_fns_per_chain(
overrides: StartDict | Sequence[StartDict | None] | None,
jitter_rvs: set[TensorVariable] | None = None,
chains: int,
) -> list[Callable]:
) -> list[Callable[[int], PointType]]:
"""Create an initial point function for each chain, as defined by initvals.

If a single initval dictionary is passed, the function is replicated for each
Expand All @@ -82,6 +82,11 @@ def make_initial_point_fns_per_chain(
Random variable tensors for which U(-1, 1) jitter shall be applied.
(To the transformed space if applicable.)

Returns
-------
ipfns : list[Callable[[int], dict[str, np.ndarray]]]
list of functions that return initial points for each chain.

Raises
------
ValueError
Expand Down Expand Up @@ -124,7 +129,7 @@ def make_initial_point_fn(
jitter_rvs: set[TensorVariable] | None = None,
default_strategy: str = "support_point",
return_transformed: bool = True,
) -> Callable:
) -> Callable[[int], PointType]:
nataziel marked this conversation as resolved.
Show resolved Hide resolved
"""Create seeded function that computes initial values for all free model variables.

Parameters
Expand All @@ -138,6 +143,10 @@ def make_initial_point_fn(
Initial value (strategies) to use instead of what's specified in `Model.initial_values`.
return_transformed : bool
If `True` the returned variables will correspond to transformed initial values.

Returns
-------
initial_point_fn : Callable[[int], dict[str, np.ndarray]]
"""
sdict_overrides = convert_str_to_rv_dict(model, overrides or {})
initval_strats = {
Expand Down
4 changes: 2 additions & 2 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import types
import warnings

from collections.abc import Iterable, Sequence
from collections.abc import Callable, Iterable, Sequence
from typing import (
Literal,
cast,
Expand Down Expand Up @@ -585,7 +585,7 @@ def compile_logp(
jacobian: bool = True,
sum: bool = True,
**compile_kwargs,
) -> PointFunc:
) -> Callable[[PointType], np.ndarray]:
"""Compiled log probability density function.

Parameters
Expand Down
63 changes: 44 additions & 19 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,15 @@ def get_jaxified_graph(
return jax_funcify(fgraph)


def get_jaxified_logp(model: Model, negative_logp=True) -> Callable:
def get_jaxified_logp(
model: Model, negative_logp=True
) -> Callable[[Sequence[np.ndarray]], np.ndarray]:
model_logp = model.logp()
if not negative_logp:
model_logp = -model_logp
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])

def logp_fn_wrap(x):
def logp_fn_wrap(x: Sequence[np.ndarray]) -> np.ndarray:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct, it takes jax arrays and outputs jax arrays

Copy link
Author

@nataziel nataziel Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's 100% true. Checking with the interactive debugger confirms that the return type is jax.Array, but the initial point functions return a dict[str, np.ndarray], and we can successfully pass the .values() of that dict into the jaxified function. So it can seemingly accept anything that's coercible to an array. Maybe it's more correct to annotate it like this:

def logp_fn_wrap(x: ArrayLike) -> jax.Array:

ArrayLike is from numpy.typing: https://numpy.org/devdocs/reference/typing.html#numpy.typing.ArrayLike

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just pushed a commit to improve this, it's a bit tricky to annotate at the interface with _init_jitter given that jax is an optional dependency. I've left the type annotation as returning a np.ndarray but included that it may return a jax.Array in the docstring.

return logp_fn(*x)[0]

return logp_fn_wrap
Expand Down Expand Up @@ -211,23 +213,39 @@ def _get_batched_jittered_initial_points(
chains: int,
initvals: StartDict | Sequence[StartDict | None] | None,
random_seed: RandomSeed,
logp_fn: Callable[[Sequence[np.ndarray]], np.ndarray],
jitter: bool = True,
jitter_max_retries: int = 10,
) -> np.ndarray | list[np.ndarray]:
"""Get jittered initial point in format expected by NumPyro MCMC kernel.
"""Get jittered initial point in format expected by Jax MCMC kernel.

Parameters
----------
logp_fn : Callable[Sequence[np.ndarray]], np.ndarray]
Jaxified logp function

Returns
-------
out: list of ndarrays
out: list[np.ndarray]
list with one item per variable and number of chains as batch dimension.
Each item has shape `(chains, *var.shape)`
"""

def eval_logp_initial_point(point: dict[str, np.ndarray]) -> np.ndarray:
"""Wrap logp_fn to conform to _init_jitter logic.

Wraps jaxified logp function to accept a dict of
{model_variable: np.array} key:value pairs.
"""
return logp_fn(point.values())

initial_points = _init_jitter(
model,
initvals,
seeds=_get_seeds_per_chain(random_seed, chains),
jitter=jitter,
jitter_max_retries=jitter_max_retries,
logp_fn=eval_logp_initial_point,
)
initial_points_values = [list(initial_point.values()) for initial_point in initial_points]
if chains == 1:
Expand All @@ -236,7 +254,7 @@ def _get_batched_jittered_initial_points(


def _blackjax_inference_loop(
seed, init_position, logprob_fn, draws, tune, target_accept, **adaptation_kwargs
seed, init_position, logp_fn, draws, tune, target_accept, **adaptation_kwargs
):
import blackjax

Expand All @@ -252,13 +270,13 @@ def _blackjax_inference_loop(

adapt = blackjax.window_adaptation(
algorithm=algorithm,
logdensity_fn=logprob_fn,
logdensity_fn=logp_fn,
target_acceptance_rate=target_accept,
adaptation_info_fn=get_filter_adapt_info_fn(),
**adaptation_kwargs,
)
(last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune)
kernel = algorithm(logprob_fn, **tuned_params).step
kernel = algorithm(logp_fn, **tuned_params).step

def _one_step(state, xs):
_, rng_key = xs
Expand Down Expand Up @@ -292,8 +310,9 @@ def _sample_blackjax_nuts(
chain_method: str | None,
progressbar: bool,
random_seed: int,
initial_points,
initial_points: np.ndarray | list[np.ndarray],
nuts_kwargs,
logp_fn: Callable[[Sequence[np.ndarray]], np.ndarray] | None = None,
) -> az.InferenceData:
"""
Draw samples from the posterior using the NUTS method from the ``blackjax`` library.
Expand Down Expand Up @@ -366,15 +385,16 @@ def _sample_blackjax_nuts(
if chains == 1:
initial_points = [np.stack(init_state) for init_state in zip(initial_points)]

logprob_fn = get_jaxified_logp(model)
if logp_fn is None:
logp_fn = get_jaxified_logp(model)

seed = jax.random.PRNGKey(random_seed)
keys = jax.random.split(seed, chains)

nuts_kwargs["progress_bar"] = progressbar
get_posterior_samples = partial(
_blackjax_inference_loop,
logprob_fn=logprob_fn,
logp_fn=logp_fn,
tune=tune,
draws=draws,
target_accept=target_accept,
Expand Down Expand Up @@ -415,14 +435,16 @@ def _sample_numpyro_nuts(
chain_method: str | None,
progressbar: bool,
random_seed: int,
initial_points,
initial_points: np.ndarray | list[np.ndarray],
nuts_kwargs: dict[str, Any],
logp_fn: Callable | None = None,
):
import numpyro

from numpyro.infer import MCMC, NUTS

logp_fn = get_jaxified_logp(model, negative_logp=False)
if logp_fn is None:
logp_fn = get_jaxified_logp(model, negative_logp=False)

nuts_kwargs.setdefault("adapt_step_size", True)
nuts_kwargs.setdefault("adapt_mass_matrix", True)
Expand Down Expand Up @@ -590,6 +612,15 @@ def sample_jax_nuts(
get_default_varnames(filtered_var_names, include_transformed=keep_untransformed)
)

if nuts_sampler == "numpyro":
sampler_fn = _sample_numpyro_nuts
logp_fn = get_jaxified_logp(model, negative_logp=False)
elif nuts_sampler == "blackjax":
sampler_fn = _sample_blackjax_nuts
logp_fn = get_jaxified_logp(model)
else:
raise ValueError(f"{nuts_sampler=} not recognized")

(random_seed,) = _get_seeds_per_chain(random_seed, 1)

initial_points = _get_batched_jittered_initial_points(
Expand All @@ -598,15 +629,9 @@ def sample_jax_nuts(
initvals=initvals,
random_seed=random_seed,
jitter=jitter,
logp_fn=logp_fn,
)

if nuts_sampler == "numpyro":
sampler_fn = _sample_numpyro_nuts
elif nuts_sampler == "blackjax":
sampler_fn = _sample_blackjax_nuts
else:
raise ValueError(f"{nuts_sampler=} not recognized")

tic1 = datetime.now()
raw_mcmc_samples, sample_stats, library = sampler_fn(
model=model,
Expand Down
18 changes: 13 additions & 5 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,6 +1339,7 @@ def _init_jitter(
jitter: bool,
jitter_max_retries: int,
logp_dlogp_func=None,
logp_fn: Callable[[PointType], np.ndarray] | None = None,
) -> list[PointType]:
"""Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain.

Expand All @@ -1353,11 +1354,13 @@ def _init_jitter(
Whether to apply jitter or not.
jitter_max_retries : int
Maximum number of repeated attempts at initializing values (per chain).
logp_fn: Callable[[dict[str, np.ndarray]], np.ndarray]
Jaxified logp function that takes the output of the initial point functions as input.

Returns
-------
start : ``pymc.model.Point``
Starting point for sampler
initial_points : list[dict[str, np.ndarray]]
List of starting points for the sampler
"""
ipfns = make_initial_point_fns_per_chain(
model=model,
Expand All @@ -1369,12 +1372,17 @@ def _init_jitter(
if not jitter:
return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)]

model_logp_fn: Callable
model_logp_fn: Callable[[PointType], np.ndarray]
if logp_dlogp_func is None:
model_logp_fn = model.compile_logp()
if logp_fn is None:
# pymc NUTS path
model_logp_fn = model.compile_logp()
else:
# Jax path
model_logp_fn = logp_fn
else:

def model_logp_fn(ip):
def model_logp_fn(ip: PointType) -> np.ndarray:
q, _ = DictToArrayBijection.map(ip)
return logp_dlogp_func([q], extra_vars={})[0]
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved

Expand Down
12 changes: 9 additions & 3 deletions tests/sampling/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,20 +333,26 @@ def test_get_batched_jittered_initial_points():
with pm.Model() as model:
x = pm.MvNormal("x", mu=np.zeros(3), cov=np.eye(3), shape=(2, 3), initval=np.zeros((2, 3)))

logp_fn = get_jaxified_logp(model)

# No jitter
ips = _get_batched_jittered_initial_points(
model=model, chains=1, random_seed=1, initvals=None, jitter=False
model=model, chains=1, random_seed=1, initvals=None, jitter=False, logp_fn=logp_fn
nataziel marked this conversation as resolved.
Show resolved Hide resolved
)
assert np.all(ips[0] == 0)

# Single chain
ips = _get_batched_jittered_initial_points(model=model, chains=1, random_seed=1, initvals=None)
ips = _get_batched_jittered_initial_points(
model=model, chains=1, random_seed=1, initvals=None, logp_fn=logp_fn
)

assert ips[0].shape == (2, 3)
assert np.all(ips[0] != 0)

# Multiple chains
ips = _get_batched_jittered_initial_points(model=model, chains=2, random_seed=1, initvals=None)
ips = _get_batched_jittered_initial_points(
model=model, chains=2, random_seed=1, initvals=None, logp_fn=logp_fn
)

assert ips[0].shape == (2, 2, 3)
assert np.all(ips[0][0] != ips[0][1])
Expand Down
Loading