From c22b438ee0d19aee7cfdff491ac1ad8e949cc600 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 17 Oct 2024 11:44:00 +0200 Subject: [PATCH] docs: Add documentation template --- book.toml | 6 + docs/SUMMARY.md | 8 + docs/benchmarks.md | 1 + docs/nf-adapt.md | 3 + docs/pymc-usage.md | 1 + docs/stan-usage.md | 3 + python/nutpie/compiled_pyfunc.py | 308 ++++++++++++++++++++++++++++++- 7 files changed, 324 insertions(+), 6 deletions(-) create mode 100644 book.toml create mode 100644 docs/SUMMARY.md create mode 100644 docs/benchmarks.md create mode 100644 docs/nf-adapt.md create mode 100644 docs/pymc-usage.md create mode 100644 docs/stan-usage.md diff --git a/book.toml b/book.toml new file mode 100644 index 0000000..9835ea7 --- /dev/null +++ b/book.toml @@ -0,0 +1,6 @@ +[book] +authors = ["Adrian Seyboldt"] +language = "en" +multilingual = false +src = "docs" +title = "nutpie" diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md new file mode 100644 index 0000000..279cdd3 --- /dev/null +++ b/docs/SUMMARY.md @@ -0,0 +1,8 @@ +# Summary + +[Introduction](../README.md) + +- [Usage with PyMC](./pymc-usage.md) +- [Usage with Stan](./stan-usage.md) +- [Adaptation with normalizing flows](./nf-adapt.md) +- [Benchmarks](./benchmarks.md) diff --git a/docs/benchmarks.md b/docs/benchmarks.md new file mode 100644 index 0000000..680d565 --- /dev/null +++ b/docs/benchmarks.md @@ -0,0 +1 @@ +# Benchmarks diff --git a/docs/nf-adapt.md b/docs/nf-adapt.md new file mode 100644 index 0000000..d9d0325 --- /dev/null +++ b/docs/nf-adapt.md @@ -0,0 +1,3 @@ +# Adaptation with normalizing flows + +**Experimental** diff --git a/docs/pymc-usage.md b/docs/pymc-usage.md new file mode 100644 index 0000000..e66c098 --- /dev/null +++ b/docs/pymc-usage.md @@ -0,0 +1 @@ +# Usage with PyMC models diff --git a/docs/stan-usage.md b/docs/stan-usage.md new file mode 100644 index 0000000..78259ef --- /dev/null +++ b/docs/stan-usage.md @@ -0,0 +1,3 @@ +# Usage with Stan models + +foobar diff --git a/python/nutpie/compiled_pyfunc.py b/python/nutpie/compiled_pyfunc.py index 3e4e605..dca71ed 100644 --- a/python/nutpie/compiled_pyfunc.py +++ b/python/nutpie/compiled_pyfunc.py @@ -9,6 +9,304 @@ from nutpie.sample import CompiledModel +def make_transform_adapter(*, verbose=False, window_size=2000): + import jax + import equinox as eqx + import jax.numpy as jnp + import flowjax + import flowjax.train + import flowjax.flows + import optax + import traceback + + class FisherLoss: + @eqx.filter_jit + def __call__( + self, + params, + static, + x, + condition=None, + key=None, + ): + flow = flowjax.train.losses.unwrap( + eqx.combine(params, static, is_leaf=eqx.is_inexact_array) + ) + + def compute_loss(bijection, draw, grad): + draw, grad, _ = bijection.inverse_gradient_and_val( + draw, grad, jnp.array(0.0) + ) + return ((draw + grad) ** 2).sum() + + assert x.shape[1] == 2 + draws = x[:, 0, :] + grads = x[:, 1, :] + return jnp.log( + jax.vmap(compute_loss, [None, 0, 0])( + flow.bijection, draws, grads + ).mean() + ) + + def _get_untransformed(bijection, draw_trafo, grad_trafo): + bijection = flowjax.train.losses.unwrap(bijection) + draw = bijection.inverse(draw_trafo) + _, pull_grad_fn = jax.vjp(bijection.transform_and_log_det, draw) + (grad,) = pull_grad_fn((grad_trafo, 1.0)) + return draw, grad + + pull_points = eqx.filter_jit(jax.vmap(_get_untransformed, [None, 0, 0])) + + def fit_flow(key, bijection, positions, gradients, **kwargs): + flow = flowjax.flows.Transformed( + flowjax.distributions.StandardNormal(bijection.shape), bijection + ) + + points = jnp.transpose(jnp.array([positions, gradients]), [1, 0, 2]) + + key, train_key = jax.random.split(key) + + fit, losses, opt_state = flowjax.train.fit_to_data( + key=train_key, + dist=flow, + x=points, + loss_fn=FisherLoss(), + **kwargs, + ) + + draws_pulled, grads_pulled = pull_points(fit.bijection, positions, gradients) + final_cost = np.log(((draws_pulled + grads_pulled) ** 2).sum(1).mean(0)) + return fit, final_cost, opt_state + + def make_flow(seed, positions, gradients, *, n_layers): + positions = np.array(positions) + gradients = np.array(gradients) + + n_draws, n_dim = positions.shape + + if n_dim < 2: + n_layers = 0 + + assert positions.shape == gradients.shape + assert n_draws > 0 + + if n_draws == 0: + raise ValueError("No draws") + elif n_draws == 1: + diag = 1 / jnp.abs(gradients[0]) + mean = jnp.zeros_like(diag) + else: + diag = jnp.sqrt(positions.std(0) / gradients.std(0)) + mean = positions.mean(0) + diag * gradients.mean(0) + + key = jax.random.PRNGKey(seed % (2**63)) + + flows = [ + flowjax.flows.Affine(loc=mean, scale=diag), + ] + + for layer in range(n_layers): + key, key_couple, key_permute, key_init = jax.random.split(key, 4) + + scale = flowjax.wrappers.Parameterize( + lambda x: jnp.exp(jnp.arcsinh(x)), jnp.array(0.0) + ) + affine = eqx.tree_at( + where=lambda aff: aff.scale, + pytree=flowjax.bijections.Affine(), + replace=scale, + ) + + coupling = flowjax.bijections.coupling.Coupling( + key_couple, + transformer=affine, + untransformed_dim=n_dim // 2, + dim=n_dim, + nn_activation=jax.nn.gelu, + nn_width=n_dim // 2, + nn_depth=1, + ) + + if layer == n_layers - 1: + flow = coupling + else: + flow = flowjax.flows._add_default_permute(coupling, n_dim, key_permute) + + flows.append(flow) + + return flowjax.bijections.Chain(flows[::-1]) + + @eqx.filter_jit + def _init_from_transformed_position(logp_fn, bijection, transformed_position): + bijection = flowjax.train.losses.unwrap(bijection) + (untransformed_position, logdet), pull_grad = jax.vjp( + bijection.transform_and_log_det, transformed_position + ) + logp, untransformed_gradient = jax.value_and_grad(lambda x: logp_fn(x)[0])( + untransformed_position + ) + (transformed_gradient,) = pull_grad((untransformed_gradient, 1.0)) + return ( + logp, + logdet, + untransformed_position, + untransformed_gradient, + transformed_gradient, + ) + + @eqx.filter_jit + def _init_from_untransformed_position(logp_fn, bijection, untransformed_position): + logp, untransformed_gradient = jax.value_and_grad(lambda x: logp_fn(x)[0])( + untransformed_position + ) + logdet, transformed_position, transformed_gradient = _inv_transform( + bijection, untransformed_position, untransformed_gradient + ) + return ( + logp, + logdet, + untransformed_gradient, + transformed_position, + transformed_gradient, + ) + + @eqx.filter_jit + def _inv_transform(bijection, untransformed_position, untransformed_gradient): + bijection = flowjax.train.losses.unwrap(bijection) + transformed_position, transformed_gradient, logdet = ( + bijection.inverse_gradient_and_val( + untransformed_position, untransformed_gradient, 0.0 + ) + ) + return -logdet, transformed_position, transformed_gradient + + class TransformAdapter: + def __init__( + self, + seed, + position, + gradient, + chain, + *, + logp_fn, + make_flow_fn, + verbose=False, + window_size=2000, + ): + self._logp_fn = logp_fn + self._make_flow_fn = make_flow_fn + self._chain = chain + self._verbose = verbose + self._window_size = window_size + try: + self._bijection = make_flow_fn(seed, [position], [gradient], n_layers=0) + except Exception as e: + print("make_flow", e) + print(traceback.format_exc()) + raise + self.index = 0 + + @property + def transformation_id(self): + return self.index + + def update(self, seed, positions, gradients): + self.index += 1 + if self._verbose: + print(f"Chain {self._chain}: Total available points: {len(positions)}") + n_draws = len(positions) + if n_draws == 0: + return + try: + if self.index <= 10: + self._bijection = self._make_flow_fn( + seed, positions[-10:], gradients[-10:], n_layers=0 + ) + return + + positions = np.array(positions[-self._window_size :]) + gradients = np.array(gradients[-self._window_size :]) + + assert np.isfinite(positions).all() + assert np.isfinite(gradients).all() + + if len(self._bijection.bijections) == 1: + self._bijection = self._make_flow_fn( + seed, positions, gradients, n_layers=8 + ) + + # make_flow might still only return a single trafo if the for 1d problems + if len(self._bijection.bijections) == 1: + return + + # TODO don't reuse seed + key = jax.random.PRNGKey(seed % (2**63)) + fit, final_cost, _ = fit_flow( + key, + self._bijection, + positions, + gradients, + show_progress=self._verbose, + optimizer=optax.adabelief(1e-3), + batch_size=128, + ) + if self._verbose: + print(f"Chain {self._chain}: final cost {final_cost}") + if np.isfinite(final_cost).all(): + self._bijection = fit.bijection + else: + self._bijection = self._make_flow_fn( + seed, positions, gradients, n_layers=0 + ) + except Exception as e: + print("update error:", e) + print(traceback.format_exc()) + + def init_from_transformed_position(self, transformed_position): + try: + logp, logdet, *arrays = _init_from_transformed_position( + self._logp_fn, + self._bijection, + jnp.array(transformed_position), + ) + return float(logp), float(logdet), *[np.array(val) for val in arrays] + except Exception as e: + print(e) + print(traceback.format_exc()) + raise + + def init_from_untransformed_position(self, untransformed_position): + try: + logp, logdet, *arrays = _init_from_untransformed_position( + self._logp_fn, + self._bijection, + jnp.array(untransformed_position), + ) + return float(logp), float(logdet), *[np.array(val) for val in arrays] + except Exception as e: + print(e) + print(traceback.format_exc()) + raise + + def inv_transform(self, position, gradient): + try: + logdet, *arrays = _inv_transform( + self._bijection, jnp.array(position), jnp.array(gradient) + ) + return logdet, *[np.array(val) for val in arrays] + except Exception as e: + print(e) + print(traceback.format_exc()) + raise + + return partial( + TransformAdapter, + verbose=verbose, + window_size=window_size, + make_flow_fn=make_flow, + ) + + @dataclass(frozen=True) class PyFuncModel(CompiledModel): _make_logp_func: Callable @@ -59,19 +357,17 @@ def make_expand_func(seed1, seed2, chain): expand_fn = self._make_expand_func(seed1, seed2, chain) return partial(expand_fn, **self._shared_data) - if self._make_transform_adapter is not None: - make_transform_adapter = partial( - self._make_transform_adapter, logp_fn=self._raw_logp_fn - ) + if self._raw_logp_fn is not None: + make_adapter = partial(make_transform_adapter(), logp_fn=self._raw_logp_fn) else: - make_transform_adapter = None + make_adapter = None return _lib.PyModel( make_logp_func, make_expand_func, self._variables, self.n_dim, - make_transform_adapter, + make_adapter, )