Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making use of Symbolics.jl/SymbolicUtils.jl #234

Closed
wants to merge 13 commits into from
Closed

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented Apr 27, 2021

I'm just going to put this here before heading to bed, but there's some more stuff to do here:

  • Allow making observations/model arguments symbolic. I've verified that this works, but I just haven't gotten around to automatically constructing Variable for the model arguments yet.
  • Docstrings/tests
  • ???

Anyways, it's pretty dope.

julia> using DynamicPPL, Distributions

julia> @model function demo(x, ::Type{TV} = Vector{Float64}) where {TV}
           # Just to demonstrate that we're not restricted to `x ~ DistKnowAtExpansionTime`
           s_prior = InverseGamma()
           s ~ s_prior

           num_obs = length(x)
           m = TV(undef, num_obs)
           m[1] ~ Normal(0, s)
           x[1] ~ Normal(x[1], s)
           for i = 2:num_obs
               m[i] ~ Normal(m[i - 1], s)
               x[i] ~ Normal(m[i], s)
           end
           return m, x
       end;

julia> m = demo(randn(10) .+ 1);

Expression generation

EDIT: This doesn't work right now as it seems the original rewriters are now out of date. We can still extract an expression using vi, θ Symbolic.symbolize(m); getlogp(vi) but we won't have tracing through the logpdf computation.

julia> lp, θ = Symbolic.symbolic_logp(m) # can drop the constant in front too if we want
(-18.378770664093455 - (2.0log(θ₁)) - (20log(sqrt(θ₁))) - (0.5abs2(θ₂*(sqrt(θ₁)^-1))) - (0.5abs2((θ₁₀ - θ₉)*(sqrt(θ₁)^-1))) - (0.5abs2((θ₁₁ - θ₁₀)*(sqrt(θ₁)^-1))) - (0.5abs2((θ₃ - θ₂)*(sqrt(θ₁)^-1))) - (0.5abs2((θ₄ - θ₃)*(sqrt(θ₁)^-1))) - (0.5abs2((θ₅ - θ₄)*(sqrt(θ₁)^-1))) - (0.5abs2((θ₆ - θ₅)*(sqrt(θ₁)^-1))) - (0.5abs2((θ₇ - θ₆)*(sqrt(θ₁)^-1))) - (0.5abs2((θ₈ - θ₇)*(sqrt(θ₁)^-1))) - (0.5abs2((θ₉ - θ₈)*(sqrt(θ₁)^-1))) - (0.5abs2((sqrt(θ₁)^-1)*(-1.4166318988896798 - θ₁₀))) - (0.5abs2((sqrt(θ₁)^-1)*(1.9786928906469687 - θ₁₁))) - (0.5abs2((sqrt(θ₁)^-1)*(-0.7713281377661316 - θ₃))) - (0.5abs2((sqrt(θ₁)^-1)*(0.5953344430416955 - θ₄))) - (0.5abs2((sqrt(θ₁)^-1)*(0.4888163369370143 - θ₅))) - (0.5abs2((sqrt(θ₁)^-1)*(0.9869369683122523 - θ₆))) - (0.5abs2((sqrt(θ₁)^-1)*(0.18968854568519278 - θ₇))) - (0.5abs2((sqrt(θ₁)^-1)*(2.20765228354745 - θ₈))) - (0.5abs2((sqrt(θ₁)^-1)*(1.1746311120999957 - θ₉))) - (θ₁^-1), Symbolics.Num[θ₁, θ₂, θ₃, θ₄, θ₅, θ₆, θ₇, θ₈, θ₉, θ₁₀, θ₁₁])

