Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General improvements and fixes #133

Merged
merged 52 commits into from
Dec 4, 2022
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
4d48a35
additional deps and test deps
torfjelde Sep 11, 2021
2df4a2b
updated stepping code to use AbstractSwapStrategy and made TemperedSa…
torfjelde Oct 4, 2021
83c5c49
made TemperedSampler concretely typed and fixed soem docs
torfjelde Oct 4, 2021
9e9153c
introduced AbstractSwapStrategy and removed get_params and make_tempe…
torfjelde Oct 4, 2021
26f12f1
updated stepping.jl
torfjelde Oct 4, 2021
aa88ae4
added docstring for make_tempered_model
torfjelde Oct 4, 2021
2523a9d
updated adaptation.jl and made structs concrete
torfjelde Oct 4, 2021
7462ebf
updated ladders.jl
torfjelde Oct 4, 2021
3fb13e9
addressed some comments
torfjelde Oct 13, 2021
90e60f2
added tests
torfjelde Oct 13, 2021
e30718d
fixed a bug
torfjelde Oct 13, 2021
9e8607a
updated the StateHistoryCallback a bit
torfjelde Oct 20, 2021
3bb79df
made the distinction between chains and processes clearer
torfjelde Oct 20, 2021
90722c9
added tests
torfjelde Oct 20, 2021
50ad1ec
fixed incorrect statement
torfjelde Oct 20, 2021
0c204ac
renamed some fields to be more descriptive and fixed left-over bug
torfjelde Oct 20, 2021
e2bbc90
updated tests
torfjelde Oct 20, 2021
e7466cc
removed some show from tests
torfjelde Oct 20, 2021
0a59517
began updating docstrings
torfjelde Oct 20, 2021
f7a7f31
fixed docstring for TemeperedState
torfjelde Oct 21, 2021
696d8d1
fix exports
torfjelde Oct 21, 2021
bbb2fc2
a bunch of renaming
torfjelde Oct 21, 2021
8ac7374
deleted plotting functionality
torfjelde Oct 21, 2021
f7c46e7
fixed bug and added should_swap method
torfjelde Oct 21, 2021
dcb2a0d
improved tests
torfjelde Oct 21, 2021
287b501
Typo
ParadaCarleton Nov 20, 2021
6239710
Typo
ParadaCarleton Nov 20, 2021
8c1b8ff
implemented adaptation scheme for inverse temperatures using a geomet…
torfjelde Dec 7, 2021
f70690f
made some changes to some code that I cannot understand the original …
torfjelde Dec 7, 2021
afa2900
added parameter for controlling which type of schedule to use when ad…
torfjelde Dec 7, 2021
c0d9e61
make number of steps taken for each adaptor part of their state
torfjelde Dec 16, 2021
4563d33
improvements to parameterization of the adaptation techniques
torfjelde Nov 14, 2022
86fbf0d
updated test env
torfjelde Nov 14, 2022
229059b
keep track of swapping ratios
torfjelde Nov 14, 2022
c4583f9
tests are now runnable
torfjelde Nov 14, 2022
a1efd11
commented out unused code
torfjelde Nov 14, 2022
c221572
Corrected typo
HarrisonWilde Nov 16, 2022
632cca9
Added 1D GMM, sort of works for it
HarrisonWilde Nov 16, 2022
f30acfa
Make `StandardSwap` the default strategy when one
HarrisonWilde Nov 16, 2022
0a0b131
Fixing test case for GMM
HarrisonWilde Nov 16, 2022
910dc6d
Implementing burn-in, introduces depedency on StatsBase
HarrisonWilde Nov 16, 2022
a4437f9
Fixed error with burnin
HarrisonWilde Nov 16, 2022
b9c6a90
cleaning up working_code
HarrisonWilde Nov 16, 2022
79cbbf0
QoL improvements on the code
HarrisonWilde Nov 16, 2022
4336516
Removing `StatsBase` dependency and `discard_initial` override
HarrisonWilde Nov 16, 2022
915b610
Adding back accidentally deleted RandomPermutationSwap stuff
HarrisonWilde Nov 16, 2022
4898193
cleaned up testing a bit
torfjelde Nov 17, 2022
13cb3b2
made the compute_tempered_logdensities a bit more general
torfjelde Nov 17, 2022
cb6937e
Implementing `tempered_sample` to allow for no-swap burn-in and easie…
HarrisonWilde Nov 17, 2022
db4b842
Tweaking sample call
HarrisonWilde Nov 17, 2022
fce26aa
Working burnin
HarrisonWilde Nov 17, 2022
8a6b47e
Merge pull request #137 from TuringLang/harry/improvements_additions
yebai Dec 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,27 @@ version = "0.1.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"

