From 23141be4cfa82fa0a5fdf0a1d5829e0d802c6771 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 16:35:05 +0100 Subject: [PATCH] formatting --- src/symbolic/Symbolic.jl | 40 +++++++++++++++++++++------------------- src/symbolic/contexts.jl | 1 - src/symbolic/rules.jl | 29 ++++++++++++++++------------- src/varinfo.jl | 4 ++-- 4 files changed, 39 insertions(+), 35 deletions(-) diff --git a/src/symbolic/Symbolic.jl b/src/symbolic/Symbolic.jl index 13c84de88..696103e19 100644 --- a/src/symbolic/Symbolic.jl +++ b/src/symbolic/Symbolic.jl @@ -1,15 +1,16 @@ module Symbolic -import ..DynamicPPL -import ..DynamicPPL: Model, VarInfo, AbstractSampler, SampleFromPrior, VarName, DefaultContext +using ..DynamicPPL: DynamicPPL +import ..DynamicPPL: + Model, VarInfo, AbstractSampler, SampleFromPrior, VarName, DefaultContext -import Random -import Bijectors +using Random: Random +using Bijectors: Bijectors using Distributions -import Symbolics +using Symbolics: Symbolics import Symbolics: SymbolicUtils -issym(x::Union{Symbolics.Num, SymbolicUtils.Symbolic}) = true +issym(x::Union{Symbolics.Num,SymbolicUtils.Symbolic}) = true issym(x) = false include("rules.jl") @@ -19,18 +20,19 @@ symbolize(args...; kwargs...) = symbolize(Random.GLOBAL_RNG, args...; kwargs...) function symbolize( rng::Random.AbstractRNG, m::Model, - vi::VarInfo = VarInfo(m); - spl = SampleFromPrior(), - ctx = DefaultContext(), - include_data = false + vi::VarInfo=VarInfo(m); + spl=SampleFromPrior(), + ctx=DefaultContext(), + include_data=false, ) - m(rng, vi, spl, ctx); + m(rng, vi, spl, ctx) θ_orig = vi[spl] # Symbolic `logpdf` for fixed observations. + # TODO: don't `collect` once symbolic arrays are mature enough. Symbolics.@variables θ[1:length(θ_orig)] - vi = VarInfo{Real}(vi, spl, θ, 0.0); - m(vi, ctx); + vi = VarInfo{Real}(vi, spl, θ, 0.0) + m(vi, ctx) return vi, θ end @@ -49,23 +51,23 @@ function dependencies(ctx::SymbolicContext, vn::VarName) Symbolics.get_variables(a) end end -function dependencies(ctx::SymbolicContext, symbolic = false) +function dependencies(ctx::SymbolicContext, symbolic=false) vn2var = ctx.vn2var var2vn = Dict(values(vn2var) .=> keys(vn2var)) return Dict( - (symbolic ? vn2var[vn] : vn) => map(x -> symbolic ? x : var2vn[x], dependencies(ctx, vn)) - for vn in keys(ctx.vn2var) + (symbolic ? vn2var[vn] : vn) => + map(x -> symbolic ? x : var2vn[x], dependencies(ctx, vn)) for + vn in keys(ctx.vn2var) ) end -function dependencies(m::Model, symbolic = false) +function dependencies(m::Model, symbolic=false) ctx = SymbolicContext(DefaultContext()) - vi = symbolize(m, VarInfo(m), ctx = ctx) + vi = symbolize(m, VarInfo(m); ctx=ctx) return dependencies(ctx, symbolic) end - function symbolic_logp(m::Model) vi, θ = symbolize(m) lp = DynamicPPL.getlogp(vi) diff --git a/src/symbolic/contexts.jl b/src/symbolic/contexts.jl index cc9eb62fc..faf371e8f 100644 --- a/src/symbolic/contexts.jl +++ b/src/symbolic/contexts.jl @@ -17,7 +17,6 @@ function DynamicPPL.tilde_assume(rng, ctx::SymbolicContext, sampler, right, vn, return DynamicPPL.tilde_assume(rng, ctx.ctx, sampler, right, vn, inds, vi) end - # TODO: Make it more useful when working with symbolic observations. # observe function DynamicPPL.tilde_observe(ctx::SymbolicContext, sampler, right, left, vi) diff --git a/src/symbolic/rules.jl b/src/symbolic/rules.jl index d0a8421ce..a075b1e98 100644 --- a/src/symbolic/rules.jl +++ b/src/symbolic/rules.jl @@ -1,5 +1,5 @@ -import Bijectors -import Symbolics +using Bijectors: Bijectors +using Symbolics: Symbolics using Symbolics.SymbolicUtils Symbolics.@register Bijectors.logpdf_with_trans(dist, r, istrans) @@ -11,8 +11,12 @@ islogpdf(x) = false # HACK: Apparently this is needed for disambiguiation. # TODO: Open issue. -Symbolics.:<ₑ(a::Real, b::Symbolics.Num) = Symbolics.:<ₑ(Symbolics.value(a), Symbolics.value(b)) -Symbolics.:<ₑ(a::Symbolics.Num, b::Real) = Symbolics.:<ₑ(Symbolics.value(a), Symbolics.value(b)) +function Symbolics.:<ₑ(a::Real, b::Symbolics.Num) + return Symbolics.:<ₑ(Symbolics.value(a), Symbolics.value(b)) +end +function Symbolics.:<ₑ(a::Symbolics.Num, b::Real) + return Symbolics.:<ₑ(Symbolics.value(a), Symbolics.value(b)) +end ############# ### Rules ### @@ -22,13 +26,15 @@ const rmnum_rule = @rule (~x) => Symbolics.value(~x) const addnum_rule = @rule (~x) => Symbolics.Num(~x) # In the case where we want to work directly with the `x ~ Distribution` statements, the following rules can be useful: -const logpdf_rule = @rule (~x ~ ~d) => Distributions.logpdf(Symbolics.Num(~d), Symbolics.Num(~x)); +const logpdf_rule = @rule (~x ~ ~d) => + Distributions.logpdf(Symbolics.Num(~d), Symbolics.Num(~x)); const rand_rule = @rule (~x ~ ~d) => Distributions.rand(Symbolics.Num(~d)) # We don't want to trace into `Bijectors.logpdf_with_trans`, so we just replace it with `logpdf`. islogpdf_with_trans(f::Function) = f === Bijectors.logpdf_with_trans islogpdf_with_trans(x) = false -const logpdf_with_trans_rule = @rule (~f::islogpdf_with_trans)(~dist, ~x, ~istrans) => logpdf(~dist, ~x) +const logpdf_with_trans_rule = @rule (~f::islogpdf_with_trans)(~dist, ~x, ~istrans) => + logpdf(~dist, ~x) # Attempt to expand `logpdf` to get analytical expressions. # The idea is that `getlogpdf(d, args)` should return a method of the following signature: @@ -39,11 +45,8 @@ const logpdf_with_trans_rule = @rule (~f::islogpdf_with_trans)(~dist, ~x, ~istra # HACK: this is very hacky but you get the idea import Distributions: StatsFuns function getlogpdf(d, args) - replacements = Dict( - :Normal => StatsFuns.normlogpdf, - :Gamma => StatsFuns.gammalogpdf - ) - + replacements = Dict(:Normal => StatsFuns.normlogpdf, :Gamma => StatsFuns.gammalogpdf) + dsym = Symbol(d) if haskey(replacements, dsym) return replacements[dsym] @@ -52,8 +55,8 @@ function getlogpdf(d, args) end end -const analytic_rule = @rule (~f::islogpdf)((~d::isdist)(~~args), ~x) => getlogpdf(~d, ~~args)(map(Symbolics.Num, (~~args))..., Symbolics.Num(~x)) - +const analytic_rule = @rule (~f::islogpdf)((~d::isdist)(~~args), ~x) => + getlogpdf(~d, ~~args)(map(Symbolics.Num, (~~args))..., Symbolics.Num(~x)) ################# ### Rewriters ### diff --git a/src/varinfo.jl b/src/varinfo.jl index 4bad2a50a..d7991c73d 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -124,12 +124,12 @@ end function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector, lp::T) where {T} md = newmetadata(old_vi.metadata, Val(getspace(spl)), x) - VarInfo(md, Base.RefValue{T}(lp), Ref(get_num_produce(old_vi))) + return VarInfo(md, Base.RefValue{T}(lp), Ref(get_num_produce(old_vi))) end function VarInfo{T}(old_vi::TypedVarInfo, spl, x::AbstractVector) where {T} md = newmetadata(old_vi.metadata, Val(getspace(spl)), x) - VarInfo(md, Base.RefValue{T}(0.0), Ref(get_num_produce(old_vi))) + return VarInfo(md, Base.RefValue{T}(0.0), Ref(get_num_produce(old_vi))) end function VarInfo(