julia> ∂lp = Symbolics.gradient(lp, θ)
11-element Vector{Symbolics.Num}:
                                     θ₁^-2 + 0.5(θ₂^2)*(sqrt(θ₁)^-4) + 0.5(sqrt(θ₁)^-4)*((θ₁₀ - θ₉)^2) + 0.5(sqrt(θ₁)^-4)*((θ₁₁ - θ₁₀)^2) + 0.5(sqrt(θ₁)^-4)*((θ₃ - θ₂)^2) + 0.5(sqrt(θ₁)^-4)*((θ₄ - θ₃)^2) + 0.5(sqrt(θ₁)^-4)*((θ₅ - θ₄)^2) + 0.5(sqrt(θ₁)^-4)*((θ₆ - θ₅)^2) + 0.5(sqrt(θ₁)^-4)*((θ₇ - θ₆)^2) + 0.5(sqrt(θ₁)^-4)*((θ₈ - θ₇)^2) + 0.5(sqrt(θ₁)^-4)*((θ₉ - θ₈)^2) + 0.5(sqrt(θ₁)^-4)*((-1.4166318988896798 - θ₁₀)^2) + 0.5(sqrt(θ₁)^-4)*((1.9786928906469687 - θ₁₁)^2) + 0.5(sqrt(θ₁)^-4)*((-0.7713281377661316 - θ₃)^2) + 0.5(sqrt(θ₁)^-4)*((0.5953344430416955 - θ₄)^2) + 0.5(sqrt(θ₁)^-4)*((0.4888163369370143 - θ₅)^2) + 0.5(sqrt(θ₁)^-4)*((0.9869369683122523 - θ₆)^2) + 0.5(sqrt(θ₁)^-4)*((0.18968854568519278 - θ₇)^2) + 0.5(sqrt(θ₁)^-4)*((2.20765228354745 - θ₈)^2) + 0.5(sqrt(θ₁)^-4)*((1.1746311120999957 - θ₉)^2) - (2.0(θ₁^-1)) - ((10//1)*(sqrt(θ₁)^-2))
  (θ₃ - θ₂)*(sqrt(θ₁)^-2) - (θ₂*(sqrt(θ₁)^-2))
   (θ₄ - θ₃)*(sqrt(θ₁)^-2) + (sqrt(θ₁)^-2)*(-0.7713281377661316 - θ₃) - ((θ₃ - θ₂)*(sqrt(θ₁)^-2))
    (θ₅ - θ₄)*(sqrt(θ₁)^-2) + (sqrt(θ₁)^-2)*(0.5953344430416955 - θ₄) - ((θ₄ - θ₃)*(sqrt(θ₁)^-2))
    (θ₆ - θ₅)*(sqrt(θ₁)^-2) + (sqrt(θ₁)^-2)*(0.4888163369370143 - θ₅) - ((θ₅ - θ₄)*(sqrt(θ₁)^-2))
    (θ₇ - θ₆)*(sqrt(θ₁)^-2) + (sqrt(θ₁)^-2)*(0.9869369683122523 - θ₆) - ((θ₆ - θ₅)*(sqrt(θ₁)^-2))
    (θ₈ - θ₇)*(sqrt(θ₁)^-2) + (sqrt(θ₁)^-2)*(0.18968854568519278 - θ₇) - ((θ₇ - θ₆)*(sqrt(θ₁)^-2))
    (θ₉ - θ₈)*(sqrt(θ₁)^-2) + (sqrt(θ₁)^-2)*(2.20765228354745 - θ₈) - ((θ₈ - θ₇)*(sqrt(θ₁)^-2))
   (θ₁₀ - θ₉)*(sqrt(θ₁)^-2) + (sqrt(θ₁)^-2)*(1.1746311120999957 - θ₉) - ((θ₉ - θ₈)*(sqrt(θ₁)^-2))
 (θ₁₁ - θ₁₀)*(sqrt(θ₁)^-2) + (sqrt(θ₁)^-2)*(-1.4166318988896798 - θ₁₀) - ((θ₁₀ - θ₉)*(sqrt(θ₁)^-2))
                              (sqrt(θ₁)^-2)*(1.9786928906469687 - θ₁₁) - ((θ₁₁ - θ₁₀)*(sqrt(θ₁)^-2))

julia> ∂f, ∂f! = Symbolics.build_function(∂lp, θ, expression = false);

julia> ∂f(rand(Symbolics.value(length(θ))))
11-element Vector{Float64}:
 26.085486889785475
 -1.1940138187926816
 -0.709786284651692
 -1.8974997722462343
  3.2298551373276108
  0.07043163767794725
  3.301109716439576
  3.527331147067538
  0.6372253272009718
 -7.676985933953628
  3.717132883407671

Dependencies

julia> Symbolic.dependencies(m) # `VarName` -> `VarName`
Dict{VarName, Vector{T} where T} with 11 entries:
  m[5]  => VarName[m[4], s]
  m[4]  => VarName[m[3], s]
  m[8]  => VarName[m[7], s]
  s     => Union{typeof(var), VarName}[]
  m[6]  => VarName[m[5], s]
  m[3]  => VarName[m[2], s]
  m[1]  => [s]
  m[7]  => VarName[m[6], s]
  m[2]  => VarName[m[1], s]
  m[10] => VarName[m[9], s]
  m[9]  => VarName[m[8], s]

julia> Symbolic.dependencies(m, true) # `Symbol` -> `Symbol`
Dict{Symbolics.Num, Vector{T} where T} with 11 entries:
  θ₉  => SymbolicUtils.Sym{Real, Nothing}[θ₈, θ₁]
  θ₇  => SymbolicUtils.Sym{Real, Nothing}[θ₆, θ₁]
  θ₅  => SymbolicUtils.Sym{Real, Nothing}[θ₄, θ₁]
  θ₆  => SymbolicUtils.Sym{Real, Nothing}[θ₅, θ₁]
  θ₃  => SymbolicUtils.Sym{Real, Nothing}[θ₂, θ₁]
  θ₁₀ => SymbolicUtils.Sym{Real, Nothing}[θ₉, θ₁]
  θ₂  => SymbolicUtils.Sym{Real, Nothing}[θ₁]
  θ₁₁ => SymbolicUtils.Sym{Real, Nothing}[θ₁₀, θ₁]
  θ₁  => Any[]
  θ₄  => SymbolicUtils.Sym{Real, Nothing}[θ₃, θ₁]
  θ₈  => SymbolicUtils.Sym{Real, Nothing}[θ₇, θ₁]

Can of course use LightGraphs.jl, etc. to generate visualizations of the models too:

julia> using LightGraphs, MetaGraphs

julia> function graph(m::Model)
           deps = Symbolic.dependencies(m)

           g = MetaDiGraph(length(deps))
           for (i, vn) in enumerate(keys(deps))
               set_prop!(g, i, :vn, vn)
           end
           set_indexing_prop!(g, :vn)

           for vn in keys(deps)
               for parent in deps[vn]
                   add_edge!(g, (g[parent, :vn], g[vn, :vn]))
               end
           end
           return g
       end
graph (generic function with 1 method)

julia> g = graph(m)
{11, 19} directed Int64 metagraph with Float64 weights defined by :weight (default weight 1.0)

julia> using GraphRecipes, Plots

julia> nodelabels = map(vertices(g)) do v
           get_prop(g, v, :vn)
       end;

julia> graphplot(g, names=nodelabels)

Resulting in:

Screenshot_20210427_051434

value,
vi,
)
increment_num_produce!(vi)
return Distributions.loglikelihood(dist, value)
return sum(Distributions.logpdf(dist, value))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is left-over from some earlier experimentation I did. We should make loglikelihood a primitive, and then work from there I think.

src/symbolic/Symbolic.jl Outdated Show resolved Hide resolved
src/symbolic/Symbolic.jl Outdated Show resolved Hide resolved
src/symbolic/Symbolic.jl Outdated Show resolved Hide resolved
src/symbolic/Symbolic.jl Outdated Show resolved Hide resolved
src/symbolic/Symbolic.jl Outdated Show resolved Hide resolved
src/symbolic/rules.jl Outdated Show resolved Hide resolved
src/symbolic/rules.jl Outdated Show resolved Hide resolved
src/symbolic/rules.jl Outdated Show resolved Hide resolved
src/symbolic/rules.jl Outdated Show resolved Hide resolved
src/varinfo.jl Outdated Show resolved Hide resolved
@torfjelde torfjelde marked this pull request as draft July 24, 2021 15:36
Comment on lines +23 to +41
function symbolize(
rng::Random.AbstractRNG,
m::Model,
vi::VarInfo=VarInfo(m);
spl=SampleFromPrior(),
ctx=DefaultContext(),
include_data=false,
)
m(rng, vi, spl, ctx)
θ_orig = vi[spl]

# Symbolic `logpdf` for fixed observations.
# TODO: don't `collect` once symbolic arrays are mature enough.
Symbolics.@variables θ[1:length(θ_orig)]
vi = VarInfo{Real}(vi, spl, θ)
m(vi, ctx)

return vi, θ
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Execute model once to get shape of latent variables.
  2. Construct symbolic variables.
  3. Execute model on symbolic variables.
  4. vi (the trace struct) now contains a symbolic representation of the logjoint retrievable through getlogp(vi).

Comment on lines +67 to +72
function dependencies(m::Model, symbolic=false)
ctx = SymbolicContext(DefaultContext())
vi = symbolize(m, VarInfo(m); ctx=ctx)

return dependencies(ctx, symbolic)
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uses "contextual" dispatch to overload the corresponding *tilde_ statements, stroing the mapping from symbolic variable θ[i] to VarName.

Comment on lines +47 to +56
function getlogpdf(d, args)
replacements = Dict(:Normal => StatsFuns.normlogpdf, :Gamma => StatsFuns.gammalogpdf)

dsym = Symbol(d)
if haskey(replacements, dsym)
return replacements[dsym]
else
return d
end
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea behind all this was to replace the "non-tracable" logpdf impls from Distributions.jl with traceable impls from StatsFuns.jl. After #292 we could probably avoid this by simply using MeasureTheory.jl instead:)

Also, this is super messy and probably not the greatest; I blame it on the fact that I had no idea what I was doing:)

@yebai
Copy link
Member

yebai commented Mar 17, 2022

Closed in favour of TuringLang/AbstractPPL.jl#47

@yebai yebai closed this Mar 17, 2022
@yebai yebai deleted the tor/symbolics branch March 17, 2022 17:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants