Skip to content

Commit

Permalink
docs: Add documentation template
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Oct 17, 2024
1 parent e1fdfff commit c22b438
Show file tree
Hide file tree
Showing 7 changed files with 324 additions and 6 deletions.
6 changes: 6 additions & 0 deletions book.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[book]
authors = ["Adrian Seyboldt"]
language = "en"
multilingual = false
src = "docs"
title = "nutpie"
8 changes: 8 additions & 0 deletions docs/SUMMARY.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Summary

[Introduction](../README.md)

- [Usage with PyMC](./pymc-usage.md)
- [Usage with Stan](./stan-usage.md)
- [Adaptation with normalizing flows](./nf-adapt.md)
- [Benchmarks](./benchmarks.md)
1 change: 1 addition & 0 deletions docs/benchmarks.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Benchmarks
3 changes: 3 additions & 0 deletions docs/nf-adapt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Adaptation with normalizing flows

**Experimental**
1 change: 1 addition & 0 deletions docs/pymc-usage.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Usage with PyMC models
3 changes: 3 additions & 0 deletions docs/stan-usage.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Usage with Stan models

foobar
308 changes: 302 additions & 6 deletions python/nutpie/compiled_pyfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,304 @@
from nutpie.sample import CompiledModel


def make_transform_adapter(*, verbose=False, window_size=2000):
import jax
import equinox as eqx
import jax.numpy as jnp
import flowjax
import flowjax.train
import flowjax.flows
import optax
import traceback

class FisherLoss:
@eqx.filter_jit
def __call__(
self,
params,
static,
x,
condition=None,
key=None,
):
flow = flowjax.train.losses.unwrap(
eqx.combine(params, static, is_leaf=eqx.is_inexact_array)
)

def compute_loss(bijection, draw, grad):
draw, grad, _ = bijection.inverse_gradient_and_val(
draw, grad, jnp.array(0.0)
)
return ((draw + grad) ** 2).sum()

assert x.shape[1] == 2
draws = x[:, 0, :]
grads = x[:, 1, :]
return jnp.log(
jax.vmap(compute_loss, [None, 0, 0])(
flow.bijection, draws, grads
).mean()
)

def _get_untransformed(bijection, draw_trafo, grad_trafo):
bijection = flowjax.train.losses.unwrap(bijection)
draw = bijection.inverse(draw_trafo)
_, pull_grad_fn = jax.vjp(bijection.transform_and_log_det, draw)
(grad,) = pull_grad_fn((grad_trafo, 1.0))
return draw, grad

pull_points = eqx.filter_jit(jax.vmap(_get_untransformed, [None, 0, 0]))

def fit_flow(key, bijection, positions, gradients, **kwargs):
flow = flowjax.flows.Transformed(
flowjax.distributions.StandardNormal(bijection.shape), bijection
)

points = jnp.transpose(jnp.array([positions, gradients]), [1, 0, 2])

key, train_key = jax.random.split(key)

fit, losses, opt_state = flowjax.train.fit_to_data(
key=train_key,
dist=flow,
x=points,
loss_fn=FisherLoss(),
**kwargs,
)

draws_pulled, grads_pulled = pull_points(fit.bijection, positions, gradients)
final_cost = np.log(((draws_pulled + grads_pulled) ** 2).sum(1).mean(0))
return fit, final_cost, opt_state

def make_flow(seed, positions, gradients, *, n_layers):
positions = np.array(positions)
gradients = np.array(gradients)

n_draws, n_dim = positions.shape

if n_dim < 2:
n_layers = 0

assert positions.shape == gradients.shape
assert n_draws > 0

if n_draws == 0:
raise ValueError("No draws")
elif n_draws == 1:
diag = 1 / jnp.abs(gradients[0])
mean = jnp.zeros_like(diag)
else:
diag = jnp.sqrt(positions.std(0) / gradients.std(0))
mean = positions.mean(0) + diag * gradients.mean(0)

key = jax.random.PRNGKey(seed % (2**63))

flows = [
flowjax.flows.Affine(loc=mean, scale=diag),
]

for layer in range(n_layers):
key, key_couple, key_permute, key_init = jax.random.split(key, 4)

scale = flowjax.wrappers.Parameterize(
lambda x: jnp.exp(jnp.arcsinh(x)), jnp.array(0.0)
)
affine = eqx.tree_at(
where=lambda aff: aff.scale,
pytree=flowjax.bijections.Affine(),
replace=scale,
)

coupling = flowjax.bijections.coupling.Coupling(
key_couple,
transformer=affine,
untransformed_dim=n_dim // 2,
dim=n_dim,
nn_activation=jax.nn.gelu,
nn_width=n_dim // 2,
nn_depth=1,
)

if layer == n_layers - 1:
flow = coupling
else:
flow = flowjax.flows._add_default_permute(coupling, n_dim, key_permute)

flows.append(flow)

return flowjax.bijections.Chain(flows[::-1])

