Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Jul 24, 2021
1 parent 70c9997 commit 23141be
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 35 deletions.
40 changes: 21 additions & 19 deletions src/symbolic/Symbolic.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
module Symbolic

import ..DynamicPPL
import ..DynamicPPL: Model, VarInfo, AbstractSampler, SampleFromPrior, VarName, DefaultContext
using ..DynamicPPL: DynamicPPL
import ..DynamicPPL:
Model, VarInfo, AbstractSampler, SampleFromPrior, VarName, DefaultContext

import Random
import Bijectors
using Random: Random
using Bijectors: Bijectors
using Distributions
import Symbolics
using Symbolics: Symbolics
import Symbolics: SymbolicUtils

issym(x::Union{Symbolics.Num, SymbolicUtils.Symbolic}) = true
issym(x::Union{Symbolics.Num,SymbolicUtils.Symbolic}) = true
issym(x) = false

include("rules.jl")
Expand All @@ -19,18 +20,19 @@ symbolize(args...; kwargs...) = symbolize(Random.GLOBAL_RNG, args...; kwargs...)
function symbolize(
rng::Random.AbstractRNG,
m::Model,
vi::VarInfo = VarInfo(m);
spl = SampleFromPrior(),
ctx = DefaultContext(),
include_data = false
vi::VarInfo=VarInfo(m);
spl=SampleFromPrior(),
ctx=DefaultContext(),
include_data=false,
)
m(rng, vi, spl, ctx);
m(rng, vi, spl, ctx)
θ_orig = vi[spl]

# Symbolic `logpdf` for fixed observations.
# TODO: don't `collect` once symbolic arrays are mature enough.
Symbolics.@variables θ[1:length(θ_orig)]
vi = VarInfo{Real}(vi, spl, θ, 0.0);
m(vi, ctx);
vi = VarInfo{Real}(vi, spl, θ, 0.0)
m(vi, ctx)

return vi, θ
end
Expand All @@ -49,23 +51,23 @@ function dependencies(ctx::SymbolicContext, vn::VarName)
Symbolics.get_variables(a)
end
end
function dependencies(ctx::SymbolicContext, symbolic = false)
function dependencies(ctx::SymbolicContext, symbolic=false)
vn2var = ctx.vn2var
var2vn = Dict(values(vn2var) .=> keys(vn2var))
return Dict(
(symbolic ? vn2var[vn] : vn) => map(x -> symbolic ? x : var2vn[x], dependencies(ctx, vn))
for vn in keys(ctx.vn2var)
(symbolic ? vn2var[vn] : vn) =>
map(x -> symbolic ? x : var2vn[x], dependencies(ctx, vn)) for
vn in keys(ctx.vn2var)
)
end

function dependencies(m::Model, symbolic = false)
function dependencies(m::Model, symbolic=false)
ctx = SymbolicContext(DefaultContext())
vi = symbolize(m, VarInfo(m), ctx = ctx)
vi = symbolize(m, VarInfo(m); ctx=ctx)

return dependencies(ctx, symbolic)
end


