Skip to content

Commit

Permalink
Revival of MCMCTempering.jl (#158)
Browse files Browse the repository at this point in the history
* Bump breaking releases for project deps

* Updated `init_params` kwarg which is now `initial_params` after [email protected]

* Updated more `init_params` to `initial_params`

* Fixed incorrect references to stuff in AdvancedMH.jl and AdvancedHMC.jl

* Added fix for AdvancedHMC.jl needing `n_adapts` as kwarg

* Replaced DynamicPPL.jl model since the inverse gamma conjugate model
has very high variance fwhen working with so few observations + using
initial parameters in constrained space rather than unconstrained
space since this has been addressed in `Turing.externalsampler`

* Relaxed `significance` in `atol_for_chain`

* Relaxed on test slightly

* Updated docs

* Use slightly better initialization for Turing.jl tests

* Added some debugging statement for tests

* Use `median` instead of `mean` in a test to lower the variance of the
estimate a bit

* Further relaxation of tests

* Narrowed `signficiance` in `atol_for_chain` but added `min_atol` kwarg
to alllow us to specify a minimum atol that we're happy with

* Use elementwise comparison for a mean rather than vector-based approx

* Made the test suite a bit more consistent by using explicit RNGs here
and there

* Forgot an RNG in one `sample`
  • Loading branch information
torfjelde authored Sep 30, 2024
1 parent 33414c0 commit 600b82c
Show file tree
Hide file tree
Showing 11 changed files with 164 additions and 71 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MCMCTempering"
uuid = "ce233488-44ea-4441-b732-192676ce2298"
authors = ["Harrison Wilde <[email protected]> and contributors"]
version = "0.3.2"
version = "0.4.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -17,7 +17,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"

[compat]
AbstractMCMC = "3.2, 4"
AbstractMCMC = "5"
ConcreteStructs = "0.2"
Distributions = "0.24, 0.25"
DocStringExtensions = "0.8, 0.9"
Expand Down
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ makedocs(
sitename = "MCMCTempering",
format = Documenter.HTML(),
modules = [MCMCTempering],
pages=["Home" => "index.md", "getting-started.md", "api.md"],
pages = ["Home" => "index.md", "getting-started.md", "api.md"],
warnonly = true
)

# Deply!
Expand Down
11 changes: 11 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,17 @@ MCMCTempering.swap_step

## Other samplers

To make a sampler work with MCMCTempering.jl, the sampler needs to implement a few methods:

```@docs
MCMCTempering.getparams
MCMCTempering.getlogprob
MCMCTempering.getparams_and_logprob
MCMCTempering.setparams_and_logprob!!
```

Other useful methods are:

```@docs
MCMCTempering.saveall
```
Expand Down
16 changes: 8 additions & 8 deletions docs/src/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ Let's instead try to use a _tempered_ version of `RWMH`. _But_ before we do that

To do that we need to implement two methods. First we need to tell MCMCTempering how to extract the parameters, and potentially the log-probabilities, from a `AdvancedMH.Transition`:

```@docs
```@docs; canonical=false
MCMCTempering.getparams_and_logprob
```

And similarly, we need a way to _update_ the parameters and the log-probabilities of a `AdvancedMH.Transition`:

```@docs
```@docs; canonical=false
MCMCTempering.setparams_and_logprob!!
```

Expand Down Expand Up @@ -159,10 +159,7 @@ using AdvancedHMC: AdvancedHMC
using ForwardDiff: ForwardDiff # for automatic differentation of the logdensity
# Creation of the sampler.
metric = AdvancedHMC.DiagEuclideanMetric(1)
integrator = AdvancedHMC.Leapfrog(0.1)
proposal = AdvancedHMC.StaticTrajectory(integrator, 8)
sampler = AdvancedHMC.HMCSampler(proposal, metric)
sampler = AdvancedHMC.HMC(0.1, 8)
sampler_tempered = MCMCTempering.TemperedSampler(sampler, inverse_temperatures)
# Sample!
Expand All @@ -172,6 +169,7 @@ chain = sample(
target_model, sampler, num_iterations;
chain_type=MCMCChains.Chains,
param_names=["x"],
n_adapts=0, # HACK: need this to make AdvancedHMC.jl happy :/
)
plot(chain, size=figsize)
```
Expand Down Expand Up @@ -205,7 +203,8 @@ chain_tempered_all = sample(
StableRNG(42),
target_model, sampler_tempered, num_iterations;
chain_type=Vector{MCMCChains.Chains},
param_names=["x"]
param_names=["x"],
n_adapts=0, # HACK: need this to make AdvancedHMC.jl happy :/
);
```

Expand Down Expand Up @@ -289,7 +288,8 @@ chain_tempered_all = sample(
StableRNG(42),
target_model, sampler_tempered, num_iterations;
chain_type=Vector{MCMCChains.Chains},
param_names=["x"]
param_names=["x"],
n_adapts=0, # HACK: need this to make AdvancedHMC.jl happy :/
);
```

Expand Down
10 changes: 5 additions & 5 deletions src/samplers/multi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,16 @@ function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::MultiModel,
sampler::MultiSampler;
init_params=nothing,
initial_params=nothing,
kwargs...
)
@assert length(model.models) == length(sampler.samplers) "Number of models and samplers must be equal"

# TODO: Handle `init_params` properly. Make sure that they respect the container-types used in
# TODO: Handle `initial_params` properly. Make sure that they respect the container-types used in
# `MultiModel` and `MultiSampler`.
init_params_multi = initparams(model, init_params)
transition_and_states = asyncmap(model.models, sampler.samplers, init_params_multi) do model, sampler, init_params
AbstractMCMC.step(rng, model, sampler; init_params, kwargs...)
initial_params_multi = initparams(model, initial_params)
transition_and_states = asyncmap(model.models, sampler.samplers, initial_params_multi) do model, sampler, initial_params
AbstractMCMC.step(rng, model, sampler; initial_params, kwargs...)
end

return MultipleTransitions(map(first, transition_and_states)), MultipleStates(map(last, transition_and_states))
Expand Down
10 changes: 5 additions & 5 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
AbstractMCMC = "3.2, 4"
AdvancedHMC = "0.4"
AdvancedMH = "0.7"
Bijectors = "0.10"
AbstractMCMC = "5"
AdvancedHMC = "0.6"
AdvancedMH = "0.8"
Bijectors = "0.13"
Distributions = "0.24, 0.25"
LogDensityProblems = "2"
LogDensityProblemsAD = "1"
MCMCChains = "6"
Turing = "0.24"
Turing = "0.34"
julia = "1"
9 changes: 6 additions & 3 deletions test/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@

@testset "SwapSampler" begin
# SwapSampler without tempering (i.e. in a composition and using `MultiModel`, etc.)
init_params = [[5.0], [5.0]]
initial_params = [[5.0], [5.0]]
mdl1 = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), DistributionLogDensity(Normal(4.9999, 1)))
mdl2 = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), DistributionLogDensity(Normal(5.0001, 1)))
spl1 = RWMH(MvNormal(Zeros(dimension(mdl1)), I))
Expand All @@ -181,7 +181,7 @@

# Sample!
rng = Random.default_rng()
transition, state = AbstractMCMC.step(rng, product_model, spl_full; init_params)
transition, state = AbstractMCMC.step(rng, product_model, spl_full; initial_params)
transitions = typeof(transition)[]

# A bit of warm-up.
Expand Down Expand Up @@ -222,7 +222,10 @@

# Check that means are roughly okay.
params_resolved = map(first MCMCTempering.getparams, transitions_resolved_inner)
@test vec(mean(params_resolved; dims=2)) [5.0, 5.0] atol = 0.3
mean_tmp = vec(median(params_resolved; dims=2))
for i in 1:2
@test mean_tmp[i] 5.0 atol = 0.5
end

# A composition of `SwapSampler` and `MultiSampler` has special `AbstractMCMC.bundle_samples`.
@testset "bundle_samples with Vector" begin
Expand Down
6 changes: 5 additions & 1 deletion test/compat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ MCMCTempering.getparams_and_logprob(transition::AdvancedMH.GradientTransition) =
function MCMCTempering.setparams_and_logprob!!(model, transition::AdvancedMH.GradientTransition, params, lp)
# NOTE: We have to re-compute the gradient here because this will be used in the subsequent `step` for
# the MALA sampler.
return AdvancedMH.GradientTransition(params, AdvancedMH.logdensity_and_gradient(model, params)...)
return AdvancedMH.GradientTransition(
params,
AdvancedMH.logdensity_and_gradient(model, params)...,
transition.accepted
)
end

# AdvancedHMC.jl
Expand Down
90 changes: 63 additions & 27 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Several properties of the tempered sampler are tested before returning:
- `adapt_atol`: The absolute tolerance for the check of average swap acceptance rate and target swap acceptance rate. Defaults to `0.05`.
- `mean_swap_rate_bound`: A bound on the acceptance rate of swaps performed, e.g. if set to `0.1` and `compare_mean_swap_rate=≥` then at least 10% of attempted swaps should be accepted. Defaults to `0.1`.
- `compare_mean_swap_rate`: a binary function for comparing average swap rate against `mean_swap_rate_bound`. Defaults to `≥`.
- `init_params`: The initial parameters to use for the sampler. Defaults to `nothing`.
- `initial_params`: The initial parameters to use for the sampler. Defaults to `nothing`.
- `param_names`: The names of the parameters in the chain; used to construct the resulting chain. Defaults to `missing`.
- `progress`: Whether to show a progress bar. Defaults to `false`.
"""
Expand All @@ -41,10 +41,12 @@ function test_and_sample_model(
adapt_target=0.234,
adapt_rtol=0.1,
adapt_atol=0.05,
init_params=nothing,
initial_params=nothing,
param_names=missing,
progress=false,
minimum_roundtrips=nothing
minimum_roundtrips=nothing,
rng=make_rng(),
kwargs...
)
# Make the tempered sampler.
sampler_tempered = tempered(
Expand All @@ -64,8 +66,9 @@ function test_and_sample_model(

# Sample.
samples_tempered = AbstractMCMC.sample(
model, sampler_tempered, num_iterations;
callback=callback, progress=progress, init_params=init_params
rng, model, sampler_tempered, num_iterations;
callback=callback, progress=progress, initial_params=initial_params,
kwargs...
)

if !isnothing(minimum_roundtrips)
Expand Down Expand Up @@ -245,7 +248,7 @@ end
[1.0, 1e-3], # extreme temperatures -> don't exect much swapping to occur
num_iterations=num_iterations,
adapt=false,
init_params=[[0.0], [1000.0]], # initialized far apart
initial_params=[[0.0], [1000.0]], # initialized far apart
# At MOST 1% of swaps should be successful.
mean_swap_rate_bound=0.01,
compare_mean_swap_rate=,
Expand Down Expand Up @@ -302,7 +305,7 @@ end
# Set up our sampler with a joint multivariate Normal proposal.
sampler = RWMH(MvNormal(zeros(d), Diagonal(σ_true .^ 2)))
# Sample for the non-tempered model for comparison.
samples = AbstractMCMC.sample(model, sampler, num_iterations; progress=false)
samples = AbstractMCMC.sample(make_rng(), model, sampler, num_iterations; progress=false)
chain = AbstractMCMC.bundle_samples(
samples, MCMCTempering.maybe_wrap_model(model), sampler, samples[1], MCMCChains.Chains
)
Expand All @@ -326,19 +329,41 @@ end
adapt=false,
# Make sure we have _some_ roundtrips.
minimum_roundtrips=10,
rng=make_rng(),
)

compare_chains(chain, chain_tempered, rtol=0.1, compare_ess=true)
# Some swap strategies are not great.
ess_slack_ratio = if swap_strategy isa Union{MCMCTempering.SingleRandomSwap,MCMCTempering.SingleSwap}
0.25
else
0.5
end
compare_chains(chain, chain_tempered, rtol=0.1, compare_ess=true, compare_ess_slack=ess_slack_ratio)
end
end

@testset "Turing.jl" begin
# Let's make a default seed we can `deepcopy` throughout to get reproducible results.
seed = 42

# And let's set the seed explicitly for reproducibility.
Random.seed!(seed)

# Instantiate model.
model_dppl = DynamicPPL.TestUtils.demo_assume_dot_observe()
DynamicPPL.@model function demo_model(x)
s ~ Exponential()
m ~ Normal(0, s)
x .~ Normal(m, s)
end
xs_true = rand(Normal(2, 1), 100)
model_dppl = demo_model(xs_true)

# Move to unconstrained space.
vi = DynamicPPL.VarInfo(model_dppl)
# Move to unconstrained space.
vi = DynamicPPL.link!!(DynamicPPL.VarInfo(model_dppl), model_dppl)
vi = DynamicPPL.link!!(vi, model_dppl)
# Get some initial values in unconstrained space.
init_params = copy(vi[:])
initial_params = rand(length(vi[:]))
# Get the parameter names.
param_names = map(Symbol, DynamicPPL.TestUtils.varnames(model_dppl))
# Get bijector so we can get back to unconstrained space afterwards.
Expand All @@ -349,9 +374,9 @@ end
@testset "Tempering of models" begin
beta = 0.5
model_tempered = MCMCTempering.make_tempered_model(model, beta)
@test logdensity(model_tempered, init_params) beta * logdensity(model, init_params)
@test last(logdensity_and_gradient(model_tempered, init_params))
beta .* last(logdensity_and_gradient(model, init_params))
@test logdensity(model_tempered, initial_params) beta * logdensity(model, initial_params)
@test last(logdensity_and_gradient(model_tempered, initial_params))
beta .* last(logdensity_and_gradient(model, initial_params))
end

@testset "AdvancedHMC.jl" begin
Expand All @@ -360,12 +385,19 @@ end
# Set up HMC smpler.
initial_ϵ = 0.1
integrator = AdvancedHMC.Leapfrog(initial_ϵ)
proposal = AdvancedHMC.NUTS{AdvancedHMC.MultinomialTS, AdvancedHMC.GeneralisedNoUTurn}(integrator)
proposal = AdvancedHMC.HMCKernel(AdvancedHMC.Trajectory{AdvancedHMC.MultinomialTS}(
integrator, AdvancedHMC.GeneralisedNoUTurn()
))
metric = AdvancedHMC.DiagEuclideanMetric(LogDensityProblems.dimension(model))
sampler_hmc = AdvancedHMC.HMCSampler(proposal, metric)
sampler_hmc = AdvancedHMC.HMCSampler(proposal, metric, AdvancedHMC.NoAdaptation())

# Sample using HMC.
samples_hmc = sample(model, sampler_hmc, num_iterations; init_params=copy(init_params), progress=false)
samples_hmc = sample(
make_rng(seed), model, sampler_hmc, num_iterations;
n_adapts=0, # FIXME(torfjelde): Remove once AHMC.jl has fixed.
initial_params=copy(initial_params),
progress=false
)
chain_hmc = AbstractMCMC.bundle_samples(
samples_hmc, MCMCTempering.maybe_wrap_model(model), sampler_hmc, samples_hmc[1], MCMCChains.Chains;
param_names=param_names,
Expand All @@ -375,8 +407,9 @@ end
# Make sure that we get the "same" result when only using the inverse temperature 1.
sampler_tempered = MCMCTempering.TemperedSampler(sampler_hmc, [1])
chain_tempered = sample(
model, sampler_tempered, num_iterations;
init_params=copy(init_params),
make_rng(seed), model, sampler_tempered, num_iterations;
n_adapts=0, # FIXME(torfjelde): Remove once AHMC.jl has fixed.
initial_params=copy(initial_params),
chain_type=MCMCChains.Chains,
param_names=param_names,
progress=false,
Expand All @@ -398,9 +431,11 @@ end
num_iterations=num_iterations,
adapt=false,
mean_swap_rate_bound=0.1,
init_params=copy(init_params),
initial_params=copy(initial_params),
param_names=param_names,
progress=false
progress=false,
n_adapts=0, # FIXME(torfjelde): Remove once AHMC.jl has fixed.
rng=make_rng(seed),
)
map_parameters!(b, chain_tempered)
compare_chains(
Expand All @@ -421,8 +456,8 @@ end

# Sample using MALA.
chain_mh = AbstractMCMC.sample(
model, sampler_mh, num_iterations;
init_params=copy(init_params),
make_rng(), model, sampler_mh, num_iterations;
initial_params=copy(initial_params),
progress=false,
chain_type=MCMCChains.Chains,
param_names=param_names,
Expand All @@ -432,8 +467,8 @@ end
# Make sure that we get the "same" result when only using the inverse temperature 1.
sampler_tempered = MCMCTempering.TemperedSampler(sampler_mh, [1])
chain_tempered = sample(
model, sampler_tempered, num_iterations;
init_params=copy(init_params),
make_rng(), model, sampler_tempered, num_iterations;
initial_params=copy(initial_params),
chain_type=MCMCChains.Chains,
param_names=param_names,
progress=false,
Expand All @@ -455,8 +490,9 @@ end
num_iterations=num_iterations,
adapt=false,
mean_swap_rate_bound=0.1,
init_params=copy(init_params),
param_names=param_names
initial_params=copy(initial_params),
param_names=param_names,
rng=make_rng(),
)
map_parameters!(b, chain_tempered)

Expand Down
Loading

0 comments on commit 600b82c

Please sign in to comment.