@eqx.filter_jit
def _init_from_transformed_position(logp_fn, bijection, transformed_position):
bijection = flowjax.train.losses.unwrap(bijection)
(untransformed_position, logdet), pull_grad = jax.vjp(
bijection.transform_and_log_det, transformed_position
)
logp, untransformed_gradient = jax.value_and_grad(lambda x: logp_fn(x)[0])(
untransformed_position
)
(transformed_gradient,) = pull_grad((untransformed_gradient, 1.0))
return (
logp,
logdet,
untransformed_position,
untransformed_gradient,
transformed_gradient,
)

@eqx.filter_jit
def _init_from_untransformed_position(logp_fn, bijection, untransformed_position):
logp, untransformed_gradient = jax.value_and_grad(lambda x: logp_fn(x)[0])(
untransformed_position
)
logdet, transformed_position, transformed_gradient = _inv_transform(
bijection, untransformed_position, untransformed_gradient
)
return (
logp,
logdet,
untransformed_gradient,
transformed_position,
transformed_gradient,
)

@eqx.filter_jit
def _inv_transform(bijection, untransformed_position, untransformed_gradient):
bijection = flowjax.train.losses.unwrap(bijection)
transformed_position, transformed_gradient, logdet = (
bijection.inverse_gradient_and_val(
untransformed_position, untransformed_gradient, 0.0
)
)
return -logdet, transformed_position, transformed_gradient

class TransformAdapter:
def __init__(
self,
seed,
position,
gradient,
chain,
*,
logp_fn,
make_flow_fn,
verbose=False,
window_size=2000,
):
self._logp_fn = logp_fn
self._make_flow_fn = make_flow_fn
self._chain = chain
self._verbose = verbose
self._window_size = window_size
try:
self._bijection = make_flow_fn(seed, [position], [gradient], n_layers=0)
except Exception as e:
print("make_flow", e)
print(traceback.format_exc())
raise
self.index = 0

@property
def transformation_id(self):
return self.index

def update(self, seed, positions, gradients):
self.index += 1
if self._verbose:
print(f"Chain {self._chain}: Total available points: {len(positions)}")
n_draws = len(positions)
if n_draws == 0:
return
try:
if self.index <= 10:
self._bijection = self._make_flow_fn(
seed, positions[-10:], gradients[-10:], n_layers=0
)
return

positions = np.array(positions[-self._window_size :])
gradients = np.array(gradients[-self._window_size :])

assert np.isfinite(positions).all()
assert np.isfinite(gradients).all()

if len(self._bijection.bijections) == 1:
self._bijection = self._make_flow_fn(
seed, positions, gradients, n_layers=8
)

# make_flow might still only return a single trafo if the for 1d problems
if len(self._bijection.bijections) == 1:
return

# TODO don't reuse seed
key = jax.random.PRNGKey(seed % (2**63))
fit, final_cost, _ = fit_flow(
key,
self._bijection,
positions,
gradients,
show_progress=self._verbose,
optimizer=optax.adabelief(1e-3),
batch_size=128,
)
if self._verbose:
print(f"Chain {self._chain}: final cost {final_cost}")
if np.isfinite(final_cost).all():
self._bijection = fit.bijection
else:
self._bijection = self._make_flow_fn(
seed, positions, gradients, n_layers=0
)
except Exception as e:
print("update error:", e)
print(traceback.format_exc())

def init_from_transformed_position(self, transformed_position):
try:
logp, logdet, *arrays = _init_from_transformed_position(
self._logp_fn,
self._bijection,
jnp.array(transformed_position),
)
return float(logp), float(logdet), *[np.array(val) for val in arrays]
except Exception as e:
print(e)
print(traceback.format_exc())
raise

def init_from_untransformed_position(self, untransformed_position):
try:
logp, logdet, *arrays = _init_from_untransformed_position(
self._logp_fn,
self._bijection,
jnp.array(untransformed_position),
)
return float(logp), float(logdet), *[np.array(val) for val in arrays]
except Exception as e:
print(e)
print(traceback.format_exc())
raise

def inv_transform(self, position, gradient):
try:
logdet, *arrays = _inv_transform(
self._bijection, jnp.array(position), jnp.array(gradient)
)
return logdet, *[np.array(val) for val in arrays]
except Exception as e:
print(e)
print(traceback.format_exc())
raise

return partial(
TransformAdapter,
verbose=verbose,
window_size=window_size,
make_flow_fn=make_flow,
)


@dataclass(frozen=True)
class PyFuncModel(CompiledModel):
_make_logp_func: Callable
Expand Down Expand Up @@ -59,19 +357,17 @@ def make_expand_func(seed1, seed2, chain):
expand_fn = self._make_expand_func(seed1, seed2, chain)
return partial(expand_fn, **self._shared_data)

if self._make_transform_adapter is not None:
make_transform_adapter = partial(
self._make_transform_adapter, logp_fn=self._raw_logp_fn
)
if self._raw_logp_fn is not None:
make_adapter = partial(make_transform_adapter(), logp_fn=self._raw_logp_fn)
else:
make_transform_adapter = None
make_adapter = None

return _lib.PyModel(
make_logp_func,
make_expand_func,
self._variables,
self.n_dim,
make_transform_adapter,
make_adapter,
)


Expand Down

0 comments on commit c22b438

Please sign in to comment.