From 98a10418d0b43a9b16ddb3f1a265cd844edbbf23 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Thu, 31 Oct 2024 13:24:35 +0000 Subject: [PATCH] Add `getparams` and `setparams!!` following `AbstractMCMC` v5.5 and v5.6 (#103) * add `getparams` and `setparams!!` * add BangBang as dep, add functions for `GradientTransition` * increase atol * Update test/runtests.jl Co-authored-by: Penelope Yong * update functions with `model` arguments * fix test errors * remove BangBang dep --------- Co-authored-by: Penelope Yong --- Project.toml | 8 ++++---- src/AdvancedMH.jl | 14 ++++++++++++++ src/MALA.jl | 18 ++++++++++++++++-- test/runtests.jl | 22 +++++++++++++++++++++- 4 files changed, 55 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index bb95eea..f89912d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedMH" uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" -version = "0.8.3" +version = "0.8.4" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -23,18 +23,18 @@ AdvancedMHMCMCChainsExt = "MCMCChains" AdvancedMHStructArraysExt = "StructArrays" [compat] -AbstractMCMC = "5" +AbstractMCMC = "5.6" DiffResults = "1" Distributions = "0.25" FillArrays = "1" ForwardDiff = "0.10" +LinearAlgebra = "1.6" LogDensityProblems = "2" MCMCChains = "6.0.4" +Random = "1.6" Requires = "1" StructArrays = "0.6" julia = "1.6" -LinearAlgebra = "1.6" -Random = "1.6" [extras] DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" diff --git a/src/AdvancedMH.jl b/src/AdvancedMH.jl index 0d87524..7ff9759 100644 --- a/src/AdvancedMH.jl +++ b/src/AdvancedMH.jl @@ -140,6 +140,20 @@ function __init__() end end +# AbstractMCMC.jl interface +function AbstractMCMC.getparams(t::Transition) + return t.params +end + +# TODO (sunxd): remove `DensityModel` in favor of `AbstractMCMC.LogDensityModel` +function AbstractMCMC.setparams!!(model::DensityModelOrLogDensityModel, t::Transition, params) + return Transition( + params, + logdensity(model, params), + t.accepted + ) +end + # Include inference methods. include("proposal.jl") include("mh-core.jl") diff --git a/src/MALA.jl b/src/MALA.jl index 7c260c7..527f338 100644 --- a/src/MALA.jl +++ b/src/MALA.jl @@ -11,7 +11,7 @@ MALA(d::RandomWalkProposal) = MALA{typeof(d)}(d) MALA(d) = MALA(RandomWalkProposal(d)) -struct GradientTransition{T<:Union{Vector, Real, NamedTuple}, L<:Real, G<:Union{Vector, Real, NamedTuple}} <: AbstractTransition +struct GradientTransition{T<:Union{Vector,Real,NamedTuple},L<:Real,G<:Union{Vector,Real,NamedTuple}} <: AbstractTransition params::T lp::L gradient::G @@ -20,6 +20,20 @@ end logdensity(model::DensityModelOrLogDensityModel, t::GradientTransition) = t.lp +function AbstractMCMC.getparams(t::GradientTransition) + return t.params +end + +function AbstractMCMC.setparams!!(model::DensityModelOrLogDensityModel, t::GradientTransition, params) + lp, gradient = logdensity_and_gradient(model, params) + return GradientTransition( + params, + lp, + gradient, + t.accepted + ) +end + propose(::Random.AbstractRNG, ::MALA, ::DensityModelOrLogDensityModel) = error("please specify initial parameters") function transition(sampler::MALA, model::DensityModelOrLogDensityModel, params, accepted) return GradientTransition(params, logdensity_and_gradient(model, params)..., accepted) @@ -88,6 +102,6 @@ logdensity_and_gradient(::DensityModelOrLogDensityModel, ::Any) function logdensity_and_gradient(model::AbstractMCMC.LogDensityModel, params) check_capabilities(model) return LogDensityProblems.logdensity_and_gradient(model.logdensity, params) - end +end diff --git a/test/runtests.jl b/test/runtests.jl index 8c4f4fa..fd06296 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,5 @@ using AdvancedMH +using AbstractMCMC using DiffResults using Distributions using ForwardDiff @@ -33,6 +34,25 @@ include("util.jl") LogDensityProblems.logdensity(::typeof(density), θ) = density(θ) LogDensityProblems.dimension(::typeof(density)) = 2 + @testset "getparams/setparams!! (AbstractMCMC interface)" begin + t1, _ = AbstractMCMC.step(Random.default_rng(), model, StaticMH([Normal(0, 1), Normal(0, 1)])) + t2, _ = AbstractMCMC.step(Random.default_rng(), model, MALA(x -> MvNormal(x, I)); initial_params=ones(2)) + for t in [t1, t2] + @test AbstractMCMC.getparams(model, t) == t.params + + new_transition = AbstractMCMC.setparams!!(model, t, AbstractMCMC.getparams(model, t)) + @test new_transition.lp == t.lp + @test new_transition.accepted == t.accepted + @test new_transition.params == t.params + if hasfield(typeof(t), :gradient) + @test new_transition.gradient == t.gradient + end + + t_replaced = AbstractMCMC.setparams!!(model, t, [1.0, 2.0]) + @test t_replaced.params == [1.0, 2.0] + end + end + @testset "StaticMH" begin # Set up our sampler with initial parameters. spl1 = StaticMH([Normal(0,1), Normal(0, 1)]) @@ -69,7 +89,7 @@ include("util.jl") @test mean(chain1.σ) ≈ 1.0 atol=0.1 @test mean(chain2.μ) ≈ 0.0 atol=0.1 @test mean(chain2.σ) ≈ 1.0 atol=0.1 - @test mean(chain3.μ) ≈ 0.0 atol=0.1 + @test mean(chain3.μ) ≈ 0.0 atol=0.15 @test mean(chain3.σ) ≈ 1.0 atol=0.1 end