Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove hmc.jl and mh.jl in light of upstreamed "getparams" into AbstractMCMC #2367

Open
1 of 2 tasks
sunxd3 opened this issue Oct 14, 2024 · 1 comment
Open
1 of 2 tasks
Assignees

Comments

@sunxd3
Copy link
Member

sunxd3 commented Oct 14, 2024

After the merging of TuringLang/AbstractMCMC.jl#86. We can start removing HMC and MH samplers to using ExternalSampler interface.

This would involve following steps:

  • add getparams and setparams!! implementation in AdvacnedHMC and AdvancedMH
  • alias HMC(...; adtype) in Turing as externalsampler(AdvancedHMC.hmc(...); adtype)
@sunxd3
Copy link
Member Author

sunxd3 commented Nov 11, 2024

Some notes for completion of the second item

One can already use AdvancedHMC and AdvancedMH though the ExternalSampler interface (more details to be find at https://turinglang.org/docs/tutorials/docs-16-using-turing-external-samplers/)

julia> using Turing; using AdvancedHMC: AdvancedHMC; using ADTypes

julia> model = Turing.DynamicPPL.TestUtils.DEMO_MODELS[1]
DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_dot_assume_dot_observe), (:x, Symbol("##arg#269")), (), (), Tuple{Vector{Float64}, DynamicPPL.TypeWrap{Vector{Float64}}}, Tuple{}, DynamicPPL.DefaultContext}(DynamicPPL.TestUtils.demo_dot_assume_dot_observe, (x = [1.5, 2.0], var"##arg#269" = DynamicPPL.TypeWrap{Vector{Float64}}()), NamedTuple(), DynamicPPL.DefaultContext())

julia> spl = externalsampler(AdvancedHMC.HMC(0.2, 30); adtype=AutoForwardDiff())

Turing.Inference.ExternalSampler{AdvancedHMC.HMC{AdvancedHMC.Leapfrog{Float64}, Symbol}, AutoForwardDiff{nothing, Nothing}, true}(AdvancedHMC.HMC{AdvancedHMC.Leapfrog{Float64}, Symbol}(30, Leapfrog=0.2), :diagonal), AutoForwardDiff())

julia> sample(model, spl, 10)
Sampling 100%|██████████████████████████████████████████| Time: 
....

The mechanism is as follows: at the end of the step call with an external sampler, transition_to_turing and state_to_turing are called. These two functions then call getparams on the sampler state of the external samplers. getparams for AdvancedHMC.HMCState and AdvancedMH.Transition (AdvancedMH uses Transition as state) are defined in abstractmcmc.jl.

So we should be able to move the HMC (NUTS, etc.) and MH samplers that are current in Turing.Inference to ExternalSampler interface in a relatively painless way. (And currently we can only move these two).

That being said, we need to be particularly care about the differences in the interface of Turing's API and API of AdvancedHMC (and AdvancedMH). E.g. NUTS(Turing AdvancedHMC), i.e. Turing provides some default arguments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants