From 19335dd0aaed3d5a4a4338f88599532330442f4d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 27 Jul 2021 01:44:22 +0100 Subject: [PATCH 1/5] relax some unnecessary type-constraints on VarInfo --- src/varinfo.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index a226506f4..f81bf6da8 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -43,7 +43,7 @@ barrier to make the rest of the sampling type stable. """ struct Metadata{ TIdcs<:Dict{<:VarName,Int}, - TDists<:AbstractVector{<:Distribution}, + TDists<:AbstractVector, TVN<:AbstractVector{<:VarName}, TVal<:AbstractVector{<:Real}, TGIds<:AbstractVector{Set{Selector}}, @@ -192,7 +192,7 @@ function Metadata() Vector{VarName}(), Vector{UnitRange{Int}}(), vals, - Vector{Distribution}(), + Vector(), Vector{Set{Selector}}(), Vector{Int}(), flags, @@ -1096,7 +1096,7 @@ end Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to the `VarInfo` `vi`. """ -function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) +function push!(vi::AbstractVarInfo, vn::VarName, r, dist) return push!(vi, vn, r, dist, Set{Selector}([])) end @@ -1108,11 +1108,11 @@ from a distribution `dist` to `VarInfo` `vi`. The sampler is passed here to invalidate its cache where defined. """ -function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler) +function push!(vi::AbstractVarInfo, vn::VarName, r, dist, spl::Sampler) return push!(vi, vn, r, dist, spl.selector) end function push!( - vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler + vi::AbstractVarInfo, vn::VarName, r, dist, spl::AbstractSampler ) return push!(vi, vn, r, dist) end @@ -1123,10 +1123,10 @@ end Push a new random variable `vn` with a sampled value `r` sampled with a sampler of selector `gid` from a distribution `dist` to `VarInfo` `vi`. """ -function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) +function push!(vi::AbstractVarInfo, vn::VarName, r, dist, gid::Selector) return push!(vi, vn, r, dist, Set([gid])) end -function push!(vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}) +function push!(vi::VarInfo, vn::VarName, r, dist, gidset::Set{Selector}) if vi isa UntypedVarInfo @assert ~(vn in keys(vi)) "[push!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" elseif vi isa TypedVarInfo From a68aeeda83be679426d92455f8c50ddc19a7ef0c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 27 Jul 2021 01:45:33 +0100 Subject: [PATCH 2/5] added simple impl of sampling and evaluation for MeasureTheory --- Project.toml | 1 + src/measuretheory.jl | 62 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) create mode 100644 src/measuretheory.jl diff --git a/Project.toml b/Project.toml index 99c211ccc..9b3d0e6bb 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +MeasureTheory = "eadaa1a4-d27c-401d-8699-e962e1bbc33b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" diff --git a/src/measuretheory.jl b/src/measuretheory.jl new file mode 100644 index 000000000..bcb23188f --- /dev/null +++ b/src/measuretheory.jl @@ -0,0 +1,62 @@ +using MeasureTheory: MeasureTheory + +# src/compiler.jl +# Allow `AbstractMeasure` on the RHS of `~`. +check_tilde_rhs(x::MeasureTheory.AbstractMeasure) = x + +# src/utils.jl +# Linearization. +vectorize(d::MeasureTheory.AbstractMeasure, x::Real) = [x] +vectorize(d::MeasureTheory.AbstractMeasure, x::AbstractArray{<:Real}) = copy(vec(x)) + +function reconstruct(d::MeasureTheory.AbstractMeasure, x::AbstractVector{<:Real}) + return _reconstruct(d, x, sampletype(d)) +end + +# TODO: Higher dims. What to do? Do we have access to size, e.g. for `LKJ` we should have? +_reconstruct(d::MeasureTheory.AbstractMeasure, x::AbstractVector{<:Real}, ::Type{<:Real}) = x[1] +function _reconstruct(d::MeasureTheory.AbstractMeasure, x::AbstractVector{<:Real}, ::Type{<:AbstractVector{<:Real}}) + return x +end + +# src/context_implementations.jl +# assume +function assume(dist::MeasureTheory.AbstractMeasure, vn::VarName, vi) + r = vi[vn] + # TODO: Transformed variables. + return r, MeasureTheory.logdensity(dist, r) +end + +function assume( + rng::Random.AbstractRNG, + sampler::Union{SampleFromPrior,SampleFromUniform}, + dist::MeasureTheory.AbstractMeasure, + vn::VarName, + vi, +) + if haskey(vi, vn) + # Always overwrite the parameters with new ones for `SampleFromUniform`. + if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") + unset_flag!(vi, vn, "del") + r = init(rng, dist, sampler) + vi[vn] = vectorize(dist, r) + settrans!(vi, false, vn) + setorder!(vi, vn, get_num_produce(vi)) + else + r = vi[vn] + end + else + r = init(rng, dist, sampler) + push!(vi, vn, r, dist, sampler) + settrans!(vi, false, vn) + end + + # TODO: Transformed variables. + return r, MeasureTheory.logdensity(dist, r) +end + +# observe +function observe(right::MeasureTheory.AbstractMeasure, left, vi) + increment_num_produce!(vi) + return MeasureTheory.logdensity(right, left) +end From 267e9fd03e2cb19f92093ac1e4f3097ea1883eab Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 27 Jul 2021 01:46:40 +0100 Subject: [PATCH 3/5] formatting --- src/measuretheory.jl | 12 ++++++++++-- src/varinfo.jl | 4 +--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/measuretheory.jl b/src/measuretheory.jl index bcb23188f..7fb1448a1 100644 --- a/src/measuretheory.jl +++ b/src/measuretheory.jl @@ -14,8 +14,16 @@ function reconstruct(d::MeasureTheory.AbstractMeasure, x::AbstractVector{<:Real} end # TODO: Higher dims. What to do? Do we have access to size, e.g. for `LKJ` we should have? -_reconstruct(d::MeasureTheory.AbstractMeasure, x::AbstractVector{<:Real}, ::Type{<:Real}) = x[1] -function _reconstruct(d::MeasureTheory.AbstractMeasure, x::AbstractVector{<:Real}, ::Type{<:AbstractVector{<:Real}}) +function _reconstruct( + d::MeasureTheory.AbstractMeasure, x::AbstractVector{<:Real}, ::Type{<:Real} +) + return x[1] +end +function _reconstruct( + d::MeasureTheory.AbstractMeasure, + x::AbstractVector{<:Real}, + ::Type{<:AbstractVector{<:Real}}, +) return x end diff --git a/src/varinfo.jl b/src/varinfo.jl index f81bf6da8..b3ec2e453 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1111,9 +1111,7 @@ The sampler is passed here to invalidate its cache where defined. function push!(vi::AbstractVarInfo, vn::VarName, r, dist, spl::Sampler) return push!(vi, vn, r, dist, spl.selector) end -function push!( - vi::AbstractVarInfo, vn::VarName, r, dist, spl::AbstractSampler -) +function push!(vi::AbstractVarInfo, vn::VarName, r, dist, spl::AbstractSampler) return push!(vi, vn, r, dist) end From 4443ec832f64b91ad9ef59ecca5f7de2a114dca0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 27 Jul 2021 01:57:27 +0100 Subject: [PATCH 4/5] include measuretheory --- src/DynamicPPL.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a46c941a1..4d64cccd6 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -129,5 +129,6 @@ include("prob_macro.jl") include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") +include("measuretheory.jl") end # module From 4714178e92dd6bb5cfebcefe582677a6f2098867 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 27 Jul 2021 02:02:18 +0100 Subject: [PATCH 5/5] forgot a namespace specification --- src/measuretheory.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/measuretheory.jl b/src/measuretheory.jl index 7fb1448a1..20efa8b30 100644 --- a/src/measuretheory.jl +++ b/src/measuretheory.jl @@ -10,7 +10,7 @@ vectorize(d::MeasureTheory.AbstractMeasure, x::Real) = [x] vectorize(d::MeasureTheory.AbstractMeasure, x::AbstractArray{<:Real}) = copy(vec(x)) function reconstruct(d::MeasureTheory.AbstractMeasure, x::AbstractVector{<:Real}) - return _reconstruct(d, x, sampletype(d)) + return _reconstruct(d, x, MeasureTheory.sampletype(d)) end # TODO: Higher dims. What to do? Do we have access to size, e.g. for `LKJ` we should have?