diff --git a/Project.toml b/Project.toml index 758e95243..23566ffaf 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +MeasureTheory = "eadaa1a4-d27c-401d-8699-e962e1bbc33b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a46c941a1..4d64cccd6 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -129,5 +129,6 @@ include("prob_macro.jl") include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") +include("measuretheory.jl") end # module diff --git a/src/measuretheory.jl b/src/measuretheory.jl new file mode 100644 index 000000000..20efa8b30 --- /dev/null +++ b/src/measuretheory.jl @@ -0,0 +1,70 @@ +using MeasureTheory: MeasureTheory + +# src/compiler.jl +# Allow `AbstractMeasure` on the RHS of `~`. +check_tilde_rhs(x::MeasureTheory.AbstractMeasure) = x + +# src/utils.jl +# Linearization. +vectorize(d::MeasureTheory.AbstractMeasure, x::Real) = [x] +vectorize(d::MeasureTheory.AbstractMeasure, x::AbstractArray{<:Real}) = copy(vec(x)) + +function reconstruct(d::MeasureTheory.AbstractMeasure, x::AbstractVector{<:Real}) + return _reconstruct(d, x, MeasureTheory.sampletype(d)) +end + +# TODO: Higher dims. What to do? Do we have access to size, e.g. for `LKJ` we should have? +function _reconstruct( + d::MeasureTheory.AbstractMeasure, x::AbstractVector{<:Real}, ::Type{<:Real} +) + return x[1] +end +function _reconstruct( + d::MeasureTheory.AbstractMeasure, + x::AbstractVector{<:Real}, + ::Type{<:AbstractVector{<:Real}}, +) + return x +end + +# src/context_implementations.jl +# assume +function assume(dist::MeasureTheory.AbstractMeasure, vn::VarName, vi) + r = vi[vn] + # TODO: Transformed variables. + return r, MeasureTheory.logdensity(dist, r) +end + +function assume( + rng::Random.AbstractRNG, + sampler::Union{SampleFromPrior,SampleFromUniform}, + dist::MeasureTheory.AbstractMeasure, + vn::VarName, + vi, +) + if haskey(vi, vn) + # Always overwrite the parameters with new ones for `SampleFromUniform`. + if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") + unset_flag!(vi, vn, "del") + r = init(rng, dist, sampler) + vi[vn] = vectorize(dist, r) + settrans!(vi, false, vn) + setorder!(vi, vn, get_num_produce(vi)) + else + r = vi[vn] + end + else + r = init(rng, dist, sampler) + push!(vi, vn, r, dist, sampler) + settrans!(vi, false, vn) + end + + # TODO: Transformed variables. + return r, MeasureTheory.logdensity(dist, r) +end + +# observe +function observe(right::MeasureTheory.AbstractMeasure, left, vi) + increment_num_produce!(vi) + return MeasureTheory.logdensity(right, left) +end diff --git a/src/varinfo.jl b/src/varinfo.jl index 64c122dc2..450c121fe 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -43,7 +43,7 @@ barrier to make the rest of the sampling type stable. """ struct Metadata{ TIdcs<:Dict{<:VarName,Int}, - TDists<:AbstractVector{<:Distribution}, + TDists<:AbstractVector, TVN<:AbstractVector{<:VarName}, TVal<:AbstractVector{<:Real}, TGIds<:AbstractVector{Set{Selector}}, @@ -192,7 +192,7 @@ function Metadata() Vector{VarName}(), Vector{UnitRange{Int}}(), vals, - Vector{Distribution}(), + Vector(), Vector{Set{Selector}}(), Vector{Int}(), flags, @@ -1098,7 +1098,7 @@ end Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to the `VarInfo` `vi`. """ -function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) +function push!(vi::AbstractVarInfo, vn::VarName, r, dist) return push!(vi, vn, r, dist, Set{Selector}([])) end @@ -1110,12 +1110,10 @@ from a distribution `dist` to `VarInfo` `vi`. The sampler is passed here to invalidate its cache where defined. """ -function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler) +function push!(vi::AbstractVarInfo, vn::VarName, r, dist, spl::Sampler) return push!(vi, vn, r, dist, spl.selector) end -function push!( - vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler -) +function push!(vi::AbstractVarInfo, vn::VarName, r, dist, spl::AbstractSampler) return push!(vi, vn, r, dist) end @@ -1125,10 +1123,10 @@ end Push a new random variable `vn` with a sampled value `r` sampled with a sampler of selector `gid` from a distribution `dist` to `VarInfo` `vi`. """ -function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) +function push!(vi::AbstractVarInfo, vn::VarName, r, dist, gid::Selector) return push!(vi, vn, r, dist, Set([gid])) end -function push!(vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}) +function push!(vi::VarInfo, vn::VarName, r, dist, gidset::Set{Selector}) if vi isa UntypedVarInfo @assert ~(vn in keys(vi)) "[push!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" elseif vi isa TypedVarInfo