[compat]
AbstractMCMC = "3.2"
ConcreteStructs = "0.2"
Distributions = "0.24, 0.25"
DocStringExtensions = "0.8"
Setfield = "0.7"
julia = "1"

[extras]
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Test", "AdvancedMH", "MCMCChains", "Bijectors", "StatsPlots", "LinearAlgebra"]
20 changes: 16 additions & 4 deletions src/MCMCTempering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,24 @@ import AbstractMCMC
import Distributions
import Random

using ConcreteStructs: @concrete
using Setfield: @set, @set!

using DocStringExtensions

include("adaptation.jl")
include("swapping.jl")
include("tempered.jl")
include("ladders.jl")
include("stepping.jl")
include("model.jl")
include("swapping.jl")
include("plotting.jl")

export tempered, TemperedSampler, plot_swaps, plot_ladders, make_tempered_model, get_tempered_loglikelihoods_and_params, make_tempered_loglikelihood, get_params
export tempered,
TemperedSampler,
make_tempered_model,
StandardSwap,
RandomPermutationSwap,
NonReversibleSwap

function AbstractMCMC.bundle_samples(
ts::Vector,
Expand All @@ -22,7 +31,10 @@ function AbstractMCMC.bundle_samples(
chain_type::Type;
kwargs...
)
AbstractMCMC.bundle_samples(ts, model, sampler.internal_sampler, state, chain_type; kwargs...)
AbstractMCMC.bundle_samples(
ts, model, sampler_for_chain(sampler, state), state_for_chain(state), chain_type;
kwargs...
)
end

end
32 changes: 15 additions & 17 deletions src/adaptation.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@

struct PolynomialStep
η :: Real
c :: Real
@concrete struct PolynomialStep
η
c
end
function get(step::PolynomialStep, k::Real)
step.c * (k + 1.) ^ (-step.η)
end


struct AdaptiveState
swap_target_ar :: Real
scale :: Base.RefValue{<:Real}
step :: PolynomialStep
end
function AdaptiveState(swap_target::Real, scale::Real, step::PolynomialStep)
AdaptiveState(swap_target, Ref(log(scale)), step)
struct AdaptiveState{T1<:Real,T2<:Real,P<:PolynomialStep}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason not to @concrete this? I think I'd prefer it for consistency.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't do it here because in this cause I think the type-restrictions are reasonable, but I'm not super-invested either way 🤷

swap_target_ar :: T1
logscale :: T2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log_scale

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Julia community is quite fond of squash-case, hence why I did it like this.

Copy link
Member

@ParadaCarleton ParadaCarleton Nov 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't aware of squash case in variable names/struct fields, just in function names, but makes sense.

step :: P
end


Expand All @@ -26,15 +22,15 @@ function init_adaptation(
)
Nt = length(Δ)
step = PolynomialStep(γ, Nt - 1)
Ρ = [AdaptiveState(swap_target, scale, step) for _ in 1:(Nt - 1)]
Ρ = [AdaptiveState(swap_target, log(scale), step) for _ in 1:(Nt - 1)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps the function should take log_scale as its input, to keep consistency? Switching around between scale+log_scale will probably cause a bug at some point.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this is essentially what's holding the PR back atm. This was left-over from before, not something I introduced; I've just tried to make it consistent (pretty sure the previous impl was buggy). But I need to find the source of the adaptation method so I can figure out exactly what the correct impl is.

return Ρ
end


function rhos_to_ladder(Ρ, Δ)
β′ = Δ[1]
for i in 1:length(Ρ)
β′ += exp(Ρ[i].scale[])
β′ += exp(Ρ[i].logscale)
Δ[i + 1] = Δ[1] / β′
end
return Δ
Expand All @@ -48,7 +44,9 @@ function adapt_rho(ρ::AdaptiveState, swap_ar, n)
end


function adapt_ladder(Ρ, Δ, k, swap_ar, n)
Ρ[k].scale[] += adapt_rho(Ρ[k], swap_ar, n)
return Ρ, rhos_to_ladder(Ρ, Δ)
end
function adapt_ladder(P, Δ, k, swap_ar, n)
P[k] = let Pk = P[k]
@set Pk.logscale += adapt_rho(Pk, swap_ar, n)
end
return P, rhos_to_ladder(P, Δ)
end
31 changes: 11 additions & 20 deletions src/ladders.jl
Original file line number Diff line number Diff line change
@@ -1,46 +1,37 @@
"""
get_scaling_val(Nt, swap_strategy)

Calculates the correct scaling factor for polynomial step size between temperatures
Calculates the correct scaling factor for polynomial step size between temperatures.
"""
function get_scaling_val(Nt, swap_strategy)
# Why these?
if swap_strategy == :standard
scaling_val = Nt - 1
elseif swap_strategy == :nonrev
scaling_val = 2
else
scaling_val = 1
end
return scaling_val
end

get_scaling_val(Nt, ::StandardSwap) = Nt - 1
get_scaling_val(Nt, ::NonReversibleSwap) = 2
get_scaling_val(Nt, ::RandomPermutationSwap) = 1

"""
generate_Δ(Nt, swap_strategy)
generate_inverse_temperatures(Nt, swap_strategy)

Returns a temperature ladder `Δ` containing `Nt` temperatures,
generated in accordance with the chosen `swap_strategy`
generated in accordance with the chosen `swap_strategy`.
"""
function generate_Δ(Nt, swap_strategy)
function generate_inverse_temperatures(Nt, swap_strategy)
scaling_val = get_scaling_val(Nt, swap_strategy)
Δ = zeros(Real, Nt)
Δ = zeros(Nt)
Δ[1] = 1.0
β′ = Δ[1]
for i ∈ 1:(Nt - 1)
β′ += exp(scaling_val)
β′ += scaling_val
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be scaling_val or log(scaling_val)? Not sure.

Δ[i + 1] = Δ[1] / β′
end
return Δ
end


"""
check_Δ(Δ)
check_inverse_temperatures(Δ)

Checks and returns a sorted `Δ` containing `{β₀, ..., βₙ}` conforming such that `1 = β₀ > β₁ > ... > βₙ ≥ 0`
"""
function check_Δ(Δ)
function check_inverse_temperatures(Δ)
if !all(zero.(Δ) .≤ Δ .≤ one.(Δ))
error("Temperature schedule provided has values outside of the acceptable range, ensure all values are in [0, 1].")
end
Expand Down
9 changes: 5 additions & 4 deletions src/model.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""
make_tempered_model(sampler, model, args...)

# struct TemperedModel <: AbstractPPL.AbstractProbabilisticProgram
# model :: DynamicPPL.Model
# β :: AbstractFloat
# end
Return an instance representing a model.

The return-type depends on its usage in [`compute_tempered_logdensities`](@ref).
"""
function make_tempered_model end

12 changes: 0 additions & 12 deletions src/plotting.jl

This file was deleted.

Loading