A NumPyro-compatible JAX implementation of AutoStep samplers
pip install "autostep @ git+https://github.com/UBC-Stat-ML/autostep.git"We apply autoHMC to the classic toy eight schools problem. We use all default
settings (32 leapfrog steps, DeterministicSymmetricSelector for the step
size adaptation critetion), except for the preconditioner. Since the problem
is low dimensional, we can afford to use a full dense mass matrix to drastically
improve the conditioning of the target.
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC
from autostep import preconditioning
from autostep.autohmc import AutoHMC
from autostep import utils
# define model
y = jnp.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = jnp.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
def eight_schools(sigma, y=None):
mu = numpyro.sample('mu', dist.Normal(0, 5))
tau = numpyro.sample('tau', dist.HalfCauchy(5))
with numpyro.plate('J', len(sigma)):
theta = numpyro.sample('theta', dist.Normal(mu, tau))
numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)
# instantiate sampler and run
n_rounds = 12
n_warmup, n_keep = utils.split_n_rounds(n_rounds) # translate rounds to warmup/keep
kernel = AutoHMC(
eight_schools,
preconditioner = preconditioning.FixedDensePreconditioner()
)
mcmc = MCMC(kernel, num_warmup=n_warmup, num_samples=n_keep)
mcmc.run(random.key(9), sigma, y=y)
mcmc.print_summary()sample: 100%|███| 8190/8190 [00:13<00:00, 614.36it/s, base_step 1.01e-01, rev_rate=0.96, acc_prob=0.90]
mean std median 5.0% 95.0% n_eff r_hat
mu 4.38 3.40 4.36 -1.31 9.54 11296.80 1.00
tau 3.67 3.23 2.73 0.08 7.85 224.57 1.00
theta[0] 6.28 5.75 5.74 -2.11 15.37 1661.16 1.00
theta[1] 4.87 4.76 4.91 -4.28 11.59 7753.07 1.00
theta[2] 3.90 5.48 4.18 -4.87 12.24 4461.17 1.00
theta[3] 4.83 4.99 4.83 -3.90 11.89 6988.49 1.00
theta[4] 3.60 4.72 3.76 -4.05 11.02 3519.48 1.00
theta[5] 3.92 4.88 4.14 -3.75 11.89 4680.06 1.00
theta[6] 6.30 5.13 5.85 -1.38 14.61 1916.97 1.00
theta[7] 4.89 5.22 4.81 -3.07 13.13 5879.55 1.00
In less than 15 seconds, the sampler achieves r_hat~1 across latent variables,
as well as a minimum effective sample size of 224.57.
- autoHMC with randomized number of steps (RHMC)
- Re-implement the
MixDiagonalPreconditionerin the new framework
Biron-Lattes, M., Surjanovic, N., Syed, S., Campbell, T., & Bouchard-Côté, A.. (2024). autoMALA: Locally adaptive Metropolis-adjusted Langevin algorithm. Proceedings of The 27th International Conference on Artificial Intelligence and Statistics, in Proceedings of Machine Learning Research 238:4600-4608.
Liu, T., Surjanovic, N., Biron-Lattes, M., Bouchard-Côté, A., & Campbell, T. (2024). AutoStep: Locally adaptive involutive MCMC. ICML 2025.