Skip to content

Commit

Permalink
feat: Add transforming adaptation
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Oct 17, 2024
1 parent b8f7e74 commit e1fdfff
Show file tree
Hide file tree
Showing 7 changed files with 588 additions and 193 deletions.
30 changes: 21 additions & 9 deletions python/nutpie/compile_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,14 +293,24 @@ def _compile_pymc_model_jax(model, *, gradient_backend=None, **kwargs):
logp_fn = logp_fn_pt.vm.jit_fn
expand_fn = expand_fn_pt.vm.jit_fn

logp_shared_names = [var.name for var in logp_fn_pt.get_shared()]
expand_shared_names = [var.name for var in expand_fn_pt.get_shared()]

if gradient_backend == "jax":
orig_logp_fn = logp_fn._fun

@jax.jit
def logp_fn_jax_grad(x, *shared):
return jax.value_and_grad(lambda x: orig_logp_fn(x, *shared)[0])(x)

static_argnums = list(range(1, len(logp_shared_names) + 1))
logp_fn_jax_grad = jax.jit(
logp_fn_jax_grad,
# static_argnums=static_argnums,
)

logp_fn = logp_fn_jax_grad
else:
orig_logp_fn = None

shared_data = {}
shared_vars = {}
Expand All @@ -312,9 +322,6 @@ def logp_fn_jax_grad(x, *shared):
shared_vars[val.name] = val
seen.add(val)

logp_shared_names = [var.name for var in logp_fn_pt.get_shared()]
expand_shared_names = [var.name for var in expand_fn_pt.get_shared()]

def make_logp_func():
def logp(x, **shared):
logp, grad = logp_fn(x, *[shared[name] for name in logp_shared_names])
Expand All @@ -323,7 +330,8 @@ def logp(x, **shared):
return logp

names, slices, shapes = shape_info
dtypes = [np.float64] * len(names)
# TODO do not cast to float64
dtypes = [np.dtype("float64")] * len(names)

def make_expand_func(seed1, seed2, chain):
# TODO handle seeds
Expand All @@ -348,6 +356,7 @@ def expand(x, **shared):
shared_data=shared_data,
dims=dims,
coords=coords,
raw_logp_fn=orig_logp_fn,
)


Expand Down Expand Up @@ -397,7 +406,7 @@ def compile_pymc_model(
raise ValueError(f"Backend must be one of numba and jax. Got {backend}")


def _compute_shapes(model):
def _compute_shapes(model) -> dict[str, tuple[int, ...]]:
import pytensor
from pymc.initial_point import make_initial_point_fn

Expand Down Expand Up @@ -426,9 +435,12 @@ def _compute_shapes(model):
return dict(zip(trace_vars.keys(), shape_func()))


def _make_functions(model, *, mode, compute_grad, join_expanded):
def _make_functions(
model, *, mode, compute_grad, join_expanded
) -> tuple[
int, int, Callable, Callable, tuple[list[str], list[slice], list[tuple[int, ...]]]
]:
import pytensor
import pytensor.link.numba.dispatch
import pytensor.tensor as pt
from pymc.pytensorf import compile_pymc

Expand Down Expand Up @@ -513,7 +525,7 @@ def _make_functions(model, *, mode, compute_grad, join_expanded):

for var in remaining_rvs:
all_names.append(var.name)
shape = shapes[var.name]
shape = cast(tuple[int, ...], shapes[var.name])
all_shapes.append(shape)
length = prod(shape)
all_slices.append(slice(count, count + length))
Expand Down
16 changes: 14 additions & 2 deletions python/nutpie/compiled_pyfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class PyFuncModel(CompiledModel):
_n_dim: int
_variables: list[_lib.PyVariable]
_coords: dict[str, Any]
_make_transform_adapter: Callable | None
_raw_logp_fn: Callable | None

@property
def shapes(self) -> dict[str, tuple[int, ...]]:
Expand Down Expand Up @@ -57,11 +59,19 @@ def make_expand_func(seed1, seed2, chain):
expand_fn = self._make_expand_func(seed1, seed2, chain)
return partial(expand_fn, **self._shared_data)

if self._make_transform_adapter is not None:
make_transform_adapter = partial(
self._make_transform_adapter, logp_fn=self._raw_logp_fn
)
else:
make_transform_adapter = None

return _lib.PyModel(
make_logp_func,
make_expand_func,
self._variables,
self.n_dim,
make_transform_adapter,
)


Expand All @@ -77,6 +87,8 @@ def from_pyfunc(
coords: dict[str, Any] | None = None,
dims: dict[str, tuple[str, ...]] | None = None,
shared_data: dict[str, Any] | None = None,
make_transform_adapter=None,
raw_logp_fn=None,
):
variables = []
for name, shape, dtype in zip(
Expand All @@ -98,8 +110,6 @@ def from_pyfunc(
if shared_data is None:
shared_data = {}

if shared_data is None:
shared_data = dict()
return PyFuncModel(
_n_dim=ndim,
dims=dims,
Expand All @@ -108,4 +118,6 @@ def from_pyfunc(
_make_expand_func=make_expand_fn,
_variables=variables,
_shared_data=shared_data,
_make_transform_adapter=make_transform_adapter,
_raw_logp_fn=raw_logp_fn,
)
14 changes: 14 additions & 0 deletions python/nutpie/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ def sample(
save_warmup: bool,
progress_bar: bool,
low_rank_modified_mass_matrix: bool = False,
transform_adapt: bool = False,
init_mean: Optional[np.ndarray],
return_raw_trace: bool,
blocking: Literal[True],
Expand All @@ -480,6 +481,7 @@ def sample(
save_warmup: bool,
progress_bar: bool,
low_rank_modified_mass_matrix: bool = False,
transform_adapt: bool = False,
init_mean: Optional[np.ndarray],
return_raw_trace: bool,
blocking: Literal[False],
Expand All @@ -498,6 +500,7 @@ def sample(
save_warmup: bool = True,
progress_bar: bool = True,
low_rank_modified_mass_matrix: bool = False,
transform_adapt: bool = False,
init_mean: Optional[np.ndarray] = None,
return_raw_trace: bool = False,
blocking: bool = True,
Expand Down Expand Up @@ -585,6 +588,9 @@ def sample(
mass_matrix_gamma: float > 0, default=1e-5
Regularisation parameter for the eigenvalues. Only
applicable with low_rank_modified_mass_matrix=True.
transform_adapt: bool, default=False
Use the experimental transform adaptation algorithm
during tuning.
**kwargs
Pass additional arguments to nutpie._lib.PySamplerArgs
Expand All @@ -594,10 +600,18 @@ def sample(
An ArviZ ``InferenceData`` object that contains the samples.
"""

if low_rank_modified_mass_matrix and transform_adapt:
raise ValueError(
"Specify only one of `low_rank_modified_mass_matrix` and `transform_adapt`"
)

if low_rank_modified_mass_matrix:
settings = _lib.PyNutsSettings.LowRank(seed)
elif transform_adapt:
settings = _lib.PyNutsSettings.Transform(seed)
else:
settings = _lib.PyNutsSettings.Diag(seed)

settings.num_tune = tune
settings.num_draws = draws
settings.num_chains = chains
Expand Down
Loading

0 comments on commit e1fdfff

Please sign in to comment.