From 4d48a35f43bbfcee385aeccc539c3631cc459ce9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 11 Sep 2021 01:46:54 +0100 Subject: [PATCH 01/51] additional deps and test deps --- Project.toml | 9 ++++++++- src/MCMCTempering.jl | 5 ++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index fa54387..102bf22 100644 --- a/Project.toml +++ b/Project.toml @@ -5,8 +5,10 @@ version = "0.1.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" [compat] AbstractMCMC = "3.2" @@ -14,7 +16,12 @@ Distributions = "0.24, 0.25" julia = "1" [extras] +AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["Test", "AdvancedMH", "MCMCChains", "Bijectors", "StatsPlots", "LinearAlgebra"] diff --git a/src/MCMCTempering.jl b/src/MCMCTempering.jl index 36041f8..54833af 100644 --- a/src/MCMCTempering.jl +++ b/src/MCMCTempering.jl @@ -4,12 +4,15 @@ import AbstractMCMC import Distributions import Random +using ConcreteStructs: @concrete +using Setfield: @set, @set! + include("adaptation.jl") +include("swapping.jl") include("tempered.jl") include("ladders.jl") include("stepping.jl") include("model.jl") -include("swapping.jl") include("plotting.jl") export tempered, TemperedSampler, plot_swaps, plot_ladders, make_tempered_model, get_tempered_loglikelihoods_and_params, make_tempered_loglikelihood, get_params From 2df4a2b449899ba43e67729cf959bf7bc19ca353 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 4 Oct 2021 16:31:20 +0100 Subject: [PATCH 02/51] updated stepping code to use AbstractSwapStrategy and made TemperedSampler concrete --- src/stepping.jl | 109 +++++++++++++++++++++++++++++------------------- 1 file changed, 66 insertions(+), 43 deletions(-) diff --git a/src/stepping.jl b/src/stepping.jl index 3521e81..e9d3ce9 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -41,19 +41,16 @@ Chains: chain_index: Δ_index: | | | | 2 3 1 4 3 1 2 4 | | | | """ -mutable struct TemperedState - states :: Array{Any} - Δ :: Vector{<:Real} - Δ_index :: Vector{<:Integer} - chain_index :: Vector{<:Integer} - step_counter :: Integer - total_steps :: Integer - Δ_history :: Array{<:Real, 2} - Δ_index_history :: Array{<:Integer, 2} - Ρ :: Vector{AdaptiveState} +@concrete struct TemperedState + states + Δ + Δ_index + chain_index + step_counter + total_steps + Ρ end - """ For each `β` in `Δ`, carry out a step with a tempered model at the corresponding `β` inverse temperature, resulting in a list of transitions and states, the transition associated with `β₀ = 1` is then returned with the @@ -63,6 +60,7 @@ function AbstractMCMC.step( rng::Random.AbstractRNG, model, spl::TemperedSampler; + init_params=nothing, kwargs... ) states = [ @@ -70,14 +68,15 @@ function AbstractMCMC.step( rng, make_tempered_model(model, spl.Δ[spl.Δ_init[i]]), spl.internal_sampler; + init_params=init_params !== nothing ? init_params[i] : nothing, kwargs... ) for i in 1:length(spl.Δ) ] return ( - states[sortperm(spl.Δ_init)[1]][1], + first(states[argmax(spl.Δ_init)]), TemperedState( - states,spl.Δ, spl.Δ_init, sortperm(spl.Δ_init), 1, 1, Array{Real, 2}(spl.Δ'), Array{Integer, 2}(spl.Δ_init'), spl.Ρ + states, spl.Δ, spl.Δ_init, sortperm(spl.Δ_init), 1, 1, spl.Ρ ) ) end @@ -90,30 +89,29 @@ function AbstractMCMC.step( ) if ts.step_counter == spl.N_swap ts = swap_step(rng, model, spl, ts) - ts.step_counter = 0 + @set! ts.step_counter = 0 else - ts.states = [ + @set! ts.states = [ AbstractMCMC.step( rng, make_tempered_model(model, ts.Δ[ts.Δ_index[i]]), spl.internal_sampler, - ts.states[i][2]; + ts.states[ts.chain_index[i]][2]; kwargs... ) for i in 1:length(ts.Δ) ] - ts.step_counter += 1 + @set! ts.step_counter += 1 end - ts.Δ_history = vcat(ts.Δ_history, Array{Real, 2}(ts.Δ')) - ts.Δ_index_history = vcat(ts.Δ_index_history, Array{Integer, 2}(ts.Δ_index')) - ts.total_steps += 1 - return ts.states[ts.chain_index[1]][1], ts # Use chain_index[1] to ensure the sample from the target is always returned for the step + @set! ts.total_steps += 1 + # Use `chain_index[1]` to ensure the sample from the target is always returned for the step. + return ts.states[ts.chain_index[1]][1], ts end """ - swap_step(rng, model, spl, ts) + swap_step([strategy::SwapStrategy, ]rng, model, spl, ts) Uses the internals of the passed `TemperedSampler` - `spl` - and `TemperedState` - `ts` - to perform a "swap step" between temperatures, in accordance with the relevant @@ -122,35 +120,60 @@ swap strategy. function swap_step( rng::Random.AbstractRNG, model, - spl::TemperedSampler, + sampler::TemperedSampler, + ts::TemperedState +) + return swap_step(swapstrategy(sampler), rng, model, sampler, ts) +end + +function swap_step( + strategy::StandardSwap, + rng::Random.AbstractRNG, + model, + sampler::TemperedSampler, ts::TemperedState ) L = length(ts.Δ) - 1 - sampler = spl.internal_sampler + k = rand(rng, 1:L) + return swap_attempt(rng, model, sampler.internal_sampler, ts, k, sampler.adapt, ts.total_steps / L) +end - if spl.swap_strategy == :standard +function swap_step( + strategy::RandomPermutationSwap, + rng::Random.AbstractRNG, + model, + sampler::TemperedSampler, + ts::TemperedState +) + L = length(ts.Δ) - 1 + levels = Vector{Int}(undef, L) + Random.randperm!(rng, levels) - k = rand(rng, Distributions.Categorical(L)) # Pick randomly from 1, 2, ..., k - 1 - ts = swap_attempt(model, sampler, ts, k, spl.adapt, ts.total_steps / L) + # Iterate through all levels and attempt swaps. + for k in levels + ts = swap_attempt(rng, model, sampler.internal_sampler, ts, k, sampler.adapt, ts.total_steps) + end + return ts +end +function swap_step( + strategy::NonReversibleSwap, + rng::Random.AbstractRNG, + model, + sampler::TemperedSampler, + ts::TemperedState +) + L = length(ts.Δ) - 1 + # Alternate between swapping odds and evens. + levels = if ts.total_steps % (2 * sampler.N_swap) == 0 + 1:2:L else + 2:2:L + end - # Define a vector to populate with levels at which to propose swaps according to swap_strategy - levels = Vector{Int}(undef, L) - if spl.swap_strategy == :nonrev - if ts.step_counter % (2 * spl.N_swap) == 0 - levels = 1:2:L - else - levels = 2:2:L - end - elseif spl.swap_strategy == :randperm - randperm!(rng, levels) - end - - for k in levels - ts = swap_attempt(model, sampler, ts, k, spl.adapt, ts.total_steps) - end - + # Iterate through all levels and attempt swaps. + for k in levels + ts = swap_attempt(rng, model, sampler.internal_sampler, ts, k, sampler.adapt, ts.total_steps) end return ts end From 83c5c49719e4094a5a681e774aa67780ec93851f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 4 Oct 2021 16:32:41 +0100 Subject: [PATCH 03/51] made TemperedSampler concretely typed and fixed soem docs --- src/tempered.jl | 45 ++++++++++++++++++++------------------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/src/tempered.jl b/src/tempered.jl index 12cbf59..85f13fd 100644 --- a/src/tempered.jl +++ b/src/tempered.jl @@ -13,48 +13,42 @@ A `TemperedSampler` struct wraps an `internal_sampler` (could just be an algorit - The number of steps between each temperature swap attempt `N_swap` - The `swap_strategy` defining how these swaps should be carried out """ -struct TemperedSampler{A} <: AbstractMCMC.AbstractSampler +struct TemperedSampler{A,TΔ,TP,TSwap} <: AbstractMCMC.AbstractSampler internal_sampler :: A - Δ :: Vector{<:Real} - Δ_init :: Vector{<:Integer} + Δ :: TΔ + Δ_init :: Vector{Int} N_swap :: Integer - swap_strategy :: Symbol + swap_strategy :: TSwap adapt :: Bool - Ρ :: Vector{AdaptiveState} + Ρ :: TP end +swapstrategy(sampler::TemperedSampler) = sampler.swap_strategy + """ - tempered(internal_sampler, Δ; ) + tempered(internal_sampler, Δ; kwargs...) OR - tempered(internal_sampler, Nt; ) + tempered(internal_sampler, Nt; kwargs...) # Arguments - `internal_sampler` is an algorithm or sampler object to be used for underlying sampling and to apply tempering to - The temperature schedule can be defined either explicitly or just as an integer number of temperatures, i.e. as: - - `Δ::Vector{<:Real}` containing a sequence of 'inverse temperatures' {β₀, ..., βₙ} where 0 ≤ βₙ < ... < β₁ < β₀ = 1 + - `Δ` containing a sequence of 'inverse temperatures' {β₀, ..., βₙ} where 0 ≤ βₙ < ... < β₁ < β₀ = 1 OR - - `Nt::Integer`, specifying the number of inverse temperatures to include in a generated `Δ` -- `swap_strategy::Symbol` is the way in which temperature swaps are made, one of: - `:standard` as in original proposed algorithm, a single randomly picked swap is proposed - `:nonrev` alternate even/odd swaps as in Syed, Bouchard-Côté, Deligiannidis, Doucet, arXiv:1905.02939 such that a reverse swap cannot be made in immediate succession - `:randperm` generates a permutation in order to swap in a random order + - `Nt::Integer`, specifying the number of inverse temperatures to include in a generated `Δ` +- `swap_strategy::AbstractSwapStrategy` is the way in which temperature swaps are made, one of: + `:standard` as in original proposed algorithm, a single randomly picked swap is proposed + `:nonrev` alternate even/odd swaps as in Syed, Bouchard-Côté, Deligiannidis, Doucet, arXiv:1905.02939 such that a reverse swap cannot be made in immediate succession + `:randperm` generates a permutation in order to swap in a random order - `Δ_init::Vector{<:Integer}` is a list containing a sequence including the integers `1:length(Δ)` and determines the starting temperature of each chain - i.e. [3, 1, 2, 4] across temperatures [1.0, 0.1, 0.01, 0.001] would mean the first chain starts at temperature 0.01, second starts at 1.0, etc. + i.e. [3, 1, 2, 4] across temperatures [1.0, 0.1, 0.01, 0.001] would mean the first chain starts at temperature 0.01, second starts at 1.0, etc. - `N_swap::Integer` steps are carried out between each tempering swap step attempt """ function tempered( internal_sampler, - Δ::Vector{<:Real}; - swap_strategy::Symbol = :standard, - kwargs... -) - return tempered(internal_sampler, check_Δ(Δ), swap_strategy; kwargs...) -end -function tempered( - internal_sampler, - Nt::Integer; - swap_strategy::Symbol = :standard, + Nt::Integer, + swap_strategy::AbstractSwapStrategy = StandardSwap(); kwargs... ) return tempered(internal_sampler, generate_Δ(Nt, swap_strategy), swap_strategy; kwargs...) @@ -62,7 +56,7 @@ end function tempered( internal_sampler, Δ::Vector{<:Real}, - swap_strategy::Symbol; + swap_strategy::AbstractSwapStrategy; Δ_init::Vector{<:Integer} = collect(1:length(Δ)), N_swap::Integer = 1, adapt::Bool = true, @@ -71,6 +65,7 @@ function tempered( adapt_step::Real = 0.66, kwargs... ) + Δ = check_Δ(Δ) length(Δ) > 1 || error("More than one inverse temperatures must be provided.") N_swap >= 1 || error("This must be a positive integer.") Ρ = init_adaptation(Δ, adapt_target, adapt_scale, adapt_step) From 9e9153c3aa8de9b21c51e3bc021667458539b5b7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 4 Oct 2021 16:32:58 +0100 Subject: [PATCH 04/51] introduced AbstractSwapStrategy and removed get_params and make_tempered_loglikelihood --- src/swapping.jl | 121 ++++++++++++++++++++++++++++++------------------ 1 file changed, 77 insertions(+), 44 deletions(-) diff --git a/src/swapping.jl b/src/swapping.jl index 0e1c3c1..9d6d058 100644 --- a/src/swapping.jl +++ b/src/swapping.jl @@ -1,3 +1,50 @@ +""" + AbstractSwapStrategy + +Represents a strategy for swapping between parallel chains. + +A concrete subtype is expected to implement the method [`swap_step`](@ref). +""" +abstract type AbstractSwapStrategy end + +""" + StandardSwap <: AbstractSwapStrategy + +At every swap step taken, this strategy samples a single chain index `i` and proposes +a swap between chains `i` and `i + 1`. + +This approach goes under a number of names, e.g. Parallel Tempering (PT) MCMC and Replica-Exchange MCMC.[^PTPH05] + +The sampling of the chain index ensures reversibility/detailed balance is satisfied. + +# References +[^PTPH05]: Earl, D. J., & Deem, M. W., Parallel tempering: theory, applications, and new perspectives, Physical Chemistry Chemical Physics, 7(23), 3910–3916 (2005). +""" +struct StandardSwap <: AbstractSwapStrategy end + +""" + RandomPermutationSwap <: AbstractSwapStrategy + +At every swap step taken, this strategy randomly shuffles all the chain indices +and then iterates through them, proposing swaps for neighboring chains. + +The shuffling of chain indices ensures reversibility/detailed balance is satisfied. +""" +struct RandomPermutationSwap <: AbstractSwapStrategy end + + +""" + NonReversibleSwap <: AbstractSwapStrategy + +At every swap step taken, this strategy _deterministically_ traverses first the +odd chain indices, proposing swaps between neighbors, and then in the _next_ swap step +taken traverses even chain indices, proposing swaps between neighbors. + +Note that this method is _not_ reversible, and does not satisfy detailed balance. +As a result, this method is asymptotically biased. +""" +struct NonReversibleSwap <: AbstractSwapStrategy end + """ swap_betas(chain_index, k) @@ -9,49 +56,24 @@ function swap_betas(chain_index, k) return sortperm(chain_index), chain_index end -function make_tempered_loglikelihood end -function get_params end - """ - get_tempered_loglikelihoods_and_params(model, sampler, states, k, Δ, chain_index) + compute_tempered_logdensities(model, sampler, transition, transition_other, β) -Temper the `model`'s density using the `k`th and `k + 1`th temperatures -selected via `Δ` and `chain_index`. Then retrieve the parameters using the chains' -current transitions extracted from the collection of `states`. +Return `(logπ(transition, β), logπ(transition_other, β))` where `logπ(x, β)` denotes the +log-density for `model` with inverse-temperature `β`. """ -function get_tempered_loglikelihoods_and_params( - model, - sampler::AbstractMCMC.AbstractSampler, - states, - k::Integer, - Δ::Vector{Real}, - chain_index::Vector{<:Integer} -) - - logπk = make_tempered_loglikelihood(model, Δ[k]) - logπkp1 = make_tempered_loglikelihood(model, Δ[k + 1]) - - θk = get_params(states[chain_index[k]][1]) - θkp1 = get_params(states[chain_index[k + 1]][1]) - - return logπk, logπkp1, θk, θkp1 -end - +function compute_tempered_logdensities end """ - swap_acceptance_pt(logπk, logπkp1, θk, θkp1) + swap_acceptance_pt(logπk, logπkp1) Calculates and returns the swap acceptance ratio for swapping the temperature of two chains. Using tempered likelihoods `logπk` and `logπkp1` at the chains' -current state parameters `θk` and `θkp1`. +current state parameters. """ -function swap_acceptance_pt(logπk, logπkp1, θk, θkp1) - return min( - 1, - exp(logπkp1(θk) + logπk(θkp1)) / exp(logπk(θk) + logπkp1(θkp1)) - # exp(abs(βk - βkp1) * abs(AdvancedMH.logdensity(model, samplek) - AdvancedMH.logdensity(model, samplekp1))) - ) +function swap_acceptance_pt(logπk_θk, logπk_θkp1, logπkp1_θk, logπkp1_θkp1) + return (logπkp1_θk + logπk_θkp1) - (logπk_θk + logπkp1_θkp1) end @@ -61,21 +83,32 @@ end Attempt to swap the temperatures of two chains by tempering the densities and calculating the swap acceptance ratio; then swapping if it is accepted. """ -function swap_attempt(model, sampler, ts, k, adapt, n) - - logπk, logπkp1, θk, θkp1 = get_tempered_loglikelihoods_and_params(model, sampler, ts.states, k, ts.Δ, ts.chain_index) +function swap_attempt(rng, model, sampler, ts, k, adapt, n) + # Extract the relevant transitions. + transitionk = first(ts.states[ts.chain_index[k]]) + transitionkp1 = first(ts.states[ts.chain_index[k + 1]]) + # Evaluate logdensity for both parameters for each tempered density. + logπk_θk, logπk_θkp1 = compute_tempered_logdensities( + model, sampler, transitionk, transitionkp1, ts.Δ[k] + ) + logπkp1_θkp1, logπkp1_θk = compute_tempered_logdensities( + model, sampler, transitionkp1, transitionk, ts.Δ[k + 1] + ) - swap_ar = swap_acceptance_pt(logπk, logπkp1, θk, θkp1) - U = rand(Distributions.Uniform(0, 1)) - - # If the proposed temperature swap is accepted according to swap_ar and U, swap the temperatures for future steps - if U ≤ swap_ar - ts.Δ_index, ts.chain_index = swap_betas(ts.chain_index, k) + # If the proposed temperature swap is accepted according `logα`, + # swap the temperatures for future steps. + logα = swap_acceptance_pt(logπk_θk, logπk_θkp1, logπkp1_θk, logπkp1_θkp1) + if -Random.randexp(rng) ≤ logα + Δ_index, chain_index = swap_betas(ts.chain_index, k) + @set! ts.Δ_index = Δ_index + @set! ts.chain_index = chain_index end # Adaptation steps affects Ρ and Δ, as the Ρ is adapted before a new Δ is generated and returned if adapt - ts.Ρ, ts.Δ = adapt_ladder(ts.Ρ, ts.Δ, k, swap_ar, n) + P, Δ = adapt_ladder(ts.Ρ, ts.Δ, k, min(one(logα), exp(logα)), n) + @set! ts.Ρ = P + @set! ts.Δ = Δ end return ts -end \ No newline at end of file +end From 26f12f10eb3439920f7ff74e91c03a49db9b9f57 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 4 Oct 2021 16:33:39 +0100 Subject: [PATCH 05/51] updated stepping.jl --- src/stepping.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/stepping.jl b/src/stepping.jl index e9d3ce9..a3affe2 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -66,7 +66,7 @@ function AbstractMCMC.step( states = [ AbstractMCMC.step( rng, - make_tempered_model(model, spl.Δ[spl.Δ_init[i]]), + make_tempered_model(spl, model, spl.Δ[spl.Δ_init[i]]), spl.internal_sampler; init_params=init_params !== nothing ? init_params[i] : nothing, kwargs... @@ -94,7 +94,7 @@ function AbstractMCMC.step( @set! ts.states = [ AbstractMCMC.step( rng, - make_tempered_model(model, ts.Δ[ts.Δ_index[i]]), + make_tempered_model(spl, model, ts.Δ[ts.Δ_index[i]]), spl.internal_sampler, ts.states[ts.chain_index[i]][2]; kwargs... @@ -111,7 +111,7 @@ end """ - swap_step([strategy::SwapStrategy, ]rng, model, spl, ts) + swap_step([strategy::AbstractSwapStrategy, ]rng, model, spl, ts) Uses the internals of the passed `TemperedSampler` - `spl` - and `TemperedState` - `ts` - to perform a "swap step" between temperatures, in accordance with the relevant From aa88ae443aa27f474498c9636a1c9229ae71cfe0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 4 Oct 2021 16:33:49 +0100 Subject: [PATCH 06/51] added docstring for make_tempered_model --- src/model.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/model.jl b/src/model.jl index 4e1a5ca..eb1c53a 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,8 +1,9 @@ +""" + make_tempered_model(sampler, model, args...) -# struct TemperedModel <: AbstractPPL.AbstractProbabilisticProgram -# model :: DynamicPPL.Model -# β :: AbstractFloat -# end +Return an instance representing a model. +The return-type depends on it's usage in [`compute_tempered_logdensities`](@ref). +""" function make_tempered_model end From 2523a9d716ca159946380abb71a52b9762033d2e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 4 Oct 2021 16:34:00 +0100 Subject: [PATCH 07/51] updated adaptation.jl and made structs concrete --- src/adaptation.jl | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/src/adaptation.jl b/src/adaptation.jl index 31927b7..fc211f1 100644 --- a/src/adaptation.jl +++ b/src/adaptation.jl @@ -1,20 +1,16 @@ - -struct PolynomialStep - η :: Real - c :: Real +@concrete struct PolynomialStep + η + c end function get(step::PolynomialStep, k::Real) step.c * (k + 1.) ^ (-step.η) end -struct AdaptiveState - swap_target_ar :: Real - scale :: Base.RefValue{<:Real} - step :: PolynomialStep -end -function AdaptiveState(swap_target::Real, scale::Real, step::PolynomialStep) - AdaptiveState(swap_target, Ref(log(scale)), step) +struct AdaptiveState{T1<:Real,T2<:Real,P<:PolynomialStep} + swap_target_ar :: T1 + logscale :: T2 + step :: P end @@ -26,7 +22,7 @@ function init_adaptation( ) Nt = length(Δ) step = PolynomialStep(γ, Nt - 1) - Ρ = [AdaptiveState(swap_target, scale, step) for _ in 1:(Nt - 1)] + Ρ = [AdaptiveState(swap_target, log(scale), step) for _ in 1:(Nt - 1)] return Ρ end @@ -34,7 +30,7 @@ end function rhos_to_ladder(Ρ, Δ) β′ = Δ[1] for i in 1:length(Ρ) - β′ += exp(Ρ[i].scale[]) + β′ += exp(Ρ[i].logscale) Δ[i + 1] = Δ[1] / β′ end return Δ @@ -48,7 +44,9 @@ function adapt_rho(ρ::AdaptiveState, swap_ar, n) end -function adapt_ladder(Ρ, Δ, k, swap_ar, n) - Ρ[k].scale[] += adapt_rho(Ρ[k], swap_ar, n) - return Ρ, rhos_to_ladder(Ρ, Δ) -end \ No newline at end of file +function adapt_ladder(P, Δ, k, swap_ar, n) + P[k] = let Pk = P[k] + @set Pk.logscale += adapt_rho(Pk, swap_ar, n) + end + return P, rhos_to_ladder(P, Δ) +end From 7462ebfca234c78318b271e601acd1b9b28227eb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 4 Oct 2021 16:34:24 +0100 Subject: [PATCH 08/51] updated ladders.jl --- src/ladders.jl | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/src/ladders.jl b/src/ladders.jl index 2017764..722cfaf 100644 --- a/src/ladders.jl +++ b/src/ladders.jl @@ -1,34 +1,26 @@ +# Why these? """ get_scaling_val(Nt, swap_strategy) -Calculates the correct scaling factor for polynomial step size between temperatures +Calculates the correct scaling factor for polynomial step size between temperatures. """ -function get_scaling_val(Nt, swap_strategy) - # Why these? - if swap_strategy == :standard - scaling_val = Nt - 1 - elseif swap_strategy == :nonrev - scaling_val = 2 - else - scaling_val = 1 - end - return scaling_val -end - +get_scaling_val(Nt, ::StandardSwap) = Nt - 1 +get_scaling_val(Nt, ::NonReversibleSwap) = 2 +get_scaling_val(Nt, ::RandomPermutationSwap) = 1 """ generate_Δ(Nt, swap_strategy) Returns a temperature ladder `Δ` containing `Nt` temperatures, -generated in accordance with the chosen `swap_strategy` +generated in accordance with the chosen `swap_strategy`. """ function generate_Δ(Nt, swap_strategy) scaling_val = get_scaling_val(Nt, swap_strategy) - Δ = zeros(Real, Nt) + Δ = zeros(Nt) Δ[1] = 1.0 β′ = Δ[1] for i ∈ 1:(Nt - 1) - β′ += exp(scaling_val) + β′ += scaling_val Δ[i + 1] = Δ[1] / β′ end return Δ From 3fb13e9677f7d754c0758e231ed6490d56f65447 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 13 Oct 2021 15:56:20 +0100 Subject: [PATCH 09/51] addressed some comments --- src/stepping.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/stepping.jl b/src/stepping.jl index a3affe2..4afb52a 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -63,7 +63,7 @@ function AbstractMCMC.step( init_params=nothing, kwargs... ) - states = [ + transitions_and_states = [ AbstractMCMC.step( rng, make_tempered_model(spl, model, spl.Δ[spl.Δ_init[i]]), @@ -74,9 +74,10 @@ function AbstractMCMC.step( for i in 1:length(spl.Δ) ] return ( - first(states[argmax(spl.Δ_init)]), + # Get the left-most `(transition, state)` pair, then get the `transition`. + first(first(transitions_and_states)), TemperedState( - states, spl.Δ, spl.Δ_init, sortperm(spl.Δ_init), 1, 1, spl.Ρ + transitions_and_states, spl.Δ, spl.Δ_init, sortperm(spl.Δ_init), 1, 1, spl.Ρ ) ) end From 90e60f2fd69e66a2d739249890eb611c653b99f2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 13 Oct 2021 15:56:37 +0100 Subject: [PATCH 10/51] added tests --- test/compat.jl | 1 + test/compat/advancedmh.jl | 25 ++++++++++++ test/runtests.jl | 83 ++++++++++++++++++++++++--------------- test/utils.jl | 24 +++++++++++ 4 files changed, 102 insertions(+), 31 deletions(-) create mode 100644 test/compat.jl create mode 100644 test/compat/advancedmh.jl create mode 100644 test/utils.jl diff --git a/test/compat.jl b/test/compat.jl new file mode 100644 index 0000000..4c624bc --- /dev/null +++ b/test/compat.jl @@ -0,0 +1 @@ +include("compat/advancedmh.jl") diff --git a/test/compat/advancedmh.jl b/test/compat/advancedmh.jl new file mode 100644 index 0000000..2f88b85 --- /dev/null +++ b/test/compat/advancedmh.jl @@ -0,0 +1,25 @@ +########################################## +### Make compatible with AdvancedMH.jl ### +########################################## +# Makes the first step possible. +# This constructs the model that are passed to the respective samplers. +function MCMCTempering.make_tempered_model(sampler, m::DensityModel, β) + return DensityModel(Base.Fix1(*, β) ∘ m.logdensity) +end + +# Now we need to make swapping possible. +# This should return a callable which evaluates to the temperered logdensity. +function MCMCTempering.compute_tempered_logdensities( + model::DensityModel, + sampler, + transition::AdvancedMH.Transition, + transition_other::AdvancedMH.Transition, + β +) + # Just re-use computation from transition. + # lp = transition.lp + lp = β * AdvancedMH.logdensity(model, transition.params) + # Compute for the other. + lp_other = β * AdvancedMH.logdensity(model, transition_other.params) + return lp, lp_other +end diff --git a/test/runtests.jl b/test/runtests.jl index 94a02c7..7fd8128 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,39 +1,60 @@ using MCMCTempering using Test using Distributions -using Plots using AdvancedMH using MCMCChains +using Bijectors +using LinearAlgebra +using AbstractMCMC -@testset "MCMCTempering.jl" begin - - # θᵣ = [-1., 1., 2., 1., 15., 2., 90., 1.5] - # γs = [0.15, 0.25, 0.3, 0.3] - - # Δ = check_Δ([0, 0.01, 0.1, 0.25, 0.5, 1]) - - # modelᵣ = MixtureModel(Distributions.Normal.(eachrow(reshape(θᵣ, (2,4)))...), γs) - # # xrange = -10:0.1:100 - # # tempered_densities = pdf.(modelᵣ, xrange) .^ Δ' - # # norm_const = sum(tempered_densities[:,1]) - # # for i in 2:length(Δ) - # # tempered_densities[:,i] = (tempered_densities[:,i] ./ sum(tempered_densities[:,i])) .* norm_const - # # end - # # plot(xrange, tempered_densities, label = Δ') - - # data = rand(modelᵣ, 100) - - # insupport(θ) = all(reshape(θ, (2,4))[2,:] .≥ 0) - # dist(θ) = MixtureModel(Distributions.Normal.(eachrow(reshape(θ, (2,4)))...), γs) - # density(θ) = insupport(θ) ? sum(logpdf.(dist(θ), data)) : -Inf - - # # Construct a DensityModel. - # model = DensityModel(density) - - # # Set up our sampler with a joint multivariate Normal proposal. - # spl = RWMH(MvNormal(8,1)) - - # @test chain, temps = SimulatedTempering(model, spl, Δ, chain_type=Chains) - # @test chains, temps = ParallelTempering(model, spl, Δ, chain_type=Chains) +include("utils.jl") +include("compat.jl") +@testset "MCMCTempering.jl" begin + @testset "MvNormal 2D" begin + d = 2 + nsamples = 20_000 + function logdensity(x) + logpdf(MvNormal(ones(length(x)), I), x) + end + + # Sampler parameters. + Δ = MCMCTempering.check_Δ(vcat(0.25:0.1:0.9, 0.91:0.005:1.0)) + + # Construct a DensityModel. + model = DensityModel(logdensity) + + # Set up our sampler with a joint multivariate Normal proposal. + spl_inner = RWMH(MvNormal(zeros(d), 1e-1I)) + spl = tempered(spl_inner, Δ, MCMCTempering.StandardSwap(); adapt=false, N_swap=2) + + # Useful for analysis. + states = [] + callback = StateHistoryCallback(states) + + # Sample. + samples = AbstractMCMC.sample(model, spl, nsamples; callback=callback); + states + + # Extract the history of chain indices. + Δ_index_history_list = map(states) do state + state.Δ_index + end + Δ_index_history = permutedims(reduce(hcat, Δ_index_history_list), (2, 1)) + + # Get example state. + state = states[end] + chain = if spl isa MCMCTempering.TemperedSampler + AbstractMCMC.bundle_samples( + samples, model, spl.internal_sampler, state.states[first(state.chain_index)][2], MCMCChains.Chains + ) + else + AbstractMCMC.bundle_samples(samples, model, spl, state, MCMCChains.Chains) + end; + + + μ = mean(chain[length(chain) ÷ 2 + 1:10:end]).nt.mean + # HACK: This is quite a large threshold. + @test norm(μ - ones(length(μ))) ≤ 2e-1 + end end diff --git a/test/utils.jl b/test/utils.jl new file mode 100644 index 0000000..5947852 --- /dev/null +++ b/test/utils.jl @@ -0,0 +1,24 @@ +""" + StateHistoryCallback + +Defines a callable which simply pushes the `state` onto the `states` container.! + +Example usage when used with AbstractMCMC.jl: +```julia +# 1. Create empty container for state-history. +state_history = [] +# 2. Sample. +AbstractMCMC.sample(model, sampler, N; callback=StateHistoryCallback(state_history)) +# 3. Inspect states. +state_history +``` +""" +struct StateHistoryCallback{A} + states::A +end +StateHistoryCallback() = StateHistoryCallback(Any[]) + +function (cb::StateHistoryCallback)(rng, model, sampler, sample, state, i; kwargs...) + push!(cb.states, state) + return nothing +end From e30718d2c835caab1bbe90f9edd7e0d16d99634d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 13 Oct 2021 16:07:43 +0100 Subject: [PATCH 11/51] fixed a bug --- src/stepping.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stepping.jl b/src/stepping.jl index 4afb52a..118f748 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -97,7 +97,7 @@ function AbstractMCMC.step( rng, make_tempered_model(spl, model, ts.Δ[ts.Δ_index[i]]), spl.internal_sampler, - ts.states[ts.chain_index[i]][2]; + ts.states[i][2]; kwargs... ) for i in 1:length(ts.Δ) From 9e8607a8c70fbb55a4ba9f38b2a3ae02b6626aca Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 20 Oct 2021 01:07:55 +0100 Subject: [PATCH 12/51] updated the StateHistoryCallback a bit --- test/utils.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/utils.jl b/test/utils.jl index 5947852..d357388 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -13,12 +13,16 @@ AbstractMCMC.sample(model, sampler, N; callback=StateHistoryCallback(state_histo state_history ``` """ -struct StateHistoryCallback{A} +struct StateHistoryCallback{A,F} states::A + selector::F end StateHistoryCallback() = StateHistoryCallback(Any[]) +function StateHistoryCallback(states, selector=deepcopy) + return StateHistoryCallback{typeof(states), typeof(selector)}(states, selector) +end function (cb::StateHistoryCallback)(rng, model, sampler, sample, state, i; kwargs...) - push!(cb.states, state) + push!(cb.states, cb.selector(state)) return nothing end From 3bb79dfc8597be217f220ba1c456862c247b855f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 20 Oct 2021 01:09:31 +0100 Subject: [PATCH 13/51] made the distinction between chains and processes clearer --- src/MCMCTempering.jl | 5 +- src/stepping.jl | 173 +++++++++++++++++++++++++++++++------------ src/swapping.jl | 30 ++++---- src/tempered.jl | 11 +++ 4 files changed, 157 insertions(+), 62 deletions(-) diff --git a/src/MCMCTempering.jl b/src/MCMCTempering.jl index 54833af..096d86a 100644 --- a/src/MCMCTempering.jl +++ b/src/MCMCTempering.jl @@ -25,7 +25,10 @@ function AbstractMCMC.bundle_samples( chain_type::Type; kwargs... ) - AbstractMCMC.bundle_samples(ts, model, sampler.internal_sampler, state, chain_type; kwargs...) + AbstractMCMC.bundle_samples( + ts, model, sampler_for_chain(sampler, state), state_for_chain(state), chain_type; + kwargs... + ) end end diff --git a/src/stepping.jl b/src/stepping.jl index 118f748..1e115f2 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -1,30 +1,25 @@ """ mutable struct TemperedState - states :: Array{Any} + transitions_and_states :: Array{Any} Δ :: Vector{<:Real} Δ_index :: Vector{<:Integer} chain_index :: Vector{<:Integer} step_counter :: Integer total_steps :: Integer - Δ_history :: Array{<:Real, 2} - Δ_index_history :: Array{<:Integer, 2} Ρ :: Vector{AdaptiveState} end -A `TemperedState` struct contains the `states` of each of the parallel chains +A `TemperedState` struct contains the `transitions_and_states` of each of the parallel chains used throughout parallel tempering as pairs of `Transition`s and `VarInfo`s, it also stores necessary information for tempering: -- `states` is an Array of pairs of `Transition`s and `VarInfo`s, one for each - tempered chain -- `Δ` contains the ordered sequence of inverse temperatures +- `transitions_and_states` is a collection of `(transition, state)` pairs, one for each tempered chain. +- `Δ` contains the ordered sequence of inverse temperatures. - `Δ_index` contains the current ordering to apply the temperatures to each chain, tracking swaps, i.e., contains the index `Δ_index[i] = j` of the temperature in `Δ`, `Δ[j]`, to apply to chain `i` - `chain_index` contains the index `chain_index[i] = k` of the chain tempered by `Δ[i]` NOTE that to convert between this and `Δ_index` we simply use the `sortperm()` function - `step_counter` maintains the number of steps taken since the last swap attempt - `total_steps` maintains the count of the total number of steps taken -- `Δ_index_history` records the history of swaps that occur in sampling by recording the `Δ_index` at each step -- `Δ_history` records the values of the inverse temperatures, these will change if adaptation is being used - `Ρ` contains all of the information required for adaptation of Δ Example of swaps across 4 chains and the values of `chain_index` and `Δ_index`: @@ -42,7 +37,7 @@ Chains: chain_index: Δ_index: | | | | """ @concrete struct TemperedState - states + transitions_and_states Δ Δ_index chain_index @@ -51,6 +46,74 @@ Chains: chain_index: Δ_index: Ρ end +""" + transition_for_chain(state[, I...]) + +Return the transition corresponding to the chain indexed by `I...`. +If `I...` is not specified, the transition corresponding to `β=1.0` will be returned, i.e. `I = (1, )`. +""" +transition_for_chain(state::TemperedState) = transition_for_chain(state, 1) +transition_for_chain(state::TemperedState, I...) = state.transitions_and_states[state.Δ_index[I...]][1] + +""" + transition_for_process(state, I...) + +Return the transition corresponding to the process indexed by `I...`. +""" +transition_for_process(state::TemperedState, I...) = state.transitions_and_states[I...][1] + +""" + state_for_chain(state[, I...]) + +Return the state corresponding to the chain indexed by `I...`. +If `I...` is not specified, the state corresponding to `β=1.0` will be returned. +""" +state_for_chain(state::TemperedState) = state_for_chain(state, 1) +state_for_chain(state::TemperedState, I...) = state.transitions_and_states[I...][2] + +""" + state_for_process(state, I...) + +Return the state corresponding to the process indexed by `I...`. +""" +state_for_process(state::TemperedState, I...) = state.transitions_and_states[I...][2] + +""" + β_for_chain(state[, I...]) + +Return the β corresponding to the chain indexed by `I...`. +If `I...` is not specified, the β corresponding to `β=1.0` will be returned. +""" +β_for_chain(state::TemperedState) = β_for_chain(state, 1) +β_for_chain(state::TemperedState, I...) = state.Δ[state.Δ_index[I...]] + +""" + β_for_process(state, I...) + +Return the β corresponding to the process indexed by `I...`. +""" +β_for_process(state::TemperedState, I...) = state.Δ[I...] + +""" + sampler_for_chain(sampler::TemperedSampler, state::TemperedState[, I...]) + +Return the sampler corresponding to the chain indexed by `I...`. +If `I...` is not specified, the sampler corresponding to `β=1.0` will be returned. +""" +sampler_for_chain(sampler::TemperedSampler, state::TemperedState) = sampler_for_chain(sampler, state, 1) +function sampler_for_chain(sampler::TemperedSampler, state::TemperedState, I...) + return getsampler(sampler.internal_sampler, state.Δ_index[I...]) +end + +""" + sampler_for_process(sampler::TemperedSampler, state::TemperedState, I...) + +Return the sampler corresponding to the process indexed by `I...`. +""" +function sampler_for_process(sampler::TemperedSampler, state::TemperedState, I...) + return getsampler(sampler.internal_sampler, I...) +end + """ For each `β` in `Δ`, carry out a step with a tempered model at the corresponding `β` inverse temperature, resulting in a list of transitions and states, the transition associated with `β₀ = 1` is then returned with the @@ -63,68 +126,84 @@ function AbstractMCMC.step( init_params=nothing, kwargs... ) + # `TemperedState` has the transitions and states in the order of + # the processes, and performs swaps by moving the (inverse) temperatures + # `β` between the processes, rather than moving states between processes + # and keeping the `β` local to each process. + # + # Therefore we iterate over the processes and then extract the corresponding + # `β`, `sampler` and `state`, and take a initialize. transitions_and_states = [ AbstractMCMC.step( rng, + # TODO: Should we also have one a `β_for_process` for the sampler to + # cover the initial step? Do we even _need_ this `Δ_init[1]`? + # Can we not just assume that the `Δ` is always in the "correct" initial order? make_tempered_model(spl, model, spl.Δ[spl.Δ_init[i]]), - spl.internal_sampler; + getsampler(spl, i); init_params=init_params !== nothing ? init_params[i] : nothing, kwargs... ) - for i in 1:length(spl.Δ) + for i in 1:numtemps(spl) ] - return ( - # Get the left-most `(transition, state)` pair, then get the `transition`. - first(first(transitions_and_states)), - TemperedState( - transitions_and_states, spl.Δ, spl.Δ_init, sortperm(spl.Δ_init), 1, 1, spl.Ρ - ) + + state = TemperedState( + transitions_and_states, spl.Δ, copy(spl.Δ_init), sortperm(spl.Δ_init), 1, 1, spl.Ρ ) + + return transition_for_chain(state), state end function AbstractMCMC.step( rng::Random.AbstractRNG, model, spl::TemperedSampler, - ts::TemperedState; + state::TemperedState; kwargs... ) - if ts.step_counter == spl.N_swap - ts = swap_step(rng, model, spl, ts) - @set! ts.step_counter = 0 + if state.step_counter == spl.N_swap + state = swap_step(rng, model, spl, state) + @set! state.step_counter = 0 else - @set! ts.states = [ + # `TemperedState` has the transitions and states in the order of + # the processes, and performs swaps by moving the (inverse) temperatures + # `β` between the processes, rather than moving states between processes + # and keeping the `β` local to each process. + # + # Therefore we iterate over the processes and then extract the corresponding + # `β`, `sampler` and `state`, and take a step. + @set! state.transitions_and_states = [ AbstractMCMC.step( rng, - make_tempered_model(spl, model, ts.Δ[ts.Δ_index[i]]), - spl.internal_sampler, - ts.states[i][2]; + make_tempered_model(spl, model, β_for_process(state, i)), + sampler_for_process(spl, state, i), + state_for_process(state, i); kwargs... ) - for i in 1:length(ts.Δ) + for i in 1:numtemps(spl) ] - @set! ts.step_counter += 1 + @set! state.step_counter += 1 end - @set! ts.total_steps += 1 - # Use `chain_index[1]` to ensure the sample from the target is always returned for the step. - return ts.states[ts.chain_index[1]][1], ts + @set! state.total_steps += 1 + # We want to return the transition for the _first_ chain, i.e. the chain usually corresponding to `β=1.0`. + return transition_for_chain(state), state end """ - swap_step([strategy::AbstractSwapStrategy, ]rng, model, spl, ts) + swap_step([strategy::AbstractSwapStrategy, ]rng, model, spl, state) Uses the internals of the passed `TemperedSampler` - `spl` - and `TemperedState` - -`ts` - to perform a "swap step" between temperatures, in accordance with the relevant +`state` - to perform a "swap step" between temperatures, in accordance with the relevant swap strategy. """ function swap_step( rng::Random.AbstractRNG, model, sampler::TemperedSampler, - ts::TemperedState + state::TemperedState ) - return swap_step(swapstrategy(sampler), rng, model, sampler, ts) + return swap_step(swapstrategy(sampler), rng, model, sampler, state) end function swap_step( @@ -132,11 +211,11 @@ function swap_step( rng::Random.AbstractRNG, model, sampler::TemperedSampler, - ts::TemperedState + state::TemperedState ) - L = length(ts.Δ) - 1 + L = length(state.Δ) - 1 k = rand(rng, 1:L) - return swap_attempt(rng, model, sampler.internal_sampler, ts, k, sampler.adapt, ts.total_steps / L) + return swap_attempt(rng, model, sampler, state, k, sampler.adapt, state.total_steps / L) end function swap_step( @@ -144,17 +223,17 @@ function swap_step( rng::Random.AbstractRNG, model, sampler::TemperedSampler, - ts::TemperedState + state::TemperedState ) - L = length(ts.Δ) - 1 + L = numtemps(sampler) - 1 levels = Vector{Int}(undef, L) Random.randperm!(rng, levels) # Iterate through all levels and attempt swaps. for k in levels - ts = swap_attempt(rng, model, sampler.internal_sampler, ts, k, sampler.adapt, ts.total_steps) + state = swap_attempt(rng, model, sampler, state, k, sampler.adapt, state.total_steps) end - return ts + return state end function swap_step( @@ -162,11 +241,11 @@ function swap_step( rng::Random.AbstractRNG, model, sampler::TemperedSampler, - ts::TemperedState + state::TemperedState ) - L = length(ts.Δ) - 1 + L = numtemps(sampler) - 1 # Alternate between swapping odds and evens. - levels = if ts.total_steps % (2 * sampler.N_swap) == 0 + levels = if state.total_steps % (2 * sampler.N_swap) == 0 1:2:L else 2:2:L @@ -174,7 +253,7 @@ function swap_step( # Iterate through all levels and attempt swaps. for k in levels - ts = swap_attempt(rng, model, sampler.internal_sampler, ts, k, sampler.adapt, ts.total_steps) + state = swap_attempt(rng, model, sampler, state, k, sampler.adapt, state.total_steps) end - return ts + return state end diff --git a/src/swapping.jl b/src/swapping.jl index 9d6d058..bef0d50 100644 --- a/src/swapping.jl +++ b/src/swapping.jl @@ -51,7 +51,7 @@ struct NonReversibleSwap <: AbstractSwapStrategy end Swaps the `k`th and `k + 1`th temperatures. Use `sortperm()` to convert the `chain_index` to a `Δ_index` to be used in tempering moves. """ -function swap_betas(chain_index, k) +function swap_betas(chain_index, Δ_index, k) chain_index[k], chain_index[k + 1] = chain_index[k + 1], chain_index[k] return sortperm(chain_index), chain_index end @@ -78,37 +78,39 @@ end """ - swap_attempt(model, sampler, states, k, Δ, Δ_index) + swap_attempt(rng, model, sampler, state, k, adapt) Attempt to swap the temperatures of two chains by tempering the densities and calculating the swap acceptance ratio; then swapping if it is accepted. """ -function swap_attempt(rng, model, sampler, ts, k, adapt, n) +function swap_attempt(rng, model, sampler, state, k, adapt, total_steps) # Extract the relevant transitions. - transitionk = first(ts.states[ts.chain_index[k]]) - transitionkp1 = first(ts.states[ts.chain_index[k + 1]]) + transitionk = transition_for_chain(state, k) + transitionkp1 = transition_for_chain(state, k + 1) # Evaluate logdensity for both parameters for each tempered density. + # NOTE: Here we want to propose swaps between the neighboring _chains_ not processes, + # and so we get the `β` and `sampler` corresponding to the k-th and (k+1)-th chains. logπk_θk, logπk_θkp1 = compute_tempered_logdensities( - model, sampler, transitionk, transitionkp1, ts.Δ[k] + model, sampler_for_chain(sampler, state, k), transitionk, transitionkp1, β_for_chain(state, k) ) logπkp1_θkp1, logπkp1_θk = compute_tempered_logdensities( - model, sampler, transitionkp1, transitionk, ts.Δ[k + 1] + model, sampler_for_chain(sampler, state, k + 1), transitionkp1, transitionk, β_for_chain(state, k + 1) ) # If the proposed temperature swap is accepted according `logα`, # swap the temperatures for future steps. logα = swap_acceptance_pt(logπk_θk, logπk_θkp1, logπkp1_θk, logπkp1_θkp1) if -Random.randexp(rng) ≤ logα - Δ_index, chain_index = swap_betas(ts.chain_index, k) - @set! ts.Δ_index = Δ_index - @set! ts.chain_index = chain_index + Δ_index, chain_index = swap_betas(state.chain_index, state.Δ_index, k) + @set! state.Δ_index = Δ_index + @set! state.chain_index = chain_index end # Adaptation steps affects Ρ and Δ, as the Ρ is adapted before a new Δ is generated and returned if adapt - P, Δ = adapt_ladder(ts.Ρ, ts.Δ, k, min(one(logα), exp(logα)), n) - @set! ts.Ρ = P - @set! ts.Δ = Δ + P, Δ = adapt_ladder(state.Ρ, state.Δ, k, min(one(logα), exp(logα)), total_steps) + @set! state.Ρ = P + @set! state.Δ = Δ end - return ts + return state end diff --git a/src/tempered.jl b/src/tempered.jl index 85f13fd..2ac962c 100644 --- a/src/tempered.jl +++ b/src/tempered.jl @@ -25,6 +25,17 @@ end swapstrategy(sampler::TemperedSampler) = sampler.swap_strategy +getsampler(samplers, I...) = getindex(samplers, I...) +getsampler(sampler::AbstractMCMC.AbstractSampler, I...) = sampler +getsampler(sampler::TemperedSampler, I...) = getsampler(sampler.internal_sampler, I...) + +""" + numsteps(sampler::TemperedSampler) + +Return number of temperatures used by `sampler`. +""" +numtemps(sampler::TemperedSampler) = length(sampler.Δ) + """ tempered(internal_sampler, Δ; kwargs...) From 90722c9db33b8accc2a751fcef03f9d559c12945 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 20 Oct 2021 01:10:04 +0100 Subject: [PATCH 14/51] added tests --- test/runtests.jl | 74 +++++++++++++++++++++++++++++++----------------- 1 file changed, 48 insertions(+), 26 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 7fd8128..fdd190b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,6 +14,8 @@ include("compat.jl") @testset "MvNormal 2D" begin d = 2 nsamples = 20_000 + swap_every_n = 2 + function logdensity(x) logpdf(MvNormal(ones(length(x)), I), x) end @@ -26,35 +28,55 @@ include("compat.jl") # Set up our sampler with a joint multivariate Normal proposal. spl_inner = RWMH(MvNormal(zeros(d), 1e-1I)) - spl = tempered(spl_inner, Δ, MCMCTempering.StandardSwap(); adapt=false, N_swap=2) - # Useful for analysis. - states = [] - callback = StateHistoryCallback(states) + swapstrategies = [ + MCMCTempering.StandardSwap(), + MCMCTempering.RandomPermutationSwap(), + MCMCTempering.NonReversibleSwap() + ] + @testset "$swapstrategy" for swapstrategy in swapstrategies + swapstrategy = MCMCTempering.NonReversibleSwap() + spl = tempered(spl_inner, Δ, swapstrategy; adapt=false, N_swap=swap_every_n) + + # TODO: Remove or make use of. + # # Useful for analysis. + # states = [] + # callback = StateHistoryCallback(states) + callback = (args...; kwargs...) -> nothing + + # Sample. + samples = AbstractMCMC.sample(model, spl, nsamples; callback=callback, progress=false); + + # # Extract the history of chain indices. + # Δ_index_history_list = map(states) do state + # state.Δ_index + # end + # Δ_index_history = permutedims(reduce(hcat, Δ_index_history_list), (2, 1)) + + # Get example state. + state = states[end] + chain = if spl isa MCMCTempering.TemperedSampler + AbstractMCMC.bundle_samples( + samples, model, spl.internal_sampler, MCMCTempering.state_for_chain(state), MCMCChains.Chains + ) + else + AbstractMCMC.bundle_samples(samples, model, spl, state, MCMCChains.Chains) + end; + + # Thin chain and discard burnin. + chain_thinned = chain[length(chain) ÷ 2 + 1:swap_every_n:end] + # Extract some summary statistics to compare. + desc = describe(chain_thinned)[1].nt + μ = desc.mean + σ = desc.std - # Sample. - samples = AbstractMCMC.sample(model, spl, nsamples; callback=callback); - states + # HACK: These bounds are quite generous. We're swapping quite frequently here + # so some of the strategies results in a rather large variance of the estimators + # it seems. + @test norm(μ - ones(length(μ))) ≤ 2e-1 + @test norm(σ - ones(length(σ))) ≤ 3e-1 - # Extract the history of chain indices. - Δ_index_history_list = map(states) do state - state.Δ_index + # TODO: Add some tests so ensure that we are doing _some_ swapping? end - Δ_index_history = permutedims(reduce(hcat, Δ_index_history_list), (2, 1)) - - # Get example state. - state = states[end] - chain = if spl isa MCMCTempering.TemperedSampler - AbstractMCMC.bundle_samples( - samples, model, spl.internal_sampler, state.states[first(state.chain_index)][2], MCMCChains.Chains - ) - else - AbstractMCMC.bundle_samples(samples, model, spl, state, MCMCChains.Chains) - end; - - - μ = mean(chain[length(chain) ÷ 2 + 1:10:end]).nt.mean - # HACK: This is quite a large threshold. - @test norm(μ - ones(length(μ))) ≤ 2e-1 end end From 50ad1ec107217857b9acb386abdbd42bd8c34795 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 20 Oct 2021 01:10:36 +0100 Subject: [PATCH 15/51] fixed incorrect statement --- src/swapping.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/swapping.jl b/src/swapping.jl index bef0d50..7b34f2f 100644 --- a/src/swapping.jl +++ b/src/swapping.jl @@ -39,9 +39,6 @@ struct RandomPermutationSwap <: AbstractSwapStrategy end At every swap step taken, this strategy _deterministically_ traverses first the odd chain indices, proposing swaps between neighbors, and then in the _next_ swap step taken traverses even chain indices, proposing swaps between neighbors. - -Note that this method is _not_ reversible, and does not satisfy detailed balance. -As a result, this method is asymptotically biased. """ struct NonReversibleSwap <: AbstractSwapStrategy end From 0c204ac9a93a9bf1421c13748e9f4b3095a351c4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 20 Oct 2021 02:30:33 +0100 Subject: [PATCH 16/51] renamed some fields to be more descriptive and fixed left-over bug --- src/stepping.jl | 36 ++++++++++++++++++------------------ src/swapping.jl | 24 ++++++++++++++++-------- 2 files changed, 34 insertions(+), 26 deletions(-) diff --git a/src/stepping.jl b/src/stepping.jl index 1e115f2..32565c5 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -2,8 +2,8 @@ mutable struct TemperedState transitions_and_states :: Array{Any} Δ :: Vector{<:Real} - Δ_index :: Vector{<:Integer} - chain_index :: Vector{<:Integer} + chain_to_process :: Vector{<:Integer} + process_to_chain :: Vector{<:Integer} step_counter :: Integer total_steps :: Integer Ρ :: Vector{AdaptiveState} @@ -14,33 +14,33 @@ used throughout parallel tempering as pairs of `Transition`s and `VarInfo`s, it also stores necessary information for tempering: - `transitions_and_states` is a collection of `(transition, state)` pairs, one for each tempered chain. - `Δ` contains the ordered sequence of inverse temperatures. -- `Δ_index` contains the current ordering to apply the temperatures to each chain, tracking swaps, - i.e., contains the index `Δ_index[i] = j` of the temperature in `Δ`, `Δ[j]`, to apply to chain `i` -- `chain_index` contains the index `chain_index[i] = k` of the chain tempered by `Δ[i]` - NOTE that to convert between this and `Δ_index` we simply use the `sortperm()` function +- `chain_to_process` contains the current ordering to apply the temperatures to each chain, tracking swaps, + i.e., contains the index `chain_to_process[i] = j` of the temperature in `Δ`, `Δ[j]`, to apply to chain `i` +- `process_to_chain` contains the index `process_to_chain[i] = k` of the chain tempered by `Δ[i]` + NOTE that to convert between this and `chain_to_process` we simply use the `sortperm()` function - `step_counter` maintains the number of steps taken since the last swap attempt - `total_steps` maintains the count of the total number of steps taken - `Ρ` contains all of the information required for adaptation of Δ -Example of swaps across 4 chains and the values of `chain_index` and `Δ_index`: +Example of swaps across 4 chains and the values of `process_to_chain` and `chain_to_process`: -Chains: chain_index: Δ_index: -| | | | 1 2 3 4 1 2 3 4 +Chains: process_to_chain: chain_to_process: +| | | | 1 2 3 4 1 2 3 4 | | | | - V | | 2 1 3 4 2 1 3 4 + V | | 2 1 3 4 2 1 3 4 Λ | | -| | | | 2 1 3 4 2 1 3 4 +| | | | 2 1 3 4 2 1 3 4 | | | | -| V | 2 3 1 4 3 1 2 4 +| V | 2 3 1 4 3 1 2 4 | Λ | -| | | | 2 3 1 4 3 1 2 4 +| | | | 2 3 1 4 3 1 2 4 | | | | """ @concrete struct TemperedState transitions_and_states Δ - Δ_index - chain_index + chain_to_process + process_to_chain step_counter total_steps Ρ @@ -53,7 +53,7 @@ Return the transition corresponding to the chain indexed by `I...`. If `I...` is not specified, the transition corresponding to `β=1.0` will be returned, i.e. `I = (1, )`. """ transition_for_chain(state::TemperedState) = transition_for_chain(state, 1) -transition_for_chain(state::TemperedState, I...) = state.transitions_and_states[state.Δ_index[I...]][1] +transition_for_chain(state::TemperedState, I...) = state.transitions_and_states[state.chain_to_process[I...]][1] """ transition_for_process(state, I...) @@ -85,7 +85,7 @@ Return the β corresponding to the chain indexed by `I...`. If `I...` is not specified, the β corresponding to `β=1.0` will be returned. """ β_for_chain(state::TemperedState) = β_for_chain(state, 1) -β_for_chain(state::TemperedState, I...) = state.Δ[state.Δ_index[I...]] +β_for_chain(state::TemperedState, I...) = state.Δ[state.chain_to_process[I...]] """ β_for_process(state, I...) @@ -102,7 +102,7 @@ If `I...` is not specified, the sampler corresponding to `β=1.0` will be return """ sampler_for_chain(sampler::TemperedSampler, state::TemperedState) = sampler_for_chain(sampler, state, 1) function sampler_for_chain(sampler::TemperedSampler, state::TemperedState, I...) - return getsampler(sampler.internal_sampler, state.Δ_index[I...]) + return getsampler(sampler.internal_sampler, state.chain_to_process[I...]) end """ diff --git a/src/swapping.jl b/src/swapping.jl index 7b34f2f..69b61a3 100644 --- a/src/swapping.jl +++ b/src/swapping.jl @@ -43,14 +43,24 @@ taken traverses even chain indices, proposing swaps between neighbors. struct NonReversibleSwap <: AbstractSwapStrategy end """ - swap_betas(chain_index, k) + swap_betas!(chain_to_process, process_to_chain, k) Swaps the `k`th and `k + 1`th temperatures. -Use `sortperm()` to convert the `chain_index` to a `Δ_index` to be used in tempering moves. """ -function swap_betas(chain_index, Δ_index, k) - chain_index[k], chain_index[k + 1] = chain_index[k + 1], chain_index[k] - return sortperm(chain_index), chain_index +function swap_betas!(chain_to_process, process_to_chain, k) + # TODO: Use BangBang's `@set!!` to also support tuples? + # Extract the process index for each of the chains. + process_for_chain_k, process_for_chain_kp1 = chain_to_process[k], chain_to_process[k + 1] + + # Switch the mapping of the `chain → process` map. + # The temperature for the k-th chain will now be moved from its current process + # to the process for the (k + 1)-th chain, and vice versa. + chain_to_process[k], chain_to_process[k + 1] = process_for_chain_kp1, process_for_chain_k + + # Swap the mapping of the `process → chain` map. + # The process that used to have the k-th chain, now has the (k+1)-th chain, and vice versa. + process_to_chain[process_for_chain_k], process_to_chain[process_for_chain_kp1] = k + 1, k + return chain_to_process, process_to_chain end @@ -98,9 +108,7 @@ function swap_attempt(rng, model, sampler, state, k, adapt, total_steps) # swap the temperatures for future steps. logα = swap_acceptance_pt(logπk_θk, logπk_θkp1, logπkp1_θk, logπkp1_θkp1) if -Random.randexp(rng) ≤ logα - Δ_index, chain_index = swap_betas(state.chain_index, state.Δ_index, k) - @set! state.Δ_index = Δ_index - @set! state.chain_index = chain_index + swap_betas!(state.chain_to_process, state.process_to_chain, k) end # Adaptation steps affects Ρ and Δ, as the Ρ is adapted before a new Δ is generated and returned From e2bbc900e0fa805c0b79db4f95dc0e63e7544dbb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 20 Oct 2021 02:30:57 +0100 Subject: [PATCH 17/51] updated tests --- test/runtests.jl | 41 ++++++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index fdd190b..0cad8b2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,12 +16,15 @@ include("compat.jl") nsamples = 20_000 swap_every_n = 2 + μ_true = [-1.0, 1.0] + σ_true = [1.0, √(10.0)] + function logdensity(x) - logpdf(MvNormal(ones(length(x)), I), x) + logpdf(MvNormal(μ_true, Diagonal(σ_true.^2)), x) end # Sampler parameters. - Δ = MCMCTempering.check_Δ(vcat(0.25:0.1:0.9, 0.91:0.005:1.0)) + Δ = MCMCTempering.check_Δ(0.5:0.01:1.0) # Construct a DensityModel. model = DensityModel(logdensity) @@ -34,24 +37,22 @@ include("compat.jl") MCMCTempering.RandomPermutationSwap(), MCMCTempering.NonReversibleSwap() ] - @testset "$swapstrategy" for swapstrategy in swapstrategies - swapstrategy = MCMCTempering.NonReversibleSwap() + + @testset "$(swapstrategy)" for swapstrategy in swapstrategies spl = tempered(spl_inner, Δ, swapstrategy; adapt=false, N_swap=swap_every_n) - # TODO: Remove or make use of. - # # Useful for analysis. - # states = [] - # callback = StateHistoryCallback(states) - callback = (args...; kwargs...) -> nothing + # Useful for analysis. + states = [] + callback = StateHistoryCallback(states) # Sample. samples = AbstractMCMC.sample(model, spl, nsamples; callback=callback, progress=false); - # # Extract the history of chain indices. - # Δ_index_history_list = map(states) do state - # state.Δ_index - # end - # Δ_index_history = permutedims(reduce(hcat, Δ_index_history_list), (2, 1)) + # Extract the history of chain indices. + process_to_chain_history_list = map(states) do state + state.process_to_chain + end + process_to_chain_history = permutedims(reduce(hcat, process_to_chain_history_list), (2, 1)) # Get example state. state = states[end] @@ -64,7 +65,9 @@ include("compat.jl") end; # Thin chain and discard burnin. - chain_thinned = chain[length(chain) ÷ 2 + 1:swap_every_n:end] + chain_thinned = chain[length(chain) ÷ 2 + 1:5swap_every_n:end] + show(stdout, MIME"text/plain"(), chain_thinned) + # Extract some summary statistics to compare. desc = describe(chain_thinned)[1].nt μ = desc.mean @@ -73,10 +76,10 @@ include("compat.jl") # HACK: These bounds are quite generous. We're swapping quite frequently here # so some of the strategies results in a rather large variance of the estimators # it seems. - @test norm(μ - ones(length(μ))) ≤ 2e-1 - @test norm(σ - ones(length(σ))) ≤ 3e-1 - - # TODO: Add some tests so ensure that we are doing _some_ swapping? + show(stdout, MIME"text/plain"(), norm(μ - μ_true)) + show(stdout, MIME"text/plain"(), norm(σ - σ_true)) + @test norm(μ - μ_true) ≤ 0.5 + @test norm(σ - σ_true) ≤ 0.5 end end end From e7466cc8a580f906b43f1f175ecb470f0280594b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 20 Oct 2021 04:31:50 +0100 Subject: [PATCH 18/51] removed some show from tests --- test/runtests.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 0cad8b2..aea42ee 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -76,8 +76,6 @@ include("compat.jl") # HACK: These bounds are quite generous. We're swapping quite frequently here # so some of the strategies results in a rather large variance of the estimators # it seems. - show(stdout, MIME"text/plain"(), norm(μ - μ_true)) - show(stdout, MIME"text/plain"(), norm(σ - σ_true)) @test norm(μ - μ_true) ≤ 0.5 @test norm(σ - σ_true) ≤ 0.5 end From 0a5951769b9cb471add97fbed41332ec47007141 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 20 Oct 2021 20:55:59 +0100 Subject: [PATCH 19/51] began updating docstrings --- Project.toml | 4 ++ src/MCMCTempering.jl | 2 + src/stepping.jl | 104 ++++++++++++++++++++++++++++--------------- 3 files changed, 75 insertions(+), 35 deletions(-) diff --git a/Project.toml b/Project.toml index 102bf22..d998c29 100644 --- a/Project.toml +++ b/Project.toml @@ -7,12 +7,16 @@ version = "0.1.1" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" [compat] AbstractMCMC = "3.2" +ConcreteStructs = "0.2" Distributions = "0.24, 0.25" +DocStringExtensions = "0.8" +Setfield = "0.7" julia = "1" [extras] diff --git a/src/MCMCTempering.jl b/src/MCMCTempering.jl index 096d86a..17e6baf 100644 --- a/src/MCMCTempering.jl +++ b/src/MCMCTempering.jl @@ -7,6 +7,8 @@ import Random using ConcreteStructs: @concrete using Setfield: @set, @set! +using DocStringExtensions + include("adaptation.jl") include("swapping.jl") include("tempered.jl") diff --git a/src/stepping.jl b/src/stepping.jl index 32565c5..ef93c54 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -1,48 +1,82 @@ """ - mutable struct TemperedState - transitions_and_states :: Array{Any} - Δ :: Vector{<:Real} - chain_to_process :: Vector{<:Integer} - process_to_chain :: Vector{<:Integer} - step_counter :: Integer - total_steps :: Integer - Ρ :: Vector{AdaptiveState} - end + TemperedState + +A general implementation of a state for a [`TemperedSampler`](@ref). + +# Fields + +$(FIELDS) + +# Description + +Suppose we're running 4 chains `X`, `Y`, `Z`, and `W`, each targeting a distribution for different +(inverse) temperatures `β`, say, `1.0`, `0.75`, `0.5`, and `0.25`, respectively. That is, we're mainly +interested in the chain `(X[1], X[2], … )` which targets the distribution with `β=1.0`. + +Moreover, suppose we also have 4 workers/processes for which we run these chains in "parallel" +(it can also be in serial, but for the sake of demonstration imagine it's parallel). + +When can then perform a swap in two different ways: +1. Swap the the _states_ between each process, i.e. permute `transitions_and_states`. +2. Swap the _temperatures_ between each process, i.e. permute `Δ`. + +(1) is possibly the most intuitive approach since it means that the i-th worker/process +corresponds to the i-th chain; in this case, process 1 corresponds to `X`, process 2 to `Y`, etc. +The downside is that we need to move (potentially high-dimensional) states between the +workers/processes. + +(2) on the other hand does _not_ preserve the direct process-chain correspondance. +We now need to keep track of which process has which chain, from this we can +reconstruct each of the chains `X`, `Y`, etc. afterwards. +On the other hand, this means that we only need to transfer a pair of numbers +representing the (inverse) temperatures between workers rather than the full states. + +The current implementation follows approach (2). + +Here's an example realization of using five steps of sampling and swap-attempts: + +``` +Chains: process_to_chain chain_to_process Δ[process_to_chain[i]] +| | | | 1 2 3 4 1 2 3 4 1.00 0.75 0.50 0.25 +| | | | + V | | 2 1 3 4 2 1 3 4 0.75 1.00 0.50 0.25 + Λ | | +| | | | 2 1 3 4 2 1 3 4 0.75 1.00 0.50 0.25 +| | | | +| V | 2 3 1 4 3 1 2 4 0.75 0.50 1.00 0.25 +| Λ | +| | | | 2 3 1 4 3 1 2 4 0.75 0.50 1.00 0.25 +| | | | + +In this case, the chain `X` can be reconstructed as: + +```julia +X[1] = states[1].transitions_and_states[1] +X[2] = states[2].transitions_and_states[2] +X[3] = states[3].transitions_and_states[2] +X[4] = states[4].transitions_and_states[3] +X[5] = states[5].transitions_and_states[3] +``` + +The indices here are exactly those represented by `states[k].chain_to_process[1]`. -A `TemperedState` struct contains the `transitions_and_states` of each of the parallel chains -used throughout parallel tempering as pairs of `Transition`s and `VarInfo`s, -it also stores necessary information for tempering: -- `transitions_and_states` is a collection of `(transition, state)` pairs, one for each tempered chain. -- `Δ` contains the ordered sequence of inverse temperatures. -- `chain_to_process` contains the current ordering to apply the temperatures to each chain, tracking swaps, - i.e., contains the index `chain_to_process[i] = j` of the temperature in `Δ`, `Δ[j]`, to apply to chain `i` -- `process_to_chain` contains the index `process_to_chain[i] = k` of the chain tempered by `Δ[i]` - NOTE that to convert between this and `chain_to_process` we simply use the `sortperm()` function -- `step_counter` maintains the number of steps taken since the last swap attempt -- `total_steps` maintains the count of the total number of steps taken -- `Ρ` contains all of the information required for adaptation of Δ - -Example of swaps across 4 chains and the values of `process_to_chain` and `chain_to_process`: - -Chains: process_to_chain: chain_to_process: -| | | | 1 2 3 4 1 2 3 4 -| | | | - V | | 2 1 3 4 2 1 3 4 - Λ | | -| | | | 2 1 3 4 2 1 3 4 -| | | | -| V | 2 3 1 4 3 1 2 4 -| Λ | -| | | | 2 3 1 4 3 1 2 4 -| | | | +``` """ @concrete struct TemperedState + "collection of `(transition, state)` pairs for each process" transitions_and_states + "collection of (inverse) temperatures β corresponding to each process" Δ + "collection indices such that `chain_to_process[i] = j` if the j-th process corresponds to the i-th chain" chain_to_process + "collection indices such that `process_chain_to[j] = i` if the i-th chain corresponds to the j-th process" process_to_chain + # TODO: Remove this and just introduce a `RepeatedSampler` or something. + "keeps track of the number of since the last attempted swap" step_counter + "total number of steps taken" total_steps + "contains all necessary information for adaptation of Δ" Ρ end From f7a7f319a27771ce66677505b3f7902da913220e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 21 Oct 2021 08:09:52 +0100 Subject: [PATCH 20/51] fixed docstring for TemeperedState --- src/stepping.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/stepping.jl b/src/stepping.jl index ef93c54..8900c62 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -47,6 +47,7 @@ Chains: process_to_chain chain_to_process Δ[process_to_chain[i]] | Λ | | | | | 2 3 1 4 3 1 2 4 0.75 0.50 1.00 0.25 | | | | +``` In this case, the chain `X` can be reconstructed as: @@ -59,8 +60,6 @@ X[5] = states[5].transitions_and_states[3] ``` The indices here are exactly those represented by `states[k].chain_to_process[1]`. - -``` """ @concrete struct TemperedState "collection of `(transition, state)` pairs for each process" From 696d8d121c94ee4ed4af849ca9df642712e24bbe Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 21 Oct 2021 11:02:05 +0100 Subject: [PATCH 21/51] fix exports --- src/MCMCTempering.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/MCMCTempering.jl b/src/MCMCTempering.jl index 17e6baf..5f2d6bc 100644 --- a/src/MCMCTempering.jl +++ b/src/MCMCTempering.jl @@ -15,9 +15,13 @@ include("tempered.jl") include("ladders.jl") include("stepping.jl") include("model.jl") -include("plotting.jl") -export tempered, TemperedSampler, plot_swaps, plot_ladders, make_tempered_model, get_tempered_loglikelihoods_and_params, make_tempered_loglikelihood, get_params +export tempered, + TemperedSampler, + make_tempered_model, + StandardSwap, + RandomPermutationSwap, + NonReversibleSwap function AbstractMCMC.bundle_samples( ts::Vector, From bbb2fc2c35b9119495af60cab9edcd65ed90f720 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 21 Oct 2021 11:02:38 +0100 Subject: [PATCH 22/51] a bunch of renaming --- src/ladders.jl | 9 ++--- src/stepping.jl | 81 +++++++++++++++++++-------------------- src/swapping.jl | 20 ++++++---- src/tempered.jl | 99 ++++++++++++++++++++++++++---------------------- test/runtests.jl | 2 +- 5 files changed, 108 insertions(+), 103 deletions(-) diff --git a/src/ladders.jl b/src/ladders.jl index 722cfaf..ac65373 100644 --- a/src/ladders.jl +++ b/src/ladders.jl @@ -1,4 +1,3 @@ -# Why these? """ get_scaling_val(Nt, swap_strategy) @@ -9,12 +8,12 @@ get_scaling_val(Nt, ::NonReversibleSwap) = 2 get_scaling_val(Nt, ::RandomPermutationSwap) = 1 """ - generate_Δ(Nt, swap_strategy) + generate_inverse_temperatures(Nt, swap_strategy) Returns a temperature ladder `Δ` containing `Nt` temperatures, generated in accordance with the chosen `swap_strategy`. """ -function generate_Δ(Nt, swap_strategy) +function generate_inverse_temperatures(Nt, swap_strategy) scaling_val = get_scaling_val(Nt, swap_strategy) Δ = zeros(Nt) Δ[1] = 1.0 @@ -28,11 +27,11 @@ end """ - check_Δ(Δ) + check_inverse_temperatures(Δ) Checks and returns a sorted `Δ` containing `{β₀, ..., βₙ}` conforming such that `1 = β₀ > β₁ > ... > βₙ ≥ 0` """ -function check_Δ(Δ) +function check_inverse_temperatures(Δ) if !all(zero.(Δ) .≤ Δ .≤ one.(Δ)) error("Temperature schedule provided has values outside of the acceptable range, ensure all values are in [0, 1].") end diff --git a/src/stepping.jl b/src/stepping.jl index 8900c62..d8c9ffb 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -18,7 +18,7 @@ Moreover, suppose we also have 4 workers/processes for which we run these chains When can then perform a swap in two different ways: 1. Swap the the _states_ between each process, i.e. permute `transitions_and_states`. -2. Swap the _temperatures_ between each process, i.e. permute `Δ`. +2. Swap the _temperatures_ between each process, i.e. permute `inverse_temperatures`. (1) is possibly the most intuitive approach since it means that the i-th worker/process corresponds to the i-th chain; in this case, process 1 corresponds to `X`, process 2 to `Y`, etc. @@ -36,16 +36,16 @@ The current implementation follows approach (2). Here's an example realization of using five steps of sampling and swap-attempts: ``` -Chains: process_to_chain chain_to_process Δ[process_to_chain[i]] -| | | | 1 2 3 4 1 2 3 4 1.00 0.75 0.50 0.25 +Chains: process_to_chain chain_to_process inverse_temperatures[process_to_chain[i]] +| | | | 1 2 3 4 1 2 3 4 1.00 0.75 0.50 0.25 | | | | - V | | 2 1 3 4 2 1 3 4 0.75 1.00 0.50 0.25 + V | | 2 1 3 4 2 1 3 4 0.75 1.00 0.50 0.25 Λ | | -| | | | 2 1 3 4 2 1 3 4 0.75 1.00 0.50 0.25 +| | | | 2 1 3 4 2 1 3 4 0.75 1.00 0.50 0.25 | | | | -| V | 2 3 1 4 3 1 2 4 0.75 0.50 1.00 0.25 +| V | 2 3 1 4 3 1 2 4 0.75 0.50 1.00 0.25 | Λ | -| | | | 2 3 1 4 3 1 2 4 0.75 0.50 1.00 0.25 +| | | | 2 3 1 4 3 1 2 4 0.75 0.50 1.00 0.25 | | | | ``` @@ -65,17 +65,14 @@ The indices here are exactly those represented by `states[k].chain_to_process[1] "collection of `(transition, state)` pairs for each process" transitions_and_states "collection of (inverse) temperatures β corresponding to each process" - Δ + inverse_temperatures "collection indices such that `chain_to_process[i] = j` if the j-th process corresponds to the i-th chain" chain_to_process "collection indices such that `process_chain_to[j] = i` if the i-th chain corresponds to the j-th process" process_to_chain - # TODO: Remove this and just introduce a `RepeatedSampler` or something. - "keeps track of the number of since the last attempted swap" - step_counter "total number of steps taken" total_steps - "contains all necessary information for adaptation of Δ" + "contains all necessary information for adaptation of inverse_temperatures" Ρ end @@ -118,14 +115,14 @@ Return the β corresponding to the chain indexed by `I...`. If `I...` is not specified, the β corresponding to `β=1.0` will be returned. """ β_for_chain(state::TemperedState) = β_for_chain(state, 1) -β_for_chain(state::TemperedState, I...) = state.Δ[state.chain_to_process[I...]] +β_for_chain(state::TemperedState, I...) = state.inverse_temperatures[state.chain_to_process[I...]] """ β_for_process(state, I...) Return the β corresponding to the process indexed by `I...`. """ -β_for_process(state::TemperedState, I...) = state.Δ[I...] +β_for_process(state::TemperedState, I...) = state.inverse_temperatures[I...] """ sampler_for_chain(sampler::TemperedSampler, state::TemperedState[, I...]) @@ -135,7 +132,7 @@ If `I...` is not specified, the sampler corresponding to `β=1.0` will be return """ sampler_for_chain(sampler::TemperedSampler, state::TemperedState) = sampler_for_chain(sampler, state, 1) function sampler_for_chain(sampler::TemperedSampler, state::TemperedState, I...) - return getsampler(sampler.internal_sampler, state.chain_to_process[I...]) + return getsampler(sampler.sampler, state.chain_to_process[I...]) end """ @@ -144,18 +141,13 @@ end Return the sampler corresponding to the process indexed by `I...`. """ function sampler_for_process(sampler::TemperedSampler, state::TemperedState, I...) - return getsampler(sampler.internal_sampler, I...) + return getsampler(sampler.sampler, I...) end -""" -For each `β` in `Δ`, carry out a step with a tempered model at the corresponding `β` inverse temperature, -resulting in a list of transitions and states, the transition associated with `β₀ = 1` is then returned with the -rest of the information being stored in the state. -""" function AbstractMCMC.step( rng::Random.AbstractRNG, model, - spl::TemperedSampler; + sampler::TemperedSampler; init_params=nothing, kwargs... ) @@ -169,19 +161,23 @@ function AbstractMCMC.step( transitions_and_states = [ AbstractMCMC.step( rng, - # TODO: Should we also have one a `β_for_process` for the sampler to - # cover the initial step? Do we even _need_ this `Δ_init[1]`? - # Can we not just assume that the `Δ` is always in the "correct" initial order? - make_tempered_model(spl, model, spl.Δ[spl.Δ_init[i]]), - getsampler(spl, i); + make_tempered_model(sampler, model, sampler.inverse_temperatures), + getsampler(sampler, i); init_params=init_params !== nothing ? init_params[i] : nothing, kwargs... ) - for i in 1:numtemps(spl) + for i in 1:numtemps(sampler) ] + process_to_chain = 1:length(sampler.inverse_temperatures) state = TemperedState( - transitions_and_states, spl.Δ, copy(spl.Δ_init), sortperm(spl.Δ_init), 1, 1, spl.Ρ + transitions_and_states, + sampler.inverse_temperatures, + process_to_chain, + process_to_chain, + 1, + 1, + sampler.Ρ ) return transition_for_chain(state), state @@ -189,13 +185,12 @@ end function AbstractMCMC.step( rng::Random.AbstractRNG, model, - spl::TemperedSampler, + sampler::TemperedSampler, state::TemperedState; kwargs... ) - if state.step_counter == spl.N_swap - state = swap_step(rng, model, spl, state) - @set! state.step_counter = 0 + if sampler.swap_every % state.total_steps == 0 + state = swap_step(rng, model, sampler, state) else # `TemperedState` has the transitions and states in the order of # the processes, and performs swaps by moving the (inverse) temperatures @@ -207,14 +202,13 @@ function AbstractMCMC.step( @set! state.transitions_and_states = [ AbstractMCMC.step( rng, - make_tempered_model(spl, model, β_for_process(state, i)), - sampler_for_process(spl, state, i), + make_tempered_model(sampler, model, β_for_process(state, i)), + sampler_for_process(sampler, state, i), state_for_process(state, i); kwargs... ) - for i in 1:numtemps(spl) + for i in 1:numtemps(sampler) ] - @set! state.step_counter += 1 end @set! state.total_steps += 1 @@ -224,11 +218,12 @@ end """ - swap_step([strategy::AbstractSwapStrategy, ]rng, model, spl, state) + swap_step([strategy::AbstractSwapStrategy, ]rng, model, sampler, state) -Uses the internals of the passed `TemperedSampler` - `spl` - and `TemperedState` - -`state` - to perform a "swap step" between temperatures, in accordance with the relevant -swap strategy. +Return new `state`, now with temperatures swapped according to `strategy`. + +If no `strategy` is provided, the return-value of [`swapstrategy`](@ref) called on `sampler` +is used. """ function swap_step( rng::Random.AbstractRNG, @@ -246,7 +241,7 @@ function swap_step( sampler::TemperedSampler, state::TemperedState ) - L = length(state.Δ) - 1 + L = numtemps(sampler) - 1 k = rand(rng, 1:L) return swap_attempt(rng, model, sampler, state, k, sampler.adapt, state.total_steps / L) end @@ -278,7 +273,7 @@ function swap_step( ) L = numtemps(sampler) - 1 # Alternate between swapping odds and evens. - levels = if state.total_steps % (2 * sampler.N_swap) == 0 + levels = if state.total_steps % (2 * sampler.swap_every) == 0 1:2:L else 2:2:L diff --git a/src/swapping.jl b/src/swapping.jl index 69b61a3..5685da9 100644 --- a/src/swapping.jl +++ b/src/swapping.jl @@ -15,8 +15,6 @@ a swap between chains `i` and `i + 1`. This approach goes under a number of names, e.g. Parallel Tempering (PT) MCMC and Replica-Exchange MCMC.[^PTPH05] -The sampling of the chain index ensures reversibility/detailed balance is satisfied. - # References [^PTPH05]: Earl, D. J., & Deem, M. W., Parallel tempering: theory, applications, and new perspectives, Physical Chemistry Chemical Physics, 7(23), 3910–3916 (2005). """ @@ -27,8 +25,6 @@ struct StandardSwap <: AbstractSwapStrategy end At every swap step taken, this strategy randomly shuffles all the chain indices and then iterates through them, proposing swaps for neighboring chains. - -The shuffling of chain indices ensures reversibility/detailed balance is satisfied. """ struct RandomPermutationSwap <: AbstractSwapStrategy end @@ -39,13 +35,18 @@ struct RandomPermutationSwap <: AbstractSwapStrategy end At every swap step taken, this strategy _deterministically_ traverses first the odd chain indices, proposing swaps between neighbors, and then in the _next_ swap step taken traverses even chain indices, proposing swaps between neighbors. + +See [^SYED19] for more on this approach. + +# References +[^SYED19]: Syed, S., Bouchard-Côté, Alexandre, Deligiannidis, G., & Doucet, A., Non-reversible Parallel Tempering: A Scalable Highly Parallel MCMC Scheme, arXiv:1905.02939, (2019). """ struct NonReversibleSwap <: AbstractSwapStrategy end """ swap_betas!(chain_to_process, process_to_chain, k) -Swaps the `k`th and `k + 1`th temperatures. +Swaps the `k`th and `k + 1`th temperatures in place. """ function swap_betas!(chain_to_process, process_to_chain, k) # TODO: Use BangBang's `@set!!` to also support tuples? @@ -111,11 +112,14 @@ function swap_attempt(rng, model, sampler, state, k, adapt, total_steps) swap_betas!(state.chain_to_process, state.process_to_chain, k) end - # Adaptation steps affects Ρ and Δ, as the Ρ is adapted before a new Δ is generated and returned + # Adaptation steps affects `Ρ` and `inverse_temperatures`, as the `Ρ` is + # adapted before a new `inverse_temperatures` is generated and returned. if adapt - P, Δ = adapt_ladder(state.Ρ, state.Δ, k, min(one(logα), exp(logα)), total_steps) + P, inverse_temperatures = adapt_ladder( + state.Ρ, state.inverse_temperatures, k, min(one(logα), exp(logα)), total_steps + ) @set! state.Ρ = P - @set! state.Δ = Δ + @set! state.inverse_temperatures = inverse_temperatures end return state end diff --git a/src/tempered.jl b/src/tempered.jl index 2ac962c..cb70845 100644 --- a/src/tempered.jl +++ b/src/tempered.jl @@ -1,84 +1,91 @@ """ - struct TemperedSampler{T} <: AbstractMCMC.AbstractSampler - internal_sampler :: T - Δ :: Vector{<:Real} - Δ_init :: Vector{<:Integer} - N_swap :: Integer - swap_strategy :: Symbol - end + TemperedSampler <: AbstractMCMC.AbstractSampler -A `TemperedSampler` struct wraps an `internal_sampler` (could just be an algorithm) alongside: -- A temperature ladder `Δ` containing a list of inverse temperatures `β`s -- The initial state of the tempered chains `Δ_init` in terms of which `β` each chain should begin at -- The number of steps between each temperature swap attempt `N_swap` -- The `swap_strategy` defining how these swaps should be carried out +A `TemperedSampler` struct wraps an `sampler` and samples using parallel tempering. + +# Fields + +$(FIELDS) """ -struct TemperedSampler{A,TΔ,TP,TSwap} <: AbstractMCMC.AbstractSampler - internal_sampler :: A - Δ :: TΔ - Δ_init :: Vector{Int} - N_swap :: Integer - swap_strategy :: TSwap - adapt :: Bool - Ρ :: TP +@concrete struct TemperedSampler <: AbstractMCMC.AbstractSampler + "sampler(s) used to target the tempered distributions" + sampler + "collection of inverse temperatures β; β[i] correponds i-th tempered model" + inverse_temperatures + "number of steps of `sampler` to take before proposing swaps" + swap_every + "the swap strategy that will be used when proposing swaps" + swap_strategy + # TODO: This should be replaced with `P` just being some `NoAdapt` type. + "boolean flag specifying whether or not to adapt" + adapt + "adaptation parameters" + Ρ end swapstrategy(sampler::TemperedSampler) = sampler.swap_strategy getsampler(samplers, I...) = getindex(samplers, I...) getsampler(sampler::AbstractMCMC.AbstractSampler, I...) = sampler -getsampler(sampler::TemperedSampler, I...) = getsampler(sampler.internal_sampler, I...) +getsampler(sampler::TemperedSampler, I...) = getsampler(sampler.sampler, I...) """ numsteps(sampler::TemperedSampler) Return number of temperatures used by `sampler`. """ -numtemps(sampler::TemperedSampler) = length(sampler.Δ) +numtemps(sampler::TemperedSampler) = length(sampler.inverse_temperatures) """ - tempered(internal_sampler, Δ; kwargs...) + tempered(sampler, inverse_temperatures; kwargs...) OR - tempered(internal_sampler, Nt; kwargs...) + tempered(sampler, Nt::Integer; kwargs...) + +Return tempered version of `sampler` using the provided `inverse_temperatures` or +inverse temperatures generated from `Nt` and the `swap_strategy`. # Arguments -- `internal_sampler` is an algorithm or sampler object to be used for underlying sampling and to apply tempering to +- `sampler` is an algorithm or sampler object to be used for underlying sampling and to apply tempering to - The temperature schedule can be defined either explicitly or just as an integer number of temperatures, i.e. as: - - `Δ` containing a sequence of 'inverse temperatures' {β₀, ..., βₙ} where 0 ≤ βₙ < ... < β₁ < β₀ = 1 + - `inverse_temperatures` containing a sequence of 'inverse temperatures' {β₀, ..., βₙ} where 0 ≤ βₙ < ... < β₁ < β₀ = 1 OR - - `Nt::Integer`, specifying the number of inverse temperatures to include in a generated `Δ` -- `swap_strategy::AbstractSwapStrategy` is the way in which temperature swaps are made, one of: - `:standard` as in original proposed algorithm, a single randomly picked swap is proposed - `:nonrev` alternate even/odd swaps as in Syed, Bouchard-Côté, Deligiannidis, Doucet, arXiv:1905.02939 such that a reverse swap cannot be made in immediate succession - `:randperm` generates a permutation in order to swap in a random order -- `Δ_init::Vector{<:Integer}` is a list containing a sequence including the integers `1:length(Δ)` and determines the starting temperature of each chain - i.e. [3, 1, 2, 4] across temperatures [1.0, 0.1, 0.01, 0.001] would mean the first chain starts at temperature 0.01, second starts at 1.0, etc. -- `N_swap::Integer` steps are carried out between each tempering swap step attempt + - `Nt::Integer`, specifying the number of inverse temperatures to include in a generated `inverse_temperatures` + +# Keyword arguments +- `swap_strategy::AbstractSwapStrategy` is the way in which temperature swaps are made. +- `swap_every::Integer` steps are carried out between each tempering swap step attempt + +# See also +- [`TemperedSampler`](@ref) +- For more on the swap strategies: + - [`AbstractSwapStrategy`](@ref) + - [`StandardSwap`](@ref) + - [`RandomPermutationSwap`](@ref) + - [`NonReversibleSwap`](@ref) """ function tempered( - internal_sampler, + sampler, Nt::Integer, swap_strategy::AbstractSwapStrategy = StandardSwap(); kwargs... ) - return tempered(internal_sampler, generate_Δ(Nt, swap_strategy), swap_strategy; kwargs...) + return tempered(sampler, generate_inverse_temperatures(Nt, swap_strategy), swap_strategy; kwargs...) end function tempered( - internal_sampler, - Δ::Vector{<:Real}, + sampler, + inverse_temperatures::Vector{<:Real}, swap_strategy::AbstractSwapStrategy; - Δ_init::Vector{<:Integer} = collect(1:length(Δ)), - N_swap::Integer = 1, + swap_every::Integer = 1, adapt::Bool = true, adapt_target::Real = 0.234, - adapt_scale::Real = get_scaling_val(length(Δ), swap_strategy), + adapt_scale::Real = get_scaling_val(length(inverse_temperatures), swap_strategy), adapt_step::Real = 0.66, kwargs... ) - Δ = check_Δ(Δ) - length(Δ) > 1 || error("More than one inverse temperatures must be provided.") - N_swap >= 1 || error("This must be a positive integer.") - Ρ = init_adaptation(Δ, adapt_target, adapt_scale, adapt_step) - return TemperedSampler(internal_sampler, Δ, Δ_init, N_swap, swap_strategy, adapt, Ρ) + inverse_temperatures = check_inverse_temperatures(inverse_temperatures) + length(inverse_temperatures) > 1 || error("More than one inverse temperatures must be provided.") + swap_every >= 1 || error("This must be a positive integer.") + Ρ = init_adaptation(inverse_temperatures, adapt_target, adapt_scale, adapt_step) + return TemperedSampler(sampler, inverse_temperatures, swap_every, swap_strategy, adapt, Ρ) end diff --git a/test/runtests.jl b/test/runtests.jl index aea42ee..8c3acf5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,7 +24,7 @@ include("compat.jl") end # Sampler parameters. - Δ = MCMCTempering.check_Δ(0.5:0.01:1.0) + Δ = MCMCTempering.check_inverse_temperatures(0.5:0.01:1.0) # Construct a DensityModel. model = DensityModel(logdensity) From 8ac73748f8be073efb3e15d582b412351de2e38a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 21 Oct 2021 11:02:45 +0100 Subject: [PATCH 23/51] deleted plotting functionality --- src/plotting.jl | 12 ------------ 1 file changed, 12 deletions(-) delete mode 100644 src/plotting.jl diff --git a/src/plotting.jl b/src/plotting.jl deleted file mode 100644 index 21c2a8d..0000000 --- a/src/plotting.jl +++ /dev/null @@ -1,12 +0,0 @@ - -""" -When sample is called with the `save_state` kwarg set to `true`, the chain can be used to plot the tempering swaps that occurred during sampling -""" -function plot_swaps(chain) - plot(chain.info.samplerstate.Δ_index_history) -end - - -function plot_ladders(chain) - plot(chain.info.samplerstate.Δ_history) -end From f7c46e7bd4040493f2c83342993a79fc13c7171b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 21 Oct 2021 13:59:38 +0100 Subject: [PATCH 24/51] fixed bug and added should_swap method --- src/stepping.jl | 25 ++++++++++++++++++++----- test/runtests.jl | 2 +- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/stepping.jl b/src/stepping.jl index d8c9ffb..4501e03 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -144,6 +144,15 @@ function sampler_for_process(sampler::TemperedSampler, state::TemperedState, I.. return getsampler(sampler.sampler, I...) end +""" + should_swap(sampler, state) + +Return `true` if a swap should happen at this iteration, and `false` otherwise. +""" +function should_swap(sampler::TemperedSampler, state::TemperedState) + return state.total_steps % sampler.swap_every == 0 +end + function AbstractMCMC.step( rng::Random.AbstractRNG, model, @@ -161,7 +170,7 @@ function AbstractMCMC.step( transitions_and_states = [ AbstractMCMC.step( rng, - make_tempered_model(sampler, model, sampler.inverse_temperatures), + make_tempered_model(sampler, model, sampler.inverse_temperatures[i]), getsampler(sampler, i); init_params=init_params !== nothing ? init_params[i] : nothing, kwargs... @@ -169,19 +178,22 @@ function AbstractMCMC.step( for i in 1:numtemps(sampler) ] - process_to_chain = 1:length(sampler.inverse_temperatures) + # Make sure to collect, because we'll be using `setindex!(!)` later. + process_to_chain = collect(1:length(sampler.inverse_temperatures)) + # Need to `copy` because this might be mutated. + chain_to_process = copy(process_to_chain) state = TemperedState( transitions_and_states, sampler.inverse_temperatures, process_to_chain, - process_to_chain, - 1, + chain_to_process, 1, sampler.Ρ ) return transition_for_chain(state), state end + function AbstractMCMC.step( rng::Random.AbstractRNG, model, @@ -189,7 +201,7 @@ function AbstractMCMC.step( state::TemperedState; kwargs... ) - if sampler.swap_every % state.total_steps == 0 + if should_swap(sampler, state) state = swap_step(rng, model, sampler, state) else # `TemperedState` has the transitions and states in the order of @@ -281,6 +293,9 @@ function swap_step( # Iterate through all levels and attempt swaps. for k in levels + # TODO: For this swapping strategy, we should really be using the adaptation from Syed et. al. (2019), + # but with the current one: shouldn't we at least divide `state.total_steps` by 2 since it will + # take use two swap-attempts before we have tried swapping all of them (in expectation). state = swap_attempt(rng, model, sampler, state, k, sampler.adapt, state.total_steps) end return state diff --git a/test/runtests.jl b/test/runtests.jl index 8c3acf5..d30b4a7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,7 +24,7 @@ include("compat.jl") end # Sampler parameters. - Δ = MCMCTempering.check_inverse_temperatures(0.5:0.01:1.0) + inverse_temperatures = MCMCTempering.check_inverse_temperatures(0.5:0.01:1.0) # Construct a DensityModel. model = DensityModel(logdensity) From dcb2a0db7cbbb9b26d36609cf5793293ad57d9b3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 21 Oct 2021 20:44:23 +0100 Subject: [PATCH 25/51] improved tests --- test/runtests.jl | 72 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 56 insertions(+), 16 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index d30b4a7..6e5c758 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,7 +14,7 @@ include("compat.jl") @testset "MvNormal 2D" begin d = 2 nsamples = 20_000 - swap_every_n = 2 + swap_every = 2 μ_true = [-1.0, 1.0] σ_true = [1.0, √(10.0)] @@ -24,7 +24,7 @@ include("compat.jl") end # Sampler parameters. - inverse_temperatures = MCMCTempering.check_inverse_temperatures(0.5:0.01:1.0) + inverse_temperatures = MCMCTempering.check_inverse_temperatures(vcat(0.5:0.05:0.7, 0.71:0.0025:1.0)) # Construct a DensityModel. model = DensityModel(logdensity) @@ -38,8 +38,20 @@ include("compat.jl") MCMCTempering.NonReversibleSwap() ] + # First we run MH to have something to compare to. + samples_mh = AbstractMCMC.sample(model, spl_inner, nsamples; progress=false); + chain_mh = AbstractMCMC.bundle_samples(samples_mh, model, spl_inner, samples_mh[1], MCMCChains.Chains) + chain_thinned_mh = chain_mh[length(chain_mh) ÷ 2 + 1:5swap_every:end] + + # Extract some summary statistics to compare. + desc_mh = describe(chain_thinned_mh)[1].nt + μ_mh = desc_mh.mean + σ_mh = desc_mh.std + ess_mh = MCMCChains.ess_rhat(chain_thinned_mh).nt.ess + @testset "$(swapstrategy)" for swapstrategy in swapstrategies - spl = tempered(spl_inner, Δ, swapstrategy; adapt=false, N_swap=swap_every_n) + swapstrategy = swapstrategies[1] + spl = tempered(spl_inner, inverse_temperatures, swapstrategy; adapt=false, swap_every=swap_every) # Useful for analysis. states = [] @@ -54,18 +66,39 @@ include("compat.jl") end process_to_chain_history = permutedims(reduce(hcat, process_to_chain_history_list), (2, 1)) + # Check that the swapping has been done correctly. + process_to_chain_uniqueness = map(states) do state + length(unique(state.process_to_chain)) == length(state.process_to_chain) + end + @test all(process_to_chain_uniqueness) + + if any(isa.(Ref(swapstrategy), [MCMCTempering.StandardSwap, MCMCTempering.NonReversibleSwap])) + # For these strategies, the index process should not move by more than 1. + @test all(abs.(diff(process_to_chain_history[:, 1])) .≤ 1) + end + + chain_to_process_uniqueness = map(states) do state + length(unique(state.chain_to_process)) == length(state.chain_to_process) + end + @test all(chain_to_process_uniqueness) + + # Tests that we have at least swapped some times (say at least 10% of attempted swaps). + swap_success_indicators = map(eachrow(diff(process_to_chain_history; dims=1))) do row + # Some of the strategies performs multiple swaps in a swap-iteration, + # but we want to count the number of iterations for which we had a successful swap, + # i.e. only count non-zero elements in a row _once_. Hence the `min`. + min(1, sum(abs, row)) + end + @test sum(swap_success_indicators) ≥ (nsamples / swap_every) * 0.1 + # Get example state. state = states[end] - chain = if spl isa MCMCTempering.TemperedSampler - AbstractMCMC.bundle_samples( - samples, model, spl.internal_sampler, MCMCTempering.state_for_chain(state), MCMCChains.Chains - ) - else - AbstractMCMC.bundle_samples(samples, model, spl, state, MCMCChains.Chains) - end; + chain = AbstractMCMC.bundle_samples( + samples, model, spl.sampler, MCMCTempering.state_for_chain(state), MCMCChains.Chains + ) # Thin chain and discard burnin. - chain_thinned = chain[length(chain) ÷ 2 + 1:5swap_every_n:end] + chain_thinned = chain[length(chain) ÷ 2 + 1:5swap_every:end] show(stdout, MIME"text/plain"(), chain_thinned) # Extract some summary statistics to compare. @@ -73,11 +106,18 @@ include("compat.jl") μ = desc.mean σ = desc.std - # HACK: These bounds are quite generous. We're swapping quite frequently here - # so some of the strategies results in a rather large variance of the estimators - # it seems. - @test norm(μ - μ_true) ≤ 0.5 - @test norm(σ - σ_true) ≤ 0.5 + # `StandardSwap` is quite unreliable, so struggling to come up with reasonable tests. + if !(swapstrategy isa StandardSwap) + # HACK: These bounds are quite generous. We're swapping quite frequently here + # so some of the strategies results in a rather large variance of the estimators + # it seems. + @test norm(μ - μ_true) ≤ 0.5 + @test norm(σ - σ_true) ≤ 0.5 + + # Comparison to just running the internal sampler. + ess = MCMCChains.ess_rhat(chain_thinned).nt.ess + @test all(ess .≥ ess_mh) + end end end end From 287b5016330b3dcc19236bfc20553cd3c311fb1e Mon Sep 17 00:00:00 2001 From: Carlos Parada Date: Fri, 19 Nov 2021 18:24:25 -0800 Subject: [PATCH 26/51] Typo --- src/model.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index eb1c53a..38f39d0 100644 --- a/src/model.jl +++ b/src/model.jl @@ -3,7 +3,7 @@ Return an instance representing a model. -The return-type depends on it's usage in [`compute_tempered_logdensities`](@ref). +The return-type depends on its usage in [`compute_tempered_logdensities`](@ref). """ function make_tempered_model end From 6239710aa1d45ff75785d8554b12283066a8640d Mon Sep 17 00:00:00 2001 From: Carlos Parada Date: Fri, 19 Nov 2021 18:41:00 -0800 Subject: [PATCH 27/51] Typo --- test/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/utils.jl b/test/utils.jl index d357388..e838850 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,7 +1,7 @@ """ StateHistoryCallback -Defines a callable which simply pushes the `state` onto the `states` container.! +Defines a callable which pushes the `state` onto the `states` container. Example usage when used with AbstractMCMC.jl: ```julia From 8c1b8ff677673065dae20cb66d33053813be0adc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 7 Dec 2021 18:15:28 +0000 Subject: [PATCH 28/51] implemented adaptation scheme for inverse temperatures using a geometric schedule --- src/adaptation.jl | 108 +++++++++++++++++++++++++++++++++++++--------- src/stepping.jl | 4 +- src/swapping.jl | 11 ++--- 3 files changed, 95 insertions(+), 28 deletions(-) diff --git a/src/adaptation.jl b/src/adaptation.jl index fc211f1..5b427fa 100644 --- a/src/adaptation.jl +++ b/src/adaptation.jl @@ -1,18 +1,41 @@ +using Distributions: StatsFuns + @concrete struct PolynomialStep η c end function get(step::PolynomialStep, k::Real) - step.c * (k + 1.) ^ (-step.η) + return step.c * (k + 1) ^ (-step.η) end - struct AdaptiveState{T1<:Real,T2<:Real,P<:PolynomialStep} - swap_target_ar :: T1 - logscale :: T2 - step :: P + swap_target_ar::T1 + scale_unconstrained::T2 + step::P end +""" + weight(ρ::AdaptiveState) + +Return the weight/scale to be used in the mapping `β[ℓ] ↦ β[ℓ + 1]`. + +# Notes +In Eq. (13) in [^MIAS12] they use the relation + + β[ℓ + 1] = β[ℓ] * w(ρ) + +with + + w(ρ) = exp(-exp(ρ)) + +because we want `w(ρ) ∈ (0, 1)` while `ρ ∈ ℝ`. As an alternative, we use +`StatsFuns.logistic(ρ)` which is numerically more stable than `exp(-exp(ρ))` and +leads to less extreme values, i.e. 0 or 1. + +# References +[^MIAS12] Miasojedow, B., Moulines, E., & Vihola, M., Adaptive Parallel Tempering Algorithm, (2012). +""" +weight(ρ::AdaptiveState) = StatsFuns.logistic(ρ.scale_unconstrained) function init_adaptation( Δ::Vector{<:Real}, @@ -22,31 +45,74 @@ function init_adaptation( ) Nt = length(Δ) step = PolynomialStep(γ, Nt - 1) - Ρ = [AdaptiveState(swap_target, log(scale), step) for _ in 1:(Nt - 1)] - return Ρ + ρs = [ + AdaptiveState(swap_target, StatsFuns.logit(scale), step) + for _ in 1:(Nt - 1) + ] + return ρs end -function rhos_to_ladder(Ρ, Δ) - β′ = Δ[1] - for i in 1:length(Ρ) - β′ += exp(Ρ[i].logscale) - Δ[i + 1] = Δ[1] / β′ - end - return Δ -end +""" + adapt!!(ρ::AdaptiveState, swap_ar, n) + +Return increment used to update `ρ`. +Corresponds to the increment in Eq. (14) from [^MIAS12]. -function adapt_rho(ρ::AdaptiveState, swap_ar, n) +# References +[^MIAS12] Miasojedow, B., Moulines, E., & Vihola, M., Adaptive Parallel Tempering Algorithm, (2012). +""" +function adapt!!(ρ::AdaptiveState, swap_ar, n) swap_diff = swap_ar - ρ.swap_target_ar γ = get(ρ.step, n) - return γ * swap_diff + return @set ρ.scale_unconstrained = ρ.scale_unconstrained + γ * swap_diff end +""" + adapt!!(ρ::AdaptiveState, Δ, k, swap_ar, n) + adapt!!(ρ::AbstractVector{<:AdaptiveState}, Δ, k, swap_ar, n) -function adapt_ladder(P, Δ, k, swap_ar, n) - P[k] = let Pk = P[k] - @set Pk.logscale += adapt_rho(Pk, swap_ar, n) +Return adapted state(s) given that we just proposed a swap of the `k`-th +and `(k + 1)`-th temperatures with acceptance ratio `swap_ar`. +""" +adapt!!(ρ::AdaptiveState, Δ, k, swap_ar, n) = adapt!!(ρ, swap_ar, n) +function adapt!!(ρs::AbstractVector{<:AdaptiveState}, Δ, k, swap_ar, n) + ρs[k] = adapt!!(ρs[k], swap_ar, n) + return ρs +end + +""" + update_inverse_temperatures(ρ::AdaptiveState, Δ_current) + update_inverse_temperatures(ρ::AbstractVector{<:AdaptiveState}, Δ_current) + +Return updated inverse temperatures computed from adaptation state(s) and `Δ_current`. + +If `ρ` is a `AbstractVector`, then it should be of length `length(Δ_current) - 1`, +with `ρ[k]` corresponding to the adaptation state for the `k`-th inverse temperature. + +This performs an update similar to Eq. (13) in [^MIAS12], with the only possible deviation +being how we compute the scaling factor from `ρ`: see [`weight`](@ref) for information. + +# References +[^MIAS12] Miasojedow, B., Moulines, E., & Vihola, M., Adaptive Parallel Tempering Algorithm, (2012). +""" +function update_inverse_temperatures(ρ::AdaptiveState, Δ_current) + Δ = Δ_current + N = length(Δ) + @assert length(ρs) ≥ N - 1 "number of adaptive states < number of temperatures" + + for ℓ in 1:N - 1 + @inbounds Δ[ℓ + 1] = Δ[ℓ] * weight(ρ) + end + return Δ +end + +function update_inverse_temperatures(ρs::AbstractVector{<:AdaptiveState}, Δ_current) + Δ = Δ_current + Δ[1] = Δ_current[1] + for ℓ in 1:length(Δ) - 1 + @inbounds Δ[ℓ + 1] = Δ[ℓ] * weight(ρs[ℓ]) end - return P, rhos_to_ladder(P, Δ) + return Δ end diff --git a/src/stepping.jl b/src/stepping.jl index 4501e03..0f8d120 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -73,7 +73,7 @@ The indices here are exactly those represented by `states[k].chain_to_process[1] "total number of steps taken" total_steps "contains all necessary information for adaptation of inverse_temperatures" - Ρ + adaptation_states end """ @@ -188,7 +188,7 @@ function AbstractMCMC.step( process_to_chain, chain_to_process, 1, - sampler.Ρ + sampler.adaptation_states ) return transition_for_chain(state), state diff --git a/src/swapping.jl b/src/swapping.jl index 5685da9..a336c72 100644 --- a/src/swapping.jl +++ b/src/swapping.jl @@ -112,14 +112,15 @@ function swap_attempt(rng, model, sampler, state, k, adapt, total_steps) swap_betas!(state.chain_to_process, state.process_to_chain, k) end - # Adaptation steps affects `Ρ` and `inverse_temperatures`, as the `Ρ` is + # Adaptation steps affects `ρs` and `inverse_temperatures`, as the `ρs` is # adapted before a new `inverse_temperatures` is generated and returned. if adapt - P, inverse_temperatures = adapt_ladder( - state.Ρ, state.inverse_temperatures, k, min(one(logα), exp(logα)), total_steps + ρs = adapt!!( + state.adaptation_states, state.inverse_temperatures, + k, min(one(logα), exp(logα)), total_steps ) - @set! state.Ρ = P - @set! state.inverse_temperatures = inverse_temperatures + @set! state.adaptation_states = ρs + @set! state.inverse_temperatures = update_inverse_temperatures(ρs, state.inverse_temperatures) end return state end From f70690f1a8e28050eba0e583b8db3aaeb519a411 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 7 Dec 2021 18:16:01 +0000 Subject: [PATCH 29/51] made some changes to some code that I cannot understand the original intent behind --- src/ladders.jl | 14 ++++++++------ src/tempered.jl | 8 ++++---- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/ladders.jl b/src/ladders.jl index ac65373..65312c0 100644 --- a/src/ladders.jl +++ b/src/ladders.jl @@ -14,13 +14,15 @@ Returns a temperature ladder `Δ` containing `Nt` temperatures, generated in accordance with the chosen `swap_strategy`. """ function generate_inverse_temperatures(Nt, swap_strategy) + # Apparently, here we increase the temperature by a constant + # factor which depends on `swap_strategy`. scaling_val = get_scaling_val(Nt, swap_strategy) - Δ = zeros(Nt) - Δ[1] = 1.0 - β′ = Δ[1] - for i ∈ 1:(Nt - 1) - β′ += scaling_val - Δ[i + 1] = Δ[1] / β′ + Δ = Vector{Float64}(undef, Nt) + Δ[1] = 1 + T = Δ[1] + for i in 1:(Nt - 1) + T += scaling_val + Δ[i + 1] = inv(T) end return Δ end diff --git a/src/tempered.jl b/src/tempered.jl index cb70845..70ce7f3 100644 --- a/src/tempered.jl +++ b/src/tempered.jl @@ -20,7 +20,7 @@ $(FIELDS) "boolean flag specifying whether or not to adapt" adapt "adaptation parameters" - Ρ + adaptation_states end swapstrategy(sampler::TemperedSampler) = sampler.swap_strategy @@ -79,13 +79,13 @@ function tempered( swap_every::Integer = 1, adapt::Bool = true, adapt_target::Real = 0.234, - adapt_scale::Real = get_scaling_val(length(inverse_temperatures), swap_strategy), + adapt_scale::Real = √2, adapt_step::Real = 0.66, kwargs... ) inverse_temperatures = check_inverse_temperatures(inverse_temperatures) length(inverse_temperatures) > 1 || error("More than one inverse temperatures must be provided.") swap_every >= 1 || error("This must be a positive integer.") - Ρ = init_adaptation(inverse_temperatures, adapt_target, adapt_scale, adapt_step) - return TemperedSampler(sampler, inverse_temperatures, swap_every, swap_strategy, adapt, Ρ) + adaptation_states = init_adaptation(inverse_temperatures, adapt_target, inv(adapt_scale), adapt_step) + return TemperedSampler(sampler, inverse_temperatures, swap_every, swap_strategy, adapt, adaptation_states) end From afa2900ddaa357077a0ee6adf942820806691ea7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 7 Dec 2021 20:47:05 +0000 Subject: [PATCH 30/51] added parameter for controlling which type of schedule to use when adapting the temperatures --- src/adaptation.jl | 150 +++++++++++++++++++++++++++++++++++++++------- src/tempered.jl | 5 +- 2 files changed, 131 insertions(+), 24 deletions(-) diff --git a/src/adaptation.jl b/src/adaptation.jl index 5b427fa..a4d71f2 100644 --- a/src/adaptation.jl +++ b/src/adaptation.jl @@ -8,14 +8,39 @@ function get(step::PolynomialStep, k::Real) return step.c * (k + 1) ^ (-step.η) end -struct AdaptiveState{T1<:Real,T2<:Real,P<:PolynomialStep} +""" + Geometric + +Specifies a geometric schedule for the inverse temperatures. + +See also: [`AdaptiveState`](@ref), [`update_inverse_temperatures`](@ref), and +[`weight`](@ref). +""" +struct Geometric end + +""" + InverselyAdditive + +Specifies an additive schedule for the temperatures (not _inverse_ temperatures). + +See also: [`AdaptiveState`](@ref), [`update_inverse_temperatures`](@ref), and +[`weight`](@ref). +""" +struct InverselyAdditive end + +struct AdaptiveState{S,T1<:Real,T2<:Real,P<:PolynomialStep} + schedule_type::S swap_target_ar::T1 scale_unconstrained::T2 step::P end +function AdaptiveState(swap_target_ar, scale_unconstrained, step) + return AdaptiveState(InverselyAdditive(), swap_target_ar, scale_unconstrained, step) +end + """ - weight(ρ::AdaptiveState) + weight(ρ::AdaptiveState{<:Geometric}) Return the weight/scale to be used in the mapping `β[ℓ] ↦ β[ℓ + 1]`. @@ -32,12 +57,39 @@ because we want `w(ρ) ∈ (0, 1)` while `ρ ∈ ℝ`. As an alternative, we use `StatsFuns.logistic(ρ)` which is numerically more stable than `exp(-exp(ρ))` and leads to less extreme values, i.e. 0 or 1. +This the same approach as mentioned in [^ATCH11]. + # References -[^MIAS12] Miasojedow, B., Moulines, E., & Vihola, M., Adaptive Parallel Tempering Algorithm, (2012). +[^MIAS12]: Miasojedow, B., Moulines, E., & Vihola, M., Adaptive Parallel Tempering Algorithm, (2012). +[^ATCH11]: Atchade, Yves F, Roberts, G. O., & Rosenthal, J. S., Towards optimal scaling of metropolis-coupled markov chain monte carlo, Statistics and Computing, 21(4), 555–568 (2011). +""" +weight(ρ::AdaptiveState{<:Geometric}) = StatsFuns.logistic(ρ.scale_unconstrained) + +""" + weight(ρ::AdaptiveState{<:InverselyAdditive}) + +Return the weight/scale to be used in the mapping `β[ℓ] ↦ β[ℓ + 1]`. """ -weight(ρ::AdaptiveState) = StatsFuns.logistic(ρ.scale_unconstrained) +weight(ρ::AdaptiveState{<:InverselyAdditive}) = exp(ρ.scale_unconstrained) + +function init_adaptation( + schedule::InverselyAdditive, + Δ::Vector{<:Real}, + swap_target::Real, + scale::Real, + γ::Real +) + Nt = length(Δ) + step = PolynomialStep(γ, Nt - 1) + ρs = [ + AdaptiveState(schedule, swap_target, log(scale), step) + for _ in 1:(Nt - 1) + ] + return ρs +end function init_adaptation( + schedule::Geometric, Δ::Vector{<:Real}, swap_target::Real, scale::Real, @@ -46,7 +98,9 @@ function init_adaptation( Nt = length(Δ) step = PolynomialStep(γ, Nt - 1) ρs = [ - AdaptiveState(swap_target, StatsFuns.logit(scale), step) + # TODO: Figure out a good way to make use of the `scale` here + # rather than a default value of `√2`. + AdaptiveState(schedule, swap_target, StatsFuns.logit(inv(√2)), step) for _ in 1:(Nt - 1) ] return ρs @@ -56,12 +110,10 @@ end """ adapt!!(ρ::AdaptiveState, swap_ar, n) -Return increment used to update `ρ`. +Return updated `ρ` based on swap acceptance ratio `swap_ar` and iteration `n`. -Corresponds to the increment in Eq. (14) from [^MIAS12]. - -# References -[^MIAS12] Miasojedow, B., Moulines, E., & Vihola, M., Adaptive Parallel Tempering Algorithm, (2012). +See [`update_inverse_temperatures`](@ref) to see how we compute the resulting +inverse temperatures from the adapted state `ρ`. """ function adapt!!(ρ::AdaptiveState, swap_ar, n) swap_diff = swap_ar - ρ.swap_target_ar @@ -83,36 +135,88 @@ function adapt!!(ρs::AbstractVector{<:AdaptiveState}, Δ, k, swap_ar, n) end """ - update_inverse_temperatures(ρ::AdaptiveState, Δ_current) - update_inverse_temperatures(ρ::AbstractVector{<:AdaptiveState}, Δ_current) + update_inverse_temperatures(ρ::AdaptiveState{<:Geometric}, Δ_current) + update_inverse_temperatures(ρ::AbstractVector{<:AdaptiveState{<:Geometric}}, Δ_current) Return updated inverse temperatures computed from adaptation state(s) and `Δ_current`. +This update is similar to Eq. (13) in [^MIAS12], with the only possible deviation +being how we compute the scaling factor from `ρ`: see [`weight`](@ref) for information. + If `ρ` is a `AbstractVector`, then it should be of length `length(Δ_current) - 1`, with `ρ[k]` corresponding to the adaptation state for the `k`-th inverse temperature. -This performs an update similar to Eq. (13) in [^MIAS12], with the only possible deviation -being how we compute the scaling factor from `ρ`: see [`weight`](@ref) for information. - # References -[^MIAS12] Miasojedow, B., Moulines, E., & Vihola, M., Adaptive Parallel Tempering Algorithm, (2012). +[^MIAS12]: Miasojedow, B., Moulines, E., & Vihola, M., Adaptive Parallel Tempering Algorithm, (2012). """ -function update_inverse_temperatures(ρ::AdaptiveState, Δ_current) - Δ = Δ_current +function update_inverse_temperatures(ρ::AdaptiveState{<:Geometric}, Δ_current) + Δ = similar(Δ_current) + β₀ = Δ_current[1] + Δ[1] = β₀ + + β = inv(β₀) + for ℓ in 1:length(Δ) - 1 + # TODO: Is it worth it to do this on log-scale instead? + β *= weight(ρ) + @inbounds Δ[ℓ + 1] = β + end + return Δ +end + +function update_inverse_temperatures(ρs::AbstractVector{<:AdaptiveState{<:Geometric}}, Δ_current) + Δ = similar(Δ_current) N = length(Δ) @assert length(ρs) ≥ N - 1 "number of adaptive states < number of temperatures" + β₀ = Δ_current[1] + Δ[1] = β₀ + + β = β₀ for ℓ in 1:N - 1 - @inbounds Δ[ℓ + 1] = Δ[ℓ] * weight(ρ) + # TODO: Is it worth it to do this on log-scale instead? + β *= weight(ρs[ℓ]) + @inbounds Δ[ℓ + 1] = β end return Δ end -function update_inverse_temperatures(ρs::AbstractVector{<:AdaptiveState}, Δ_current) - Δ = Δ_current - Δ[1] = Δ_current[1] +""" + update_inverse_temperatures(ρ::AdaptiveState{<:InverselyAdditive}, Δ_current) + update_inverse_temperatures(ρ::AbstractVector{<:AdaptiveState{<:InverselyAdditive}}, Δ_current) + +Return updated inverse temperatures computed from adaptation state(s) and `Δ_current`. + +This update increments the temperature (not _inverse_ temperature) by a positive constant, +which is adapted by `ρ`. + +If `ρ` is a `AbstractVector`, then it should be of length `length(Δ_current) - 1`, +with `ρ[k]` corresponding to the adaptation state for the `k`-th inverse temperature. +""" +function update_inverse_temperatures(ρ::AdaptiveState{<:InverselyAdditive}, Δ_current) + Δ = similar(Δ_current) + β₀ = Δ_current[1] + Δ[1] = β₀ + + T = inv(β₀) for ℓ in 1:length(Δ) - 1 - @inbounds Δ[ℓ + 1] = Δ[ℓ] * weight(ρs[ℓ]) + T += weight(ρ) + @inbounds Δ[ℓ + 1] = inv(T) + end + return Δ +end + +function update_inverse_temperatures(ρs::AbstractVector{<:AdaptiveState{<:InverselyAdditive}}, Δ_current) + Δ = similar(Δ_current) + N = length(Δ) + @assert length(ρs) ≥ N - 1 "number of adaptive states < number of temperatures" + + β₀ = Δ_current[1] + Δ[1] = β₀ + + T = inv(β₀) + for ℓ in 1:N - 1 + T += weight(ρs[ℓ]) + @inbounds Δ[ℓ + 1] = inv(T) end return Δ end diff --git a/src/tempered.jl b/src/tempered.jl index 70ce7f3..3e0065d 100644 --- a/src/tempered.jl +++ b/src/tempered.jl @@ -81,11 +81,14 @@ function tempered( adapt_target::Real = 0.234, adapt_scale::Real = √2, adapt_step::Real = 0.66, + adapt_schedule=InverselyAdditive(), kwargs... ) inverse_temperatures = check_inverse_temperatures(inverse_temperatures) length(inverse_temperatures) > 1 || error("More than one inverse temperatures must be provided.") swap_every >= 1 || error("This must be a positive integer.") - adaptation_states = init_adaptation(inverse_temperatures, adapt_target, inv(adapt_scale), adapt_step) + adaptation_states = init_adaptation( + adapt_schedule, inverse_temperatures, adapt_target, inv(adapt_scale), adapt_step + ) return TemperedSampler(sampler, inverse_temperatures, swap_every, swap_strategy, adapt, adaptation_states) end From c0d9e6168dfa588b17cade8a9c5b986f51af885b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 16 Dec 2021 01:24:35 +0000 Subject: [PATCH 31/51] make number of steps taken for each adaptor part of their state --- src/adaptation.jl | 26 ++++++++++++++++---------- src/swapping.jl | 3 +-- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/adaptation.jl b/src/adaptation.jl index a4d71f2..a607c90 100644 --- a/src/adaptation.jl +++ b/src/adaptation.jl @@ -33,12 +33,17 @@ struct AdaptiveState{S,T1<:Real,T2<:Real,P<:PolynomialStep} swap_target_ar::T1 scale_unconstrained::T2 step::P + n::Int end function AdaptiveState(swap_target_ar, scale_unconstrained, step) return AdaptiveState(InverselyAdditive(), swap_target_ar, scale_unconstrained, step) end +function AdaptiveState(schedule_type, swap_target_ar, scale_unconstrained, step) + return AdaptiveState(schedule_type, swap_target_ar, scale_unconstrained, step, 1) +end + """ weight(ρ::AdaptiveState{<:Geometric}) @@ -108,29 +113,30 @@ end """ - adapt!!(ρ::AdaptiveState, swap_ar, n) + adapt!!(ρ::AdaptiveState, swap_ar) -Return updated `ρ` based on swap acceptance ratio `swap_ar` and iteration `n`. +Return updated `ρ` based on swap acceptance ratio `swap_ar`. See [`update_inverse_temperatures`](@ref) to see how we compute the resulting inverse temperatures from the adapted state `ρ`. """ -function adapt!!(ρ::AdaptiveState, swap_ar, n) +function adapt!!(ρ::AdaptiveState, swap_ar) swap_diff = swap_ar - ρ.swap_target_ar - γ = get(ρ.step, n) - return @set ρ.scale_unconstrained = ρ.scale_unconstrained + γ * swap_diff + γ = get(ρ.step, ρ.n) + ρ_new = @set ρ.scale_unconstrained = ρ.scale_unconstrained + γ * swap_diff + return @set ρ_new.n += 1 end """ - adapt!!(ρ::AdaptiveState, Δ, k, swap_ar, n) - adapt!!(ρ::AbstractVector{<:AdaptiveState}, Δ, k, swap_ar, n) + adapt!!(ρ::AdaptiveState, Δ, k, swap_ar) + adapt!!(ρ::AbstractVector{<:AdaptiveState}, Δ, k, swap_ar) Return adapted state(s) given that we just proposed a swap of the `k`-th and `(k + 1)`-th temperatures with acceptance ratio `swap_ar`. """ -adapt!!(ρ::AdaptiveState, Δ, k, swap_ar, n) = adapt!!(ρ, swap_ar, n) -function adapt!!(ρs::AbstractVector{<:AdaptiveState}, Δ, k, swap_ar, n) - ρs[k] = adapt!!(ρs[k], swap_ar, n) +adapt!!(ρ::AdaptiveState, Δ, k, swap_ar) = adapt!!(ρ, swap_ar) +function adapt!!(ρs::AbstractVector{<:AdaptiveState}, Δ, k, swap_ar) + ρs[k] = adapt!!(ρs[k], swap_ar) return ρs end diff --git a/src/swapping.jl b/src/swapping.jl index a336c72..cf2e2ae 100644 --- a/src/swapping.jl +++ b/src/swapping.jl @@ -116,8 +116,7 @@ function swap_attempt(rng, model, sampler, state, k, adapt, total_steps) # adapted before a new `inverse_temperatures` is generated and returned. if adapt ρs = adapt!!( - state.adaptation_states, state.inverse_temperatures, - k, min(one(logα), exp(logα)), total_steps + state.adaptation_states, state.inverse_temperatures, k, min(one(logα), exp(logα)) ) @set! state.adaptation_states = ρs @set! state.inverse_temperatures = update_inverse_temperatures(ρs, state.inverse_temperatures) From 4563d33fd4f4281c2f24b074f34a9072d1cae676 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 14 Nov 2022 18:13:42 +0000 Subject: [PATCH 32/51] improvements to parameterization of the adaptation techniques --- src/MCMCTempering.jl | 2 ++ src/adaptation.jl | 32 +++++++++++++++++++++----------- src/stepping.jl | 6 ++++++ src/tempered.jl | 9 +++++---- 4 files changed, 34 insertions(+), 15 deletions(-) diff --git a/src/MCMCTempering.jl b/src/MCMCTempering.jl index 5f2d6bc..2529260 100644 --- a/src/MCMCTempering.jl +++ b/src/MCMCTempering.jl @@ -7,6 +7,8 @@ import Random using ConcreteStructs: @concrete using Setfield: @set, @set! +using InverseFunctions + using DocStringExtensions include("adaptation.jl") diff --git a/src/adaptation.jl b/src/adaptation.jl index a607c90..98b8cfb 100644 --- a/src/adaptation.jl +++ b/src/adaptation.jl @@ -18,6 +18,8 @@ See also: [`AdaptiveState`](@ref), [`update_inverse_temperatures`](@ref), and """ struct Geometric end +defaultscale(::Geometric, inverse_temperatures) = eltype(inverse_temperatures)(0.9) + """ InverselyAdditive @@ -28,6 +30,8 @@ See also: [`AdaptiveState`](@ref), [`update_inverse_temperatures`](@ref), and """ struct InverselyAdditive end +defaultscale(::InverselyAdditive, inverse_temperatures) = eltype(inverse_temperatures)(0.9) + struct AdaptiveState{S,T1<:Real,T2<:Real,P<:PolynomialStep} schedule_type::S swap_target_ar::T1 @@ -68,28 +72,34 @@ This the same approach as mentioned in [^ATCH11]. [^MIAS12]: Miasojedow, B., Moulines, E., & Vihola, M., Adaptive Parallel Tempering Algorithm, (2012). [^ATCH11]: Atchade, Yves F, Roberts, G. O., & Rosenthal, J. S., Towards optimal scaling of metropolis-coupled markov chain monte carlo, Statistics and Computing, 21(4), 555–568 (2011). """ -weight(ρ::AdaptiveState{<:Geometric}) = StatsFuns.logistic(ρ.scale_unconstrained) +weight(ρ::AdaptiveState{<:Geometric}) = geometric_weight_constrain(ρ.scale_unconstrained) +geometric_weight_constrain(x) = StatsFuns.logistic(x) +geometric_weight_unconstrain(y) = inverse(StatsFuns.logistic)(y) """ weight(ρ::AdaptiveState{<:InverselyAdditive}) Return the weight/scale to be used in the mapping `β[ℓ] ↦ β[ℓ + 1]`. """ -weight(ρ::AdaptiveState{<:InverselyAdditive}) = exp(ρ.scale_unconstrained) +weight(ρ::AdaptiveState{<:InverselyAdditive}) = inversely_additive_weight_constrain(ρ.scale_unconstrained) +inversely_additive_weight_constrain(x) = exp(-x) +inversely_additive_weight_unconstrain(y) = -log(y) function init_adaptation( schedule::InverselyAdditive, Δ::Vector{<:Real}, swap_target::Real, scale::Real, - γ::Real + η::Real, + stepsize::Real ) Nt = length(Δ) - step = PolynomialStep(γ, Nt - 1) + step = PolynomialStep(η, stepsize) ρs = [ - AdaptiveState(schedule, swap_target, log(scale), step) + AdaptiveState(schedule, swap_target, inversely_additive_weight_unconstrain(scale), step) for _ in 1:(Nt - 1) ] + ρs = AdaptiveState(schedule, swap_target, log(scale), step) return ρs end @@ -98,16 +108,16 @@ function init_adaptation( Δ::Vector{<:Real}, swap_target::Real, scale::Real, - γ::Real + η::Real, + stepsize::Real ) Nt = length(Δ) - step = PolynomialStep(γ, Nt - 1) + step = PolynomialStep(η, stepsize) ρs = [ - # TODO: Figure out a good way to make use of the `scale` here - # rather than a default value of `√2`. - AdaptiveState(schedule, swap_target, StatsFuns.logit(inv(√2)), step) + AdaptiveState(schedule, swap_target, geometric_weight_unconstrain(scale), step) for _ in 1:(Nt - 1) ] + ρs = AdaptiveState(schedule, swap_target, geometric_weight_unconstrain(scale), step) return ρs end @@ -121,7 +131,7 @@ See [`update_inverse_temperatures`](@ref) to see how we compute the resulting inverse temperatures from the adapted state `ρ`. """ function adapt!!(ρ::AdaptiveState, swap_ar) - swap_diff = swap_ar - ρ.swap_target_ar + swap_diff = ρ.swap_target_ar - swap_ar γ = get(ρ.step, ρ.n) ρ_new = @set ρ.scale_unconstrained = ρ.scale_unconstrained + γ * swap_diff return @set ρ_new.n += 1 diff --git a/src/stepping.jl b/src/stepping.jl index 0f8d120..abeb455 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -74,6 +74,12 @@ The indices here are exactly those represented by `states[k].chain_to_process[1] total_steps "contains all necessary information for adaptation of inverse_temperatures" adaptation_states + "flag which specifies wether this was a swap-step or not" + is_swap + "swap acceptance ratios on log-scale" + swap_acceptance_ratios +end + end """ diff --git a/src/tempered.jl b/src/tempered.jl index 3e0065d..c2b6397 100644 --- a/src/tempered.jl +++ b/src/tempered.jl @@ -79,16 +79,17 @@ function tempered( swap_every::Integer = 1, adapt::Bool = true, adapt_target::Real = 0.234, - adapt_scale::Real = √2, - adapt_step::Real = 0.66, - adapt_schedule=InverselyAdditive(), + adapt_stepsize::Real = 1, + adapt_eta::Real = 0.66, + adapt_schedule = Geometric(), + adapt_scale = defaultscale(adapt_schedule, inverse_temperatures), kwargs... ) inverse_temperatures = check_inverse_temperatures(inverse_temperatures) length(inverse_temperatures) > 1 || error("More than one inverse temperatures must be provided.") swap_every >= 1 || error("This must be a positive integer.") adaptation_states = init_adaptation( - adapt_schedule, inverse_temperatures, adapt_target, inv(adapt_scale), adapt_step + adapt_schedule, inverse_temperatures, adapt_target, adapt_scale, adapt_eta, adapt_stepsize ) return TemperedSampler(sampler, inverse_temperatures, swap_every, swap_strategy, adapt, adaptation_states) end From 86fbf0d1a9448f1f4158cf2dfe42850d89723249 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 14 Nov 2022 18:17:14 +0000 Subject: [PATCH 33/51] updated test env --- Project.toml | 19 +++++-------------- test/Project.toml | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 14 deletions(-) create mode 100644 test/Project.toml diff --git a/Project.toml b/Project.toml index d998c29..84709e8 100644 --- a/Project.toml +++ b/Project.toml @@ -8,24 +8,15 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" [compat] -AbstractMCMC = "3.2" +AbstractMCMC = "3.2, 4" ConcreteStructs = "0.2" Distributions = "0.24, 0.25" -DocStringExtensions = "0.8" -Setfield = "0.7" +DocStringExtensions = "0.8, 0.9" +InverseFunctions = "0.1" +Setfield = "0.7, 0.8, 1" julia = "1" - -[extras] -AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" -Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["Test", "AdvancedMH", "MCMCChains", "Bijectors", "StatsPlots", "LinearAlgebra"] diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..e09628d --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,16 @@ +[deps] +AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" +AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +AbstractMCMC = "3.2, 4" +AdvancedMH = "0.6" +Bijectors = "0.10" +Distributions = "0.24, 0.25" +MCMCChains = "5.5" +julia = "1" From 229059b7646703c4777b5402f334d039123d31be Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 14 Nov 2022 18:17:37 +0000 Subject: [PATCH 34/51] keep track of swapping ratios --- src/stepping.jl | 9 ++++++++- src/swapping.jl | 6 +++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/stepping.jl b/src/stepping.jl index abeb455..eaa1ef5 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -194,7 +194,9 @@ function AbstractMCMC.step( process_to_chain, chain_to_process, 1, - sampler.adaptation_states + sampler.adaptation_states, + false, + Dict{Int,Float64}() ) return transition_for_chain(state), state @@ -207,8 +209,12 @@ function AbstractMCMC.step( state::TemperedState; kwargs... ) + # Reset. + @set! state.swap_acceptance_ratios = empty(state.swap_acceptance_ratios) + if should_swap(sampler, state) state = swap_step(rng, model, sampler, state) + @set! state.is_swap = true else # `TemperedState` has the transitions and states in the order of # the processes, and performs swaps by moving the (inverse) temperatures @@ -227,6 +233,7 @@ function AbstractMCMC.step( ) for i in 1:numtemps(sampler) ] + @set! state.is_swap = false end @set! state.total_steps += 1 diff --git a/src/swapping.jl b/src/swapping.jl index cf2e2ae..a9232cc 100644 --- a/src/swapping.jl +++ b/src/swapping.jl @@ -108,10 +108,14 @@ function swap_attempt(rng, model, sampler, state, k, adapt, total_steps) # If the proposed temperature swap is accepted according `logα`, # swap the temperatures for future steps. logα = swap_acceptance_pt(logπk_θk, logπk_θkp1, logπkp1_θk, logπkp1_θkp1) - if -Random.randexp(rng) ≤ logα + should_swap = -Random.randexp(rng) ≤ logα + if should_swap swap_betas!(state.chain_to_process, state.process_to_chain, k) end + # Keep track of the (log) acceptance ratios. + state.swap_acceptance_ratios[k] = logα + # Adaptation steps affects `ρs` and `inverse_temperatures`, as the `ρs` is # adapted before a new `inverse_temperatures` is generated and returned. if adapt From c4583f90762546cdf216d3a43a8a647ded495085 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 14 Nov 2022 18:17:57 +0000 Subject: [PATCH 35/51] tests are now runnable --- test/runtests.jl | 73 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 49 insertions(+), 24 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 6e5c758..d142645 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,22 +16,21 @@ include("compat.jl") nsamples = 20_000 swap_every = 2 - μ_true = [-1.0, 1.0] + μ_true = [-5.0, 5.0] σ_true = [1.0, √(10.0)] - function logdensity(x) - logpdf(MvNormal(μ_true, Diagonal(σ_true.^2)), x) - end + logdensity(x) = logpdf(MvNormal(μ_true, Diagonal(σ_true.^2)), x) # Sampler parameters. - inverse_temperatures = MCMCTempering.check_inverse_temperatures(vcat(0.5:0.05:0.7, 0.71:0.0025:1.0)) + inverse_temperatures = MCMCTempering.check_inverse_temperatures([0.25, 0.5, 0.75, 1.0]) # Construct a DensityModel. model = DensityModel(logdensity) # Set up our sampler with a joint multivariate Normal proposal. - spl_inner = RWMH(MvNormal(zeros(d), 1e-1I)) + spl_inner = RWMH(MvNormal(zeros(d), Diagonal(σ_true.^2))) + # Different swap strategies to test. swapstrategies = [ MCMCTempering.StandardSwap(), MCMCTempering.RandomPermutationSwap(), @@ -50,38 +49,62 @@ include("compat.jl") ess_mh = MCMCChains.ess_rhat(chain_thinned_mh).nt.ess @testset "$(swapstrategy)" for swapstrategy in swapstrategies - swapstrategy = swapstrategies[1] - spl = tempered(spl_inner, inverse_temperatures, swapstrategy; adapt=false, swap_every=swap_every) + acceptance_rate_target = 0.234 + # Number of iterations needed to obtain `nsamples` of non-swap iterations. + nsamples_tempered = Int(ceil(nsamples * swap_every ÷ (swap_every - 1))) + spl = tempered( + spl_inner, inverse_temperatures, swapstrategy; + adapt=false, # TODO: Test adaptation. Seems to work in some cases though. + adapt_schedule=MCMCTempering.Geometric(), + adapt_stepsize=1, + adapt_eta=0.66, + adapt_target=0.234, + swap_every=swap_every + ) # Useful for analysis. states = [] callback = StateHistoryCallback(states) # Sample. - samples = AbstractMCMC.sample(model, spl, nsamples; callback=callback, progress=false); - + samples = AbstractMCMC.sample(model, spl, nsamples_tempered; callback=callback, progress=true); + βs = mapreduce(Base.Fix2(getproperty, :inverse_temperatures), hcat, states) + + states_swapped = filter(Base.Fix2(getproperty, :is_swap), states) + swap_acceptance_ratios = mapreduce( + collect ∘ values ∘ Base.Fix2(getproperty, :swap_acceptance_ratios), + vcat, + states_swapped + ) + # Check that the adaptation did something useful. + if spl.adapt + swap_acceptance_ratios = map(Base.Fix1(min, 1.0) ∘ exp, swap_acceptance_ratios) + empirical_acceptance_rate = sum(swap_acceptance_ratios) / length(swap_acceptance_ratios) + @test acceptance_rate_target ≈ empirical_acceptance_rate atol = 0.05 + end + # Extract the history of chain indices. process_to_chain_history_list = map(states) do state state.process_to_chain end process_to_chain_history = permutedims(reduce(hcat, process_to_chain_history_list), (2, 1)) - + # Check that the swapping has been done correctly. process_to_chain_uniqueness = map(states) do state length(unique(state.process_to_chain)) == length(state.process_to_chain) end @test all(process_to_chain_uniqueness) - + if any(isa.(Ref(swapstrategy), [MCMCTempering.StandardSwap, MCMCTempering.NonReversibleSwap])) # For these strategies, the index process should not move by more than 1. @test all(abs.(diff(process_to_chain_history[:, 1])) .≤ 1) end - + chain_to_process_uniqueness = map(states) do state length(unique(state.chain_to_process)) == length(state.chain_to_process) end @test all(chain_to_process_uniqueness) - + # Tests that we have at least swapped some times (say at least 10% of attempted swaps). swap_success_indicators = map(eachrow(diff(process_to_chain_history; dims=1))) do row # Some of the strategies performs multiple swaps in a swap-iteration, @@ -90,33 +113,35 @@ include("compat.jl") min(1, sum(abs, row)) end @test sum(swap_success_indicators) ≥ (nsamples / swap_every) * 0.1 - + # Get example state. state = states[end] chain = AbstractMCMC.bundle_samples( samples, model, spl.sampler, MCMCTempering.state_for_chain(state), MCMCChains.Chains ) - + # Thin chain and discard burnin. chain_thinned = chain[length(chain) ÷ 2 + 1:5swap_every:end] show(stdout, MIME"text/plain"(), chain_thinned) - + # Extract some summary statistics to compare. desc = describe(chain_thinned)[1].nt μ = desc.mean σ = desc.std - + # `StandardSwap` is quite unreliable, so struggling to come up with reasonable tests. if !(swapstrategy isa StandardSwap) - # HACK: These bounds are quite generous. We're swapping quite frequently here - # so some of the strategies results in a rather large variance of the estimators - # it seems. - @test norm(μ - μ_true) ≤ 0.5 - @test norm(σ - σ_true) ≤ 0.5 + @test μ ≈ μ_true rtol=0.05 + # NOTE(torfjelde): The variance is usually quite large for the tempered chains + # and I don't quite know if this is expected or not. + # @test norm(σ - σ_true) ≤ 0.5 + # Comparison to just running the internal sampler. ess = MCMCChains.ess_rhat(chain_thinned).nt.ess - @test all(ess .≥ ess_mh) + # HACK: Just make sure it's not doing _horrible_. Though we'd hope it would + # actually do better than the internal sampler. + @test all(ess .≥ ess_mh .* 0.5) end end end From a1efd11dc9fbdb92f2598905293e28e7f8fa4fdd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 14 Nov 2022 18:26:31 +0000 Subject: [PATCH 36/51] commented out unused code --- src/adaptation.jl | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/adaptation.jl b/src/adaptation.jl index 98b8cfb..1ac23ef 100644 --- a/src/adaptation.jl +++ b/src/adaptation.jl @@ -95,10 +95,11 @@ function init_adaptation( ) Nt = length(Δ) step = PolynomialStep(η, stepsize) - ρs = [ - AdaptiveState(schedule, swap_target, inversely_additive_weight_unconstrain(scale), step) - for _ in 1:(Nt - 1) - ] + # TODO: One common state or one per temperature? + # ρs = [ + # AdaptiveState(schedule, swap_target, inversely_additive_weight_unconstrain(scale), step) + # for _ in 1:(Nt - 1) + # ] ρs = AdaptiveState(schedule, swap_target, log(scale), step) return ρs end @@ -113,10 +114,11 @@ function init_adaptation( ) Nt = length(Δ) step = PolynomialStep(η, stepsize) - ρs = [ - AdaptiveState(schedule, swap_target, geometric_weight_unconstrain(scale), step) - for _ in 1:(Nt - 1) - ] + # TODO: One common state or one per temperature? + # ρs = [ + # AdaptiveState(schedule, swap_target, geometric_weight_unconstrain(scale), step) + # for _ in 1:(Nt - 1) + # ] ρs = AdaptiveState(schedule, swap_target, geometric_weight_unconstrain(scale), step) return ρs end From c2215726bc6426b69a994d55e1e38475e1f93787 Mon Sep 17 00:00:00 2001 From: HarrisonWilde Date: Wed, 16 Nov 2022 13:42:35 +0000 Subject: [PATCH 37/51] Corrected typo --- src/stepping.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/stepping.jl b/src/stepping.jl index eaa1ef5..afec46f 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -80,8 +80,6 @@ The indices here are exactly those represented by `states[k].chain_to_process[1] swap_acceptance_ratios end -end - """ transition_for_chain(state[, I...]) From 632cca9bdad4e34218c7ecec811764a9cc0e4d63 Mon Sep 17 00:00:00 2001 From: HarrisonWilde Date: Wed, 16 Nov 2022 13:43:40 +0000 Subject: [PATCH 38/51] Added 1D GMM, sort of works for it --- test/runtests.jl | 107 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index d142645..93d7ffc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,6 +11,113 @@ include("utils.jl") include("compat.jl") @testset "MCMCTempering.jl" begin + @testset "GMM 1D" begin + + nsamples = 1_000_000 + + gmm = MixtureModel(Normal, [(-3, 1.5), (3, 1.5), (15, 1.5), (90, 1.5)], [0.175, 0.25, 0.275, 0.3]) + + logdensity(x) = logpdf(gmm, x) + + # Construct a DensityModel. + model = AdvancedMH.DensityModel(logdensity) + + # Create non-tempered baseline chain via RWMH + sampler_rwmh = RWMH(Normal()) + samples = AbstractMCMC.sample(model, sampler_rwmh, nsamples) + chain = AbstractMCMC.bundle_samples(samples, model, sampler_rwmh, samples[1], MCMCChains.Chains) + + # Simple geometric ladder + inverse_temperatures = MCMCTempering.check_inverse_temperatures(0.05 .^ [0, 1, 2]) + + acceptance_rate_target = 0.234 + + # Number of iterations needed to obtain `nsamples` of non-swap iterations. + swap_every = 2 + nsamples_tempered = Int(ceil(nsamples * swap_every ÷ (swap_every - 1))) + tempered_sampler_rwmh = tempered( + sampler_rwmh, inverse_temperatures, MCMCTempering.StandardSwap(); + swap_every=swap_every + ) + + # Useful for analysis. + states = [] + callback = StateHistoryCallback(states) + + # Sample. + samples = AbstractMCMC.sample(model, tempered_sampler_rwmh, nsamples_tempered; callback=callback, progress=true); + βs = mapreduce(Base.Fix2(getproperty, :inverse_temperatures), hcat, states) + + states_swapped = filter(Base.Fix2(getproperty, :is_swap), states) + swap_acceptance_ratios = mapreduce( + collect ∘ values ∘ Base.Fix2(getproperty, :swap_acceptance_ratios), + vcat, + states_swapped + ) + + # Extract the history of chain indices. + process_to_chain_history_list = map(states) do state + state.process_to_chain + end + process_to_chain_history = permutedims(reduce(hcat, process_to_chain_history_list), (2, 1)) + + # Check that the swapping has been done correctly. + process_to_chain_uniqueness = map(states) do state + length(unique(state.process_to_chain)) == length(state.process_to_chain) + end + @test all(process_to_chain_uniqueness) + + if any(isa.(Ref(swapstrategy), [MCMCTempering.StandardSwap, MCMCTempering.NonReversibleSwap])) + # For these strategies, the index process should not move by more than 1. + @test all(abs.(diff(process_to_chain_history[:, 1])) .≤ 1) + end + + chain_to_process_uniqueness = map(states) do state + length(unique(state.chain_to_process)) == length(state.chain_to_process) + end + @test all(chain_to_process_uniqueness) + + # Tests that we have at least swapped some times (say at least 10% of attempted swaps). + swap_success_indicators = map(eachrow(diff(process_to_chain_history; dims=1))) do row + # Some of the strategies performs multiple swaps in a swap-iteration, + # but we want to count the number of iterations for which we had a successful swap, + # i.e. only count non-zero elements in a row _once_. Hence the `min`. + min(1, sum(abs, row)) + end + @test sum(swap_success_indicators) ≥ (nsamples / swap_every) * 0.1 + + # Get example state. + state = states[end] + chain = AbstractMCMC.bundle_samples( + samples, model, spl.sampler, MCMCTempering.state_for_chain(state), MCMCChains.Chains + ) + + # Thin chain and discard burnin. + chain_thinned = chain[length(chain) ÷ 2 + 1:5swap_every:end] + show(stdout, MIME"text/plain"(), chain_thinned) + + # Extract some summary statistics to compare. + desc = describe(chain_thinned)[1].nt + μ = desc.mean + σ = desc.std + + # `StandardSwap` is quite unreliable, so struggling to come up with reasonable tests. + if !(swapstrategy isa StandardSwap) + @test μ ≈ μ_true rtol=0.05 + + # NOTE(torfjelde): The variance is usually quite large for the tempered chains + # and I don't quite know if this is expected or not. + # @test norm(σ - σ_true) ≤ 0.5 + + # Comparison to just running the internal sampler. + ess = MCMCChains.ess_rhat(chain_thinned).nt.ess + # HACK: Just make sure it's not doing _horrible_. Though we'd hope it would + # actually do better than the internal sampler. + @test all(ess .≥ ess_mh .* 0.5) + end + + end + @testset "MvNormal 2D" begin d = 2 nsamples = 20_000 From f30acfa961d5de90dcd6492d80acf440a8e8ba2d Mon Sep 17 00:00:00 2001 From: HarrisonWilde Date: Wed, 16 Nov 2022 13:48:59 +0000 Subject: [PATCH 39/51] Make `StandardSwap` the default strategy when one isn't provided --- src/tempered.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tempered.jl b/src/tempered.jl index c2b6397..ea57c53 100644 --- a/src/tempered.jl +++ b/src/tempered.jl @@ -70,12 +70,12 @@ function tempered( swap_strategy::AbstractSwapStrategy = StandardSwap(); kwargs... ) - return tempered(sampler, generate_inverse_temperatures(Nt, swap_strategy), swap_strategy; kwargs...) + return tempered(sampler, generate_inverse_temperatures(Nt, swap_strategy); kwargs...) end function tempered( sampler, inverse_temperatures::Vector{<:Real}, - swap_strategy::AbstractSwapStrategy; + swap_strategy::AbstractSwapStrategy = StandardSwap(); swap_every::Integer = 1, adapt::Bool = true, adapt_target::Real = 0.234, From 0a0b1311233ee20403b41fb28bfa05e16960136e Mon Sep 17 00:00:00 2001 From: HarrisonWilde Date: Wed, 16 Nov 2022 16:02:45 +0000 Subject: [PATCH 40/51] Fixing test case for GMM --- test/runtests.jl | 31 +++++++++---------------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 93d7ffc..7222ce5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,8 +35,12 @@ include("compat.jl") # Number of iterations needed to obtain `nsamples` of non-swap iterations. swap_every = 2 nsamples_tempered = Int(ceil(nsamples * swap_every ÷ (swap_every - 1))) + tempered_sampler_rwmh = tempered( - sampler_rwmh, inverse_temperatures, MCMCTempering.StandardSwap(); + sampler_rwmh, + inverse_temperatures, + MCMCTempering.StandardSwap(); + adapt = false, swap_every=swap_every ) @@ -45,7 +49,7 @@ include("compat.jl") callback = StateHistoryCallback(states) # Sample. - samples = AbstractMCMC.sample(model, tempered_sampler_rwmh, nsamples_tempered; callback=callback, progress=true); + samples = AbstractMCMC.sample(model, tempered_sampler_rwmh, nsamples_tempered; callback=callback, progress=true) βs = mapreduce(Base.Fix2(getproperty, :inverse_temperatures), hcat, states) states_swapped = filter(Base.Fix2(getproperty, :is_swap), states) @@ -67,10 +71,8 @@ include("compat.jl") end @test all(process_to_chain_uniqueness) - if any(isa.(Ref(swapstrategy), [MCMCTempering.StandardSwap, MCMCTempering.NonReversibleSwap])) - # For these strategies, the index process should not move by more than 1. - @test all(abs.(diff(process_to_chain_history[:, 1])) .≤ 1) - end + # For these strategies, the index process should not move by more than 1. + @test all(abs.(diff(process_to_chain_history[:, 1])) .≤ 1) chain_to_process_uniqueness = map(states) do state length(unique(state.chain_to_process)) == length(state.chain_to_process) @@ -89,7 +91,7 @@ include("compat.jl") # Get example state. state = states[end] chain = AbstractMCMC.bundle_samples( - samples, model, spl.sampler, MCMCTempering.state_for_chain(state), MCMCChains.Chains + samples, model, tempered_sampler_rwmh.sampler, MCMCTempering.state_for_chain(state), MCMCChains.Chains ) # Thin chain and discard burnin. @@ -100,21 +102,6 @@ include("compat.jl") desc = describe(chain_thinned)[1].nt μ = desc.mean σ = desc.std - - # `StandardSwap` is quite unreliable, so struggling to come up with reasonable tests. - if !(swapstrategy isa StandardSwap) - @test μ ≈ μ_true rtol=0.05 - - # NOTE(torfjelde): The variance is usually quite large for the tempered chains - # and I don't quite know if this is expected or not. - # @test norm(σ - σ_true) ≤ 0.5 - - # Comparison to just running the internal sampler. - ess = MCMCChains.ess_rhat(chain_thinned).nt.ess - # HACK: Just make sure it's not doing _horrible_. Though we'd hope it would - # actually do better than the internal sampler. - @test all(ess .≥ ess_mh .* 0.5) - end end From 910dc6d5d9af2ad7244ed5bea896df3fc20d4bf6 Mon Sep 17 00:00:00 2001 From: HarrisonWilde Date: Wed, 16 Nov 2022 16:18:24 +0000 Subject: [PATCH 41/51] Implementing burn-in, introduces depedency on StatsBase --- Project.toml | 1 + src/MCMCTempering.jl | 1 + src/stepping.jl | 88 +++++++++++++++++++++++++++++++++----------- 3 files changed, 69 insertions(+), 21 deletions(-) diff --git a/Project.toml b/Project.toml index 84709e8..c5d3892 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] AbstractMCMC = "3.2, 4" diff --git a/src/MCMCTempering.jl b/src/MCMCTempering.jl index 2529260..6cddf89 100644 --- a/src/MCMCTempering.jl +++ b/src/MCMCTempering.jl @@ -1,6 +1,7 @@ module MCMCTempering import AbstractMCMC +import StatsBase import Distributions import Random diff --git a/src/stepping.jl b/src/stepping.jl index afec46f..24d8a73 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -157,6 +157,18 @@ function should_swap(sampler::TemperedSampler, state::TemperedState) return state.total_steps % sampler.swap_every == 0 end +function StatsBase.sample( + rng::Random.AbstractRNG, + model::AbstractMCMC.AbstractModel, + sampler::TemperedSampler, + N::Integer; + discard_initial = 0, + kwargs... +) + return AbstractMCMC.mcmcsample(rng, model, sampler, N; burnin = discard_initial, kwargs...) +end + +# Inital step function AbstractMCMC.step( rng::Random.AbstractRNG, model, @@ -201,6 +213,59 @@ function AbstractMCMC.step( end function AbstractMCMC.step( + rng::Random.AbstractRNG, + model, + sampler::TemperedSampler, + state::TemperedState; + burnin = 0, + kwargs... +) + + if state.total_steps < burnin + state = no_swap_step(rng, model, sampler, state; kwargs...) + elseif state.total_steps == burnin + println("Finished burning in...") + state = no_swap_step(rng, model, sampler, state; kwargs...) + else + state = full_step(rng, model, sampler, state; kwargs...) + end + + @set! state.total_steps += 1 + + # We want to return the transition for the _first_ chain, i.e. the chain usually corresponding to `β=1.0`. + return transition_for_chain(state), state +end + +function no_swap_step( + rng::Random.AbstractRNG, + model, + sampler::TemperedSampler, + state::TemperedState; + kwargs... +) + # `TemperedState` has the transitions and states in the order of + # the processes, and performs swaps by moving the (inverse) temperatures + # `β` between the processes, rather than moving states between processes + # and keeping the `β` local to each process. + # + # Therefore we iterate over the processes and then extract the corresponding + # `β`, `sampler` and `state`, and take a step. + @set! state.transitions_and_states = [ + AbstractMCMC.step( + rng, + make_tempered_model(sampler, model, β_for_process(state, i)), + sampler_for_process(sampler, state, i), + state_for_process(state, i); + kwargs... + ) + for i in 1:numtemps(sampler) + ] + @set! state.is_swap = false + + return state +end + +function full_step( rng::Random.AbstractRNG, model, sampler::TemperedSampler, @@ -214,32 +279,13 @@ function AbstractMCMC.step( state = swap_step(rng, model, sampler, state) @set! state.is_swap = true else - # `TemperedState` has the transitions and states in the order of - # the processes, and performs swaps by moving the (inverse) temperatures - # `β` between the processes, rather than moving states between processes - # and keeping the `β` local to each process. - # - # Therefore we iterate over the processes and then extract the corresponding - # `β`, `sampler` and `state`, and take a step. - @set! state.transitions_and_states = [ - AbstractMCMC.step( - rng, - make_tempered_model(sampler, model, β_for_process(state, i)), - sampler_for_process(sampler, state, i), - state_for_process(state, i); - kwargs... - ) - for i in 1:numtemps(sampler) - ] - @set! state.is_swap = false + no_swap_step(rng, model, sampler, state; kwargs...) end - @set! state.total_steps += 1 # We want to return the transition for the _first_ chain, i.e. the chain usually corresponding to `β=1.0`. - return transition_for_chain(state), state + return state end - """ swap_step([strategy::AbstractSwapStrategy, ]rng, model, sampler, state) From a4437f97ec4dfc9ec3eda0419237203740098e44 Mon Sep 17 00:00:00 2001 From: HarrisonWilde Date: Wed, 16 Nov 2022 16:48:40 +0000 Subject: [PATCH 42/51] Fixed error with burnin --- src/stepping.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/stepping.jl b/src/stepping.jl index 24d8a73..67c7757 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -221,11 +221,9 @@ function AbstractMCMC.step( kwargs... ) - if state.total_steps < burnin - state = no_swap_step(rng, model, sampler, state; kwargs...) - elseif state.total_steps == burnin - println("Finished burning in...") + if state.total_steps <= burnin state = no_swap_step(rng, model, sampler, state; kwargs...) + @set! state.is_swap = false else state = full_step(rng, model, sampler, state; kwargs...) end @@ -260,7 +258,6 @@ function no_swap_step( ) for i in 1:numtemps(sampler) ] - @set! state.is_swap = false return state end @@ -279,7 +276,8 @@ function full_step( state = swap_step(rng, model, sampler, state) @set! state.is_swap = true else - no_swap_step(rng, model, sampler, state; kwargs...) + state = no_swap_step(rng, model, sampler, state; kwargs...) + @set! state.is_swap = false end # We want to return the transition for the _first_ chain, i.e. the chain usually corresponding to `β=1.0`. From b9c6a9093ae1c23d79a5b342f89e56bb54ddbb31 Mon Sep 17 00:00:00 2001 From: HarrisonWilde Date: Wed, 16 Nov 2022 16:55:16 +0000 Subject: [PATCH 43/51] cleaning up working_code --- .gitignore | 5 +- working_code/bayesode.jl | 138 ------------------------------------ working_code/experiments.jl | 13 ---- working_code/neuralode.jl | 74 ------------------- 4 files changed, 2 insertions(+), 228 deletions(-) delete mode 100644 working_code/bayesode.jl delete mode 100644 working_code/experiments.jl delete mode 100644 working_code/neuralode.jl diff --git a/.gitignore b/.gitignore index fcbca32..534f7b3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,3 @@ -/Manifest.toml +Manifest.toml *.DS_Store -*.png -deprecated +working_code diff --git a/working_code/bayesode.jl b/working_code/bayesode.jl deleted file mode 100644 index 12732bc..0000000 --- a/working_code/bayesode.jl +++ /dev/null @@ -1,138 +0,0 @@ -using Turing, Distributions, DifferentialEquations - -# Import MCMCChain, Plots, and StatsPlots for visualizations and diagnostics. -using MCMCChains, Plots, StatsPlots - -# Set a seed for reproducibility. -using Random -Random.seed!(14); -using MCMCTempering - - - -function lotka_volterra(du,u,p,t) - x, y = u - α, β, γ, δ = p - du[1] = (α - β*y)x # dx = - du[2] = (δ*x - γ)y # dy = -end - -p = [1.5, 1.0, 3.0, 1.0] -u0 = [1.0,1.0] -prob1 = ODEProblem(lotka_volterra,u0,(0.0,10.0),p) -sol = solve(prob1,Tsit5()) -plot(sol) - -sol1 = solve(prob1,Tsit5(),saveat=0.1) -odedata = Array(sol1) + 0.8 * randn(size(Array(sol1))) -plot(sol1, alpha = 0.3, legend = false); scatter!(sol1.t, odedata') - -Turing.setadbackend(:forwarddiff) - -@model function fitlv(data, prob1) - σ ~ InverseGamma(2, 3) # ~ is the tilde character - α ~ truncated(Normal(1.5,0.5),0.5,2.5) - β ~ truncated(Normal(1.0,0.5),0,2) - γ ~ truncated(Normal(3.0,0.5),1,4) - δ ~ truncated(Normal(1.0,0.5),0,2) - - p = [α,β,γ,δ] - prob = remake(prob1, p=p) - predicted = solve(prob,Tsit5(),saveat=0.1) - - for i = 1:length(predicted) - data[:,i] ~ MvNormal(predicted[i], σ) - end -end - -model = fitlv(odedata, prob1) - -# This next command runs 3 independent chains without using multithreading. -chain1_mh = sample(model, MH(), MCMCThreads(), 10000, 4) -chain1_nuts = sample(model, NUTS(.65), 1000) -# chain1_mh_test = mapreduce(c -> sample(model, NUTS(), 1000), chainscat, 1:3) - - -chain2_mh = sample(model, Tempered(MH(), 2), MCMCThreads(), 10000, 4) -chain2_nuts = sample(model, Tempered(NUTS(.65), 2), 1000) - -chain3_mh = sample(model, Tempered(MH(), 3), MCMCThreads(), 10000, 30) -chain3_nuts = sample(model, Tempered(NUTS(.65), 3), MCMCThreads(), 1000, 30) - -chain4_mh = sample(model, Tempered(MH(), 4), MCMCThreads(), 10000, 30) -chain4_nuts = sample(model, Tempered(NUTS(.65), 4), MCMCThreads(), 1000, 30) - - -plot_swaps(chain2_mh) - -plot(chain1_mh) -plot(chain2_mh) - -interchain_stats(chain1_mh) -interchain_stats(chain2_mh) - - - -# Pumas example - -function theop_model_Depots1Central1(du, u, p, t) - Depot, Central = u - Ka, CL, Vc = p - du[1] = -Ka * Depot # d Depot = - du[2] = Ka * Depot - (CL / Vc) * Central # d Central = -end - -u0 = [1.0, 1.0] -p = [2.0, 0.2, 0.8, 2.0] -prob = ODEProblem(theop_model_Depots1Central1,u0,(0.0, 10.0),p) -sol = solve(prob, Tsit5()) -plot(sol) - -@model function theopmodel_bayes(dv, SEX, WT) - - N = length(dv) - - θ ~ arraydist(truncated.(Normal.([2.0, 0.2, 0.8, 2.0], 1.0), 0.0, 10.0)) - - ωKa ~ Gamma(1.0, 0.2) - ωCL ~ Gamma(1.0, 0.2) - ωVc ~ Gamma(1.0, 0.2) - - σ ~ Gamma(1.0, 0.5) - - ηKa ~ filldist(Normal(0.0, ωKa), N) - ηCL ~ filldist(Normal(0.0, ωCL), N) - ηVc ~ filldist(Normal(0.0, ωVc), N) - - for i in 1:N - Ka = (SEX[i] == 1 ? θ[1] : θ[4]) * exp(ηKa[i]) - CL = θ[2]*(WT[i]/70) * exp(ηCL[i]) - Vc = θ[3] * exp(ηVc[i]) - - p = [Ka, CL, Vc] - prob = remake(prob1, p=p) - predicted = solve(prob, Tsit5(), saveat=0.1) - - μ[i] = predicted[i,2] / Vc - dv[i] .~ Normal.(μ, σ) - end - dv - -end - -using Pumas - - - - -# BayesNeuralODE example - -using BayesNeuralODE - -N = 1 -prior_std = likelihood_std = 1.0 -model = BNO.generate_turing_model(:spiral, N, prior_std, likelihood_std) - -bno_chain_1 = sample(model, NUTS(.6), 100) - -bno_chain_2 = sample(model, Tempered(NUTS(.6), 4), MCMCThreads(), 1000, 4) diff --git a/working_code/experiments.jl b/working_code/experiments.jl deleted file mode 100644 index 091e1cc..0000000 --- a/working_code/experiments.jl +++ /dev/null @@ -1,13 +0,0 @@ -function interchain_stats(chains) - - d = Dict() - - for param in chains.name_map.parameters - μ = std(mean(chains[param], dims=1)) - σ = std(std(chains[param], dims=1)) - push!(d, param => Dict(:μ => μ, :σ => σ)) - end - - return d - -end \ No newline at end of file diff --git a/working_code/neuralode.jl b/working_code/neuralode.jl deleted file mode 100644 index b6c1eb8..0000000 --- a/working_code/neuralode.jl +++ /dev/null @@ -1,74 +0,0 @@ -using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, Plots, AdvancedHMC -using JLD, StatsPlots, Distributions - -u0 = [2.0; 0.0] -datasize = 40 -tspan = (0.0, 1) -tsteps = range(tspan[1], tspan[2], length = datasize) - -function trueODEfunc(du, u, p, t) - true_A = [-0.1 2.0; -2.0 -0.1] - du .= ((u.^3)'true_A)' -end - -prob_trueode = ODEProblem(trueODEfunc, u0, tspan) -mean_ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps)) -ode_data = mean_ode_data .+ 0.1 .* randn(size(mean_ode_data)..., 30) - -####DEFINE THE NEURAL ODE##### -dudt2 = FastChain((x, p) -> x.^3, - FastDense(2, 50, relu), - FastDense(50, 2)) -prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps) - -function predict_neuralode(p) - Array(prob_neuralode(u0, p)) -end -function loss_neuralode(p) - pred = predict_neuralode(p) - loss = sum(abs2, ode_data .- pred) - return loss, pred -end - -function l(θ) - lp = logpdf(MvNormal(zeros(length(θ) - 1), 1.0), θ[1:end-1]) - ll = sum(logpdf.(Normal.(ode_data, θ[end]), predict_neuralode(θ[1:end-1]))) - return lp + ll -end -function lp(θ) - return logpdf(MvNormal(zeros(length(θ) - 1), 1.0), θ[1:end-1]) -end -function ll(θ) - return sum(logpdf.(Normal.(ode_data, θ[end]), predict_neuralode(θ[1:end-1]))) -end - -function dldθ(θ) - x, lambda = Flux.Zygote.pullback(l,θ) - grad = first(lambda(1)) - return x, grad -end -function dlpdθ(θ) - x, lambda = Flux.Zygote.pullback(lp,θ) - grad = first(lambda(1)) - return x, grad -end -function dlldθ(θ) - x, lambda = Flux.Zygote.pullback(ll,θ) - grad = first(lambda(1)) - return x, grad -end - -init = [Float64.(prob_neuralode.p); 1.0] - -opt = DiffEqFlux.sciml_train(x -> -l(x), init, ADAM(0.05), maxiters = 1500) -pmin = opt.minimizer; -metric = DiagEuclideanMetric(length(pmin)) -h = Hamiltonian(metric, l, dldθ) -integrator = Leapfrog(find_good_stepsize(h, pmin)) -prop = AdvancedHMC.NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator, 10) -adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.5, prop.integrator)) - -samples, stats = sample(h, prop, pmin, 500, adaptor, 500; progress=true) - -using MCMCTempering -tempered_samples = sample() \ No newline at end of file From 79cbbf004d16809b14a3cb4b9ca42e167e53bb4e Mon Sep 17 00:00:00 2001 From: HarrisonWilde Date: Wed, 16 Nov 2022 20:45:18 +0000 Subject: [PATCH 44/51] QoL improvements on the code --- src/MCMCTempering.jl | 3 +- src/adaptation.jl | 12 ++-- src/ladders.jl | 37 ++++++----- src/states.jl | 129 ++++++++++++++++++++++++++++++++++++ src/stepping.jl | 151 ------------------------------------------- src/tempered.jl | 47 +++++++++----- test/runtests.jl | 6 +- 7 files changed, 193 insertions(+), 192 deletions(-) create mode 100644 src/states.jl diff --git a/src/MCMCTempering.jl b/src/MCMCTempering.jl index 6cddf89..357d4da 100644 --- a/src/MCMCTempering.jl +++ b/src/MCMCTempering.jl @@ -14,6 +14,7 @@ using DocStringExtensions include("adaptation.jl") include("swapping.jl") +include("states.jl") include("tempered.jl") include("ladders.jl") include("stepping.jl") @@ -35,7 +36,7 @@ function AbstractMCMC.bundle_samples( kwargs... ) AbstractMCMC.bundle_samples( - ts, model, sampler_for_chain(sampler, state), state_for_chain(state), chain_type; + ts, model, sampler_for_chain(sampler, state, 1), state_for_chain(state, 1), chain_type; kwargs... ) end diff --git a/src/adaptation.jl b/src/adaptation.jl index 1ac23ef..5d80a9d 100644 --- a/src/adaptation.jl +++ b/src/adaptation.jl @@ -18,7 +18,7 @@ See also: [`AdaptiveState`](@ref), [`update_inverse_temperatures`](@ref), and """ struct Geometric end -defaultscale(::Geometric, inverse_temperatures) = eltype(inverse_temperatures)(0.9) +defaultscale(::Geometric, Δ) = eltype(Δ)(0.9) """ InverselyAdditive @@ -30,7 +30,7 @@ See also: [`AdaptiveState`](@ref), [`update_inverse_temperatures`](@ref), and """ struct InverselyAdditive end -defaultscale(::InverselyAdditive, inverse_temperatures) = eltype(inverse_temperatures)(0.9) +defaultscale(::InverselyAdditive, Δ) = eltype(Δ)(0.9) struct AdaptiveState{S,T1<:Real,T2<:Real,P<:PolynomialStep} schedule_type::S @@ -93,12 +93,12 @@ function init_adaptation( η::Real, stepsize::Real ) - Nt = length(Δ) + N_it = length(Δ) step = PolynomialStep(η, stepsize) # TODO: One common state or one per temperature? # ρs = [ # AdaptiveState(schedule, swap_target, inversely_additive_weight_unconstrain(scale), step) - # for _ in 1:(Nt - 1) + # for _ in 1:(N_it - 1) # ] ρs = AdaptiveState(schedule, swap_target, log(scale), step) return ρs @@ -112,12 +112,12 @@ function init_adaptation( η::Real, stepsize::Real ) - Nt = length(Δ) + N_it = length(Δ) step = PolynomialStep(η, stepsize) # TODO: One common state or one per temperature? # ρs = [ # AdaptiveState(schedule, swap_target, geometric_weight_unconstrain(scale), step) - # for _ in 1:(Nt - 1) + # for _ in 1:(N_it - 1) # ] ρs = AdaptiveState(schedule, swap_target, geometric_weight_unconstrain(scale), step) return ρs diff --git a/src/ladders.jl b/src/ladders.jl index 65312c0..7e3e547 100644 --- a/src/ladders.jl +++ b/src/ladders.jl @@ -1,26 +1,25 @@ """ - get_scaling_val(Nt, swap_strategy) + get_scaling_val(N_it, swap_strategy) -Calculates the correct scaling factor for polynomial step size between temperatures. +Calculates a scaling factor for polynomial step size between inverse temperatures. """ -get_scaling_val(Nt, ::StandardSwap) = Nt - 1 -get_scaling_val(Nt, ::NonReversibleSwap) = 2 -get_scaling_val(Nt, ::RandomPermutationSwap) = 1 +get_scaling_val(N_it, ::StandardSwap) = N_it - 1 +get_scaling_val(N_it, ::NonReversibleSwap) = 2 """ - generate_inverse_temperatures(Nt, swap_strategy) + generate_inverse_temperatures(N_it, swap_strategy) -Returns a temperature ladder `Δ` containing `Nt` temperatures, +Returns a temperature ladder `Δ` containing `N_it` values, generated in accordance with the chosen `swap_strategy`. """ -function generate_inverse_temperatures(Nt, swap_strategy) +function generate_inverse_temperatures(N_it, swap_strategy) # Apparently, here we increase the temperature by a constant # factor which depends on `swap_strategy`. - scaling_val = get_scaling_val(Nt, swap_strategy) - Δ = Vector{Float64}(undef, Nt) + scaling_val = get_scaling_val(N_it, swap_strategy) + Δ = Vector{Float64}(undef, N_it) Δ[1] = 1 T = Δ[1] - for i in 1:(Nt - 1) + for i in 1:(N_it - 1) T += scaling_val Δ[i + 1] = inv(T) end @@ -34,12 +33,18 @@ end Checks and returns a sorted `Δ` containing `{β₀, ..., βₙ}` conforming such that `1 = β₀ > β₁ > ... > βₙ ≥ 0` """ function check_inverse_temperatures(Δ) + if length(Δ) <= 1 + error("More than one inverse temperatures must be provided.") + end if !all(zero.(Δ) .≤ Δ .≤ one.(Δ)) - error("Temperature schedule provided has values outside of the acceptable range, ensure all values are in [0, 1].") + error("The temperature ladder provided has values outside of the acceptable range, ensure all values are in [0, 1].") end - Δ = sort(Δ; rev=true) - if Δ[1] != one(Δ[1]) - error("Δ must contain 1, as β₀.") + Δ_sorted = sort(Δ; rev=true) + if Δ_sorted[1] != one(Δ_sorted[1]) + error("The temperature ladder must contain 1.") end - return Δ + if Δ_sorted != Δ + println("The temperature was sorted to ensure decreasing order.") + end + return Δ_sorted end diff --git a/src/states.jl b/src/states.jl new file mode 100644 index 0000000..482b28b --- /dev/null +++ b/src/states.jl @@ -0,0 +1,129 @@ +""" + TemperedState + +A general implementation of a state for a [`TemperedSampler`](@ref). + +# Fields + +$(FIELDS) + +# Description + +Suppose we're running 4 chains `X`, `Y`, `Z`, and `W`, each targeting a distribution for different +(inverse) temperatures `β`, say, `1.0`, `0.75`, `0.5`, and `0.25`, respectively. That is, we're mainly +interested in the chain `(X[1], X[2], … )` which targets the distribution with `β=1.0`. + +Moreover, suppose we also have 4 workers/processes for which we run these chains in "parallel" +(can also be serial wlog). + +We can then perform a swap in two different ways: +1. Swap the the _states_ between each process, i.e. permute `transitions_and_states`. +2. Swap the _temperatures_ between each process, i.e. permute `inverse_temperatures`. + +(1) is possibly the most intuitive approach since it means that the i-th worker/process +corresponds to the i-th chain; in this case, process 1 corresponds to `X`, process 2 to `Y`, etc. +The downside is that we need to move (potentially high-dimensional) states between the +workers/processes. + +(2) on the other hand does _not_ preserve the direct process-chain correspondance. +We now need to keep track of which process has which chain, from this we can +reconstruct each of the chains `X`, `Y`, etc. afterwards. +This means that we need only transfer a pair of numbers representing the (inverse) +temperatures between workers rather than the full states. + +This implementation follows approach (2). + +Here's an exemplar realisation of five steps of sampling and swap-attempts: + +``` +Chains: process_to_chain chain_to_process inverse_temperatures[process_to_chain[i]] +| | | | 1 2 3 4 1 2 3 4 1.00 0.75 0.50 0.25 +| | | | + V | | 2 1 3 4 2 1 3 4 0.75 1.00 0.50 0.25 + Λ | | +| | | | 2 1 3 4 2 1 3 4 0.75 1.00 0.50 0.25 +| | | | +| V | 2 3 1 4 3 1 2 4 0.75 0.50 1.00 0.25 +| Λ | +| | | | 2 3 1 4 3 1 2 4 0.75 0.50 1.00 0.25 +| | | | +``` + +In this case, the chain `X` can be reconstructed as: + +```julia +X[1] = states[1].transitions_and_states[1] +X[2] = states[2].transitions_and_states[2] +X[3] = states[3].transitions_and_states[2] +X[4] = states[4].transitions_and_states[3] +X[5] = states[5].transitions_and_states[3] +``` + +The indices here are exactly those represented by `states[k].chain_to_process[1]`. +""" +@concrete struct TemperedState + "collection of `(transition, state)` pairs for each process" + transitions_and_states + "collection of (inverse) temperatures β corresponding to each process" + inverse_temperatures + "collection indices such that `chain_to_process[i] = j` if the j-th process corresponds to the i-th chain" + chain_to_process + "collection indices such that `process_chain_to[j] = i` if the i-th chain corresponds to the j-th process" + process_to_chain + "total number of steps taken" + total_steps + "contains all necessary information for adaptation of inverse_temperatures" + adaptation_states + "flag which specifies wether this was a swap-step or not" + is_swap + "swap acceptance ratios on log-scale" + swap_acceptance_ratios +end + +""" + transition_for_chain(state[, I...]) + +Return the transition corresponding to the chain indexed by `I...`. +If `I...` is not specified, the transition corresponding to `β=1.0` will be returned, i.e. `I = (1, )`. +""" +transition_for_chain(state::TemperedState) = transition_for_chain(state, 1) +transition_for_chain(state::TemperedState, I...) = state.transitions_and_states[state.chain_to_process[I...]][1] + +""" + transition_for_process(state, I...) + +Return the transition corresponding to the process indexed by `I...`. +""" +transition_for_process(state::TemperedState, I...) = state.transitions_and_states[I...][1] + +""" + state_for_chain(state[, I...]) + +Return the state corresponding to the chain indexed by `I...`. +If `I...` is not specified, the state corresponding to `β=1.0` will be returned. +""" +state_for_chain(state::TemperedState) = state_for_chain(state, 1) +state_for_chain(state::TemperedState, I...) = state.transitions_and_states[I...][2] + +""" + state_for_process(state, I...) + +Return the state corresponding to the process indexed by `I...`. +""" +state_for_process(state::TemperedState, I...) = state.transitions_and_states[I...][2] + +""" + β_for_chain(state[, I...]) + +Return the β corresponding to the chain indexed by `I...`. +If `I...` is not specified, the β corresponding to `β=1.0` will be returned. +""" +β_for_chain(state::TemperedState) = β_for_chain(state, 1) +β_for_chain(state::TemperedState, I...) = state.inverse_temperatures[state.chain_to_process[I...]] + +""" + β_for_process(state, I...) + +Return the β corresponding to the process indexed by `I...`. +""" +β_for_process(state::TemperedState, I...) = state.inverse_temperatures[I...] \ No newline at end of file diff --git a/src/stepping.jl b/src/stepping.jl index 67c7757..a0fdbf2 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -1,153 +1,3 @@ -""" - TemperedState - -A general implementation of a state for a [`TemperedSampler`](@ref). - -# Fields - -$(FIELDS) - -# Description - -Suppose we're running 4 chains `X`, `Y`, `Z`, and `W`, each targeting a distribution for different -(inverse) temperatures `β`, say, `1.0`, `0.75`, `0.5`, and `0.25`, respectively. That is, we're mainly -interested in the chain `(X[1], X[2], … )` which targets the distribution with `β=1.0`. - -Moreover, suppose we also have 4 workers/processes for which we run these chains in "parallel" -(it can also be in serial, but for the sake of demonstration imagine it's parallel). - -When can then perform a swap in two different ways: -1. Swap the the _states_ between each process, i.e. permute `transitions_and_states`. -2. Swap the _temperatures_ between each process, i.e. permute `inverse_temperatures`. - -(1) is possibly the most intuitive approach since it means that the i-th worker/process -corresponds to the i-th chain; in this case, process 1 corresponds to `X`, process 2 to `Y`, etc. -The downside is that we need to move (potentially high-dimensional) states between the -workers/processes. - -(2) on the other hand does _not_ preserve the direct process-chain correspondance. -We now need to keep track of which process has which chain, from this we can -reconstruct each of the chains `X`, `Y`, etc. afterwards. -On the other hand, this means that we only need to transfer a pair of numbers -representing the (inverse) temperatures between workers rather than the full states. - -The current implementation follows approach (2). - -Here's an example realization of using five steps of sampling and swap-attempts: - -``` -Chains: process_to_chain chain_to_process inverse_temperatures[process_to_chain[i]] -| | | | 1 2 3 4 1 2 3 4 1.00 0.75 0.50 0.25 -| | | | - V | | 2 1 3 4 2 1 3 4 0.75 1.00 0.50 0.25 - Λ | | -| | | | 2 1 3 4 2 1 3 4 0.75 1.00 0.50 0.25 -| | | | -| V | 2 3 1 4 3 1 2 4 0.75 0.50 1.00 0.25 -| Λ | -| | | | 2 3 1 4 3 1 2 4 0.75 0.50 1.00 0.25 -| | | | -``` - -In this case, the chain `X` can be reconstructed as: - -```julia -X[1] = states[1].transitions_and_states[1] -X[2] = states[2].transitions_and_states[2] -X[3] = states[3].transitions_and_states[2] -X[4] = states[4].transitions_and_states[3] -X[5] = states[5].transitions_and_states[3] -``` - -The indices here are exactly those represented by `states[k].chain_to_process[1]`. -""" -@concrete struct TemperedState - "collection of `(transition, state)` pairs for each process" - transitions_and_states - "collection of (inverse) temperatures β corresponding to each process" - inverse_temperatures - "collection indices such that `chain_to_process[i] = j` if the j-th process corresponds to the i-th chain" - chain_to_process - "collection indices such that `process_chain_to[j] = i` if the i-th chain corresponds to the j-th process" - process_to_chain - "total number of steps taken" - total_steps - "contains all necessary information for adaptation of inverse_temperatures" - adaptation_states - "flag which specifies wether this was a swap-step or not" - is_swap - "swap acceptance ratios on log-scale" - swap_acceptance_ratios -end - -""" - transition_for_chain(state[, I...]) - -Return the transition corresponding to the chain indexed by `I...`. -If `I...` is not specified, the transition corresponding to `β=1.0` will be returned, i.e. `I = (1, )`. -""" -transition_for_chain(state::TemperedState) = transition_for_chain(state, 1) -transition_for_chain(state::TemperedState, I...) = state.transitions_and_states[state.chain_to_process[I...]][1] - -""" - transition_for_process(state, I...) - -Return the transition corresponding to the process indexed by `I...`. -""" -transition_for_process(state::TemperedState, I...) = state.transitions_and_states[I...][1] - -""" - state_for_chain(state[, I...]) - -Return the state corresponding to the chain indexed by `I...`. -If `I...` is not specified, the state corresponding to `β=1.0` will be returned. -""" -state_for_chain(state::TemperedState) = state_for_chain(state, 1) -state_for_chain(state::TemperedState, I...) = state.transitions_and_states[I...][2] - -""" - state_for_process(state, I...) - -Return the state corresponding to the process indexed by `I...`. -""" -state_for_process(state::TemperedState, I...) = state.transitions_and_states[I...][2] - -""" - β_for_chain(state[, I...]) - -Return the β corresponding to the chain indexed by `I...`. -If `I...` is not specified, the β corresponding to `β=1.0` will be returned. -""" -β_for_chain(state::TemperedState) = β_for_chain(state, 1) -β_for_chain(state::TemperedState, I...) = state.inverse_temperatures[state.chain_to_process[I...]] - -""" - β_for_process(state, I...) - -Return the β corresponding to the process indexed by `I...`. -""" -β_for_process(state::TemperedState, I...) = state.inverse_temperatures[I...] - -""" - sampler_for_chain(sampler::TemperedSampler, state::TemperedState[, I...]) - -Return the sampler corresponding to the chain indexed by `I...`. -If `I...` is not specified, the sampler corresponding to `β=1.0` will be returned. -""" -sampler_for_chain(sampler::TemperedSampler, state::TemperedState) = sampler_for_chain(sampler, state, 1) -function sampler_for_chain(sampler::TemperedSampler, state::TemperedState, I...) - return getsampler(sampler.sampler, state.chain_to_process[I...]) -end - -""" - sampler_for_process(sampler::TemperedSampler, state::TemperedState, I...) - -Return the sampler corresponding to the process indexed by `I...`. -""" -function sampler_for_process(sampler::TemperedSampler, state::TemperedState, I...) - return getsampler(sampler.sampler, I...) -end - """ should_swap(sampler, state) @@ -168,7 +18,6 @@ function StatsBase.sample( return AbstractMCMC.mcmcsample(rng, model, sampler, N; burnin = discard_initial, kwargs...) end -# Inital step function AbstractMCMC.step( rng::Random.AbstractRNG, model, diff --git a/src/tempered.jl b/src/tempered.jl index ea57c53..661ab4a 100644 --- a/src/tempered.jl +++ b/src/tempered.jl @@ -1,7 +1,7 @@ """ TemperedSampler <: AbstractMCMC.AbstractSampler -A `TemperedSampler` struct wraps an `sampler` and samples using parallel tempering. +A `TemperedSampler` struct wraps a sampler upon which to apply the Parallel Tempering algorithm. # Fields @@ -32,50 +32,68 @@ getsampler(sampler::TemperedSampler, I...) = getsampler(sampler.sampler, I...) """ numsteps(sampler::TemperedSampler) -Return number of temperatures used by `sampler`. +Return number of inverse temperatures used by `sampler`. """ numtemps(sampler::TemperedSampler) = length(sampler.inverse_temperatures) +""" + sampler_for_chain(sampler::TemperedSampler, state::TemperedState[, I...]) + +Return the sampler corresponding to the chain indexed by `I...`. +If `I...` is not specified, the sampler corresponding to `β=1.0` will be returned. +""" +sampler_for_chain(sampler::TemperedSampler, state::TemperedState) = sampler_for_chain(sampler, state, 1) +function sampler_for_chain(sampler::TemperedSampler, state::TemperedState, I...) + return getsampler(sampler.sampler, state.chain_to_process[I...]) +end + +""" + sampler_for_process(sampler::TemperedSampler, state::TemperedState, I...) +Return the sampler corresponding to the process indexed by `I...`. """ - tempered(sampler, inverse_temperatures; kwargs...) +function sampler_for_process(sampler::TemperedSampler, state::TemperedState, I...) + return getsampler(sampler.sampler, I...) +end + +""" + tempered(sampler, inverse_temperatures::Vector{<:Real}; kwargs...) OR - tempered(sampler, Nt::Integer; kwargs...) + tempered(sampler, N_it::Integer; kwargs...) Return tempered version of `sampler` using the provided `inverse_temperatures` or -inverse temperatures generated from `Nt` and the `swap_strategy`. +inverse temperatures generated from `N_it` and the `swap_strategy`. # Arguments - `sampler` is an algorithm or sampler object to be used for underlying sampling and to apply tempering to - The temperature schedule can be defined either explicitly or just as an integer number of temperatures, i.e. as: - `inverse_temperatures` containing a sequence of 'inverse temperatures' {β₀, ..., βₙ} where 0 ≤ βₙ < ... < β₁ < β₀ = 1 OR - - `Nt::Integer`, specifying the number of inverse temperatures to include in a generated `inverse_temperatures` + - `N_it::Integer`, specifying the number of inverse temperatures to include in a generated `inverse_temperatures` # Keyword arguments -- `swap_strategy::AbstractSwapStrategy` is the way in which temperature swaps are made. -- `swap_every::Integer` steps are carried out between each tempering swap step attempt +- `swap_strategy::AbstractSwapStrategy` is the way in which inverse temperature swaps between chains are made +- `swap_every::Integer` steps are carried out between each attempt at a swap # See also - [`TemperedSampler`](@ref) - For more on the swap strategies: - [`AbstractSwapStrategy`](@ref) - [`StandardSwap`](@ref) - - [`RandomPermutationSwap`](@ref) - [`NonReversibleSwap`](@ref) """ function tempered( sampler, - Nt::Integer, + N_it::Integer, swap_strategy::AbstractSwapStrategy = StandardSwap(); kwargs... ) - return tempered(sampler, generate_inverse_temperatures(Nt, swap_strategy); kwargs...) + return tempered(sampler, generate_inverse_temperatures(N_it, swap_strategy); swap_strategy = swap_strategy, kwargs...) end function tempered( sampler, - inverse_temperatures::Vector{<:Real}, - swap_strategy::AbstractSwapStrategy = StandardSwap(); + inverse_temperatures::Vector{<:Real}; + swap_strategy::AbstractSwapStrategy = StandardSwap(), swap_every::Integer = 1, adapt::Bool = true, adapt_target::Real = 0.234, @@ -85,9 +103,8 @@ function tempered( adapt_scale = defaultscale(adapt_schedule, inverse_temperatures), kwargs... ) - inverse_temperatures = check_inverse_temperatures(inverse_temperatures) - length(inverse_temperatures) > 1 || error("More than one inverse temperatures must be provided.") swap_every >= 1 || error("This must be a positive integer.") + inverse_temperatures = check_inverse_temperatures(inverse_temperatures) adaptation_states = init_adaptation( adapt_schedule, inverse_temperatures, adapt_target, adapt_scale, adapt_eta, adapt_stepsize ) diff --git a/test/runtests.jl b/test/runtests.jl index 7222ce5..9cd20a7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,8 +38,7 @@ include("compat.jl") tempered_sampler_rwmh = tempered( sampler_rwmh, - inverse_temperatures, - MCMCTempering.StandardSwap(); + inverse_temperatures; adapt = false, swap_every=swap_every ) @@ -147,7 +146,8 @@ include("compat.jl") # Number of iterations needed to obtain `nsamples` of non-swap iterations. nsamples_tempered = Int(ceil(nsamples * swap_every ÷ (swap_every - 1))) spl = tempered( - spl_inner, inverse_temperatures, swapstrategy; + spl_inner, inverse_temperatures; + swap_strategy = swapstrategy, adapt=false, # TODO: Test adaptation. Seems to work in some cases though. adapt_schedule=MCMCTempering.Geometric(), adapt_stepsize=1, From 4336516bdaa224332f87ef28b82531a8f1239c4e Mon Sep 17 00:00:00 2001 From: HarrisonWilde Date: Wed, 16 Nov 2022 20:49:14 +0000 Subject: [PATCH 45/51] Removing `StatsBase` dependency and `discard_initial` override --- Project.toml | 1 - src/MCMCTempering.jl | 1 - src/stepping.jl | 44 +++++++------------------------------------- 3 files changed, 7 insertions(+), 39 deletions(-) diff --git a/Project.toml b/Project.toml index c5d3892..84709e8 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,6 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" -StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] AbstractMCMC = "3.2, 4" diff --git a/src/MCMCTempering.jl b/src/MCMCTempering.jl index 357d4da..5bb5686 100644 --- a/src/MCMCTempering.jl +++ b/src/MCMCTempering.jl @@ -1,7 +1,6 @@ module MCMCTempering import AbstractMCMC -import StatsBase import Distributions import Random diff --git a/src/stepping.jl b/src/stepping.jl index a0fdbf2..291b18a 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -7,17 +7,6 @@ function should_swap(sampler::TemperedSampler, state::TemperedState) return state.total_steps % sampler.swap_every == 0 end -function StatsBase.sample( - rng::Random.AbstractRNG, - model::AbstractMCMC.AbstractModel, - sampler::TemperedSampler, - N::Integer; - discard_initial = 0, - kwargs... -) - return AbstractMCMC.mcmcsample(rng, model, sampler, N; burnin = discard_initial, kwargs...) -end - function AbstractMCMC.step( rng::Random.AbstractRNG, model, @@ -66,15 +55,18 @@ function AbstractMCMC.step( model, sampler::TemperedSampler, state::TemperedState; - burnin = 0, kwargs... ) - if state.total_steps <= burnin + # Reset. + @set! state.swap_acceptance_ratios = empty(state.swap_acceptance_ratios) + + if should_swap(sampler, state) + state = swap_step(rng, model, sampler, state) + @set! state.is_swap = true + else state = no_swap_step(rng, model, sampler, state; kwargs...) @set! state.is_swap = false - else - state = full_step(rng, model, sampler, state; kwargs...) end @set! state.total_steps += 1 @@ -111,28 +103,6 @@ function no_swap_step( return state end -function full_step( - rng::Random.AbstractRNG, - model, - sampler::TemperedSampler, - state::TemperedState; - kwargs... -) - # Reset. - @set! state.swap_acceptance_ratios = empty(state.swap_acceptance_ratios) - - if should_swap(sampler, state) - state = swap_step(rng, model, sampler, state) - @set! state.is_swap = true - else - state = no_swap_step(rng, model, sampler, state; kwargs...) - @set! state.is_swap = false - end - - # We want to return the transition for the _first_ chain, i.e. the chain usually corresponding to `β=1.0`. - return state -end - """ swap_step([strategy::AbstractSwapStrategy, ]rng, model, sampler, state) From 915b610fbd404cb3df780b311dde63ba86fdc188 Mon Sep 17 00:00:00 2001 From: HarrisonWilde Date: Wed, 16 Nov 2022 20:58:27 +0000 Subject: [PATCH 46/51] Adding back accidentally deleted RandomPermutationSwap stuff --- src/ladders.jl | 1 + src/tempered.jl | 1 + 2 files changed, 2 insertions(+) diff --git a/src/ladders.jl b/src/ladders.jl index 7e3e547..2cccab5 100644 --- a/src/ladders.jl +++ b/src/ladders.jl @@ -5,6 +5,7 @@ Calculates a scaling factor for polynomial step size between inverse temperature """ get_scaling_val(N_it, ::StandardSwap) = N_it - 1 get_scaling_val(N_it, ::NonReversibleSwap) = 2 +get_scaling_val(N_it, ::RandomPermutationSwap) = 1 """ generate_inverse_temperatures(N_it, swap_strategy) diff --git a/src/tempered.jl b/src/tempered.jl index 661ab4a..0ba54d7 100644 --- a/src/tempered.jl +++ b/src/tempered.jl @@ -81,6 +81,7 @@ inverse temperatures generated from `N_it` and the `swap_strategy`. - [`AbstractSwapStrategy`](@ref) - [`StandardSwap`](@ref) - [`NonReversibleSwap`](@ref) + - [`RandomPermutationSwap`](@ref) """ function tempered( sampler, From 4898193d112a0cc552710a3c21643508ea262629 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 17 Nov 2022 12:12:57 +0000 Subject: [PATCH 47/51] cleaned up testing a bit --- test/compat/advancedmh.jl | 7 +- test/runtests.jl | 358 ++++++++++++++++++-------------------- 2 files changed, 174 insertions(+), 191 deletions(-) diff --git a/test/compat/advancedmh.jl b/test/compat/advancedmh.jl index 2f88b85..363b06e 100644 --- a/test/compat/advancedmh.jl +++ b/test/compat/advancedmh.jl @@ -7,8 +7,8 @@ function MCMCTempering.make_tempered_model(sampler, m::DensityModel, β) return DensityModel(Base.Fix1(*, β) ∘ m.logdensity) end -# Now we need to make swapping possible. -# This should return a callable which evaluates to the temperered logdensity. +# Now we need to make swapping possible, which requires computing +# the log density of the tempered model at the candidate states. function MCMCTempering.compute_tempered_logdensities( model::DensityModel, sampler, @@ -16,10 +16,7 @@ function MCMCTempering.compute_tempered_logdensities( transition_other::AdvancedMH.Transition, β ) - # Just re-use computation from transition. - # lp = transition.lp lp = β * AdvancedMH.logdensity(model, transition.params) - # Compute for the other. lp_other = β * AdvancedMH.logdensity(model, transition_other.params) return lp, lp_other end diff --git a/test/runtests.jl b/test/runtests.jl index 9cd20a7..b08c41f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,109 +10,188 @@ using AbstractMCMC include("utils.jl") include("compat.jl") +""" + test_and_sample_model(model, sampler, inverse_temperatures[, swap_strategy]; kwargs...) + +Run the tempered version of `sampler` on `model` and return the resulting chain. + +Several properties of the tempered sampler are tested before returning: +- No invalid swappings has occured. +- Swaps were successfully performed at least a given portion of the time. + +# Arguments +- `model`: The model to temper and sample from. +- `sampler`: The sampler to temper and use to sample from `model`. +- `inverse_temperatures`: The inverse temperatures to for tempering.. +- `swap_strategy`: The swap strategy to use. + +# Keyword arguments +- `mean_swap_lower_bound`: A lower bound on the acceptance rate of swaps performed, e.g. if set to `0.1` then at least 10% of attempted swaps should be accepted. Defaults to `0.1`. +- `num_iterations`: The number of iterations to run the sampler for. Defaults to `2_000`. +- `swap_every`: The number of iterations between each swap attempt. Defaults to `2`. +- `adapt_target`: The target acceptance rate for the swaps. Defaults to `0.234`. +- `adapt_rtol`: The relative tolerance for the check of average swap acceptance rate and target swap acceptance rate. Defaults to `0.1`. +- `adapt_atol`: The absolute tolerance for the check of average swap acceptance rate and target swap acceptance rate. Defaults to `0.05`. +- `kwargs...`: Additional keyword arguments to pass to `MCMCTempering.tempered`. +""" +function test_and_sample_model( + model, + sampler, + inverse_temperatures, + swap_strategy=MCMCTempering.StandardSwap(); + mean_swap_rate_lower_bound=0.1, + num_iterations=2_000, + swap_every=2, + adapt_target=0.234, + adapt_rtol=0.1, + adapt_atol=0.05, + kwargs... +) + # TODO: Remove this when no longer necessary. + num_iterations_tempered = Int(ceil(num_iterations * swap_every ÷ (swap_every - 1))) + + # Make the tempered sampler. + sampler_tempered = tempered( + sampler, + inverse_temperatures; + swap_strategy=swap_strategy, + swap_every=swap_every, + adapt_target=adapt_target, + kwargs... + ) + + # Store the states. + states_tempered = [] + callback = StateHistoryCallback(states_tempered) + + # Sample. + samples_tempered = AbstractMCMC.sample( + model, sampler_tempered, num_iterations_tempered; callback=callback, progress=true + ) + + # Extract the states that were swapped. + states_swapped = filter(Base.Fix2(getproperty, :is_swap), states_tempered) + # Swap acceptance ratios should be compared against the target acceptance in case of adaptation. + swap_acceptance_ratios = mapreduce( + collect ∘ values ∘ Base.Fix2(getproperty, :swap_acceptance_ratios), + vcat, + states_swapped + ) + # Check that adaptation did something useful. + if sampler_tempered.adapt + swap_acceptance_ratios = map(Base.Fix1(min, 1.0) ∘ exp, swap_acceptance_ratios) + empirical_acceptance_rate = sum(swap_acceptance_ratios) / length(swap_acceptance_ratios) + @test adapt_target ≈ empirical_acceptance_rate atol = adapt_atol rtol = adapt_rtol + + # TODO: Maybe check something related to the temperatures themselves in case of adaptation. + # E.g. converged values shouldn't all be 0 or something. + # βs = mapreduce(Base.Fix2(getproperty, :inverse_temperatures), hcat, states) + end + + # Extract the history of chain indices. + process_to_chain_history_list = map(states_tempered) do state + state.process_to_chain + end + process_to_chain_history = permutedims(reduce(hcat, process_to_chain_history_list), (2, 1)) + + # Check that the swapping has been done correctly. + process_to_chain_uniqueness = map(states_tempered) do state + length(unique(state.process_to_chain)) == length(state.process_to_chain) + end + @test all(process_to_chain_uniqueness) + + # For the currently implemented strategies, the index process should not move by more than 1. + @test all(abs.(diff(process_to_chain_history[:, 1])) .≤ 1) + + chain_to_process_uniqueness = map(states_tempered) do state + length(unique(state.chain_to_process)) == length(state.chain_to_process) + end + @test all(chain_to_process_uniqueness) + + # Tests that we have at least swapped some times (say at least 10% of attempted swaps). + swap_success_indicators = map(eachrow(diff(process_to_chain_history; dims=1))) do row + # Some of the strategies performs multiple swaps in a swap-iteration, + # but we want to count the number of iterations for which we had a successful swap, + # i.e. only count non-zero elements in a row _once_. Hence the `min`. + min(1, sum(abs, row)) + end + @test sum(swap_success_indicators) ≥ (num_iterations_tempered / swap_every) * mean_swap_rate_lower_bound + + # Compare the tempered sampler to the untempered sampler. + state_tempered = states_tempered[end] + chain_tempered = AbstractMCMC.bundle_samples( + samples_tempered, model, sampler_tempered.sampler, MCMCTempering.state_for_chain(state_tempered), MCMCChains.Chains + ) + # Only pick out the samples after swapping. + # TODO: Remove this when no longer necessary. + chain_tempered = chain_tempered[swap_every:swap_every:end] + return chain_tempered +end + +function compare_chains( + chain::MCMCChains.Chains, chain_tempered::MCMCChains.Chains; + atol=1e-6, rtol=1e-6, + compare_std=true, + compare_ess=true +) + desc = describe(chain)[1].nt + desc_tempered = describe(chain_tempered)[1].nt + + # Compare the means. + @test desc.mean ≈ desc_tempered.mean atol = atol rtol = rtol + + # Compare the std. of the chains. + if compare_std + @test desc.std ≈ desc_tempered.std atol = atol rtol = rtol + end + + # Compare the ESS. + if compare_ess + ess = MCMCChains.ess_rhat(chain).nt.ess + ess_tempered = MCMCChains.ess_rhat(chain_tempered).nt.ess + # HACK: Just make sure it's not doing _horrible_. Though we'd hope it would + # actually do better than the internal sampler. + @test all(ess .≥ ess_tempered .* 0.5) + end +end + + @testset "MCMCTempering.jl" begin @testset "GMM 1D" begin - - nsamples = 1_000_000 - + num_iterations = 100_000 gmm = MixtureModel(Normal, [(-3, 1.5), (3, 1.5), (15, 1.5), (90, 1.5)], [0.175, 0.25, 0.275, 0.3]) - logdensity(x) = logpdf(gmm, x) - # Construct a DensityModel. + # Setup non-tempered. model = AdvancedMH.DensityModel(logdensity) - - # Create non-tempered baseline chain via RWMH sampler_rwmh = RWMH(Normal()) - samples = AbstractMCMC.sample(model, sampler_rwmh, nsamples) - chain = AbstractMCMC.bundle_samples(samples, model, sampler_rwmh, samples[1], MCMCChains.Chains) # Simple geometric ladder inverse_temperatures = MCMCTempering.check_inverse_temperatures(0.05 .^ [0, 1, 2]) - acceptance_rate_target = 0.234 - - # Number of iterations needed to obtain `nsamples` of non-swap iterations. - swap_every = 2 - nsamples_tempered = Int(ceil(nsamples * swap_every ÷ (swap_every - 1))) - - tempered_sampler_rwmh = tempered( + # Run the samplers. + chain_tempered = test_and_sample_model( + model, sampler_rwmh, - inverse_temperatures; - adapt = false, - swap_every=swap_every + [1.0, 0.5, 0.25, 0.125], + num_iterations=num_iterations, + swap_every=2, + adapt=false, ) - # Useful for analysis. - states = [] - callback = StateHistoryCallback(states) - - # Sample. - samples = AbstractMCMC.sample(model, tempered_sampler_rwmh, nsamples_tempered; callback=callback, progress=true) - βs = mapreduce(Base.Fix2(getproperty, :inverse_temperatures), hcat, states) - - states_swapped = filter(Base.Fix2(getproperty, :is_swap), states) - swap_acceptance_ratios = mapreduce( - collect ∘ values ∘ Base.Fix2(getproperty, :swap_acceptance_ratios), - vcat, - states_swapped - ) - - # Extract the history of chain indices. - process_to_chain_history_list = map(states) do state - state.process_to_chain - end - process_to_chain_history = permutedims(reduce(hcat, process_to_chain_history_list), (2, 1)) - - # Check that the swapping has been done correctly. - process_to_chain_uniqueness = map(states) do state - length(unique(state.process_to_chain)) == length(state.process_to_chain) - end - @test all(process_to_chain_uniqueness) - - # For these strategies, the index process should not move by more than 1. - @test all(abs.(diff(process_to_chain_history[:, 1])) .≤ 1) - - chain_to_process_uniqueness = map(states) do state - length(unique(state.chain_to_process)) == length(state.chain_to_process) - end - @test all(chain_to_process_uniqueness) - - # Tests that we have at least swapped some times (say at least 10% of attempted swaps). - swap_success_indicators = map(eachrow(diff(process_to_chain_history; dims=1))) do row - # Some of the strategies performs multiple swaps in a swap-iteration, - # but we want to count the number of iterations for which we had a successful swap, - # i.e. only count non-zero elements in a row _once_. Hence the `min`. - min(1, sum(abs, row)) - end - @test sum(swap_success_indicators) ≥ (nsamples / swap_every) * 0.1 - - # Get example state. - state = states[end] - chain = AbstractMCMC.bundle_samples( - samples, model, tempered_sampler_rwmh.sampler, MCMCTempering.state_for_chain(state), MCMCChains.Chains - ) - - # Thin chain and discard burnin. - chain_thinned = chain[length(chain) ÷ 2 + 1:5swap_every:end] - show(stdout, MIME"text/plain"(), chain_thinned) - - # Extract some summary statistics to compare. - desc = describe(chain_thinned)[1].nt - μ = desc.mean - σ = desc.std - + # # Compare the chains. + # compare_chains(chain, chain_tempered, atol=1e-1, compare_std=false, compare_ess=true) end @testset "MvNormal 2D" begin d = 2 - nsamples = 20_000 + num_iterations = 20_000 swap_every = 2 μ_true = [-5.0, 5.0] σ_true = [1.0, √(10.0)] - logdensity(x) = logpdf(MvNormal(μ_true, Diagonal(σ_true.^2)), x) + logdensity(x) = logpdf(MvNormal(μ_true, Diagonal(σ_true .^ 2)), x) # Sampler parameters. inverse_temperatures = MCMCTempering.check_inverse_temperatures([0.25, 0.5, 0.75, 1.0]) @@ -121,7 +200,10 @@ include("compat.jl") model = DensityModel(logdensity) # Set up our sampler with a joint multivariate Normal proposal. - spl_inner = RWMH(MvNormal(zeros(d), Diagonal(σ_true.^2))) + sampler = RWMH(MvNormal(zeros(d), Diagonal(σ_true .^ 2))) + # Sample for the non-tempered model for comparison. + samples = AbstractMCMC.sample(model, sampler, num_iterations) + chain = AbstractMCMC.bundle_samples(samples, model, sampler, samples[1], MCMCChains.Chains) # Different swap strategies to test. swapstrategies = [ @@ -130,113 +212,17 @@ include("compat.jl") MCMCTempering.NonReversibleSwap() ] - # First we run MH to have something to compare to. - samples_mh = AbstractMCMC.sample(model, spl_inner, nsamples; progress=false); - chain_mh = AbstractMCMC.bundle_samples(samples_mh, model, spl_inner, samples_mh[1], MCMCChains.Chains) - chain_thinned_mh = chain_mh[length(chain_mh) ÷ 2 + 1:5swap_every:end] - - # Extract some summary statistics to compare. - desc_mh = describe(chain_thinned_mh)[1].nt - μ_mh = desc_mh.mean - σ_mh = desc_mh.std - ess_mh = MCMCChains.ess_rhat(chain_thinned_mh).nt.ess - @testset "$(swapstrategy)" for swapstrategy in swapstrategies - acceptance_rate_target = 0.234 - # Number of iterations needed to obtain `nsamples` of non-swap iterations. - nsamples_tempered = Int(ceil(nsamples * swap_every ÷ (swap_every - 1))) - spl = tempered( - spl_inner, inverse_temperatures; - swap_strategy = swapstrategy, - adapt=false, # TODO: Test adaptation. Seems to work in some cases though. - adapt_schedule=MCMCTempering.Geometric(), - adapt_stepsize=1, - adapt_eta=0.66, - adapt_target=0.234, - swap_every=swap_every - ) - - # Useful for analysis. - states = [] - callback = StateHistoryCallback(states) - - # Sample. - samples = AbstractMCMC.sample(model, spl, nsamples_tempered; callback=callback, progress=true); - βs = mapreduce(Base.Fix2(getproperty, :inverse_temperatures), hcat, states) - - states_swapped = filter(Base.Fix2(getproperty, :is_swap), states) - swap_acceptance_ratios = mapreduce( - collect ∘ values ∘ Base.Fix2(getproperty, :swap_acceptance_ratios), - vcat, - states_swapped - ) - # Check that the adaptation did something useful. - if spl.adapt - swap_acceptance_ratios = map(Base.Fix1(min, 1.0) ∘ exp, swap_acceptance_ratios) - empirical_acceptance_rate = sum(swap_acceptance_ratios) / length(swap_acceptance_ratios) - @test acceptance_rate_target ≈ empirical_acceptance_rate atol = 0.05 - end - - # Extract the history of chain indices. - process_to_chain_history_list = map(states) do state - state.process_to_chain - end - process_to_chain_history = permutedims(reduce(hcat, process_to_chain_history_list), (2, 1)) - - # Check that the swapping has been done correctly. - process_to_chain_uniqueness = map(states) do state - length(unique(state.process_to_chain)) == length(state.process_to_chain) - end - @test all(process_to_chain_uniqueness) - - if any(isa.(Ref(swapstrategy), [MCMCTempering.StandardSwap, MCMCTempering.NonReversibleSwap])) - # For these strategies, the index process should not move by more than 1. - @test all(abs.(diff(process_to_chain_history[:, 1])) .≤ 1) - end - - chain_to_process_uniqueness = map(states) do state - length(unique(state.chain_to_process)) == length(state.chain_to_process) - end - @test all(chain_to_process_uniqueness) - - # Tests that we have at least swapped some times (say at least 10% of attempted swaps). - swap_success_indicators = map(eachrow(diff(process_to_chain_history; dims=1))) do row - # Some of the strategies performs multiple swaps in a swap-iteration, - # but we want to count the number of iterations for which we had a successful swap, - # i.e. only count non-zero elements in a row _once_. Hence the `min`. - min(1, sum(abs, row)) - end - @test sum(swap_success_indicators) ≥ (nsamples / swap_every) * 0.1 - - # Get example state. - state = states[end] - chain = AbstractMCMC.bundle_samples( - samples, model, spl.sampler, MCMCTempering.state_for_chain(state), MCMCChains.Chains + chain_tempered = test_and_sample_model( + model, + sampler, + inverse_temperatures, + num_iterations=num_iterations, + swap_every=swap_every, + swapstrategy=swapstrategy, + adapt=false, ) - - # Thin chain and discard burnin. - chain_thinned = chain[length(chain) ÷ 2 + 1:5swap_every:end] - show(stdout, MIME"text/plain"(), chain_thinned) - - # Extract some summary statistics to compare. - desc = describe(chain_thinned)[1].nt - μ = desc.mean - σ = desc.std - - # `StandardSwap` is quite unreliable, so struggling to come up with reasonable tests. - if !(swapstrategy isa StandardSwap) - @test μ ≈ μ_true rtol=0.05 - - # NOTE(torfjelde): The variance is usually quite large for the tempered chains - # and I don't quite know if this is expected or not. - # @test norm(σ - σ_true) ≤ 0.5 - - # Comparison to just running the internal sampler. - ess = MCMCChains.ess_rhat(chain_thinned).nt.ess - # HACK: Just make sure it's not doing _horrible_. Though we'd hope it would - # actually do better than the internal sampler. - @test all(ess .≥ ess_mh .* 0.5) - end + compare_chains(chain, chain_tempered, rtol=0.1, compare_std=false, compare_ess=true) end end end From 13cb3b277d62c10eb2dc2dece41e23dd46393f17 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 17 Nov 2022 12:13:16 +0000 Subject: [PATCH 48/51] made the compute_tempered_logdensities a bit more general --- src/swapping.jl | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/swapping.jl b/src/swapping.jl index a9232cc..14cebc7 100644 --- a/src/swapping.jl +++ b/src/swapping.jl @@ -67,11 +67,14 @@ end """ compute_tempered_logdensities(model, sampler, transition, transition_other, β) + compute_tempered_logdensities(model, sampler, sampler_other, transition, transition_other, state, state_other, β, β_other) Return `(logπ(transition, β), logπ(transition_other, β))` where `logπ(x, β)` denotes the log-density for `model` with inverse-temperature `β`. """ -function compute_tempered_logdensities end +function compute_tempered_logdensities(model, sampler, sampler_other, transition, transition_other, state, state_other, β, β_other) + return compute_tempered_logdensities(model, sampler, transition, transition_other, β) +end """ swap_acceptance_pt(logπk, logπkp1) @@ -92,17 +95,22 @@ Attempt to swap the temperatures of two chains by tempering the densities and calculating the swap acceptance ratio; then swapping if it is accepted. """ function swap_attempt(rng, model, sampler, state, k, adapt, total_steps) + # TODO: Allow arbitrary `k` rather than just `k + 1`. # Extract the relevant transitions. + samplerk = sampler_for_chain(sampler, state, k) + samplerkp1 = sampler_for_chain(sampler, state, k + 1) transitionk = transition_for_chain(state, k) transitionkp1 = transition_for_chain(state, k + 1) + statek = state_for_chain(state, k) + statekp1 = state_for_chain(state, k + 1) + βk = β_for_chain(state, k) + βkp1 = β_for_chain(state, k + 1) # Evaluate logdensity for both parameters for each tempered density. - # NOTE: Here we want to propose swaps between the neighboring _chains_ not processes, - # and so we get the `β` and `sampler` corresponding to the k-th and (k+1)-th chains. logπk_θk, logπk_θkp1 = compute_tempered_logdensities( - model, sampler_for_chain(sampler, state, k), transitionk, transitionkp1, β_for_chain(state, k) + model, samplerk, samplerkp1, transitionk, transitionkp1, statek, statekp1, βk, βkp1 ) logπkp1_θkp1, logπkp1_θk = compute_tempered_logdensities( - model, sampler_for_chain(sampler, state, k + 1), transitionkp1, transitionk, β_for_chain(state, k + 1) + model, samplerkp1, samplerk, transitionkp1, transitionk, statekp1, statek, βkp1, βk ) # If the proposed temperature swap is accepted according `logα`, From cb6937ed5982e272fce17b9e4cf32ef731a912c8 Mon Sep 17 00:00:00 2001 From: HarrisonWilde Date: Thu, 17 Nov 2022 17:13:40 +0000 Subject: [PATCH 49/51] Implementing `tempered_sample` to allow for no-swap burn-in and easier usage --- Project.toml | 1 + src/MCMCTempering.jl | 7 +++-- src/{tempered.jl => sampler.jl} | 32 ++++++++++---------- src/sampling.jl | 53 +++++++++++++++++++++++++++++++++ src/{states.jl => state.jl} | 2 ++ src/stepping.jl | 26 ++++++++++++++-- 6 files changed, 101 insertions(+), 20 deletions(-) rename src/{tempered.jl => sampler.jl} (82%) create mode 100644 src/sampling.jl rename src/{states.jl => state.jl} (98%) diff --git a/Project.toml b/Project.toml index 84709e8..af97b01 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" +ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" diff --git a/src/MCMCTempering.jl b/src/MCMCTempering.jl index 5bb5686..1752dfc 100644 --- a/src/MCMCTempering.jl +++ b/src/MCMCTempering.jl @@ -4,6 +4,7 @@ import AbstractMCMC import Distributions import Random +using ProgressLogging: ProgressLogging using ConcreteStructs: @concrete using Setfield: @set, @set! @@ -13,13 +14,15 @@ using DocStringExtensions include("adaptation.jl") include("swapping.jl") -include("states.jl") -include("tempered.jl") +include("state.jl") +include("sampler.jl") +include("sampling.jl") include("ladders.jl") include("stepping.jl") include("model.jl") export tempered, + tempered_sample, TemperedSampler, make_tempered_model, StandardSwap, diff --git a/src/tempered.jl b/src/sampler.jl similarity index 82% rename from src/tempered.jl rename to src/sampler.jl index 0ba54d7..a7289ec 100644 --- a/src/tempered.jl +++ b/src/sampler.jl @@ -57,11 +57,11 @@ function sampler_for_process(sampler::TemperedSampler, state::TemperedState, I.. end """ - tempered(sampler, inverse_temperatures::Vector{<:Real}; kwargs...) + tempered(sampler, inverse_temperatures; kwargs...) OR - tempered(sampler, N_it::Integer; kwargs...) + tempered(sampler, N_it; swap_strategy=StandardSwap(), kwargs...) -Return tempered version of `sampler` using the provided `inverse_temperatures` or +Return a tempered version of `sampler` using the provided `inverse_temperatures` or inverse temperatures generated from `N_it` and the `swap_strategy`. # Arguments @@ -69,7 +69,7 @@ inverse temperatures generated from `N_it` and the `swap_strategy`. - The temperature schedule can be defined either explicitly or just as an integer number of temperatures, i.e. as: - `inverse_temperatures` containing a sequence of 'inverse temperatures' {β₀, ..., βₙ} where 0 ≤ βₙ < ... < β₁ < β₀ = 1 OR - - `N_it::Integer`, specifying the number of inverse temperatures to include in a generated `inverse_temperatures` + - `N_it`, specifying the integer number of inverse temperatures to include in a generated `inverse_temperatures` # Keyword arguments - `swap_strategy::AbstractSwapStrategy` is the way in which inverse temperature swaps between chains are made @@ -84,24 +84,24 @@ inverse temperatures generated from `N_it` and the `swap_strategy`. - [`RandomPermutationSwap`](@ref) """ function tempered( - sampler, - N_it::Integer, - swap_strategy::AbstractSwapStrategy = StandardSwap(); + sampler::AbstractMCMC.AbstractSampler, + N_it::Integer; + swap_strategy::AbstractSwapStrategy=StandardSwap(), kwargs... ) return tempered(sampler, generate_inverse_temperatures(N_it, swap_strategy); swap_strategy = swap_strategy, kwargs...) end function tempered( - sampler, + sampler::AbstractMCMC.AbstractSampler, inverse_temperatures::Vector{<:Real}; - swap_strategy::AbstractSwapStrategy = StandardSwap(), - swap_every::Integer = 1, - adapt::Bool = true, - adapt_target::Real = 0.234, - adapt_stepsize::Real = 1, - adapt_eta::Real = 0.66, - adapt_schedule = Geometric(), - adapt_scale = defaultscale(adapt_schedule, inverse_temperatures), + swap_strategy::AbstractSwapStrategy=StandardSwap(), + swap_every::Integer=1, + adapt::Bool=false, + adapt_target::Real=0.234, + adapt_stepsize::Real=1, + adapt_eta::Real=0.66, + adapt_schedule=Geometric(), + adapt_scale=defaultscale(adapt_schedule, inverse_temperatures), kwargs... ) swap_every >= 1 || error("This must be a positive integer.") diff --git a/src/sampling.jl b/src/sampling.jl new file mode 100644 index 0000000..d7b7f44 --- /dev/null +++ b/src/sampling.jl @@ -0,0 +1,53 @@ +""" + tempered_sample([rng, ], model, sampler, N, inverse_temperatures; kwargs...) + OR + tempered_sample([rng, ], model, sampler, N, N_it; swap_strategy=StandardSwap(), kwargs...) + +Generate `N` samples from `model` using a tempered version of the provided `sampler`. +Provide either `inverse_temperatures` or `N_it` (and a `swap_strategy`) to generate some + +# Keyword arguments +- `N_burnin::Integer` burn-in steps will be carried out before any swapping between chains is attempted +- `swap_strategy::AbstractSwapStrategy` is the way in which inverse temperature swaps between chains are made +- `swap_every::Integer` steps are carried out between each attempt at a swap + +# See also +- [`tempered`](@ref) +- [`TemperedSampler`](@ref) +- For more on the swap strategies: + - [`AbstractSwapStrategy`](@ref) + - [`StandardSwap`](@ref) + - [`NonReversibleSwap`](@ref) + - [`RandomPermutationSwap`](@ref) +""" +function tempered_sample( + model, + sampler::AbstractMCMC.AbstractSampler, + N::Integer, + tempering_argument::Union{Integer, Vector{<:Real}}; + kwargs... +) + return tempered_sample(Random.default_rng(), model, sampler, N, tempering_argument; kwargs...) +end +function tempered_sample( + rng, + model, + sampler::AbstractMCMC.AbstractSampler, + N::Integer, + N_it::Integer; + swap_strategy::AbstractSwapStrategy = StandardSwap(), + kwargs... +) + return tempered_sample(model, sampler, N, generate_inverse_temperatures(N_it, swap_strategy); swap_strategy=swap_strategy, kwargs...) +end +function tempered_sample( + rng, + model, + sampler::AbstractMCMC.AbstractSampler, + N::Integer, + inverse_temperatures::Vector{<:Real}; + kwargs... +) + tempered_sampler = tempered(sampler, inverse_temperatures; kwargs...) + return AbstractMCMC.sample(rng, model, tempered_sampler, N; kwargs...) +end \ No newline at end of file diff --git a/src/states.jl b/src/state.jl similarity index 98% rename from src/states.jl rename to src/state.jl index 482b28b..12be39b 100644 --- a/src/states.jl +++ b/src/state.jl @@ -72,6 +72,8 @@ The indices here are exactly those represented by `states[k].chain_to_process[1] process_to_chain "total number of steps taken" total_steps + "number of burn-in steps taken" + burnin_steps "contains all necessary information for adaptation of inverse_temperatures" adaptation_states "flag which specifies wether this was a swap-step or not" diff --git a/src/stepping.jl b/src/stepping.jl index 291b18a..6ae334a 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -1,5 +1,5 @@ """ - should_swap(sampler, state) +should_swap(sampler, state) Return `true` if a swap should happen at this iteration, and `false` otherwise. """ @@ -11,6 +11,8 @@ function AbstractMCMC.step( rng::Random.AbstractRNG, model, sampler::TemperedSampler; + N_burnin::Integer=0, + burnin_progress::Bool=AbstractMCMC.PROGRESS[], init_params=nothing, kwargs... ) @@ -42,11 +44,32 @@ function AbstractMCMC.step( process_to_chain, chain_to_process, 1, + 0, sampler.adaptation_states, false, Dict{Int,Float64}() ) + if N_burnin > 0 + AbstractMCMC.@ifwithprogresslogger burnin_progress name = "Burn-in" begin + # Determine threshold values for progress logging + # (one update per 0.5% of progress) + if burnin_progress + threshold = N_burnin ÷ 200 + next_update = threshold + end + + for i in 1:N_burnin + if burnin_progress && i >= next_update + ProgressLogging.@logprogress i / N_burnin + next_update = i + threshold + end + state = no_swap_step(rng, model, sampler, state; kwargs...) + @set! state.burnin_steps += 1 + end + end + end + return transition_for_chain(state), state end @@ -57,7 +80,6 @@ function AbstractMCMC.step( state::TemperedState; kwargs... ) - # Reset. @set! state.swap_acceptance_ratios = empty(state.swap_acceptance_ratios) From db4b8421f7f989eb4364a53e9260b1645179183e Mon Sep 17 00:00:00 2001 From: HarrisonWilde Date: Thu, 17 Nov 2022 18:14:19 +0000 Subject: [PATCH 50/51] Tweaking sample call --- src/sampling.jl | 6 ++++-- src/stepping.jl | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/sampling.jl b/src/sampling.jl index d7b7f44..733158b 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -24,11 +24,12 @@ function tempered_sample( model, sampler::AbstractMCMC.AbstractSampler, N::Integer, - tempering_argument::Union{Integer, Vector{<:Real}}; + arg::Union{Integer, Vector{<:Real}}; kwargs... ) - return tempered_sample(Random.default_rng(), model, sampler, N, tempering_argument; kwargs...) + return tempered_sample(Random.default_rng(), model, sampler, N, arg; kwargs...) end + function tempered_sample( rng, model, @@ -40,6 +41,7 @@ function tempered_sample( ) return tempered_sample(model, sampler, N, generate_inverse_temperatures(N_it, swap_strategy); swap_strategy=swap_strategy, kwargs...) end + function tempered_sample( rng, model, diff --git a/src/stepping.jl b/src/stepping.jl index 6ae334a..d46b0b1 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -16,6 +16,7 @@ function AbstractMCMC.step( init_params=nothing, kwargs... ) + # `TemperedState` has the transitions and states in the order of # the processes, and performs swaps by moving the (inverse) temperatures # `β` between the processes, rather than moving states between processes From fce26aa3397a593b69def6254baae554fcd4dad7 Mon Sep 17 00:00:00 2001 From: HarrisonWilde Date: Thu, 17 Nov 2022 18:57:35 +0000 Subject: [PATCH 51/51] Working burnin --- src/sampler.jl | 4 ++-- src/stepping.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index a7289ec..f180de9 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -95,7 +95,7 @@ function tempered( sampler::AbstractMCMC.AbstractSampler, inverse_temperatures::Vector{<:Real}; swap_strategy::AbstractSwapStrategy=StandardSwap(), - swap_every::Integer=1, + swap_every::Integer=2, adapt::Bool=false, adapt_target::Real=0.234, adapt_stepsize::Real=1, @@ -104,7 +104,7 @@ function tempered( adapt_scale=defaultscale(adapt_schedule, inverse_temperatures), kwargs... ) - swap_every >= 1 || error("This must be a positive integer.") + swap_every >= 2 || error("This must be a positive integer greater than 1.") inverse_temperatures = check_inverse_temperatures(inverse_temperatures) adaptation_states = init_adaptation( adapt_schedule, inverse_temperatures, adapt_target, adapt_scale, adapt_eta, adapt_stepsize diff --git a/src/stepping.jl b/src/stepping.jl index d46b0b1..33af093 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -1,5 +1,5 @@ """ -should_swap(sampler, state) + should_swap(sampler, state) Return `true` if a swap should happen at this iteration, and `false` otherwise. """