function symbolic_logp(m::Model)
vi, θ = symbolize(m)
lp = DynamicPPL.getlogp(vi)
Expand Down
1 change: 0 additions & 1 deletion src/symbolic/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ function DynamicPPL.tilde_assume(rng, ctx::SymbolicContext, sampler, right, vn,
return DynamicPPL.tilde_assume(rng, ctx.ctx, sampler, right, vn, inds, vi)
end


# TODO: Make it more useful when working with symbolic observations.
# observe
function DynamicPPL.tilde_observe(ctx::SymbolicContext, sampler, right, left, vi)
Expand Down
29 changes: 16 additions & 13 deletions src/symbolic/rules.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import Bijectors
import Symbolics
using Bijectors: Bijectors
using Symbolics: Symbolics
using Symbolics.SymbolicUtils

Symbolics.@register Bijectors.logpdf_with_trans(dist, r, istrans)
Expand All @@ -11,8 +11,12 @@ islogpdf(x) = false

# HACK: Apparently this is needed for disambiguiation.
# TODO: Open issue.
Symbolics.:<(a::Real, b::Symbolics.Num) = Symbolics.:<(Symbolics.value(a), Symbolics.value(b))
Symbolics.:<(a::Symbolics.Num, b::Real) = Symbolics.:<(Symbolics.value(a), Symbolics.value(b))
function Symbolics.:<(a::Real, b::Symbolics.Num)
return Symbolics.:<(Symbolics.value(a), Symbolics.value(b))
end
function Symbolics.:<(a::Symbolics.Num, b::Real)
return Symbolics.:<(Symbolics.value(a), Symbolics.value(b))
end

#############
### Rules ###
Expand All @@ -22,13 +26,15 @@ const rmnum_rule = @rule (~x) => Symbolics.value(~x)
const addnum_rule = @rule (~x) => Symbolics.Num(~x)

# In the case where we want to work directly with the `x ~ Distribution` statements, the following rules can be useful:
const logpdf_rule = @rule (~x ~ ~d) => Distributions.logpdf(Symbolics.Num(~d), Symbolics.Num(~x));
const logpdf_rule = @rule (~x ~ ~d) =>
Distributions.logpdf(Symbolics.Num(~d), Symbolics.Num(~x));
const rand_rule = @rule (~x ~ ~d) => Distributions.rand(Symbolics.Num(~d))

# We don't want to trace into `Bijectors.logpdf_with_trans`, so we just replace it with `logpdf`.
islogpdf_with_trans(f::Function) = f === Bijectors.logpdf_with_trans
islogpdf_with_trans(x) = false
const logpdf_with_trans_rule = @rule (~f::islogpdf_with_trans)(~dist, ~x, ~istrans) => logpdf(~dist, ~x)
const logpdf_with_trans_rule = @rule (~f::islogpdf_with_trans)(~dist, ~x, ~istrans) =>
logpdf(~dist, ~x)

# Attempt to expand `logpdf` to get analytical expressions.
# The idea is that `getlogpdf(d, args)` should return a method of the following signature:
Expand All @@ -39,11 +45,8 @@ const logpdf_with_trans_rule = @rule (~f::islogpdf_with_trans)(~dist, ~x, ~istra
# HACK: this is very hacky but you get the idea
import Distributions: StatsFuns
function getlogpdf(d, args)
replacements = Dict(
:Normal => StatsFuns.normlogpdf,
:Gamma => StatsFuns.gammalogpdf
)

replacements = Dict(:Normal => StatsFuns.normlogpdf, :Gamma => StatsFuns.gammalogpdf)

dsym = Symbol(d)
if haskey(replacements, dsym)
return replacements[dsym]
Expand All @@ -52,8 +55,8 @@ function getlogpdf(d, args)
end
end

const analytic_rule = @rule (~f::islogpdf)((~d::isdist)(~~args), ~x) => getlogpdf(~d, ~~args)(map(Symbolics.Num, (~~args))..., Symbolics.Num(~x))

const analytic_rule = @rule (~f::islogpdf)((~d::isdist)(~~args), ~x) =>
getlogpdf(~d, ~~args)(map(Symbolics.Num, (~~args))..., Symbolics.Num(~x))

#################
### Rewriters ###
Expand Down
4 changes: 2 additions & 2 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,12 @@ end

function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector, lp::T) where {T}
md = newmetadata(old_vi.metadata, Val(getspace(spl)), x)
VarInfo(md, Base.RefValue{T}(lp), Ref(get_num_produce(old_vi)))
return VarInfo(md, Base.RefValue{T}(lp), Ref(get_num_produce(old_vi)))
end

function VarInfo{T}(old_vi::TypedVarInfo, spl, x::AbstractVector) where {T}
md = newmetadata(old_vi.metadata, Val(getspace(spl)), x)
VarInfo(md, Base.RefValue{T}(0.0), Ref(get_num_produce(old_vi)))
return VarInfo(md, Base.RefValue{T}(0.0), Ref(get_num_produce(old_vi)))
end

function VarInfo(
Expand Down

0 comments on commit 23141be

Please sign in to comment.