From 67ef9e00935aa73439278f33c1297f386793ece9 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Thu, 28 Nov 2019 11:19:36 -0800 Subject: [PATCH] Dev (#62) * update dependencies * remove junk files * remove junk files * Reorganizing (Good idea Kusti) * internals.md * update README * ignore .vscode directory * Revert "ignore .vscode directory" This reverts commit 5753f5828bd7d6320cd68928a252161c9a364ea6. * Delete settings.json * split jointdistribution.jl * parameterize model by module * make `runtests` throw errors * Parameterize Model by module * debugging * Update for https://github.com/thautwarm/GeneralizedGenerated.jl/pull/28 * insert `$sympy`(SymPy) & `$as`(TransformVariables.jl); (#63) tests passed * Package upper bounds * fix demo * Being less fancy with logpdf * Move logpdf(m,x,::typeof(codegen)) to codegen.jl * Fixing stuff I just broke * working on markovblanket * working on Markov Blanket * markov blanket * add bodyVariables * export bodyVariables * make `toposort` us `arguments` instead of `freeVariables` * up * ixnay on the ype piracytay * Type infer for module param (#66) * apply changes * fix #65: might need more changes * Update dependencies + README * update README * update README --- .travis.yml | 2 +- Project.toml | 33 +- README.md | 6 +- examples/2019-11-07-demo.jmd | 2 +- examples/2019-11-07-demo.md | 397 ++++++------------------- src/core/model.jl | 5 + src/core/statement.jl | 4 + src/core/toposort.jl | 2 +- src/core/utils.jl | 9 +- src/importance.jl | 8 +- src/particles.jl | 14 +- src/primitives/likelihood-weighting.jl | 8 +- src/primitives/logpdf.jl | 16 +- src/primitives/rand.jl | 17 +- src/primitives/xform.jl | 10 +- src/symbolic/codegen.jl | 7 +- src/symbolic/reduce.jl | 8 +- src/symbolic/symbolic.jl | 114 +++---- src/transforms/markovblanket.jl | 93 +++++- 19 files changed, 341 insertions(+), 414 deletions(-) diff --git a/.travis.yml b/.travis.yml index 29322829..d57bc217 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,7 +3,7 @@ os: - linux - osx julia: - - 1.2 + - 1.3 - nightly env: global: diff --git a/Project.toml b/Project.toml index 663b804f..f1ca34fb 100644 --- a/Project.toml +++ b/Project.toml @@ -40,8 +40,37 @@ TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -MonteCarloMeasurements = ">=0.3.5" -julia = "^1.2" +AdvancedHMC = "0.2" +Bijectors = "0.4" +DiffResults = "0.0.4" +Distributions = "0.21" +DynamicHMC = "2.1" +FillArrays = "0.8" +ForwardDiff = "0.10" +GeneralizedGenerated = "0.2" +Graphs = "0.10" +IRTools = "0.3" +IterTools = "1.3" +LazyArrays = "0.14" +LogDensityProblems = "0.9" +MLStyle = "0.3" +MacroTools = "0.5" +MonteCarloMeasurements = "0.5" +NamedTupleTools = "0.12" +Plots = "0.28" +PyCall = "1.91" +Reexport = "0.2" +ResumableFunctions = "0.5" +ReverseDiff = "0.3" +SimpleGraphs = "0.3" +SimplePartitions = "0.2" +SimplePosets = "0.0" +StatsFuns = "0.9" +Stheno = "0.3" +SymPy = "1.0" +TransformVariables = "0.3" +Zygote = "0.4" +julia = "^1.3" [extras] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" diff --git a/README.md b/README.md index a5537d94..f311e50f 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ Often these are easier to work with in terms of `particles` (built using [MonteC julia> post = dynamicHMC(m(X=truth.X), (y=truth.y,)); julia> particles(post) -(β = Particles{Float64,1000}[0.558 ± 0.25, 0.768 ± 0.49],) +(β = Particles{Float64,1000}[0.548 ± 0.24, 0.751 ± 0.5],) ```` @@ -204,11 +204,11 @@ julia> using BenchmarkTools julia> @btime logpdf($m2(X=X), $truth) - 802.533 ns (16 allocations: 464 bytes) + 1.989 μs (47 allocations: 1.05 KiB) -15.84854642585797 julia> @btime logpdf($m2(X=X), $truth, $codegen) - 324.463 ns (5 allocations: 208 bytes) + 313.860 ns (5 allocations: 208 bytes) -15.848546425857968 ```` diff --git a/examples/2019-11-07-demo.jmd b/examples/2019-11-07-demo.jmd index 44da6b36..601508f1 100644 --- a/examples/2019-11-07-demo.jmd +++ b/examples/2019-11-07-demo.jmd @@ -13,7 +13,7 @@ weave("examples/2019-11-07-demo.jmd" ```julia -using Revise, Soss, Plots, NamedTupleTools +using Soss, Plots, NamedTupleTools seed = WEAVE_ARGS[:seed] import Random; Random.seed!(seed) ``` diff --git a/examples/2019-11-07-demo.md b/examples/2019-11-07-demo.md index de5e8e5f..cec99fa5 100644 --- a/examples/2019-11-07-demo.md +++ b/examples/2019-11-07-demo.md @@ -13,18 +13,7 @@ weave("examples/2019-11-07-demo.jmd" ````julia -using Revise, Soss, Plots, NamedTupleTools -```` - - -```` -Error: ArgumentError: Package Revise not found in current path: -- Run `import Pkg; Pkg.add("Revise")` to install the Revise package. -```` - - - -````julia +using Soss, Plots, NamedTupleTools seed = WEAVE_ARGS[:seed] import Random; Random.seed!(seed) ```` @@ -46,52 +35,16 @@ mt = @model x begin Mix([Normal(yhat[j], σ), Normal(yhat[j],8σ)], [0.8,0.2]) end end; -```` - - -```` -Error: LoadError: UndefVarError: @model not defined -in expression starting at none:1 -```` - - - -````julia x = randn(100); truth = rand(mt(x=x)); ```` -```` -Error: UndefVarError: mt not defined -```` - - ````julia xx = range(extrema(truth.x)...,length=100) -```` - - -```` -Error: UndefVarError: truth not defined -```` - - - -````julia scatter(truth.x,truth.y, legend=false, c=1) -```` - - -```` -Error: UndefVarError: scatter not defined -```` - - - -````julia # plot!(xx, truth.α .+ truth.β .* xx, dpi=300,legend=false, lw=3, c=2) ```` @@ -111,12 +64,6 @@ end; ```` -```` -Error: LoadError: UndefVarError: @model not defined -in expression starting at none:1 -```` - - ````julia m(x=truth.x) @@ -124,29 +71,32 @@ m(x=truth.x) ```` -Error: UndefVarError: truth not defined -```` - - - -````julia -post = dynamicHMC(m(x=truth.x), (y=truth.y,)) -```` +Joint Distribution + Bound arguments: [x] + Variables: [σ, β, α, yhat, n, y] - -```` -Error: UndefVarError: truth not defined +@model x begin + σ ~ HalfNormal() + β ~ Normal() + α ~ Normal() + yhat = α .+ β .* x + n = length(x) + y ~ For(n) do j + Normal(yhat[j], σ) + end + end ```` ````julia +post = dynamicHMC(m(x=truth.x), (y=truth.y,)) ppost = particles(post) ```` ```` -Error: UndefVarError: particles not defined +(σ = 2.04 ± 0.15, β = 2.61 ± 0.2, α = 0.599 ± 0.21) ```` @@ -157,7 +107,10 @@ symlogpdf(m) |> foldConstants |> tolatex |> println ```` -Error: UndefVarError: symlogpdf not defined +- 0.5 α^{2} - 0.5 β^{2} - 1.0 σ^{2} + \sum_{Idx\left(j_{1}\right)=1}^{n} \l +eft(- \frac{0.5 \left({y}_{Idx\left(j_{1}\right)} - 1.0 {\hat{y}}_{Idx\left +(j_{1}\right)}\right)^{2}}{σ^{2}} - 1.0 \log{\left(σ \right)} - 0.918938533 +204673\right) - 3.67575413281869 ```` @@ -168,7 +121,10 @@ symlogpdf(m) |> expandSums |> foldConstants |> tolatex |> println ```` -Error: UndefVarError: symlogpdf not defined +- 1.0 \log{\left(σ \right)} n - 0.918938533204673 n - 0.5 α^{2} - 0.5 β^{2} + - 1.0 σ^{2} - 3.67575413281869 - \frac{0.5 \sum_{Idx\left(j_{1}\right)=1}^ +{n} \left({y}_{Idx\left(j_{1}\right)} - 1.0 {\hat{y}}_{Idx\left(j_{1}\right +)}\right)^{2}}{σ^{2}} ```` @@ -180,7 +136,8 @@ using BenchmarkTools ```` -Error: UndefVarError: m not defined +26.210 μs (633 allocations: 12.47 KiB) +-901.7607073245318 ```` @@ -191,32 +148,19 @@ Error: UndefVarError: m not defined ```` -Error: UndefVarError: m not defined +116.566 ns (1 allocation: 896 bytes) +-903.4977930382965 ```` ````julia eachplot(xx, ppost.α .+ ppost.β .* xx, lw=3, dpi=300, color=:black) -```` - - -```` -Error: UndefVarError: ppost not defined -```` - - - -````julia scatter!(truth.x,truth.y, legend=false, c=1) ```` -```` -Error: UndefVarError: scatter! not defined -```` - - +![](figures/2019-11-07-demo_11_1.png) ````julia pred = predictive(m, :α, :β, :σ) @@ -224,7 +168,13 @@ pred = predictive(m, :α, :β, :σ) ```` -Error: UndefVarError: predictive not defined +@model (x, α, β, σ) begin + yhat = α .+ β .* x + n = length(x) + y ~ For(n) do j + Normal(yhat[j], σ) + end + end ```` @@ -233,147 +183,42 @@ Error: UndefVarError: predictive not defined postpred = map(post) do θ delete(rand(pred(θ)((x=x,))), :n, :x) end |> particles -```` - - -```` -Error: UndefVarError: post not defined -```` - - - -````julia pvals = mean.(truth.y .> postpred.y) -```` - - -```` -Error: UndefVarError: truth not defined -```` - - - -````julia # PPC vs x scatter(truth.x, pvals, legend=false, dpi=300) -```` - - -```` -Error: UndefVarError: scatter not defined -```` - - - -````julia xlabel!("x") -```` - - -```` -Error: UndefVarError: xlabel! not defined -```` - - - -````julia ylabel!("Bayesian p-value") ```` -```` -Error: UndefVarError: ylabel! not defined -```` - - +![](figures/2019-11-07-demo_13_1.png) ````julia # # # # PPC vs y scatter(truth.y, pvals, legend=false, dpi=300) -```` - - -```` -Error: UndefVarError: scatter not defined -```` - - - -````julia xlabel!("y") -```` - - -```` -Error: UndefVarError: xlabel! not defined -```` - - - -````julia ylabel!("Bayesian p-value") ```` -```` -Error: UndefVarError: ylabel! not defined -```` - - +![](figures/2019-11-07-demo_14_1.png) ````julia using AverageShiftedHistograms -```` - - -```` -Error: ArgumentError: Package AverageShiftedHistograms not found in current - path: -- Run `import Pkg; Pkg.add("AverageShiftedHistograms")` to install the Aver -ageShiftedHistograms package. -```` - - - -````julia o = ash(pvals, rng=0:0.01:1, kernel=Kernels.cosine,m=8) -```` - - -```` -Error: UndefVarError: Kernels not defined -```` - - - -````julia plot(o, legend=false,dpi=300) -```` - - -```` -Error: UndefVarError: plot not defined -```` - - - -````julia xlabel!("Bayesian p-values") ```` -```` -Error: UndefVarError: xlabel! not defined -```` - - +![](figures/2019-11-07-demo_15_1.png) ````julia m2 = @model x begin @@ -391,67 +236,28 @@ end; ```` -```` -Error: LoadError: UndefVarError: @model not defined -in expression starting at none:1 -```` - - ````julia post2 = dynamicHMC(m2(x=truth.x), (y=truth.y,)) -```` - - -```` -Error: UndefVarError: truth not defined -```` - - - -````julia ppost2 = particles(post2) ```` ```` -Error: UndefVarError: particles not defined +(σ = 0.519 ± 0.094, νinv = 0.905 ± 0.18, β = 2.73 ± 0.069, α = 0.892 ± 0.06 +7) ```` ````julia eachplot(xx, ppost.α .+ ppost.β .* xx, lw=3, dpi=300, color=2) -```` - - -```` -Error: UndefVarError: ppost not defined -```` - - - -````julia eachplot!(xx, ppost2.α .+ ppost2.β .* xx, lw=3, dpi=300, color=:black) -```` - - -```` -Error: UndefVarError: ppost2 not defined -```` - - - -````julia scatter!(truth.x,truth.y, legend=false, c=1) ```` -```` -Error: UndefVarError: scatter! not defined -```` - - +![](figures/2019-11-07-demo_18_1.png) ````julia pred2 = predictive(m2, setdiff(stochastic(m2), [:y])...) @@ -459,7 +265,14 @@ pred2 = predictive(m2, setdiff(stochastic(m2), [:y])...) ```` -Error: UndefVarError: m2 not defined +@model (x, α, β, σ, νinv) begin + ν = 1 / νinv + yhat = α .+ β .* x + n = length(x) + y ~ For(n) do j + StudentT(ν, yhat[j], σ) + end + end ```` @@ -472,7 +285,15 @@ end |> particles ```` -Error: UndefVarError: post2 not defined +(α = 0.892 ± 0.067, β = 2.73 ± 0.069, σ = 0.519 ± 0.094, yhat = Particles{F +loat64,1000}[-2.7 ± 0.11, 1.26 ± 0.068, 4.59 ± 0.12, 1.01 ± 0.067, -2.48 ± +0.11, 2.26 ± 0.076, 1.49 ± 0.069, -0.811 ± 0.08, -3.54 ± 0.13, 2.3 ± 0.076 + … 5.25 ± 0.13, 0.975 ± 0.067, 0.9 ± 0.067, 4.53 ± 0.11, -0.985 ± 0.082, - +1.11 ± 0.084, 0.123 ± 0.07, 9.17 ± 0.22, 0.733 ± 0.067, -2.65 ± 0.11], y = +Particles{Float64,1000}[-2.7 ± 0.54, 1.24 ± 0.53, 4.57 ± 0.54, 1.0 ± 0.54, +-2.49 ± 0.55, 2.24 ± 0.52, 1.47 ± 0.52, -0.807 ± 0.55, -3.56 ± 0.55, 2.3 ± +0.53 … 5.27 ± 0.53, 0.976 ± 0.53, 0.883 ± 0.52, 4.52 ± 0.55, -1.01 ± 0.54 +, -1.12 ± 0.53, 0.0865 ± 0.54, 9.2 ± 0.55, 0.755 ± 0.53, -2.64 ± 0.55]) ```` @@ -483,7 +304,27 @@ pvals2 = mean.(truth.y .> post2pred.y) ```` -Error: UndefVarError: truth not defined +100-element Array{Float64,1}: + 0.738 + 0.644 + 0.42 + 0.0 + 0.119 + 0.988 + 0.961 + 0.543 + 0.266 + 0.269 + ⋮ + 0.142 + 0.96 + 0.357 + 0.987 + 0.306 + 0.417 + 0.466 + 0.0 + 0.501 ```` @@ -494,36 +335,12 @@ Error: UndefVarError: truth not defined # PPC vs x ````julia scatter(truth.x, pvals2, legend=false, dpi=300) -```` - - -```` -Error: UndefVarError: scatter not defined -```` - - - -````julia xlabel!("x") -```` - - -```` -Error: UndefVarError: xlabel! not defined -```` - - - -````julia ylabel!("Bayesian p-value") ```` -```` -Error: UndefVarError: ylabel! not defined -```` - - +![](figures/2019-11-07-demo_22_1.png) @@ -531,69 +348,21 @@ Error: UndefVarError: ylabel! not defined # PPC vs y ````julia scatter(truth.y, pvals2, legend=false, dpi=300) -```` - - -```` -Error: UndefVarError: scatter not defined -```` - - - -````julia xlabel!("y") -```` - - -```` -Error: UndefVarError: xlabel! not defined -```` - - - -````julia ylabel!("Bayesian p-value") ```` -```` -Error: UndefVarError: ylabel! not defined -```` - - +![](figures/2019-11-07-demo_23_1.png) ````julia o = ash(pvals2, rng=0:0.01:1, kernel=Kernels.cosine,m=8) -```` - - -```` -Error: UndefVarError: Kernels not defined -```` - - - -````julia plot(o, legend=false,dpi=300) -```` - - -```` -Error: UndefVarError: plot not defined -```` - - - -````julia xlabel!("Bayesian p-values") ```` -```` -Error: UndefVarError: xlabel! not defined -```` - - +![](figures/2019-11-07-demo_24_1.png) diff --git a/src/core/model.jl b/src/core/model.jl index 8106af40..605b6c45 100644 --- a/src/core/model.jl +++ b/src/core/model.jl @@ -22,6 +22,9 @@ bodytype(::Type{Model{A,B}}) where {A,B} = B getmodule(::Type{Model{A,B,M}}) where {A,B,M} = from_type(M) getmodule(::Model{A,B,M}) where {A,B,M} = from_type(M) +getmoduletypencoding(::Type{Model{A,B,M}}) where {A, B, M} = M +getmoduletypencoding(::Model{A,B,M}) where {A,B,M} = M + function Model(theModule::Module, args, vals, dists, retn) M = to_type(theModule) A = NamedTuple{Tuple(args)} @@ -184,4 +187,6 @@ Base.show(io::IO, m :: Model) = println(io, convert(Expr, m)) function findStatement(m::Model, x::Symbol) x ∈ keys(m.vals) && return Assign(x,m.vals[x]) x ∈ keys(m.dists) && return Sample(x,m.dists[x]) + x ∈ arguments(m) && return Arg(x) + error("statement not found") end \ No newline at end of file diff --git a/src/core/statement.jl b/src/core/statement.jl index 368cffa7..4760a43c 100644 --- a/src/core/statement.jl +++ b/src/core/statement.jl @@ -2,6 +2,10 @@ using MLStyle abstract type Statement end +struct Arg <: Statement + x :: Symbol +end + struct Assign <: Statement x :: Symbol rhs diff --git a/src/core/toposort.jl b/src/core/toposort.jl index e318659a..b3bf4f6d 100644 --- a/src/core/toposort.jl +++ b/src/core/toposort.jl @@ -33,5 +33,5 @@ end export toposortvars function toposortvars(m::Model) (g, _, names) = poset(m).D |> convert_simple - setdiff(map(v -> names[v], Graphs.topological_sort_by_dfs(g)), freeVariables(m)) + setdiff(map(v -> names[v], Graphs.topological_sort_by_dfs(g)), arguments(m)) end diff --git a/src/core/utils.jl b/src/core/utils.jl index bfa61bd5..6b0481b1 100644 --- a/src/core/utils.jl +++ b/src/core/utils.jl @@ -44,6 +44,8 @@ stochastic(m::Model) = keys(m.dists) export bound bound(m::Model) = keys(m.vals) +export bodyVariables +bodyVariables(m::Model) = setdiff(variables(m), arguments(m)) # TODO: Fix these broken methods # export observed @@ -55,10 +57,7 @@ bound(m::Model) = keys(m.vals) # observed(m) # ) -export freeVariables -function freeVariables(m::Model) - setdiff(arguments(m), stochastic(m)) -end + export foldall function foldall(leaf, branch; kwargs...) @@ -270,3 +269,5 @@ function tower(x) end return result end + +TypeLevel = GeneralizedGenerated.TypeLevel \ No newline at end of file diff --git a/src/importance.jl b/src/importance.jl index 7f8842ec..f9935392 100644 --- a/src/importance.jl +++ b/src/importance.jl @@ -3,14 +3,16 @@ using MonteCarloMeasurements export importanceSample @inline function importanceSample(p::JointDistribution, q::JointDistribution, _data) - return _importanceSample(p.model, p.args, q.model, q.args, _data) + return _importanceSample(getmoduletypencoding(p.model), p.model, p.args, q.model, q.args, _data) end -@gg function _importanceSample(p::Model, _pargs, q::Model, _qargs, _data) +@gg M function _importanceSample(_::Type{M}, p::Model, _pargs, q::Model, _qargs, _data) where M <: TypeLevel{Module} p = type2model(p) q = type2model(q) - sourceImportanceSample()(p,q) |> loadvals(_qargs, _data) |> loadvals(_pargs, NamedTuple()) + Expr(:let, + Expr(:(=), :M, from_type(M)), + sourceImportanceSample()(p,q) |> loadvals(_qargs, _data) |> loadvals(_pargs, NamedTuple())) end export sourceImportanceSample diff --git a/src/particles.jl b/src/particles.jl index c737de94..da067d57 100644 --- a/src/particles.jl +++ b/src/particles.jl @@ -59,15 +59,19 @@ parts(d::iid; N=1000) = map(1:d.size) do j parts(d.dist) end @inline function particles(m::JointDistribution) - return _particles(m.model, m.args) + return _particles(getmoduletypencoding(m.model), m.model, m.args) end -@gg function _particles(_m::Model, _args) - type2model(_m) |> sourceParticles() |> loadvals(_args, NamedTuple()) +@gg M function _particles(_::Type{M}, _m::Model, _args) where M <: TypeLevel{Module} + Expr(:let, + Expr(:(=), :M, from_type(M)), + type2model(_m) |> sourceParticles() |> loadvals(_args, NamedTuple())) end -@gg function _particles(_m::Model, _args::NamedTuple{()}) - type2model(_m) |> sourceParticles() +@gg M function _particles(_::Type{M}, _m::Model, _args::NamedTuple{()}) where M <: TypeLevel{Module} + Expr(:let, + Expr(:(=), :M, from_type(M)), + type2model(_m) |> sourceParticles()) end export sourceParticles diff --git a/src/primitives/likelihood-weighting.jl b/src/primitives/likelihood-weighting.jl index 9c3e1d9d..e535d48a 100644 --- a/src/primitives/likelihood-weighting.jl +++ b/src/primitives/likelihood-weighting.jl @@ -2,11 +2,13 @@ export weightedSample function weightedSample(m::JointDistribution, _data) - return _weightedSample(m.model, m.args, _data) + return _weightedSample(getmoduletypencoding(m.model), m.model, m.args, _data) end -@gg function _weightedSample(_m::Model, _args, _data) - type2model(_m) |> sourceWeightedSample(_data) |> loadvals(_args, _data) +@gg M function _weightedSample(_::Type{M}, _m::Model, _args, _data) where M <: TypeLevel{Module} + Expr(:let, + Expr(:(=), :M, from_type(M)), + type2model(_m) |> sourceWeightedSample(_data) |> loadvals(_args, _data)) end function sourceWeightedSample(_data) diff --git a/src/primitives/logpdf.jl b/src/primitives/logpdf.jl index ee256531..d5b380fe 100644 --- a/src/primitives/logpdf.jl +++ b/src/primitives/logpdf.jl @@ -1,12 +1,20 @@ export logpdf -function logpdf(m::JointDistribution,x, method=logpdf) - return method(m.model, m.args, x) +function logpdf(m::JointDistribution{A0,A,B,M},x) where {A0,A,B,M} + _logpdf(M, m.model, m.args, x) end -@gg function logpdf(_m::Model, _args, _data) - type2model(_m) |> sourceLogpdf() |> loadvals(_args, _data) +function logpdf(m::JointDistribution{A0,A,B,M},x, ::typeof(logpdf)) where {A0,A,B,M} + _logpdf(M, m.model, m.args, x) +end + + + +@gg M function _logpdf(_::Type{M}, _m::Model, _args, _data) where M <: TypeLevel{Module} + Expr(:let, + Expr(:(=), :M, from_type(M)), + type2model(_m) |> sourceLogpdf() |> loadvals(_args, _data)) end function sourceLogpdf() diff --git a/src/primitives/rand.jl b/src/primitives/rand.jl index dab21d19..305c2bd0 100644 --- a/src/primitives/rand.jl +++ b/src/primitives/rand.jl @@ -1,23 +1,26 @@ using GeneralizedGenerated export rand - EmptyNTtype = NamedTuple{(),Tuple{}} where T<:Tuple @inline function rand(m::JointDistribution) - return _rand(m.model, m.args) + return _rand(getmoduletypencoding(m.model), m.model, m.args) end @inline function rand(m::Model) - return _rand(m, NamedTuple()) + return _rand(getmoduletypencoding(m), m, NamedTuple()) end -@gg function _rand(_m::Model, _args) - type2model(_m) |> sourceRand() |> loadvals(_args, NamedTuple()) +@gg M function _rand(_::Type{M}, _m::Model, _args) where M <: TypeLevel{Module} + Expr(:let, + Expr(:(=), :M, from_type(M)), + type2model(_m) |> sourceRand() |> loadvals(_args, NamedTuple())) end -@gg function _rand(_m::Model, _args::NamedTuple{()}) - type2model(_m) |> sourceRand() +@gg M function _rand(_::Type{M}, _m::Model, _args::NamedTuple{()}) where M <: TypeLevel{Module} + Expr(:let, + Expr(:(=), :M, from_type(M)), + type2model(_m) |> sourceRand()) end export sourceRand diff --git a/src/primitives/xform.jl b/src/primitives/xform.jl index 85383263..773f243f 100644 --- a/src/primitives/xform.jl +++ b/src/primitives/xform.jl @@ -12,11 +12,13 @@ export xform function xform(m::JointDistribution{A, B}, _data) where {A,B} - return _xform(m.model, m.args, _data) + return _xform(getmoduletypencoding(m.model), m.model, m.args, _data) end -@gg function _xform(_m::Model{Asub,B}, _args::A, _data) where {Asub, A,B} - type2model(_m) |> sourceXform(_data) |> loadvals(_args, _data) +@gg M function _xform(_::Type{M}, _m::Model{Asub,B}, _args::A, _data) where {M <: TypeLevel{Module}, Asub, A,B} + Expr(:let, + Expr(:(=), :M, from_type(M)), + type2model(_m) |> sourceXform(_data) |> loadvals(_args, _data)) end # function xform(m::Model{EmptyNTtype, B}) where {B} @@ -53,7 +55,7 @@ function sourceXform(_data=NamedTuple()) wrap(kernel) = @q begin _result = NamedTuple() $kernel - as(_result) + $as(_result) end buildSource(_m, proc, wrap) |> flatten diff --git a/src/symbolic/codegen.jl b/src/symbolic/codegen.jl index c52aa281..f292b236 100644 --- a/src/symbolic/codegen.jl +++ b/src/symbolic/codegen.jl @@ -154,4 +154,9 @@ end # var"##add#407" += x # var"##add#407" += y # var"##add#407" -# end \ No newline at end of file +# end + + +function logpdf(m::JointDistribution{A0,A,B,M},x,::typeof(codegen)) where {A0,A,B,M} + codegen(M, m.model, m.args, x) +end \ No newline at end of file diff --git a/src/symbolic/reduce.jl b/src/symbolic/reduce.jl index 4714e8fa..ff0e3051 100644 --- a/src/symbolic/reduce.jl +++ b/src/symbolic/reduce.jl @@ -12,11 +12,13 @@ end export reduce function reduce(m::JointDistribution,x) - return _reduce(m.model, m.args, x) + return _reduce(getmoduletypencoding(m.model), m.model, m.args, x) end -@gg function _reduce(_m::Model, _args, _data) - type2model(_m) |> sourceReduce() |> loadvals(_args, _data) +@gg M function _reduce(_::Type{M}, _m::Model, _args, _data) where M <: TypeLevel{Module} + Expr(:let, + Expr(:(=), :M, from_type(M)), + type2model(_m) |> sourceReduce() |> loadvals(_args, _data)) end function sourceReduce() diff --git a/src/symbolic/symbolic.jl b/src/symbolic/symbolic.jl index 4ac10684..04531bbb 100644 --- a/src/symbolic/symbolic.jl +++ b/src/symbolic/symbolic.jl @@ -1,12 +1,22 @@ using MacroTools: @q +using GeneralizedGenerated import PyCall using MLStyle import SymPy -using SymPy: Sym, sympy, symbols, free_symbols -import SymPy.sympy +using SymPy: Sym, symbols, free_symbols +"""As type encoding a PyObject is unsafe without some hard +works with reference counting, we simply use a Julia proxy +to imitate the `sympy` module, e.g., + `sympy.attr = _pysympy.attr` +""" +struct PySymPyModule end +GeneralizedGenerated.NGG.@implement GeneralizedGenerated.NGG.Typeable{PySymPyModule} +const sympy = PySymPyModule() +Base.getproperty(::PySymPyModule, s::Symbol) = Base.getproperty(_pysympy, s) + const symfuncs = Dict() _pow(a,b) = Base.:^(float(a),b) @@ -15,7 +25,8 @@ function __init__() stats = PyCall.pyimport_conda("sympy.stats", "sympy") SymPy.import_from(stats) global stats = stats - + global _pysympy = SymPy.sympy + # for dist in [:Normal, :Cauchy, :Laplace, :Beta, :Uniform] @@ -26,7 +37,7 @@ function __init__() # end # Distributions.$dist(μ,σ) = $dist(promote(μ,σ)...) - # end + # end # end @@ -40,9 +51,11 @@ function __init__() )) @eval begin - @gg function codegen(_m::Model, _args, _data) + @gg M function codegen(_::Type{M}, _m::Model, _args, _data) where M <: TypeLevel{Module} f = _codegen(type2model(_m)) - :($f(_args, _data)) + Expr(:let, + Expr(:(=), :M, from_type(M)), + :($f(_args, _data))) end end @@ -62,17 +75,17 @@ for dist in [:Bernoulli] @eval begin logpdf(d::$dist, x::Sym) = logpdf($dist(sym.(Distributions.params(d))...), x) - end + end end -"Half" distributions +"Half" distributions for dist in [:Normal, :Cauchy] let half = Symbol(:Half, dist) @eval begin logpdf(d::$half, x::Sym) = 2 * logpdf($dist(0, sym.(d.σ)), x) - end + end end end @@ -80,17 +93,17 @@ for dist in [:Normal, :Cauchy, :Laplace, :Beta, :Uniform] @eval begin function Distributions.$dist(μ::Sym, σ::Sym) stats.$dist(:dist, μ,σ) |> SymPy.density - end + end Distributions.$dist(μ,σ) = $dist(promote(μ,σ)...) - end + end end export sym sym(s::Symbol) = sympy.IndexedBase(s, real=true) sym(s) = Base.convert(Sym, s) -function sym(expr::Expr) +function sym(expr::Expr) @match expr begin Expr(:call, f, args...) => :($f($(map(sym,args)...))) :($x[$j]) => begin @@ -105,7 +118,7 @@ function sym(expr::Expr) end end - + # export symlogpdf # # function symlogpdf(m::Model) @@ -126,7 +139,7 @@ end # # append!(result.args, exprs) - + # # # result # # push!(result.args, :(ctx,ℓ)) @@ -174,14 +187,14 @@ end # # x = sympy.IndexedBase(x) # # return :(Soss.sympy.Sum(logpdf($dist,$x[$j]), ($j,1,$n))) # # end - + # # f => begin # # @show f # # error("symlogpdf: bad argument") # # end # # end - + # # end # # _ => :(logpdf($(sym(d)), $(sym(x)))) @@ -229,10 +242,10 @@ end function expandMulSum(factors::NTuple{N,Sym}, limits::Sym...) where {N} limits == () && return prod(factors) - for fac in factors - for lim in limits + for fac in factors + for lim in limits (ix, ixlo, ixhi) = lim.args - if ix ∉ fac + if !insym(ix, fac) inSummand = prod(allbut(factors, fac)) inSum = expandSum(inSummand, lim) outLims = allbut(limits, lim) @@ -248,8 +261,7 @@ function atoms(s::Sym) union(result, map(x -> x.args, result)...) end -import Base.in -function Base.in(j::Sym, s::Sym) +function insym(j::Sym, s::Sym) j ∈ atoms(s) # for t in s.args # if j==t || in(j,t) @@ -261,8 +273,8 @@ end hasIdx(s::Sym) = any(startswith.(getproperty.(Soss.atoms(s), :name), "_j")) -function allbut(tup, x) - result = filter(collect(tup)) do v +function allbut(tup, x) + result = filter(collect(tup)) do v v ≠ x end tuple(result...) @@ -274,8 +286,8 @@ function maybeSum(t::Sym, limits::Sym...) for lim in limits (ix, ixlo, ixhi) = lim.args - ix ∈ t || return maybeSum(t * (ixhi - ixlo + 1), allbut(limits, lim)...) - end + insym(ix, t) || return maybeSum(t * (ixhi - ixlo + 1), allbut(limits, lim)...) + end return sympy.Sum(t, limits...) end @@ -327,12 +339,12 @@ end export sourceSymlogpdf function sourceSymlogpdf() function(_m::Model) - function proc(_m, st :: Assign) + function proc(_m, st :: Assign) # :($(st.x) = $(st.rhs)) - x = st.x + x = st.x xname = QuoteNode(x) - :($x = sympy.IndexedBase($xname)) - end + :($x = $sympy.IndexedBase($xname)) + end function proc(_m, st :: Sample) @q begin @@ -349,30 +361,30 @@ function sourceSymlogpdf() for x in variables(_m) xname = QuoteNode(x) - push!(q.args, :($x = sympy.IndexedBase($xname))) + push!(q.args, :($x = $sympy.IndexedBase($xname))) end for st in map(v -> findStatement(_m,v), toposortvars(_m)) typeof(st) == Sample || continue - x = st.x + x = st.x xname = QuoteNode(x) - rhs = st.rhs + rhs = st.rhs xsym = ifelse(rhs.args[1] ∈ [:For, :iid] - , :(sympy.IndexedBase($xname)) - , :(sym($xname)) + , :($sympy.IndexedBase($xname)) + , :($sym($xname)) ) # push!(q.args, :($x = $xsym)) end @q begin - $q + $q $kernel return _ℓ end end - + buildSource(_m, proc, wrap) |> flatten end @@ -393,10 +405,10 @@ export symlogpdf function symlogpdf(d::For{F,T,D,X}, x::Sym) where {F, N, J <: Union{Sym,Integer}, T <: NTuple{N,J}, D, X} js = symbols.(Symbol.(:_j,1:N), cls=sympy.Idx) x = sympy.IndexedBase(x) - result = symlogpdf(d.f(js...), x[js...]) + result = symlogpdf(d.f(js...), x[js...]) for k in N:-1:1 - result = sympy.Sum(result, (js[k], 1, d.θ[k])) + result = sympy.Sum(result, (js[k], 1, d.θ[k])) end result end @@ -421,7 +433,7 @@ symlogpdf(d::Beta, x::Sym) = symlogpdf(Beta(sym(d.α),sym(d.β)), x) logpdf(d::Sym, x::Sym) = symlogpdf(d,x) -function symlogpdf(d::Sym, x::Sym) +function symlogpdf(d::Sym, x::Sym) d.func result = d.pdf(x) |> log sympy.expand_log(result,force=true) @@ -430,15 +442,17 @@ end symlogpdf(d,x::Sym) = logpdf(d,x) function symlogpdf(m::JointDistribution) - return _symlogpdf(m.model) + return _symlogpdf(getmoduletypencoding(m.model), m.model) end function symlogpdf(m::Model) - return _symlogpdf(m) + return _symlogpdf(getmoduletypencoding(m), m) end -@gg function _symlogpdf(_m::Model) - type2model(_m) |> canonical |> sourceSymlogpdf() +@gg M function _symlogpdf(_::Type{M}, _m::Model) where M <: TypeLevel{Module} + Expr(:let, + Expr(:(=), :M, from_type(M)), + type2model(_m) |> canonical |> sourceSymlogpdf()) end @@ -459,14 +473,14 @@ end # # x # # julia> a = sympy.Sum(x[i], (i, 1, j)) -# # j -# # ___ -# # ╲ +# # j +# # ___ +# # ╲ # # ╲ x[i] -# # ╱ -# # ╱ -# # ‾‾‾ -# # i = 1 +# # ╱ +# # ╱ +# # ‾‾‾ +# # i = 1 # # julia> SymPy.walk_expression(a) # # :(Sum(Indexed(IndexedBase(x), i), (:i, 1, :j))) diff --git a/src/transforms/markovblanket.jl b/src/transforms/markovblanket.jl index 13a09663..efe16ec3 100644 --- a/src/transforms/markovblanket.jl +++ b/src/transforms/markovblanket.jl @@ -8,7 +8,7 @@ export children children(g::SimpleDigraph, v) = g.N[v] |> collect export partners -function partners(g::SimpleDigraph, v) +function partners(g::SimpleDigraph, v::Symbol) s = map(collect(children(g,v))) do x parents(g,x) end @@ -18,7 +18,83 @@ function partners(g::SimpleDigraph, v) setdiff(union(s...),[v]) |> collect end -markovBlanket(g,v) = [v] ∪ parents(g,v) ∪ children(g,v) ∪ partners(g,v) + +# function stochParents(m::Model, g::SimpleDigraph, v::Symbol, acc=Symbol[]) +# pars = parents(g,v) + +# result = union(pars, acc) +# for p in pars +# union!(result, _stochParents(m, g, findStatement(m,p))) +# end +# result +# end + +# _stochParents(m::Model, g::SimpleDigraph, st::Sample, acc=Symbol[]) = union([st.x], acc) +# _stochParents(m::Model, g::SimpleDigraph, st::Assign, acc=Symbol[]) = stochParents(m, g, st.x, union([st.x],acc)) +# _stochParents(m::Model, g::SimpleDigraph, st::Arg, acc=Symbol[]) = [st.x] +# _stochParents(m::Model, g::SimpleDigraph, st::Bool, acc=Symbol[]) = [] + + +#################### + +function stochChildren(m::Model, g::SimpleDigraph, v::Symbol, acc=Symbol[]) + pars = children(g,v) + + result = union(pars, acc) + for p in pars + union!(result, _stochChildren(m, g, findStatement(m,p))) + end + result +end + +_stochChildren(m::Model, g::SimpleDigraph, st::Sample, acc=Symbol[]) = union([st.x], acc) +_stochChildren(m::Model, g::SimpleDigraph, st::Assign, acc=Symbol[]) = stochChildren(m, g, st.x, union([st.x],acc)) +_stochChildren(m::Model, g::SimpleDigraph, st::Bool, acc=Symbol[]) = [] + + +######################## + +function stochPartners(m::Model, g::SimpleDigraph, v::Symbol) + s = map(collect(stochChildren(m,g,v))) do x + parents(g,x) + end + + isempty(s) && return [] + + setdiff(union(s...),[v]) |> collect +end + +function markovBlanketVars(m::Model, g::SimpleDigraph, v::Symbol) + markovBlanketVars(m,g,findStatement(m,v)) +end + +function markovBlanketVars(m::Model, g::SimpleDigraph,st::Sample) + p = poset(m) + g = digraph(m) + + body = Symbol[st.x] + + args = union(parents(g,st.x), stochPartners(m,g,st.x)) + for arg in args + union!(body, interval(p, arg, st.x)) + end + + ys = stochChildren(m,g,st.x) + union!(body, ys) + for y in ys + union!(body, interval(p, st.x, y)) + end + + setdiff!(args, body) + (args, body) +end + +function markovBlanketVars(m::Model, g::SimpleDigraph,st::Assign) + (parents(g,st.x), Symbol[st.x]) +end + + +# markovBlanket(g,v) = [v] ∪ stochParents(g,v) ∪ stochChildren(g,v) ∪ stochPartners(g,v) # function Base.convert(::Type{SimpleGraph}, d::SimpleDigraph) # g = SimpleGraph{Symbol}() @@ -41,14 +117,15 @@ function markovBlanket(m::Model, x :: Symbol) # newargs = (arguments(m) ∪ [x]) ∩ part # setdiff!(part, newargs) - newargs = parents(g,x) ∪ partners(g,x) + (args, vars) = markovBlanketVars(m,g,x) - m_init = Model(newargs, NamedTuple(), NamedTuple(), nothing) - m_init = merge(m_init, Model(findStatement(m,x))) - m = foldl(children(g,x); init= m_init) do m0,v - merge(m0, Model(findStatement(m, v))) + M = getmodule(m) + m_init = Model(M, args, NamedTuple(), NamedTuple(), nothing) + m_init = merge(m_init, Model(M,findStatement(m,x))) + m = foldl(vars; init= m_init) do m0,v + merge(m0, Model(M,findStatement(m, v))) end - m = merge(m, Model(findStatement(m,x))) + m = merge(m, Model(M,findStatement(m,x))) end