Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use jaxified logp for initial point evaluation when sampling via Jax #7610

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions pymc/initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from pymc.logprob.transforms import Transform
from pymc.pytensorf import (
SeedSequenceSeed,
compile,
find_rng_nodes,
replace_rng_nodes,
Expand Down Expand Up @@ -67,7 +68,7 @@ def make_initial_point_fns_per_chain(
overrides: StartDict | Sequence[StartDict | None] | None,
jitter_rvs: set[TensorVariable] | None = None,
chains: int,
) -> list[Callable]:
) -> list[Callable[[SeedSequenceSeed], PointType]]:
"""Create an initial point function for each chain, as defined by initvals.

If a single initval dictionary is passed, the function is replicated for each
Expand All @@ -82,6 +83,11 @@ def make_initial_point_fns_per_chain(
Random variable tensors for which U(-1, 1) jitter shall be applied.
(To the transformed space if applicable.)

Returns
-------
ipfns : list[Callable[[SeedSequenceSeed], dict[str, np.ndarray]]]
list of functions that return initial points for each chain.

Raises
------
ValueError
Expand Down Expand Up @@ -124,7 +130,7 @@ def make_initial_point_fn(
jitter_rvs: set[TensorVariable] | None = None,
default_strategy: str = "support_point",
return_transformed: bool = True,
) -> Callable:
) -> Callable[[SeedSequenceSeed], PointType]:
"""Create seeded function that computes initial values for all free model variables.

Parameters
Expand All @@ -138,6 +144,10 @@ def make_initial_point_fn(
Initial value (strategies) to use instead of what's specified in `Model.initial_values`.
return_transformed : bool
If `True` the returned variables will correspond to transformed initial values.

Returns
-------
initial_point_fn : Callable[[SeedSequenceSeed], dict[str, np.ndarray]]
"""
sdict_overrides = convert_str_to_rv_dict(model, overrides or {})
initval_strats = {
Expand Down
4 changes: 2 additions & 2 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import types
import warnings

from collections.abc import Iterable, Sequence
from collections.abc import Callable, Iterable, Sequence
from typing import (
Literal,
cast,
Expand Down Expand Up @@ -585,7 +585,7 @@ def compile_logp(
jacobian: bool = True,
sum: bool = True,
**compile_kwargs,
) -> PointFunc:
) -> Callable[[PointType], np.ndarray]:
"""Compiled log probability density function.

Parameters
Expand Down
Loading