-
Notifications
You must be signed in to change notification settings - Fork 48
Implements a simple Nutpie style adaptation (using both positions and gradients, but not changing the schedule). #473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Gonna ask some questions before I move on: Currently, the way I change the used mass matrix adaptor feels a bit hacky, reproduced below: adaptor = AdvancedHMC.StanHMCAdaptor(
AdvancedHMC.Adaptation.NutpieVar(size(metric); var=copy(metric.M⁻¹)),
AdvancedHMC.StepSizeAdaptor(spl.δ, integrator)
)
h, t = AdvancedHMC.sample_init(rng, hamiltonian, initial_params)
# Using the below uses Nutpie (as in position and gradients)
initial_state = AdvancedHMC.HMCState(0, t, metric, κ, adaptor)
# Using the below uses Stan (as in only positions)
# initial_state = nothing
@time samples = AdvancedHMC.sample(
rng,
model,
spl,
n_adapts + n_samples;
n_adapts=n_adapts, initial_state,
progress=true,
)Is there currently no easier way to specify what kind of adaptation to use, ideally just via some (keyword) argument to the sample function? Gonna also tag @penelopeysm and @mhauru who might know or have opinions on how to change the public API :) |
|
After chatting with or at @penelopeysm I've opened #475 and think that this PR should only implement what it's currently doing. I don't know whether we even want to export the defined struct currently - maybe. The main thing where I might need help is to figure out whether the needed changes to the |
|
Reproducing the code to demo the changes in this PR at the end of this comment. There's maybe one thing I'm unhappy with in this PR: There's a bunch of code duplication for the I needed to pass the position+gradient information through to the mass matrix adaptor, and the easiest way to do that was to allow a using AdvancedHMC, PosteriorDB, StanLogDensityProblems, Random, MCMCDiagnosticTools
if !@isdefined pdb
const pdb = PosteriorDB.database()
end
stan_problem(path, data) = StanProblem(
path, data;
nan_on_error=true,
make_args=["STAN_THREADS=TRUE"],
warn=false
)
stan_problem(posterior_name::AbstractString) = stan_problem(
PosteriorDB.path(PosteriorDB.implementation(PosteriorDB.model(PosteriorDB.posterior(pdb, (posterior_name))), "stan")),
PosteriorDB.load(PosteriorDB.dataset(PosteriorDB.posterior(pdb, (posterior_name))), String)
)
begin
lpdf = stan_problem("radon_mn-radon_county_intercept")
n_adapts = n_samples = 1000
rng = Xoshiro(2)
spl = NUTS(0.8)
initial_params = nothing
model = AdvancedHMC.AbstractMCMC._model(lpdf)
(;logdensity) = model
metric = AdvancedHMC.make_metric(spl, logdensity)
hamiltonian = AdvancedHMC.Hamiltonian(metric, model)
initial_params = AdvancedHMC.make_initial_params(rng, spl, logdensity, initial_params)
ϵ = AdvancedHMC.make_step_size(rng, spl, hamiltonian, initial_params)
integrator = AdvancedHMC.make_integrator(spl, ϵ)
κ = AdvancedHMC.make_kernel(spl, integrator)
adaptor = AdvancedHMC.StanHMCAdaptor(
AdvancedHMC.Adaptation.NutpieVar(size(metric); var=copy(metric.M⁻¹)),
AdvancedHMC.StepSizeAdaptor(spl.δ, integrator)
)
h, t = AdvancedHMC.sample_init(rng, hamiltonian, initial_params)
performances = map((;nutpie=AdvancedHMC.HMCState(0, t, metric, κ, adaptor), stan=nothing)) do initial_state
dt = @elapsed samples = AdvancedHMC.sample(
rng,
model,
spl,
n_adapts + n_samples;
n_adapts=n_adapts, initial_state,
progress=true,
)
ess(reshape(mapreduce(sample->sample.z.θ , hcat, samples[n_adapts+1:end])', (n_samples, 1, :))) |> minimum |> Base.Fix2(/, dt)
end
@info (;performances)
end |
|
Hm - I'm pretty sure the failings tests are not due to my changes - what's up with that? |
WIP that partially addresses #311 and supersedes #312.
There's a demo in
tmp/demo.jl, which certainly does something that finishes quicker than the current default.There are currently no additional tests, and I'm sure a few things are currently broken due to my changes.
Gonna tag @sethaxen, @aseyboldt, @svilupp, and maybe @yebai.