diff --git a/Project.toml b/Project.toml index edef540..d3056ac 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MCMCTempering" uuid = "ce233488-44ea-4441-b732-192676ce2298" authors = ["Harrison Wilde and contributors"] -version = "0.3.2" +version = "0.4.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -17,7 +17,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" [compat] -AbstractMCMC = "3.2, 4" +AbstractMCMC = "5" ConcreteStructs = "0.2" Distributions = "0.24, 0.25" DocStringExtensions = "0.8, 0.9" diff --git a/docs/make.jl b/docs/make.jl index 8875bad..2f1d621 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -7,7 +7,8 @@ makedocs( sitename = "MCMCTempering", format = Documenter.HTML(), modules = [MCMCTempering], - pages=["Home" => "index.md", "getting-started.md", "api.md"], + pages = ["Home" => "index.md", "getting-started.md", "api.md"], + warnonly = true ) # Deply! diff --git a/docs/src/api.md b/docs/src/api.md index 5aed321..30c0768 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -70,6 +70,17 @@ MCMCTempering.swap_step ## Other samplers +To make a sampler work with MCMCTempering.jl, the sampler needs to implement a few methods: + +```@docs +MCMCTempering.getparams +MCMCTempering.getlogprob +MCMCTempering.getparams_and_logprob +MCMCTempering.setparams_and_logprob!! +``` + +Other useful methods are: + ```@docs MCMCTempering.saveall ``` diff --git a/docs/src/getting-started.md b/docs/src/getting-started.md index 8326d5e..41b79ee 100644 --- a/docs/src/getting-started.md +++ b/docs/src/getting-started.md @@ -74,13 +74,13 @@ Let's instead try to use a _tempered_ version of `RWMH`. _But_ before we do that To do that we need to implement two methods. First we need to tell MCMCTempering how to extract the parameters, and potentially the log-probabilities, from a `AdvancedMH.Transition`: -```@docs +```@docs; canonical=false MCMCTempering.getparams_and_logprob ``` And similarly, we need a way to _update_ the parameters and the log-probabilities of a `AdvancedMH.Transition`: -```@docs +```@docs; canonical=false MCMCTempering.setparams_and_logprob!! ``` @@ -159,10 +159,7 @@ using AdvancedHMC: AdvancedHMC using ForwardDiff: ForwardDiff # for automatic differentation of the logdensity # Creation of the sampler. -metric = AdvancedHMC.DiagEuclideanMetric(1) -integrator = AdvancedHMC.Leapfrog(0.1) -proposal = AdvancedHMC.StaticTrajectory(integrator, 8) -sampler = AdvancedHMC.HMCSampler(proposal, metric) +sampler = AdvancedHMC.HMC(0.1, 8) sampler_tempered = MCMCTempering.TemperedSampler(sampler, inverse_temperatures) # Sample! @@ -172,6 +169,7 @@ chain = sample( target_model, sampler, num_iterations; chain_type=MCMCChains.Chains, param_names=["x"], + n_adapts=0, # HACK: need this to make AdvancedHMC.jl happy :/ ) plot(chain, size=figsize) ``` @@ -205,7 +203,8 @@ chain_tempered_all = sample( StableRNG(42), target_model, sampler_tempered, num_iterations; chain_type=Vector{MCMCChains.Chains}, - param_names=["x"] + param_names=["x"], + n_adapts=0, # HACK: need this to make AdvancedHMC.jl happy :/ ); ``` @@ -289,7 +288,8 @@ chain_tempered_all = sample( StableRNG(42), target_model, sampler_tempered, num_iterations; chain_type=Vector{MCMCChains.Chains}, - param_names=["x"] + param_names=["x"], + n_adapts=0, # HACK: need this to make AdvancedHMC.jl happy :/ ); ``` diff --git a/src/samplers/multi.jl b/src/samplers/multi.jl index d8edaf4..1316e67 100644 --- a/src/samplers/multi.jl +++ b/src/samplers/multi.jl @@ -149,16 +149,16 @@ function AbstractMCMC.step( rng::Random.AbstractRNG, model::MultiModel, sampler::MultiSampler; - init_params=nothing, + initial_params=nothing, kwargs... ) @assert length(model.models) == length(sampler.samplers) "Number of models and samplers must be equal" - # TODO: Handle `init_params` properly. Make sure that they respect the container-types used in + # TODO: Handle `initial_params` properly. Make sure that they respect the container-types used in # `MultiModel` and `MultiSampler`. - init_params_multi = initparams(model, init_params) - transition_and_states = asyncmap(model.models, sampler.samplers, init_params_multi) do model, sampler, init_params - AbstractMCMC.step(rng, model, sampler; init_params, kwargs...) + initial_params_multi = initparams(model, initial_params) + transition_and_states = asyncmap(model.models, sampler.samplers, initial_params_multi) do model, sampler, initial_params + AbstractMCMC.step(rng, model, sampler; initial_params, kwargs...) end return MultipleTransitions(map(first, transition_and_states)), MultipleStates(map(last, transition_and_states)) diff --git a/test/Project.toml b/test/Project.toml index 492577c..9178d47 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -19,13 +19,13 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] -AbstractMCMC = "3.2, 4" -AdvancedHMC = "0.4" -AdvancedMH = "0.7" -Bijectors = "0.10" +AbstractMCMC = "5" +AdvancedHMC = "0.6" +AdvancedMH = "0.8" +Bijectors = "0.13" Distributions = "0.24, 0.25" LogDensityProblems = "2" LogDensityProblemsAD = "1" MCMCChains = "6" -Turing = "0.24" +Turing = "0.34" julia = "1" diff --git a/test/abstractmcmc.jl b/test/abstractmcmc.jl index 7c0de14..24dcf35 100644 --- a/test/abstractmcmc.jl +++ b/test/abstractmcmc.jl @@ -168,7 +168,7 @@ @testset "SwapSampler" begin # SwapSampler without tempering (i.e. in a composition and using `MultiModel`, etc.) - init_params = [[5.0], [5.0]] + initial_params = [[5.0], [5.0]] mdl1 = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), DistributionLogDensity(Normal(4.9999, 1))) mdl2 = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), DistributionLogDensity(Normal(5.0001, 1))) spl1 = RWMH(MvNormal(Zeros(dimension(mdl1)), I)) @@ -181,7 +181,7 @@ # Sample! rng = Random.default_rng() - transition, state = AbstractMCMC.step(rng, product_model, spl_full; init_params) + transition, state = AbstractMCMC.step(rng, product_model, spl_full; initial_params) transitions = typeof(transition)[] # A bit of warm-up. @@ -222,7 +222,10 @@ # Check that means are roughly okay. params_resolved = map(first ∘ MCMCTempering.getparams, transitions_resolved_inner) - @test vec(mean(params_resolved; dims=2)) ≈ [5.0, 5.0] atol = 0.3 + mean_tmp = vec(median(params_resolved; dims=2)) + for i in 1:2 + @test mean_tmp[i] ≈ 5.0 atol = 0.5 + end # A composition of `SwapSampler` and `MultiSampler` has special `AbstractMCMC.bundle_samples`. @testset "bundle_samples with Vector" begin diff --git a/test/compat.jl b/test/compat.jl index 7052bb1..0f670d7 100644 --- a/test/compat.jl +++ b/test/compat.jl @@ -10,7 +10,11 @@ MCMCTempering.getparams_and_logprob(transition::AdvancedMH.GradientTransition) = function MCMCTempering.setparams_and_logprob!!(model, transition::AdvancedMH.GradientTransition, params, lp) # NOTE: We have to re-compute the gradient here because this will be used in the subsequent `step` for # the MALA sampler. - return AdvancedMH.GradientTransition(params, AdvancedMH.logdensity_and_gradient(model, params)...) + return AdvancedMH.GradientTransition( + params, + AdvancedMH.logdensity_and_gradient(model, params)..., + transition.accepted + ) end # AdvancedHMC.jl diff --git a/test/runtests.jl b/test/runtests.jl index 160cabf..0533fa8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,7 +24,7 @@ Several properties of the tempered sampler are tested before returning: - `adapt_atol`: The absolute tolerance for the check of average swap acceptance rate and target swap acceptance rate. Defaults to `0.05`. - `mean_swap_rate_bound`: A bound on the acceptance rate of swaps performed, e.g. if set to `0.1` and `compare_mean_swap_rate=≥` then at least 10% of attempted swaps should be accepted. Defaults to `0.1`. - `compare_mean_swap_rate`: a binary function for comparing average swap rate against `mean_swap_rate_bound`. Defaults to `≥`. -- `init_params`: The initial parameters to use for the sampler. Defaults to `nothing`. +- `initial_params`: The initial parameters to use for the sampler. Defaults to `nothing`. - `param_names`: The names of the parameters in the chain; used to construct the resulting chain. Defaults to `missing`. - `progress`: Whether to show a progress bar. Defaults to `false`. """ @@ -41,10 +41,12 @@ function test_and_sample_model( adapt_target=0.234, adapt_rtol=0.1, adapt_atol=0.05, - init_params=nothing, + initial_params=nothing, param_names=missing, progress=false, - minimum_roundtrips=nothing + minimum_roundtrips=nothing, + rng=make_rng(), + kwargs... ) # Make the tempered sampler. sampler_tempered = tempered( @@ -64,8 +66,9 @@ function test_and_sample_model( # Sample. samples_tempered = AbstractMCMC.sample( - model, sampler_tempered, num_iterations; - callback=callback, progress=progress, init_params=init_params + rng, model, sampler_tempered, num_iterations; + callback=callback, progress=progress, initial_params=initial_params, + kwargs... ) if !isnothing(minimum_roundtrips) @@ -245,7 +248,7 @@ end [1.0, 1e-3], # extreme temperatures -> don't exect much swapping to occur num_iterations=num_iterations, adapt=false, - init_params=[[0.0], [1000.0]], # initialized far apart + initial_params=[[0.0], [1000.0]], # initialized far apart # At MOST 1% of swaps should be successful. mean_swap_rate_bound=0.01, compare_mean_swap_rate=≤, @@ -302,7 +305,7 @@ end # Set up our sampler with a joint multivariate Normal proposal. sampler = RWMH(MvNormal(zeros(d), Diagonal(σ_true .^ 2))) # Sample for the non-tempered model for comparison. - samples = AbstractMCMC.sample(model, sampler, num_iterations; progress=false) + samples = AbstractMCMC.sample(make_rng(), model, sampler, num_iterations; progress=false) chain = AbstractMCMC.bundle_samples( samples, MCMCTempering.maybe_wrap_model(model), sampler, samples[1], MCMCChains.Chains ) @@ -326,19 +329,41 @@ end adapt=false, # Make sure we have _some_ roundtrips. minimum_roundtrips=10, + rng=make_rng(), ) - compare_chains(chain, chain_tempered, rtol=0.1, compare_ess=true) + # Some swap strategies are not great. + ess_slack_ratio = if swap_strategy isa Union{MCMCTempering.SingleRandomSwap,MCMCTempering.SingleSwap} + 0.25 + else + 0.5 + end + compare_chains(chain, chain_tempered, rtol=0.1, compare_ess=true, compare_ess_slack=ess_slack_ratio) end end @testset "Turing.jl" begin + # Let's make a default seed we can `deepcopy` throughout to get reproducible results. + seed = 42 + + # And let's set the seed explicitly for reproducibility. + Random.seed!(seed) + # Instantiate model. - model_dppl = DynamicPPL.TestUtils.demo_assume_dot_observe() + DynamicPPL.@model function demo_model(x) + s ~ Exponential() + m ~ Normal(0, s) + x .~ Normal(m, s) + end + xs_true = rand(Normal(2, 1), 100) + model_dppl = demo_model(xs_true) + + # Move to unconstrained space. + vi = DynamicPPL.VarInfo(model_dppl) # Move to unconstrained space. - vi = DynamicPPL.link!!(DynamicPPL.VarInfo(model_dppl), model_dppl) + vi = DynamicPPL.link!!(vi, model_dppl) # Get some initial values in unconstrained space. - init_params = copy(vi[:]) + initial_params = rand(length(vi[:])) # Get the parameter names. param_names = map(Symbol, DynamicPPL.TestUtils.varnames(model_dppl)) # Get bijector so we can get back to unconstrained space afterwards. @@ -349,9 +374,9 @@ end @testset "Tempering of models" begin beta = 0.5 model_tempered = MCMCTempering.make_tempered_model(model, beta) - @test logdensity(model_tempered, init_params) ≈ beta * logdensity(model, init_params) - @test last(logdensity_and_gradient(model_tempered, init_params)) ≈ - beta .* last(logdensity_and_gradient(model, init_params)) + @test logdensity(model_tempered, initial_params) ≈ beta * logdensity(model, initial_params) + @test last(logdensity_and_gradient(model_tempered, initial_params)) ≈ + beta .* last(logdensity_and_gradient(model, initial_params)) end @testset "AdvancedHMC.jl" begin @@ -360,12 +385,19 @@ end # Set up HMC smpler. initial_ϵ = 0.1 integrator = AdvancedHMC.Leapfrog(initial_ϵ) - proposal = AdvancedHMC.NUTS{AdvancedHMC.MultinomialTS, AdvancedHMC.GeneralisedNoUTurn}(integrator) + proposal = AdvancedHMC.HMCKernel(AdvancedHMC.Trajectory{AdvancedHMC.MultinomialTS}( + integrator, AdvancedHMC.GeneralisedNoUTurn() + )) metric = AdvancedHMC.DiagEuclideanMetric(LogDensityProblems.dimension(model)) - sampler_hmc = AdvancedHMC.HMCSampler(proposal, metric) + sampler_hmc = AdvancedHMC.HMCSampler(proposal, metric, AdvancedHMC.NoAdaptation()) # Sample using HMC. - samples_hmc = sample(model, sampler_hmc, num_iterations; init_params=copy(init_params), progress=false) + samples_hmc = sample( + make_rng(seed), model, sampler_hmc, num_iterations; + n_adapts=0, # FIXME(torfjelde): Remove once AHMC.jl has fixed. + initial_params=copy(initial_params), + progress=false + ) chain_hmc = AbstractMCMC.bundle_samples( samples_hmc, MCMCTempering.maybe_wrap_model(model), sampler_hmc, samples_hmc[1], MCMCChains.Chains; param_names=param_names, @@ -375,8 +407,9 @@ end # Make sure that we get the "same" result when only using the inverse temperature 1. sampler_tempered = MCMCTempering.TemperedSampler(sampler_hmc, [1]) chain_tempered = sample( - model, sampler_tempered, num_iterations; - init_params=copy(init_params), + make_rng(seed), model, sampler_tempered, num_iterations; + n_adapts=0, # FIXME(torfjelde): Remove once AHMC.jl has fixed. + initial_params=copy(initial_params), chain_type=MCMCChains.Chains, param_names=param_names, progress=false, @@ -398,9 +431,11 @@ end num_iterations=num_iterations, adapt=false, mean_swap_rate_bound=0.1, - init_params=copy(init_params), + initial_params=copy(initial_params), param_names=param_names, - progress=false + progress=false, + n_adapts=0, # FIXME(torfjelde): Remove once AHMC.jl has fixed. + rng=make_rng(seed), ) map_parameters!(b, chain_tempered) compare_chains( @@ -421,8 +456,8 @@ end # Sample using MALA. chain_mh = AbstractMCMC.sample( - model, sampler_mh, num_iterations; - init_params=copy(init_params), + make_rng(), model, sampler_mh, num_iterations; + initial_params=copy(initial_params), progress=false, chain_type=MCMCChains.Chains, param_names=param_names, @@ -432,8 +467,8 @@ end # Make sure that we get the "same" result when only using the inverse temperature 1. sampler_tempered = MCMCTempering.TemperedSampler(sampler_mh, [1]) chain_tempered = sample( - model, sampler_tempered, num_iterations; - init_params=copy(init_params), + make_rng(), model, sampler_tempered, num_iterations; + initial_params=copy(initial_params), chain_type=MCMCChains.Chains, param_names=param_names, progress=false, @@ -455,8 +490,9 @@ end num_iterations=num_iterations, adapt=false, mean_swap_rate_bound=0.1, - init_params=copy(init_params), - param_names=param_names + initial_params=copy(initial_params), + param_names=param_names, + rng=make_rng(), ) map_parameters!(b, chain_tempered) diff --git a/test/simple_gaussian.jl b/test/simple_gaussian.jl index 1ff19b8..2c49527 100644 --- a/test/simple_gaussian.jl +++ b/test/simple_gaussian.jl @@ -8,7 +8,7 @@ tempered_dists = [MvNormal(Zeros(1), I / β) for β in inverse_temperatures] tempered_multimodel = MCMCTempering.MultiModel(map(LogDensityModel ∘ DistributionLogDensity, tempered_dists)) - init_params = zeros(length(μ)) + initial_params = zeros(length(μ)) num_samples = 1_000 num_burnin = num_samples ÷ 2 @@ -23,52 +23,76 @@ # Sample. @testset "TemperedSampler" begin chains_product = sample( - DistributionLogDensity(tempered_dists[1]), rwmh_tempered, num_samples; - init_params, + make_rng(), DistributionLogDensity(tempered_dists[1]), rwmh_tempered, num_samples; + initial_params, bundle_resolve_swaps=true, chain_type=Vector{MCMCChains.Chains}, progress=false, discard_initial=num_burnin, thinning=thin, ) - test_chains_with_monotonic_variance(chains_product, Zeros(length(chains_product)), std_true_dict) + test_chains_with_monotonic_variance( + chains_product, + Zeros(length(chains_product)), + std_true_dict, + min_atol=2e-1, + max_atol=5e-1 + ) end @testset "MultiSampler without swapping" begin chains_product = sample( - tempered_multimodel, rwmh_product, num_samples; - init_params, + make_rng(), tempered_multimodel, rwmh_product, num_samples; + initial_params, chain_type=Vector{MCMCChains.Chains}, progress=false, discard_initial=num_burnin, thinning=thin, ) - test_chains_with_monotonic_variance(chains_product, Zeros(length(chains_product)), std_true_dict) + test_chains_with_monotonic_variance( + chains_product, + Zeros(length(chains_product)), + std_true_dict, + min_atol=2e-1, + max_atol=5e-1 + ) end @testset "MultiSampler with swapping (saveall=true)" begin chains_product = sample( - tempered_multimodel, rwmh_product_with_swap, num_samples; - init_params, + make_rng(), tempered_multimodel, rwmh_product_with_swap, num_samples; + initial_params, bundle_resolve_swaps=true, chain_type=Vector{MCMCChains.Chains}, progress=false, discard_initial=num_burnin, thinning=thin, ) - test_chains_with_monotonic_variance(chains_product, Zeros(length(chains_product)), std_true_dict) + test_chains_with_monotonic_variance( + chains_product, + Zeros(length(chains_product)), + std_true_dict, + min_atol=2e-1, + max_atol=5e-1 + ) end @testset "MultiSampler with swapping (saveall=true)" begin chains_product = sample( - tempered_multimodel, Setfield.@set(rwmh_product_with_swap.saveall = Val(false)), num_samples; - init_params, + make_rng(), tempered_multimodel, Setfield.@set(rwmh_product_with_swap.saveall = Val(false)), num_samples; + initial_params, chain_type=Vector{MCMCChains.Chains}, progress=false, discard_initial=num_burnin, thinning=thin, ) - test_chains_with_monotonic_variance(chains_product, Zeros(length(chains_product)), std_true_dict) + test_chains_with_monotonic_variance( + chains_product, + Zeros(length(chains_product)), + std_true_dict, + min_atol=3e-1, + max_atol=5e-1 + ) end end diff --git a/test/test_utils.jl b/test/test_utils.jl index d343a8e..d13319a 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -16,18 +16,21 @@ function to_dict(c::MCMCChains.ChainDataFrame, col::Symbol) end """ - atol_for_chain(chain; significance=1e-3, kind=Statistics.mean) + atol_for_chain(chain; significance=0.05, kind=Statistics.mean, min_atol=Inf, max_atol=0) Return a dictionary of absolute tolerances for each parameter in `chain`, computed as the confidence interval width for the mean of the parameter with `significance`. """ -function atol_for_chain(chain; significance=1e-3, kind=Statistics.mean) +function atol_for_chain(chain; significance=0.05, kind=Statistics.mean, min_atol=0, max_atol=Inf) param_names = names(chain, :parameters) # Can reject H0 if, say, `abs(mean(chain2) - mean(chain1)) > confidence_width`. # Or alternatively, compare means but with `atol` set to the `confidence_width`. # NOTE: Failure to reject, i.e. passing the tests, does not imply that the means are equal. mcse = to_dict(MCMCChains.mcse(chain; kind), :mcse) - return Dict(sym => quantile(Normal(0, mcse[sym]), 1 - significance/2) for sym in param_names) + return Dict( + sym => max(min_atol, min(max_atol, quantile(Normal(0, mcse[sym]), 1 - significance/2))) + for sym in param_names + ) end thin_to(chain, n) = chain[1:length(chain) ÷ n:end] @@ -48,6 +51,7 @@ end function test_means(chain::MCMCChains.Chains, mean_true::AbstractDict; n=length(chain), kwargs...) chain = thin_to(chain, n) atol = atol_for_chain(chain; kwargs...) + @debug "mean" [(mean(chain[sym]), atol[sym]) for sym in names(chain, :parameters)] @test all(isapprox(mean(chain[sym]), 0, atol=atol[sym]) for sym in names(chain, :parameters)) end @@ -67,7 +71,7 @@ end function test_std(chain::MCMCChains.Chains, std_true::AbstractDict; n=length(chain), kwargs...) chain = thin_to(chain, n) atol = atol_for_chain(chain; kind=Statistics.std, kwargs...) - @info "std" [(std(chain[sym]), std_true[sym], atol[sym]) for sym in names(chain, :parameters)] + @debug "std" [(std(chain[sym]), std_true[sym], atol[sym]) for sym in names(chain, :parameters)] @test all(isapprox(std(chain[sym]), std_true[sym], atol=atol[sym]) for sym in names(chain, :parameters)) end @@ -117,10 +121,20 @@ and `std_true`, respectively. Also test that the standard deviation is monotonic - `significance`: The significance level of the test. - `kwargs...`: Passed to `atol_for_chain`. """ -function test_chains_with_monotonic_variance(chains, mean_true, std_true; significance=1e-4, kwargs...) +function test_chains_with_monotonic_variance(chains, mean_true, std_true; significance=0.05, kwargs...) @testset "chain $i" for i = 1:length(chains) test_means(chains[i], mean_true[i]; kwargs...) test_std(chains[i], std_true[i]; kwargs...) end - test_std_monotonicity(chains; significance=0.05) + test_std_monotonicity(chains; significance=significance) end + +""" + make_rng([seed]) + +Create a random number generator. + +# Arguments +- `seed`: The seed for the random number generator. Default is `42`. +""" +make_rng(seed=42) = Random.Xoshiro(seed)