-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Making use of Symbolics.jl/SymbolicUtils.jl #234
Conversation
src/context_implementations.jl
Outdated
value, | ||
vi, | ||
) | ||
increment_num_produce!(vi) | ||
return Distributions.loglikelihood(dist, value) | ||
return sum(Distributions.logpdf(dist, value)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is left-over from some earlier experimentation I did. We should make loglikelihood
a primitive, and then work from there I think.
function symbolize( | ||
rng::Random.AbstractRNG, | ||
m::Model, | ||
vi::VarInfo=VarInfo(m); | ||
spl=SampleFromPrior(), | ||
ctx=DefaultContext(), | ||
include_data=false, | ||
) | ||
m(rng, vi, spl, ctx) | ||
θ_orig = vi[spl] | ||
|
||
# Symbolic `logpdf` for fixed observations. | ||
# TODO: don't `collect` once symbolic arrays are mature enough. | ||
Symbolics.@variables θ[1:length(θ_orig)] | ||
vi = VarInfo{Real}(vi, spl, θ) | ||
m(vi, ctx) | ||
|
||
return vi, θ | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Execute model once to get shape of latent variables.
- Construct symbolic variables.
- Execute model on symbolic variables.
vi
(the trace struct) now contains a symbolic representation of the logjoint retrievable throughgetlogp(vi)
.
function dependencies(m::Model, symbolic=false) | ||
ctx = SymbolicContext(DefaultContext()) | ||
vi = symbolize(m, VarInfo(m); ctx=ctx) | ||
|
||
return dependencies(ctx, symbolic) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uses "contextual" dispatch to overload the corresponding *tilde_
statements, stroing the mapping from symbolic variable θ[i]
to VarName
.
function getlogpdf(d, args) | ||
replacements = Dict(:Normal => StatsFuns.normlogpdf, :Gamma => StatsFuns.gammalogpdf) | ||
|
||
dsym = Symbol(d) | ||
if haskey(replacements, dsym) | ||
return replacements[dsym] | ||
else | ||
return d | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea behind all this was to replace the "non-tracable" logpdf
impls from Distributions.jl with traceable impls from StatsFuns.jl. After #292 we could probably avoid this by simply using MeasureTheory.jl instead:)
Also, this is super messy and probably not the greatest; I blame it on the fact that I had no idea what I was doing:)
Closed in favour of TuringLang/AbstractPPL.jl#47 |
I'm just going to put this here before heading to bed, but there's some more stuff to do here:
Variable
for the model arguments yet.Anyways, it's pretty dope.
Expression generation
EDIT: This doesn't work right now as it seems the original rewriters are now out of date. We can still extract an expression using
vi, θ Symbolic.symbolize(m); getlogp(vi)
but we won't have tracing through thelogpdf
computation.Dependencies
Can of course use LightGraphs.jl, etc. to generate visualizations of the models too:
Resulting in: