Skip to content

Commit

Permalink
Added extensions for AdvancedHMC.jl and AdvancedMH.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Oct 7, 2024
1 parent 600b82c commit e332060
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 14 deletions.
14 changes: 14 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,18 @@ ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"

[weakdeps]
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"

[extensions]
MCMCTemperingAdvancedHMCExt = ["AdvancedHMC"]
MCMCTemperingAdvancedMHExt = ["AdvancedMH"]

[compat]
AbstractMCMC = "5"
AdvancedHMC = "0.6"
AdvancedMH = "0.8"
ConcreteStructs = "0.2"
Distributions = "0.24, 0.25"
DocStringExtensions = "0.8, 0.9"
Expand All @@ -27,3 +37,7 @@ MCMCChains = "5, 6"
ProgressLogging = "0.1"
Setfield = "0.7, 0.8, 1"
julia = "1"

[extras]
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
19 changes: 19 additions & 0 deletions ext/MCMCTemperingAdvancedHMCExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module MCMCTemperingAdvancedHMCExt

using MCMCTempering: MCMCTempering, Setfield
using AdvancedHMC: AdvancedHMC

MCMCTempering.getparams_and_logprob(t::AdvancedHMC.Transition) = t.z.θ, t.z.ℓπ.value
MCMCTempering.getparams_and_logprob(state::AdvancedHMC.HMCState) = MCMCTempering.getparams_and_logprob(state.transition)

# TODO: Implement `state_from` instead, to avoid re-computation of gradients if possible.
function MCMCTempering.setparams_and_logprob!!(model, state::AdvancedHMC.HMCState, params, lp)
# NOTE: Need to recompute the gradient because it might be used in the next integration step.
hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model)
return Setfield.@set state.transition.z = AdvancedHMC.phasepoint(
hamiltonian, params, state.transition.z.r;
ℓκ=state.transition.z.ℓκ
)
end

end
18 changes: 5 additions & 13 deletions test/compat.jl → ext/MCMCTemperingAdvancedMHExt.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# AdvancedMH.jl
module MCMCTemperingAdvancedMHExt

using MCMCTempering: MCMCTempering, Setfield
using AdvancedMH: AdvancedMH

MCMCTempering.getparams_and_logprob(transition::AdvancedMH.Transition) = transition.params, transition.lp
function MCMCTempering.setparams_and_logprob!!(transition::AdvancedMH.Transition, params, lp)
Setfield.@set! transition.params = params
Expand All @@ -17,16 +21,4 @@ function MCMCTempering.setparams_and_logprob!!(model, transition::AdvancedMH.Gra
)
end

# AdvancedHMC.jl
MCMCTempering.getparams_and_logprob(t::AdvancedHMC.Transition) = t.z.θ, t.z.ℓπ.value
MCMCTempering.getparams_and_logprob(state::AdvancedHMC.HMCState) = MCMCTempering.getparams_and_logprob(state.transition)

# TODO: Implement `state_from` instead, to avoid re-computation of gradients if possible.
function MCMCTempering.setparams_and_logprob!!(model, state::AdvancedHMC.HMCState, params, lp)
# NOTE: Need to recompute the gradient because it might be used in the next integration step.
hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model)
return Setfield.@set state.transition.z = AdvancedHMC.phasepoint(
hamiltonian, params, state.transition.z.r;
ℓκ=state.transition.z.ℓκ
)
end
1 change: 0 additions & 1 deletion test/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,3 @@ using Turing: Turing, DynamicPPL

include("utils.jl")
include("test_utils.jl")
include("compat.jl")

0 comments on commit e332060

Please sign in to comment.