-
Notifications
You must be signed in to change notification settings - Fork 219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrating Turing and MarginalLogDensities #2398
Comments
This looks neat 👀 Can you elaborate a bit on what mathematical model the following line should correspond to? marginalmodel = marginalize(fullmodel, (:x,)) Would this marginalize out (And MarginalLogDensities.jl looks really nice btw. If people are interested in model comparison, this will probably be a place we'll point people:)) |
IIUC you effectively want But does this mean that you want to drop both the terms @model function MarginalizedExample()
a ~ SomeDist()
b ~ AnotherDist()
Turing.@addlogprob! MarginalLogDensity(...)
end ? Because that won't be possible with Turing.jl unfortuantely 😕 At least not in a completely automated fashion. |
Yes, your math is correct. If To do this using To do it the other way, starting with a Turing model, my hypothetical
All this should be basically possible, no? |
So I do believe this is achieveable in some models, but the problem is that I think it'll be difficult to verify that the user is actually doing the right thing 😕 Approach to implementing thisWhat I'm worried about is that both the target variable and the likelihood is absorbed into the @model function Example(y)
a ~ SomeDist()
b ~ AnotherDist()
mu = somefunction(a, b)
x ~ MvNormal(mu, 1.0)
y ~ MvNormal(x, 1.0)
end should become @model function MarginalizedExample()
a ~ SomeDist()
b ~ AnotherDist()
Turing.@addlogprob! MarginalLogDensity(...)
end Correct? However, this is not something we can do automatically in Turing.jl. In Turing.jl, your model always looks like the original definition, i.e. @model function Example(y)
a ~ SomeDist()
b ~ AnotherDist()
mu = somefunction(a, b)
x ~ MvNormal(mu, 1.0)
y ~ MvNormal(x, 1.0)
end but you alter the model by changing the behavior of each of the individual Hence, to "convert" x ~ MvNormal(mu, 1.0)
y ~ MvNormal(x, 1.0) into Turing.@addlogprob! MarginalLogDensity(...) automatically, we need overload We can technically do this. For example, we could just have Is this okay?However, the question is: how do we know whether this is an okay thing to do? For example, IIUC, the variable(s) being marginalised need to be a parent of the observation variable in the DAG. For example, in the above case, we cannot marginalise out only The problem is that Turing.jl doesn't have (currently) have access to the DAG of the model, and so performing such checks is not really possible atm 😕 |
I guess I'm not clear why any of the internal details of the model need apply here...if a Turing model can be turned into a black-box function Here's a concrete example of what I'm suggesting: using Turing
using Distributions
using MarginalLogDensities
import Zygote
a = 0.5
b = 3
c = 2
d = 3
s = collect(-5:0.5:5)
μ = a.*s.^2 .+ b.*s
x = rand(MvNormal(μ, c))
y = rand(MvNormal(x, d))
@model function demo(y, s)
a ~ Normal(0, 10)
b ~ Normal(0, 10)
logc ~ Normal()
c = exp(logc)
μ = a.*s.^2 .+ b.*s
x ~ MvNormal(μ, c)
y ~ MvNormal(x, 3.0)
end
m = demo(y, s)
ctx = Turing.Optimisation.OptimizationContext(Turing.DefaultContext())
old = Turing.Optimisation.OptimLogDensity(m, ctx)
# marginalize the `x` variables, indices 4:24
mld = MarginalLogDensity((u, p) -> -old(u), randn(24), 4:24, (), LaplaceApprox(adtype=AutoZygote()))
mld(randn(3)) |
Ah sorry, that probably wasn't clear from my end 😬 I was thinking of more convoluted scenarios, e.g. @model function demo()
s ~ InverseGamma(2, 3)
x ~ Normal(0, sqrt(s))
y ~ Normal(0, sqrt(s))
return (s, x)
end
model = demo() | (y = 1.0,) # model over `s` and `x`
model() # => (s = rand(InverseGamma(2, 3)), x = rand(Normal(0, sqrt(s)))
marginalized_model = marginalize(model, @varname(x)) # model is now only over `s`
marginalized_model() # => (s = rand(InverseGamma(2, 3)), x = argmax(logjoint(model, x))) where Does that make sense? It was a bit unclear to me how to do this + whether someone would actually want to do this.
But yes, this would be very easy to do:) using Turing, MarginalLogDensities, LogDensityProblems, Zygote
"""
marginalize(model::Model, varnames::Vector; method=LaplaceApprox())
Returns a `MarginalLogDensity` with `varnames` marginalized out from the `model`.
"""
function marginalize(model::DynamicPPL.Model, varnames::Vector; method=LaplaceApprox())
# Determine the indices for the variables to marginalise out.
varinfo = DynamicPPL.typed_varinfo(model)
varindices = DynamicPPL.getranges(varinfo, varnames)
# Construct the marginal log-density model.
# TODO(torfjelde): Should link and use optimization context to avoid inclusion jacobian corrections to the log-density.
f = Base.Fix1(LogDensityProblems.logdensity, DynamicPPL.LogDensityFunction(model))
mdl = MarginalLogDensity((u, p) -> f(u), varinfo[:], varindices, (), method)
return mdl
end
a = 0.5
b = 3
c = 2
d = 3
s = collect(-5:0.5:5)
μ = a.*s.^2 .+ b.*s
x = rand(MvNormal(μ, c))
y = rand(MvNormal(x, d))
@model function demo(y, s)
a ~ Normal(0, 10)
b ~ Normal(0, 10)
logc ~ Normal()
c = exp(logc)
μ = a.*s.^2 .+ b.*s
x ~ MvNormal(μ, c)
y ~ MvNormal(x, 3.0)
end
model = demo(y, s);
mdl = marginalize(model, [@varname(x)]; method=LaplaceApprox(adtype=AutoZygote()));
mdl(randn(3)) With the "bugfix" below, this currently works like a charm:) Though it should probably use the optimization context + linking, i.e. function marginalize(model::DynamicPPL.Model, varnames::Vector; method=LaplaceApprox())
# Determine the indices for the variables to marginalise out.
varinfo = DynamicPPL.typed_varinfo(model)
varindices = DynamicPPL.getranges(varinfo, varnames)
# Construct the marginal log-density model.
# Use linked `varinfo` to that we're working in unconstrained space and `OptimizationContext` to ensure
# that the log-abs-det jacobian terms are not included.
context = Turing.Optimisation.OptimizationContext(DynamicPPL.leafcontext(model.context))
varinfo_linked = DynamicPPL.link(varinfo, model)
f = Base.Fix1(LogDensityProblems.logdensity, DynamicPPL.LogDensityFunction(varinfo_linked, model, context))
# HACK: need the sign-flip here because `OptimizationContext` is a hacky impl :/
mdl = MarginalLogDensity((u, p) -> -f(u), varinfo_linked[:], varindices, (), method)
return mdl
end but that requires depending / an extension using Turing.jl instead of just DynamicPPL.jl (the package which defines the modeling syntax, etc.). At the moment, this requires a "bugfix" to DynamicPPL.jl (I'm using an internal function in a way that isn't quite meant to be used this way currently): # FIXME(torfjelde): This is an internal function that isn't CURRENTLY meant to be used in this way,
# but we can add the following missing def to make it work.
# It takes in a vector of variable names and returns the corresponding indices.
function DynamicPPL.getranges(varinfo::DynamicPPL.TypedVarInfo, vns::Vector{<:DynamicPPL.VarName})
# Here we need to keep track of the offset.
offset = 0
vns_all = DynamicPPL.keys(varinfo)
return mapreduce(vcat, vns_all; init=Int[]) do vn
# First we need to get the range so we can add it to the total offset.
r = DynamicPPL.getrange(varinfo, vn)
length_vn = length(r)
# Then we check if `vn` is one of the variables we're extracting the ranges for.
# TODO(torfjelde): Maybe use `findall` + `subsumes` instead?
index = findfirst(isequal(vn), vns)
return if index === nothing
# If none exist, we shift the offset and return an empty array.
offset += length_vn
Int[]
else
# Otherwise, we return the (offseted) range and update the offset.
r = r .+ offset
offset += length_vn
r
end
end
end |
Btw, it's quite unfortunate that we have to use Zygote.jl here to get 2nd order information 😕 It's really not well-suited for Turing.jl models. |
Oh, and if you wanted to use this for sampling, you could just add the following: #########################
## To sample from this ##
#########################
# 1. Add LogDensityProblems.jl interface to it and we can suddenly use samplers.
LogDensityProblems.logdensity(mdl::MarginalLogDensity, u) = mdl(u)
LogDensityProblems.dimension(mdl::MarginalLogDensity) = length(mdl.iv)
function LogDensityProblems.capabilities(mdl::MarginalLogDensity)
return LogDensityProblems.LogDensityOrder{0}()
end
# 2. Unfortunately, we have to use the sampler packages explicitly.
using AbstractMCMC, AdvancedMH
spl = AdvancedMH.RWMH(LogDensityProblems.dimension(mdl))
samples = sample(
mdl, spl, 1000;
chain_type=MCMCChains.Chains,
# HACK: this a dirty way to extract the variable names in a model; it won't work in general.
# But general methods exist, so we can fix that.
param_names=setdiff(keys(DynamicPPL.untyped_varinfo(model)), [@varname(x)])
) Results in julia> samples = sample(
mdl, spl, 1000;
chain_type=MCMCChains.Chains,
# HACK: this a dirty way to extract the variable names in a model.
param_names=setdiff(keys(DynamicPPL.untyped_varinfo(model)), [@varname(x)])
)
Sampling 100%|█████████████████████████████████████████████████████████████| Time: 0:00:03
Chains MCMC chain (1000×4×1 Array{Float64, 3}):
Iterations = 1:1:1000
Number of chains = 1
Samples per chain = 1000
parameters = a, b, logc
internals = lp
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Missing
a 0.6280 0.3893 0.0576 40.8026 18.2906 1.4236 missing
b -15.2641 4.2187 2.0435 4.2878 4.6395 1.4834 missing
logc 1.0824 1.1025 0.5967 4.0816 4.4534 1.4062 missing
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
a -0.6893 0.6643 0.6970 0.6970 1.3287
b -17.2110 -17.1433 -17.1433 -16.4625 -2.2291
logc 0.4505 0.4505 0.4505 1.2029 3.7519 |
Cool, thank you for the clarification, I knew I was missing something! I would not have thought to try using any of the fancy conditioning syntax. Making a Turing model into a And thanks for the worked example. I was trying to figure out how to get the variable names and indices automatically but couldn't do it myself. This feels like a good outline of a usable extension...does it make more sense as an MLD extension in Turing, or a Turing extension in MLD? |
We shouldn't need to, MarginalLogDensities is totally backend-agnostic (in theory anyway). |
I think, given that it returns a MLD object, that it makes more sense as an MLD extension maybe? And yes, I do think this is pretty close to being a functional extension:) If you want, I'm happy to make a PR or you can make a PR and I can support? The only annoyance is that you end up touching some internals, e.g.
Yeah no I saw this, which is very nice:) It was more a comment on the AD packages; I tried both ForwardDiff.jl and ReverseDiff.jl, but both failed upon hessian computation. |
Thanks, @ElOceanografo; I will keep it as a Turing extension if you like. Please feel free to open a PR. It is helpful for an ongoing research project for modular inference methods. |
That's kind of what I was thinking as well. If there are ways to make the interface a bit more robust to changes in those Turing internals that would be great, but MLD is an experimental package with few users at the moment, so I'm okay with a moderate risk of breakage. If you want to make a PR go for it, otherwise I will try to get to it in the next week or so.
Try again with ForwardDiff, there was a bug that just got fixed here: ElOceanografo/MarginalLogDensities.jl#36. But in general yes, things are a bit more brittle in practice than they should be in theory... |
Yeah, maybe it's better to do it as an extension to Turing.jl given that it access Turing.jl internals 🤔 I think the likelihood of us breaking the functionality is greater than you doing so, so seems sensible to put it in Turing.jl (at least for now). I'll open a PR with some minor tweaks to the above then 👍 Buuut will likely need some help to get it merged (a bit limited on time these days); maybe you could help add some tests @ElOceanografo ?:) |
Yeah, definitely. Just ping me when you've got a PR and let me know where you'd like the tests to go. |
Done! See #2421:) |
Moving a discussion with @yebai from Slack to here. @PavanChaggar asked if there was a way to do Laplace approximation in Turing, and I gave this little example of how it could be accomplished with MarginalLogDensities.jl:
This issue is to discuss how this capability might be integrated better into Turing, probably via a package extension. (See also #1382, which I opened before I made MLD). From a user perspective, an interface like this makes sense to me:
I think there are two basic ways to implement this:
marginalize
constructs a new, marginalizedDynamicPPL.Model
, orMarginalLogDensities.MarginalLogDensity
, with new methods forsample
,maximum_a_posteriori
, andmaximum_likelihood
defined for it.I'm not very familiar with Turing's internals, so happy to be corrected if there are other approaches that make more sense....
The other current roadblock is making calls to
MarginalLogDensity
objects differentiable (ElOceanografo/MarginalLogDensities.jl#34). This is doable, I just need to do it.The text was updated successfully, but these errors were encountered: