-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improving interaction between different implementations of the interface #85
Comments
@devmotion @cpfiffer @yebai thoughts? |
Yep, I 100% love both of these -- I think it's a bit of a shortcoming in our interface methods that we push the parameter stuff to the side. This seems minimal and unintrusive enough to be a good fit for AbstractMCMC. |
We already have implemented some similar functions for working with transitions in Turing, eg. |
I am less sure about Generally, I think it would be helpful to be more honest about the supported model types (possibly reusing model types such as the discussed |
Awesome!
I 100% agree with this, but we had some issues reaching a consensus the last time we discussed this, but maybe we can now 👍 How about this direction (I'm trying to do something similar to what I think you proposed before @devmotion ): struct DensityModel{F} <: AbstractModel
logdensity::F
end
logdensity(model::DensityModel, args...) = model.logdensity(args...)
"""
Differentiable{N}
Represents N-th order differentiability.
"""
struct Differentiable{N} end
const NonDifferentiable = Differentiable{0}
const FirstOrderDifferentiable = Differentiable{1}
const SecondOrderDifferentiable = Differentiable{2}
function Base.:+(::Differentiable{N1}, ::Differentiable{N2}) where {N1,N2}
return Differentiable{min(N1,N2)}()
end
"""
differentiable(model)
Return an instance of `Differentiable{N}`, where `N` represents the order.
"""
differentiable(model::DensityModel) = differentiable(model.logdensity)
"""
PosteriorModel
Represents a model which can be decomposed into a prior and a likelihood.
"""
struct PosteriorModel{P1,P2} <: AbstractModel
logprior::P1
loglikelihood::P2
end
logprior(model::PosteriorModel, args...) = model.logprior(args...)
loglikelihood(model::PosteriorModel, args...) = model.loglikelihood(args...)
logdensity(model::PosteriorModel, args...) = logprior(model, args...) + loglikelihood(model, args...)
function differentiable(model::PosteriorModel)
return differentiable(model.logprior) + differentiable(model.loglikelihood)
end ? Then we can also add (but in a different package; maybe Bijectors.jl itself or Turing.jl): struct TransformedModel{M,B} <: AbstractMCMC.AbstractModel
model::M
transform::B
end
function AbstractMCMC.logdensity(tmodel::TransformedModel, y)
x, logjac = forward(tmodel.transform, y)
return AbstractMCMC.logdensity(tmodel.model, x) + logjac
end
function AbstractMCMC.differentiable(tmodel::TransformedModel)
return AbstractMCMC.differentiable(tmodel.model)
end And then things would "just work". We might also want some of the following methods (though implementations should go somewhere else):
|
I LOVE this sketch. It's super minimal, and I think it's flexible enough to meet a bunch of downstream needs. I'm happy with putting it in AbstractMCMC since it touches so few things. |
Thinking about this again due to work on Gibbs sampler for Turing.jl (TuringLang/Turing.jl#2099) I think we need the following from a
In short, we need something like: # Needs to be implemented on a case-by-case basis.
function params_and_logprob(sampler, state)
# TODO: implement
end
function set_params_and_logprob!!(sampler, state, params, logprob)
# TODO: implement
end
# Default get and set.
params(sampler, state) = first(params_and_logprob(sampler, state))
logprob(sampler, state) = last(params_and_logprob(sampler, state))
function setparams!!(sampler, state, params)
return set_params_and_logprob!!(sampler, state, params, logprob(sampler, state))
end
function setlogprob!!(sampler, state, logprob)
return set_params_and_logprob!!(sampler, state, params(sampler, state), logprob)
end
# Default implementation.
function state_from(model_dst, sampler_dst, sampler_src, state_dst, state_src)
# Extract parameters and logprob from the source sampler.
params_src, lp_src = getparams_and_logprob(sampler_src, state_src)
# Set the parameters and logprob in the destination sampler.
return setparams_and_logprob!!(state_dst, params_src, lp_src)
end
function state_from_with_recompute_logprob(model_dst, sampler_dst, sampler_src, state_dst, state_src)
# Extract parameters from the source sampler.
params_src = getparams(sampler_src, state_src)
# Set the parameters and logprob in the destination sampler.
state_dst = setparams!!(state_dst, params_src)
# Re-evaluate the log density of the destination model.
return recompute_logprob!!(model_dst, sampler_dst, state_dst)
end
# Default implementation.
function recompute_logprob!!(model::AbstractMCMC.LogDensityModel, sampler, state)
# Extract parameters and logprob from the source sampler.
params = getparams(sampler, state)
lp = LogDensityProblems.logdensity(model.logdensity, params)
return setlogprob!!(state, lp)
end For example, if we want compositions of samplers, we can do that as: function composition_step(
rng::Random.AbstractRNG,
model_outer,
model_inner,
sampler_outer,
sampler_inner,
state_outer,
state_inner;
kwargs...
)
# Take a step with the inner model.
transition_inner, state_inner = AbstractMCMC.step(
rng,
model_inner,
sampler_inner,
state_inner;
kwargs...
)
# Update the outer state from the inner state.
state_outer = if composition_requires_recompute_logprob(model_dst, sampler_dst, sampler_src, state_dst, state_src)
state_from_with_recompute_logprob(model_outer, sampler_outer, sampler_inner, state_outer, state_inner)
else
state_from(model_outer, sampler_outer, sampler_inner, state_outer, state_inner)
end
# Take a step with the outer sampler.
transition_outer, state_outer = AbstractMCMC.step(
rng,
model_outer,
sampler_outer,
state_outer;
kwargs...
)
return (transition_inner, transition_outer), (state_inner, state_outer)
end Another example is Gibbs sampling, though this only requires function gibbs_step(
rng::Random.AbstractRNG,
model_dst,
sampler_dst,
sampler_src,
state_dst,
state_src;
kwargs...
)
# `model_dst` might be different here, e.g. conditioned on new values, so we need to check if need to recompute the log-probability.
if gibbs_requires_recompute_logprob(model_dst, sampler_dst, sampler_src, state_dst, state_src)
# Re-evaluate the log density of the destination model.
state_dst = recompute_logprob!!(model_dst, sampler_dst, state_dst, logprob_dst)
end
# Step!
return AbstractMCMC.step(rng, model_dst, sampler_dst, state_dst; kwargs...)
end EDIT: Currently giving this |
How do you feel about adding something like:
to make it easier for samplers to interact across packages? Then you just need to implement
state_from_transition
for the different types to get cross-package compat.Another issue is also the
model
argument which often is specific to a particular sampler implementation. We've previous spoken about generalizing this so that we don't have a bunch of these lying around, e.g.AdvancedMH.DensityModel
andAdvancedHMC.DifferentiableDensityModel
, but we should also maybe add a functiongetmodel(model, sampler, state)
or something too, which isidentity
by default but allows one to provide a model-type which encodes a bunch of different models for specific samplers, e.g. in the case of aMixtureSampler
you might have aMixtureState
which, among other things, holds the current sampler-index, and aManyModels
which simply wraps a collection of models corresponding to each of the components/samplers in theMixtureSampler
:I've been running into quite a few scenarios recently where I'd love to have something like this, e.g. wanting to implement
MixtureSampler
andCompositionSampler
.The text was updated successfully, but these errors were encountered: