From 14594d655a2e719ffc001afb2907b1f26c471e8f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 7 Jan 2022 20:54:56 +0000 Subject: [PATCH 001/101] performing linking in assume rather than implicitly in getindex --- src/context_implementations.jl | 18 +++++++++++++----- src/varinfo.jl | 6 ++++++ test/model.jl | 21 +++++++++++++++++++++ 3 files changed, 40 insertions(+), 5 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 20c4af446..bcb3ad463 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -194,7 +194,9 @@ end # fallback without sampler function assume(dist::Distribution, vn::VarName, vi) - r = vi[vn] + # x = vi[vn] + r_raw = getindex_raw(vi, vn) + r = maybe_invlink(vi, vn, dist, r_raw) return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi end @@ -215,7 +217,9 @@ function assume( settrans!(vi, false, vn) setorder!(vi, vn, get_num_produce(vi)) else - r = vi[vn] + # r = vi[vn] + r_raw = getindex_raw(vi, vn) + r = maybe_invlink(vi, vn, dist, r_raw) end else r = init(rng, dist, sampler) @@ -390,7 +394,9 @@ function dot_assume( # m .~ Normal() # # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - r = vi[vns] + # r = vi[vns] + r_raw = getindex_raw(vi, vns) + r = maybe_invlink(vi, vn, dist, r_raw) lp = sum(zip(vns, eachcol(r))) do (vn, ri) return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) end @@ -423,7 +429,8 @@ function dot_assume( # m .~ Normal() # # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - r = reshape(vi[vec(vns)], size(vns)) + r_raw = getindex_raw(vi, vec(vns)) + r = reshape(maybe_invlink.(Ref(vi), vns, dists, r_raw), size(vns)) lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) return r, lp, vi end @@ -467,7 +474,8 @@ function get_and_set_val!( setorder!(vi, vn, get_num_produce(vi)) end else - r = vi[vns] + r_raw = getindex_raw(vi, vns) + r = maybe_invlink(vi, vns, dist, r_raw) end else r = init(rng, dist, spl, n) diff --git a/src/varinfo.jl b/src/varinfo.jl index 9ce0414d6..c8ffca5df 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -866,6 +866,9 @@ end return expr end +maybe_invlink(vi, vn, dist, val) = istrans(vi, vn) ? Bijectors.invlink(dist, val) : val + + """ islinked(vi::VarInfo, spl::Union{Sampler, SampleFromPrior}) @@ -923,6 +926,9 @@ function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) end end +getindex_raw(vi::AbstractVarInfo, vn::VarName) = reconstruct(getdist(vi, vn), getval(vi, vn)) +getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}) = reconstruct(getdist(vi, first(vns)), getval(vi, vns)) + """ getindex(vi::VarInfo, spl::Union{SampleFromPrior, Sampler}) diff --git a/test/model.jl b/test/model.jl index 466a7d1f4..7e417eaa7 100644 --- a/test/model.jl +++ b/test/model.jl @@ -81,4 +81,25 @@ call_retval = model() @test !any(map(x -> x isa DynamicPPL.AbstractVarInfo, call_retval)) end + + @testset "Dynamic constraints" begin + @model function dynamic_constraints() + m ~ Normal() + x ~ truncated(Normal(), m, Inf) + end + + model = dynamic_constraints() + vi = VarInfo(model) + spl = SampleFromPrior() + link!(vi, spl) + + for i = 1:10 + # Sample with large variations. + r_raw = randn(length(vi[spl])) * 10 + vi[spl] = r_raw + @test vi[@varname(m)] == r_raw[1] + @test vi[@varname(x)] != r_raw[2] + model(vi) + end + end end From 0bc279f5848eb7c42794df9a0df6e490f74303d0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 7 Jan 2022 22:20:58 +0000 Subject: [PATCH 002/101] added istrans to SimpleVarInfo --- src/simple_varinfo.jl | 87 ++++++++++++++++--------------------------- 1 file changed, 32 insertions(+), 55 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 5cecda4b2..85896792b 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -129,12 +129,13 @@ ERROR: type NamedTuple has no field b struct SimpleVarInfo{NT,T} <: AbstractVarInfo values::NT logp::T + istrans::Bool end -SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) -SimpleVarInfo{T}(; kwargs...) where {T<:Real} = SimpleVarInfo{T}(NamedTuple(kwargs)) -SimpleVarInfo(; kwargs...) = SimpleVarInfo{Float64}(NamedTuple(kwargs)) -SimpleVarInfo(θ) = SimpleVarInfo{Float64}(θ) +SimpleVarInfo{T}(θ, istrans::Bool=false) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T), istrans) +SimpleVarInfo{T}(istrans::Bool=false; kwargs...) where {T<:Real} = SimpleVarInfo{T}(NamedTuple(kwargs), istrans) +SimpleVarInfo(istrans::Bool=false; kwargs...) = SimpleVarInfo{Float64}(NamedTuple(kwargs), istrans) +SimpleVarInfo(θ, istrans::Bool=false) = SimpleVarInfo{Float64}(θ, istrans) # Constructor from `Model`. SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...) @@ -158,8 +159,8 @@ function BangBang.empty!!(vi::SimpleVarInfo) end getlogp(vi::SimpleVarInfo) = vi.logp -setlogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.values, logp) -acclogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.values, getlogp(vi) + logp) +setlogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = logp +acclogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = getlogp(vi) + logp """ keys(vi::SimpleVarInfo) @@ -179,7 +180,7 @@ function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) end function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo) - return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ")") + return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ", ", svi.istrans, ")") end # `NamedTuple` @@ -224,6 +225,11 @@ Base.getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.values # TODO: Should we do better? Base.getindex(vi::SimpleVarInfo, spl::Sampler) = vi.values +# Since we don't perform any transformations in `getindex` for `SimpleVarInfo` +# we simply call `getindex` in `getindex_raw`. +getindex_raw(vi::SimpleVarInfo, vn::VarName) = vi[vn] +getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}) = vi[vns] + Base.haskey(vi::SimpleVarInfo, vn::VarName) = _haskey(vi.values, vn) function _haskey(nt::NamedTuple, vn::VarName) # LHS: Ensure that `nt` indeed has the property we want. @@ -337,58 +343,21 @@ function Base.eltype( end # Context implementations -function assume(dist::Distribution, vn::VarName, vi::SimpleOrThreadSafeSimple) - left = vi[vn] - return left, Distributions.loglikelihood(dist, left), vi -end - +# NOTE: Evaluations, i.e. those without `rng` are shared with other +# implementations of `AbstractVarInfo`. function assume( rng::Random.AbstractRNG, - sampler::SampleFromPrior, + sampler::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, vi::SimpleOrThreadSafeSimple, ) value = init(rng, dist, sampler) - vi = BangBang.push!!(vi, vn, value, dist, sampler) - return value, Distributions.loglikelihood(dist, value), vi -end - -function dot_assume( - dist::MultivariateDistribution, - var::AbstractMatrix, - vns::AbstractVector{<:VarName}, - vi::SimpleOrThreadSafeSimple, -) - @assert length(dist) == size(var, 1) - # NOTE: We cannot work with `var` here because we might have a model of the form - # - # m = Vector{Float64}(undef, n) - # m .~ Normal() - # - # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - value = vi[vns] - lp = sum(zip(vns, eachcol(value))) do (vn, val) - return Distributions.logpdf(dist, val) - end - return value, lp, vi -end - -function dot_assume( - dists::Union{Distribution,AbstractArray{<:Distribution}}, - var::AbstractArray, - vns::AbstractArray{<:VarName}, - vi::SimpleOrThreadSafeSimple, -) - # NOTE: We cannot work with `var` here because we might have a model of the form - # - # m = Vector{Float64}(undef, n) - # m .~ Normal() - # - # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - value = vi[vns] - lp = sum(Distributions.logpdf.(dists, value)) - return value, lp, vi + # Transform if we're working in unconstrained space. + ist = istrans(vi, vn) + value_raw = ist ? Bijectors.link(dist, value) : value + vi = BangBang.push!!(vi, vn, value_raw, dist, sampler) + return value, Bijectors.logpdf_with_trans(dist, value, ist), vi end function dot_assume( @@ -401,15 +370,23 @@ function dot_assume( ) f = (vn, dist) -> init(rng, dist, spl) value = f.(vns, dists) - vi = BangBang.setindex!!(vi, value, vns) - lp = sum(Distributions.logpdf.(dists, value)) + + # Transform if we're working in transformed space. + ist = istrans(vi, first(vns)) + value_raw = ist ? link.(dist, value) : value + + # Update `vi` + vi = BangBang.setindex!!(vi, value_raw, vns) + + # Compute logp. + lp = sum(Bijectors.logpdf_with_trans.(dists, value, ist)) return value, lp, vi end # HACK: Allows us to re-use the implementation of `dot_tilde`, etc. for literals. increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing settrans!(vi::SimpleOrThreadSafeSimple, trans::Bool, vn::VarName) = nothing -istrans(::SimpleVarInfo, vn::VarName) = false +istrans(svi::SimpleVarInfo, vn::VarName) = svi.istrans """ values_as(varinfo[, Type]) From 81ee12e44bc2382936094b0dd2f344e2ca7a3630 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 7 Jan 2022 23:46:13 +0100 Subject: [PATCH 003/101] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/simple_varinfo.jl | 12 +++++++++--- src/varinfo.jl | 9 ++++++--- test/model.jl | 4 ++-- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 85896792b..f14642b02 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -132,9 +132,15 @@ struct SimpleVarInfo{NT,T} <: AbstractVarInfo istrans::Bool end -SimpleVarInfo{T}(θ, istrans::Bool=false) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T), istrans) -SimpleVarInfo{T}(istrans::Bool=false; kwargs...) where {T<:Real} = SimpleVarInfo{T}(NamedTuple(kwargs), istrans) -SimpleVarInfo(istrans::Bool=false; kwargs...) = SimpleVarInfo{Float64}(NamedTuple(kwargs), istrans) +function SimpleVarInfo{T}(θ, istrans::Bool=false) where {T<:Real} + return SimpleVarInfo{typeof(θ),T}(θ, zero(T), istrans) +end +function SimpleVarInfo{T}(istrans::Bool=false; kwargs...) where {T<:Real} + return SimpleVarInfo{T}(NamedTuple(kwargs), istrans) +end +function SimpleVarInfo(istrans::Bool=false; kwargs...) + return SimpleVarInfo{Float64}(NamedTuple(kwargs), istrans) +end SimpleVarInfo(θ, istrans::Bool=false) = SimpleVarInfo{Float64}(θ, istrans) # Constructor from `Model`. diff --git a/src/varinfo.jl b/src/varinfo.jl index c8ffca5df..78d3b486b 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -868,7 +868,6 @@ end maybe_invlink(vi, vn, dist, val) = istrans(vi, vn) ? Bijectors.invlink(dist, val) : val - """ islinked(vi::VarInfo, spl::Union{Sampler, SampleFromPrior}) @@ -926,8 +925,12 @@ function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) end end -getindex_raw(vi::AbstractVarInfo, vn::VarName) = reconstruct(getdist(vi, vn), getval(vi, vn)) -getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}) = reconstruct(getdist(vi, first(vns)), getval(vi, vns)) +function getindex_raw(vi::AbstractVarInfo, vn::VarName) + return reconstruct(getdist(vi, vn), getval(vi, vn)) +end +function getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}) + return reconstruct(getdist(vi, first(vns)), getval(vi, vns)) +end """ getindex(vi::VarInfo, spl::Union{SampleFromPrior, Sampler}) diff --git a/test/model.jl b/test/model.jl index 7e417eaa7..c2ab0d724 100644 --- a/test/model.jl +++ b/test/model.jl @@ -85,7 +85,7 @@ @testset "Dynamic constraints" begin @model function dynamic_constraints() m ~ Normal() - x ~ truncated(Normal(), m, Inf) + return x ~ truncated(Normal(), m, Inf) end model = dynamic_constraints() @@ -93,7 +93,7 @@ spl = SampleFromPrior() link!(vi, spl) - for i = 1:10 + for i in 1:10 # Sample with large variations. r_raw = randn(length(vi[spl])) * 10 vi[spl] = r_raw From d39f87dc494261b75e26bdf2e7b953ea185e774e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 7 Jan 2022 22:57:59 +0000 Subject: [PATCH 004/101] added a comment --- src/simple_varinfo.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index f14642b02..c450145f8 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -129,6 +129,7 @@ ERROR: type NamedTuple has no field b struct SimpleVarInfo{NT,T} <: AbstractVarInfo values::NT logp::T + # TODO: Should we put this in the type instead? istrans::Bool end From 23f34cc2f7272094cf8d91150b9b8e999cc36ed4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 7 Jan 2022 22:58:10 +0000 Subject: [PATCH 005/101] bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b21c4a91a..cd40c074a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.17.3" +version = "0.17.4" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From d3ec10822b5c36ef5e41133170798d134063f0f7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 9 Jan 2022 23:17:10 +0000 Subject: [PATCH 006/101] introduced settrans!! --- src/context_implementations.jl | 18 +++++++++--------- src/simple_varinfo.jl | 9 ++++++++- src/varinfo.jl | 24 +++++++++++++++--------- 3 files changed, 32 insertions(+), 19 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index bcb3ad463..f66c4d83f 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -55,7 +55,7 @@ end function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi) if haskey(context.vars, getsym(vn)) vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return tilde_assume(PriorContext(), right, vn, vi) end @@ -64,7 +64,7 @@ function tilde_assume( ) if haskey(context.vars, getsym(vn)) vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return tilde_assume(rng, PriorContext(), sampler, right, vn, vi) end @@ -72,7 +72,7 @@ end function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi) if haskey(context.vars, getsym(vn)) vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return tilde_assume(LikelihoodContext(), right, vn, vi) end @@ -86,7 +86,7 @@ function tilde_assume( ) if haskey(context.vars, getsym(vn)) vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi) end @@ -214,7 +214,7 @@ function assume( unset_flag!(vi, vn, "del") r = init(rng, dist, sampler) vi[vn] = vectorize(dist, r) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) setorder!(vi, vn, get_num_produce(vi)) else # r = vi[vn] @@ -224,7 +224,7 @@ function assume( else r = init(rng, dist, sampler) push!!(vi, vn, r, dist, sampler) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi @@ -470,7 +470,7 @@ function get_and_set_val!( for i in 1:n vn = vns[i] vi[vn] = vectorize(dist, r[:, i]) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) setorder!(vi, vn, get_num_produce(vi)) end else @@ -482,7 +482,7 @@ function get_and_set_val!( for i in 1:n vn = vns[i] push!!(vi, vn, r[:, i], dist, spl) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end end return r @@ -505,7 +505,7 @@ function get_and_set_val!( vn = vns[i] dist = dists isa AbstractArray ? dists[i] : dists vi[vn] = vectorize(dist, r[i]) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) setorder!(vi, vn, get_num_produce(vi)) end else diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index c450145f8..be9fa7455 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -392,8 +392,15 @@ end # HACK: Allows us to re-use the implementation of `dot_tilde`, etc. for literals. increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing -settrans!(vi::SimpleOrThreadSafeSimple, trans::Bool, vn::VarName) = nothing + +# NOTE: We don't implement `settrans!!(vi, trans, vn)`. +settrans!!(vi::SimpleVarInfo, trans::Bool) = Setfield.@set vi.istrans = trans +function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans::Bool) + return Setfield.@set vi.varinfo = settrans!!(vi, trans) +end + istrans(svi::SimpleVarInfo, vn::VarName) = svi.istrans +istrans(svi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = svi.istrans """ values_as(varinfo[, Type]) diff --git a/src/varinfo.jl b/src/varinfo.jl index 78d3b486b..49f7b553d 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -358,12 +358,18 @@ Return the set of sampler selectors associated with `vn` in `vi`. getgid(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).gids[getidx(vi, vn)] """ - settrans!(vi::VarInfo, trans::Bool, vn::VarName) + settrans!!(vi::VarInfo, trans::Bool, vn::VarName) Set the `trans` flag value of `vn` in `vi`. """ -function settrans!(vi::AbstractVarInfo, trans::Bool, vn::VarName) - return trans ? set_flag!(vi, vn, "trans") : unset_flag!(vi, vn, "trans") +function settrans!!(vi::AbstractVarInfo, trans::Bool, vn::VarName) + if trans + set_flag!(vi, vn, "trans") + else + unset_flag!(vi, vn, "trans") + end + + return vi end """ @@ -749,7 +755,7 @@ function link!(vi::UntypedVarInfo, spl::Sampler) vectorize(dist, Bijectors.link(dist, reconstruct(dist, getval(vi, vn)))), vn, ) - settrans!(vi, true, vn) + settrans!!(vi, true, vn) end else @warn("[DynamicPPL] attempt to link a linked vi") @@ -785,7 +791,7 @@ end ), vn, ) - settrans!(vi, true, vn) + settrans!!(vi, true, vn) end else @warn("[DynamicPPL] attempt to link a linked vi") @@ -816,7 +822,7 @@ function invlink!(vi::UntypedVarInfo, spl::AbstractSampler) vectorize(dist, Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn)))), vn, ) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end else @warn("[DynamicPPL] attempt to invlink an invlinked vi") @@ -854,7 +860,7 @@ end ), vn, ) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end else @warn("[DynamicPPL] attempt to invlink an invlinked vi") @@ -1420,7 +1426,7 @@ function _setval_kernel!(vi::VarInfo, vn::VarName, values, keys) if !isempty(indices) val = reduce(vcat, values[indices]) setval!(vi, val, vn) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return indices @@ -1501,7 +1507,7 @@ function _setval_and_resample_kernel!(vi::VarInfo, vn::VarName, values, keys) if !isempty(indices) val = reduce(vcat, values[indices]) setval!(vi, val, vn) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) else # Ensures that we'll resample the variable corresponding to `vn` if we run # the model on `vi` again. From 81782c9c54c99e86768ea48cbe7f11e6e4139711 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 9 Jan 2022 23:50:33 +0000 Subject: [PATCH 007/101] added istrans(vi) and renamed all occurences of trans! to trans!! --- src/context_implementations.jl | 49 ++++++++++++++++++++++------------ src/simple_varinfo.jl | 5 ++-- src/varinfo.jl | 9 +++++++ 3 files changed, 44 insertions(+), 19 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index f66c4d83f..8509e4a09 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -213,18 +213,23 @@ function assume( if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") unset_flag!(vi, vn, "del") r = init(rng, dist, sampler) - vi[vn] = vectorize(dist, r) - settrans!!(vi, false, vn) + vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r)) setorder!(vi, vn, get_num_produce(vi)) else + # Otherwise we just extract it. # r = vi[vn] r_raw = getindex_raw(vi, vn) r = maybe_invlink(vi, vn, dist, r_raw) end else r = init(rng, dist, sampler) - push!!(vi, vn, r, dist, sampler) - settrans!!(vi, false, vn) + if istrans(vi) + push!!(vi, vn, link(dist, r), dist, sampler) + # By default `push!!` sets the transformed flag to `false`. + settrans!!(vi, true, vn) + else + push!!(vi, vn, r, dist, sampler) + end end return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi @@ -290,7 +295,7 @@ function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left, var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) + settrans!!.(Ref(vi), false, _vns) dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi) else dot_tilde_assume(LikelihoodContext(), right, left, vn, vi) @@ -309,7 +314,7 @@ function dot_tilde_assume( var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) + settrans!!.(Ref(vi), false, _vns) dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi) else dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi) @@ -330,7 +335,7 @@ function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) + settrans!!.(Ref(vi), false, _vns) dot_tilde_assume(PriorContext(), _right, _left, _vns, vi) else dot_tilde_assume(PriorContext(), right, left, vn, vi) @@ -349,7 +354,7 @@ function dot_tilde_assume( var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) + settrans!!.(Ref(vi), false, _vns) dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi) else dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi) @@ -469,8 +474,7 @@ function get_and_set_val!( r = init(rng, dist, spl, n) for i in 1:n vn = vns[i] - vi[vn] = vectorize(dist, r[:, i]) - settrans!!(vi, false, vn) + vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r[:, i])) setorder!(vi, vn, get_num_produce(vi)) end else @@ -481,8 +485,13 @@ function get_and_set_val!( r = init(rng, dist, spl, n) for i in 1:n vn = vns[i] - push!!(vi, vn, r[:, i], dist, spl) - settrans!!(vi, false, vn) + if istrans(vi) + push!!(vi, vn, maybe_link(vi, vn, dist, r[:, i]), dist, spl) + # `push!!` sets the trans-flag to `false` by default. + setttrans!!(vi, true, vn) + else + push!!(vi, vn, r[:, i], dist, spl) + end end end return r @@ -504,12 +513,13 @@ function get_and_set_val!( for i in eachindex(vns) vn = vns[i] dist = dists isa AbstractArray ? dists[i] : dists - vi[vn] = vectorize(dist, r[i]) - settrans!!(vi, false, vn) + vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r[i])) setorder!(vi, vn, get_num_produce(vi)) end else - r = reshape(vi[vec(vns)], size(vns)) + # r = reshape(vi[vec(vns)], size(vns)) + r_raw = getindex_raw(vi, vec(vns)) + r = maybe_invlink.(Ref(vi), vns, dists, reshape(r_raw, size(vns))) end else f = (vn, dist) -> init(rng, dist, spl) @@ -519,8 +529,13 @@ function get_and_set_val!( # 1. Figure out the broadcast size and use a `foreach`. # 2. Define an anonymous function which returns `nothing`, which # we then broadcast. This will allocate a vector of `nothing` though. - push!!.(Ref(vi), vns, r, dists, Ref(spl)) - settrans!.(Ref(vi), false, vns) + if istrans(vi) + push!!.(Ref(vi), vns, link.(Ref(vi), vns, dists, r), dists, Ref(spl)) + # `push!!` sets the trans-flag to `false` by default. + settrans!!.(Ref(vi), true, vns) + else + push!!.(Ref(vi), vns, r, dists, Ref(spl)) + end end return r end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index be9fa7455..fc7d9e80f 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -399,8 +399,9 @@ function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans::Bool) return Setfield.@set vi.varinfo = settrans!!(vi, trans) end -istrans(svi::SimpleVarInfo, vn::VarName) = svi.istrans -istrans(svi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = svi.istrans +istrans(vi::SimpleVarInfo) = vi.istrans +istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi) +istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) """ values_as(varinfo[, Type]) diff --git a/src/varinfo.jl b/src/varinfo.jl index 49f7b553d..d2b0cd6c4 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -644,6 +644,14 @@ function setgid!(vi::VarInfo, gid::Selector, vn::VarName) return push!(getmetadata(vi, vn).gids[getidx(vi, vn)], gid) end +""" + istrans(vi::AbstractVarInfo) + +Return `true` if `vi` is working in unconstrained space, and `false` +if `vi` is assuming realizations to be in support of the corresponding distributions. +""" +istrans(vi::AbstractVarInfo) = false + """ istrans(vi::VarInfo, vn::VarName) @@ -872,6 +880,7 @@ end return expr end +maybe_link(vi, vn, dist, val) = istrans(vi, vn) ? Bijectors.link(dist, val) : val maybe_invlink(vi, vn, dist, val) = istrans(vi, vn) ? Bijectors.invlink(dist, val) : val """ From 12bfb420a1fa9bc2e0572d2dd0d7a385c0fac356 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 10 Jan 2022 00:22:04 +0000 Subject: [PATCH 008/101] exclusively use settrans!! to set the istrans for SimpleVarInfo --- src/simple_varinfo.jl | 15 ++++++++------- src/varinfo.jl | 2 +- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index fc7d9e80f..9805e28ab 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -133,16 +133,17 @@ struct SimpleVarInfo{NT,T} <: AbstractVarInfo istrans::Bool end -function SimpleVarInfo{T}(θ, istrans::Bool=false) where {T<:Real} - return SimpleVarInfo{typeof(θ),T}(θ, zero(T), istrans) +SimpleVarInfo(values, logp) = SimpleVarInfo(values, logp, false) +function SimpleVarInfo{T}(θ) where {T<:Real} + return SimpleVarInfo{typeof(θ),T}(θ, zero(T), false) end -function SimpleVarInfo{T}(istrans::Bool=false; kwargs...) where {T<:Real} - return SimpleVarInfo{T}(NamedTuple(kwargs), istrans) +function SimpleVarInfo{T}(; kwargs...) where {T<:Real} + return SimpleVarInfo{T}(NamedTuple(kwargs)) end -function SimpleVarInfo(istrans::Bool=false; kwargs...) - return SimpleVarInfo{Float64}(NamedTuple(kwargs), istrans) +function SimpleVarInfo(; kwargs...) + return SimpleVarInfo{Float64}(NamedTuple(kwargs)) end -SimpleVarInfo(θ, istrans::Bool=false) = SimpleVarInfo{Float64}(θ, istrans) +SimpleVarInfo(θ) = SimpleVarInfo{Float64}(θ) # Constructor from `Model`. SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...) diff --git a/src/varinfo.jl b/src/varinfo.jl index d2b0cd6c4..67ded7b82 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -650,7 +650,7 @@ end Return `true` if `vi` is working in unconstrained space, and `false` if `vi` is assuming realizations to be in support of the corresponding distributions. """ -istrans(vi::AbstractVarInfo) = false +istrans(vi::VarInfo) = false # `VarInfo` works in constrained space by default. """ istrans(vi::VarInfo, vn::VarName) From c2c2417f13df1d4591965051cd68f42b5e300c78 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 10 Jan 2022 13:51:05 +0000 Subject: [PATCH 009/101] removed usage of deprecated method in turing tests --- test/turing/varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/turing/varinfo.jl b/test/turing/varinfo.jl index 892433779..0db560956 100644 --- a/test/turing/varinfo.jl +++ b/test/turing/varinfo.jl @@ -9,7 +9,7 @@ ) if !haskey(vi, vn) r = rand(dist) - push!(vi, vn, r, dist, spl) + push!!(vi, vn, r, dist, spl) r elseif is_flagged(vi, vn, "del") unset_flag!(vi, vn, "del") From 2e2cb5cd42b9df58e8ff8ff21d7a9e390bf4a2f0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 13 Jan 2022 19:03:49 +0000 Subject: [PATCH 010/101] added docstring to settrans!! --- src/simple_varinfo.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 9805e28ab..7f185780d 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -395,6 +395,11 @@ end increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing # NOTE: We don't implement `settrans!!(vi, trans, vn)`. +""" + settrans!!(vi::AbstractVarInfo, trans::Bool) + +Return new instance of `vi` but with `istrans(vi, trans)` now evaluating to `true`. +""" settrans!!(vi::SimpleVarInfo, trans::Bool) = Setfield.@set vi.istrans = trans function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans::Bool) return Setfield.@set vi.varinfo = settrans!!(vi, trans) From f6c3fc42afd427b9668b151aa31a81d1297eb630 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Jan 2022 13:22:37 +0000 Subject: [PATCH 011/101] include istrans flag in type of SimpleVarInfo instead --- src/simple_varinfo.jl | 38 +++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 7f185780d..c9c9b5ff6 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -126,16 +126,26 @@ ERROR: type NamedTuple has no field b [...] ``` """ -struct SimpleVarInfo{NT,T} <: AbstractVarInfo +struct SimpleVarInfo{NT,T,IsTrans} <: AbstractVarInfo values::NT logp::T - # TODO: Should we put this in the type instead? - istrans::Bool end -SimpleVarInfo(values, logp) = SimpleVarInfo(values, logp, false) +function Setfield.ConstructionBase.constructorof( + ::Type{<:SimpleVarInfo{<:Any,<:Any,IsTrans}} +) where {IsTrans} + return function SimpleVarInfo_constructor(values, logp) + return SimpleVarInfo{typeof(values),typeof(logp),IsTrans}(values, logp) + end +end + +SimpleVarInfo(values, logp, istrans::Bool) = SimpleVarInfo(values, logp, Val{istrans}()) +function SimpleVarInfo(values, logp, ::Val{IsTrans}) where {IsTrans} + return SimpleVarInfo{typeof(values),typeof(logp),IsTrans}(values, logp) +end + function SimpleVarInfo{T}(θ) where {T<:Real} - return SimpleVarInfo{typeof(θ),T}(θ, zero(T), false) + return SimpleVarInfo{typeof(θ),T,false}(θ, zero(T)) end function SimpleVarInfo{T}(; kwargs...) where {T<:Real} return SimpleVarInfo{T}(NamedTuple(kwargs)) @@ -187,8 +197,10 @@ function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) return vi end -function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo) - return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ", ", svi.istrans, ")") +function Base.show( + io::IO, ::MIME"text/plain", svi::SimpleVarInfo{<:Any,<:Any,IsTrans} +) where {IsTrans} + return print(io, "SimpleVarInfo{IsTrans=$(IsTrans)}(", svi.values, ", ", svi.logp, ")") end # `NamedTuple` @@ -339,8 +351,8 @@ function BangBang.push!!( return vi end -const SimpleOrThreadSafeSimple{T,V} = Union{ - SimpleVarInfo{T,V},ThreadSafeVarInfo{<:SimpleVarInfo{T,V}} +const SimpleOrThreadSafeSimple{T,V,IsTrans} = Union{ + SimpleVarInfo{T,V,IsTrans},ThreadSafeVarInfo{<:SimpleVarInfo{T,V,IsTrans}} } # Necessary for `matchingvalue` to work properly. @@ -396,16 +408,16 @@ increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing # NOTE: We don't implement `settrans!!(vi, trans, vn)`. """ - settrans!!(vi::AbstractVarInfo, trans::Bool) + settrans!!(vi::AbstractVarInfo, trans) Return new instance of `vi` but with `istrans(vi, trans)` now evaluating to `true`. """ -settrans!!(vi::SimpleVarInfo, trans::Bool) = Setfield.@set vi.istrans = trans -function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans::Bool) +settrans!!(vi::SimpleVarInfo, trans) = SimpleVarInfo(vi.values, vi.logp, trans) +function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) return Setfield.@set vi.varinfo = settrans!!(vi, trans) end -istrans(vi::SimpleVarInfo) = vi.istrans +istrans(vi::SimpleVarInfo{<:Any,<:Any,IsTrans}) where {IsTrans} = IsTrans istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi) istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) From d643a783b5f6a9dc316c69b5064b1e5cc4e249bf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Jan 2022 13:41:29 +0000 Subject: [PATCH 012/101] deprecated settrans! in favour of settrans!! --- src/DynamicPPL.jl | 2 ++ src/simple_varinfo.jl | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 82df0f008..21efbd72d 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -176,4 +176,6 @@ include("test_utils.jl") @deprecate acclogp!(vi, logp) acclogp!!(vi, logp) @deprecate resetlogp!(vi) resetlogp!!(vi) +@deprecate settrans!(vi::AbstractVarInfo, trans::Bool, vn::VarName) settrans!!(vi, trans, vn) + end # module diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index c9c9b5ff6..93c91e21e 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -139,7 +139,7 @@ function Setfield.ConstructionBase.constructorof( end end -SimpleVarInfo(values, logp, istrans::Bool) = SimpleVarInfo(values, logp, Val{istrans}()) +SimpleVarInfo(values, logp, istrans::Bool=false) = SimpleVarInfo(values, logp, Val{istrans}()) function SimpleVarInfo(values, logp, ::Val{IsTrans}) where {IsTrans} return SimpleVarInfo{typeof(values),typeof(logp),IsTrans}(values, logp) end From 3cab7d9ac7f6b64e337bed0b9711df359f05be93 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Jan 2022 13:48:49 +0000 Subject: [PATCH 013/101] added some tests specifically for istrans --- test/varinfo.jl | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/test/varinfo.jl b/test/varinfo.jl index 2f7816024..32c90bf47 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -271,4 +271,47 @@ DynamicPPL.setval_and_resample!(vi, vi.metadata.x.vals, ks) @test vals_prev == vi.metadata.x.vals end + + @testset "istrans" begin + @model demo_constrained() = x ~ truncated(Normal(), 0, Inf) + model = demo_constrained() + vn = @varname(x) + dist = truncated(Normal(), 0, Inf) + + ### `VarInfo` + # Need to run once since we can't specify that we want to _sample_ + # in the unconstrained space for `VarInfo` without having `vn` + # present in the `varinfo`. + ## `UntypedVarInfo` + vi = VarInfo() + vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + vi = DynamicPPL.settrans!!(vi, true, vn) + # Sample in unconstrained space. + vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + x = Bijectors.invlink(dist, DynamicPPL.getindex_raw(vi, vn)) + @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + + ## `TypedVarInfo` + vi = VarInfo(model) + vi = DynamicPPL.settrans!!(vi, true, vn) + # Sample in unconstrained space. + vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + x = Bijectors.invlink(dist, DynamicPPL.getindex_raw(vi, vn)) + @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + + ### `SimpleVarInfo` + ## `SimpleVarInfo{<:NamedTuple}` + vi = DynamicPPL.settrans!!(SimpleVarInfo(), true) + # Sample in unconstrained space. + vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + x = Bijectors.invlink(dist, DynamicPPL.getindex_raw(vi, vn)) + @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + + ## `SimpleVarInfo{<:Dict}` + vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true) + # Sample in unconstrained space. + vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + x = Bijectors.invlink(dist, DynamicPPL.getindex_raw(vi, vn)) + @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + end end From 8b870dcff71f69dbe4c32d81a8ca1e9269e62364 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Jan 2022 13:49:49 +0000 Subject: [PATCH 014/101] formatting --- src/DynamicPPL.jl | 4 +++- src/simple_varinfo.jl | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 21efbd72d..f0ede3f67 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -176,6 +176,8 @@ include("test_utils.jl") @deprecate acclogp!(vi, logp) acclogp!!(vi, logp) @deprecate resetlogp!(vi) resetlogp!!(vi) -@deprecate settrans!(vi::AbstractVarInfo, trans::Bool, vn::VarName) settrans!!(vi, trans, vn) +@deprecate settrans!(vi::AbstractVarInfo, trans::Bool, vn::VarName) settrans!!( + vi, trans, vn +) end # module diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 93c91e21e..ee1fa9999 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -139,7 +139,9 @@ function Setfield.ConstructionBase.constructorof( end end -SimpleVarInfo(values, logp, istrans::Bool=false) = SimpleVarInfo(values, logp, Val{istrans}()) +function SimpleVarInfo(values, logp, istrans::Bool=false) + return SimpleVarInfo(values, logp, Val{istrans}()) +end function SimpleVarInfo(values, logp, ::Val{IsTrans}) where {IsTrans} return SimpleVarInfo{typeof(values),typeof(logp),IsTrans}(values, logp) end From 0b304db2b61564b3e185988559a621e7966bb28a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Jan 2022 14:32:56 +0000 Subject: [PATCH 015/101] fixed bugs for ThreadSafeVarInfo --- src/threadsafe.jl | 3 +++ src/varinfo.jl | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 6f020a352..b42cf82a5 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -65,6 +65,9 @@ getindex(vi::ThreadSafeVarInfo, spl::SampleFromUniform) = getindex(vi.varinfo, s getindex(vi::ThreadSafeVarInfo, vn::VarName) = getindex(vi.varinfo, vn) getindex(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) = getindex(vi.varinfo, vns) +getindex_raw(vi::ThreadSafeVarInfo, vn::VarName) = getindex_raw(vi.varinfo, vn) +getindex_raw(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) = getindex_raw(vi.varinfo, vns) + function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::AbstractSampler) return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) end diff --git a/src/varinfo.jl b/src/varinfo.jl index 67ded7b82..c3aaf8255 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -650,7 +650,7 @@ end Return `true` if `vi` is working in unconstrained space, and `false` if `vi` is assuming realizations to be in support of the corresponding distributions. """ -istrans(vi::VarInfo) = false # `VarInfo` works in constrained space by default. +istrans(vi::AbstractVarInfo) = false # `VarInfo` works in constrained space by default. """ istrans(vi::VarInfo, vn::VarName) From b146b116d5877215849d457e7d9a3e2771a00ee6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Jan 2022 21:01:46 +0000 Subject: [PATCH 016/101] additional constructor for SimpleVarInfo --- src/simple_varinfo.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index ee1fa9999..153e43918 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -139,7 +139,8 @@ function Setfield.ConstructionBase.constructorof( end end -function SimpleVarInfo(values, logp, istrans::Bool=false) +SimpleVarInfo(values, logp) = SimpleVarInfo(values, logp, Val{false}()) +function SimpleVarInfo(values, logp, istrans::Bool) return SimpleVarInfo(values, logp, Val{istrans}()) end function SimpleVarInfo(values, logp, ::Val{IsTrans}) where {IsTrans} From d170d92d98494d2078afd39c6ce5f2b88d158299 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Feb 2022 13:28:41 +0000 Subject: [PATCH 017/101] Update src/DynamicPPL.jl Co-authored-by: David Widmann --- src/DynamicPPL.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index f0ede3f67..fc7a2fd0c 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -178,6 +178,6 @@ include("test_utils.jl") @deprecate settrans!(vi::AbstractVarInfo, trans::Bool, vn::VarName) settrans!!( vi, trans, vn -) +) false end # module From f46183bcb3cbfe4a29f38639bee769e81336502d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Feb 2022 13:30:13 +0000 Subject: [PATCH 018/101] added ConstructionBase.jl as dep --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index cd40c074a..7412279b2 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" @@ -22,6 +23,7 @@ AbstractPPL = "0.3" BangBang = "0.3" Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9, 0.10" ChainRulesCore = "0.9.7, 0.10, 1" +ConstructionBase = "1" Distributions = "0.23.8, 0.24, 0.25" MacroTools = "0.5.6" Setfield = "0.7.1, 0.8" From 7a78eec52a5e700cc1570618e78630bd850fb924 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Feb 2022 15:58:32 +0000 Subject: [PATCH 019/101] added constraint types and doctests --- src/simple_varinfo.jl | 108 ++++++++++++++++++++++++++++++++---------- 1 file changed, 82 insertions(+), 26 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 153e43918..fb5e916f1 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -1,5 +1,10 @@ +abstract type AbstractConstraint end + +struct Constrained <: AbstractConstraint end +struct Unconstrained <: AbstractConstraint end + """ - SimpleVarInfo{NT,T} <: AbstractVarInfo + SimpleVarInfo{NT,T,C} <: AbstractVarInfo A simple wrapper of the parameters with a `logp` field for accumulation of the logdensity. @@ -16,7 +21,7 @@ The major differences between this and `TypedVarInfo` are: # Examples ## General usage -```jldoctest; setup=:(using Distributions) +```jldoctest simplevarinfo-general; setup=:(using Distributions) julia> using StableRNGs julia> @model function demo() @@ -78,6 +83,60 @@ ERROR: KeyError: key x[1:2] not found [...] ``` +You can also sample in _unconstrained_ space: + +```jldoctest simplevarinfo-general +julia> @model demo_constrained() = x ~ Exponential() +demo_constrained (generic function with 2 methods) + +julia> m = demo_constrained(); + +julia> _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo(), ctx); + +julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞ +1.8632965762164932 + +julia> _, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(), true), ctx); + +julia> vi[@varname(x)] # (✓) -∞ < x < ∞ +-0.21080155351918753 + +julia> xs = [last(DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(), true), ctx))[@varname(x)] for i = 1:10]; + +julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! +true + +julia> # And with `Dict` of course! + _, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true), ctx); + +julia> vi[@varname(x)] # (✓) -∞ < x < ∞ +0.6225185067787314 + +julia> xs = [last(DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(), true), ctx))[@varname(x)] for i = 1:10]; + +julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! +true +``` + +Evaluation in unconstrained space of course also works: + +```jldoctest simplevarinfo-general +julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true) +Unconstrained SimpleVarInfo((x = -1.0,), 0.0) + +julia> # (✓) Positive probability mass on negative numbers! + getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) +-1.3678794411714423 + +julia> # While if we forget to make indicate that it's unconstrained/transformed: + vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false) +Constrained SimpleVarInfo((x = -1.0,), 0.0) + +julia> # (✓) No probability mass on negative numbers! + getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) +-Inf +``` + ## Indexing Using `NamedTuple` as underlying storage. @@ -126,29 +185,19 @@ ERROR: type NamedTuple has no field b [...] ``` """ -struct SimpleVarInfo{NT,T,IsTrans} <: AbstractVarInfo +struct SimpleVarInfo{NT,T,C<:AbstractConstraint} <: AbstractVarInfo + "underlying representation of the realization represented" values::NT + "holds the accumulated log-probability" logp::T + "represents whether it assumes variables to be constrained or unconstrained" + constraint::C end -function Setfield.ConstructionBase.constructorof( - ::Type{<:SimpleVarInfo{<:Any,<:Any,IsTrans}} -) where {IsTrans} - return function SimpleVarInfo_constructor(values, logp) - return SimpleVarInfo{typeof(values),typeof(logp),IsTrans}(values, logp) - end -end - -SimpleVarInfo(values, logp) = SimpleVarInfo(values, logp, Val{false}()) -function SimpleVarInfo(values, logp, istrans::Bool) - return SimpleVarInfo(values, logp, Val{istrans}()) -end -function SimpleVarInfo(values, logp, ::Val{IsTrans}) where {IsTrans} - return SimpleVarInfo{typeof(values),typeof(logp),IsTrans}(values, logp) -end +SimpleVarInfo(values, logp) = SimpleVarInfo(values, logp, Constrained()) function SimpleVarInfo{T}(θ) where {T<:Real} - return SimpleVarInfo{typeof(θ),T,false}(θ, zero(T)) + return SimpleVarInfo(θ, zero(T)) end function SimpleVarInfo{T}(; kwargs...) where {T<:Real} return SimpleVarInfo{T}(NamedTuple(kwargs)) @@ -201,9 +250,15 @@ function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) end function Base.show( - io::IO, ::MIME"text/plain", svi::SimpleVarInfo{<:Any,<:Any,IsTrans} -) where {IsTrans} - return print(io, "SimpleVarInfo{IsTrans=$(IsTrans)}(", svi.values, ", ", svi.logp, ")") + io::IO, ::MIME"text/plain", svi::SimpleVarInfo{<:Any,<:Any,<:Constrained} +) + return print(io, "Constrained SimpleVarInfo(", svi.values, ", ", svi.logp, ")") +end + +function Base.show( + io::IO, ::MIME"text/plain", svi::SimpleVarInfo{<:Any,<:Any,<:Unconstrained} +) + return print(io, "Unconstrained SimpleVarInfo(", svi.values, ", ", svi.logp, ")") end # `NamedTuple` @@ -354,8 +409,8 @@ function BangBang.push!!( return vi end -const SimpleOrThreadSafeSimple{T,V,IsTrans} = Union{ - SimpleVarInfo{T,V,IsTrans},ThreadSafeVarInfo{<:SimpleVarInfo{T,V,IsTrans}} +const SimpleOrThreadSafeSimple{T,V,C} = Union{ + SimpleVarInfo{T,V,C},ThreadSafeVarInfo{<:SimpleVarInfo{T,V,C}} } # Necessary for `matchingvalue` to work properly. @@ -415,12 +470,13 @@ increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing Return new instance of `vi` but with `istrans(vi, trans)` now evaluating to `true`. """ -settrans!!(vi::SimpleVarInfo, trans) = SimpleVarInfo(vi.values, vi.logp, trans) +settrans!!(vi::SimpleVarInfo, trans) = SimpleVarInfo(vi.values, vi.logp, trans ? Unconstrained() : Constrained()) function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) return Setfield.@set vi.varinfo = settrans!!(vi, trans) end -istrans(vi::SimpleVarInfo{<:Any,<:Any,IsTrans}) where {IsTrans} = IsTrans +istrans(vi::SimpleVarInfo{<:Any,<:Any,<:Constrained}) = false +istrans(vi::SimpleVarInfo{<:Any,<:Any,<:Unconstrained}) = true istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi) istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) From 70b3b70e7c29d230822a17e4d9c98fb46720447a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Feb 2022 15:59:04 +0000 Subject: [PATCH 020/101] added DocStringExtensions as a dep --- Project.toml | 2 ++ src/DynamicPPL.jl | 2 ++ src/simple_varinfo.jl | 3 +++ 3 files changed, 7 insertions(+) diff --git a/Project.toml b/Project.toml index 7412279b2..8f4cfa184 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -25,6 +26,7 @@ Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9, 0.10" ChainRulesCore = "0.9.7, 0.10, 1" ConstructionBase = "1" Distributions = "0.23.8, 0.24, 0.25" +DocStringExtensions = "0.8" MacroTools = "0.5.6" Setfield = "0.7.1, 0.8" ZygoteRules = "0.2" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index fc7a2fd0c..9a66bb57e 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -15,6 +15,8 @@ using Setfield: Setfield using Setfield: Setfield using BangBang: BangBang +using DocStringExtensions + using Random: Random import Base: diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index fb5e916f1..c0dcacd17 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -11,6 +11,9 @@ accumulation of the logdensity. Currently only implemented for `NT<:NamedTuple` and `NT<:Dict`. +# Fields +$(FIELDS) + # Notes The major differences between this and `TypedVarInfo` are: 1. `SimpleVarInfo` does not require linearization. From a03e8cf7aac44e7bd75d32a963f7d65eab974d14 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Feb 2022 16:00:50 +0000 Subject: [PATCH 021/101] formatting --- src/simple_varinfo.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index c0dcacd17..2d1f807fc 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -473,7 +473,9 @@ increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing Return new instance of `vi` but with `istrans(vi, trans)` now evaluating to `true`. """ -settrans!!(vi::SimpleVarInfo, trans) = SimpleVarInfo(vi.values, vi.logp, trans ? Unconstrained() : Constrained()) +function settrans!!(vi::SimpleVarInfo, trans) + return SimpleVarInfo(vi.values, vi.logp, trans ? Unconstrained() : Constrained()) +end function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) return Setfield.@set vi.varinfo = settrans!!(vi, trans) end From 793c931a0ec3b0247a0e9628c8ade314d8bfffc4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Feb 2022 16:07:31 +0000 Subject: [PATCH 022/101] remove redundant maybe_link --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 8509e4a09..3cc1c9def 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -486,7 +486,7 @@ function get_and_set_val!( for i in 1:n vn = vns[i] if istrans(vi) - push!!(vi, vn, maybe_link(vi, vn, dist, r[:, i]), dist, spl) + push!!(vi, vn, Bijectors.link(dist, r[:, i]), dist, spl) # `push!!` sets the trans-flag to `false` by default. setttrans!!(vi, true, vn) else From a9b12fdd82ad0f1089bc5ca705166b8139d909b7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Feb 2022 16:08:08 +0000 Subject: [PATCH 023/101] fixed typo --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 3cc1c9def..53db489f1 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -488,7 +488,7 @@ function get_and_set_val!( if istrans(vi) push!!(vi, vn, Bijectors.link(dist, r[:, i]), dist, spl) # `push!!` sets the trans-flag to `false` by default. - setttrans!!(vi, true, vn) + settrans!!(vi, true, vn) else push!!(vi, vn, r[:, i], dist, spl) end From be989612e1d9b5c3e025215a91d58dec90a65511 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 11 Feb 2022 16:35:27 +0000 Subject: [PATCH 024/101] moved a docstring --- src/simple_varinfo.jl | 5 ----- src/varinfo.jl | 11 +++++++++++ 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 2d1f807fc..e3a709fdc 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -468,11 +468,6 @@ end increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing # NOTE: We don't implement `settrans!!(vi, trans, vn)`. -""" - settrans!!(vi::AbstractVarInfo, trans) - -Return new instance of `vi` but with `istrans(vi, trans)` now evaluating to `true`. -""" function settrans!!(vi::SimpleVarInfo, trans) return SimpleVarInfo(vi.values, vi.logp, trans ? Unconstrained() : Constrained()) end diff --git a/src/varinfo.jl b/src/varinfo.jl index c3aaf8255..d849b0fd7 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -372,6 +372,17 @@ function settrans!!(vi::AbstractVarInfo, trans::Bool, vn::VarName) return vi end +""" + settrans!!(vi::AbstractVarInfo, trans) + +Return new instance of `vi` but with `istrans(vi, trans)` now evaluating to `true`. +""" +function settrans!!(vi::VarInfo, trans::Bool) + for vn in keys(vi) + settrans!!(vi, trans, vn) + end +end + """ syms(vi::VarInfo) From 3e1588bc01183fcef00cfc3aa35bb9fd058910cd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 11 Feb 2022 16:35:55 +0000 Subject: [PATCH 025/101] fixed bug in tets --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 53db489f1..9d3ba5aa9 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -401,7 +401,7 @@ function dot_assume( # in which case `var` will have `undef` elements, even if `m` is present in `vi`. # r = vi[vns] r_raw = getindex_raw(vi, vns) - r = maybe_invlink(vi, vn, dist, r_raw) + r = maybe_invlink(vi, vns, dist, r_raw) lp = sum(zip(vns, eachcol(r))) do (vn, ri) return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) end From 3139c62f41ccde2d0ad82d66b3f392fe276d6c56 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 11 Feb 2022 16:36:57 +0000 Subject: [PATCH 026/101] version bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c07179202..0c243767d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.17.4" +version = "0.17.5" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From 3610658285eea9c44a1136903a11f782ce33f9b6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 12 Feb 2022 18:31:37 +0000 Subject: [PATCH 027/101] added missing istrans impl --- src/varinfo.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/varinfo.jl b/src/varinfo.jl index d849b0fd7..e430a9967 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -670,6 +670,7 @@ Return true if `vn`'s values in `vi` are transformed to Euclidean space, and fal they are in the support of `vn`'s distribution. """ istrans(vi::AbstractVarInfo, vn::VarName) = is_flagged(vi, vn, "trans") +istrans(vi::AbstractVarInfo, vns::AbstractVector{<:VarName}) = all(istrans, vns) """ getlogp(vi::VarInfo) From 27171ad7d5427b977c530355a201ece9e8657661 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 13 Feb 2022 16:22:28 +0000 Subject: [PATCH 028/101] fixed bug with istrans --- src/varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index e430a9967..482dca6cb 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -670,7 +670,7 @@ Return true if `vn`'s values in `vi` are transformed to Euclidean space, and fal they are in the support of `vn`'s distribution. """ istrans(vi::AbstractVarInfo, vn::VarName) = is_flagged(vi, vn, "trans") -istrans(vi::AbstractVarInfo, vns::AbstractVector{<:VarName}) = all(istrans, vns) +istrans(vi::AbstractVarInfo, vns::AbstractVector{<:VarName}) = all(Base.Fix1(istrans, vi), vns) """ getlogp(vi::VarInfo) From cd2d9d69a18c0e18d7b4c92dcb1be2a0bb7ed297 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 13 Feb 2022 17:54:16 +0000 Subject: [PATCH 029/101] fixed issue with getindex_raw for VarInfo --- src/varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 482dca6cb..f8923856a 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -956,7 +956,7 @@ function getindex_raw(vi::AbstractVarInfo, vn::VarName) return reconstruct(getdist(vi, vn), getval(vi, vn)) end function getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}) - return reconstruct(getdist(vi, first(vns)), getval(vi, vns)) + return reconstruct(getdist(vi, first(vns)), getval(vi, vns), length(vns)) end """ From d948cb91c1c769afbdcbcd9501c5dbce5c5c064f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 13 Feb 2022 17:54:43 +0000 Subject: [PATCH 030/101] Update src/varinfo.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/varinfo.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index f8923856a..0503b3a10 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -670,7 +670,9 @@ Return true if `vn`'s values in `vi` are transformed to Euclidean space, and fal they are in the support of `vn`'s distribution. """ istrans(vi::AbstractVarInfo, vn::VarName) = is_flagged(vi, vn, "trans") -istrans(vi::AbstractVarInfo, vns::AbstractVector{<:VarName}) = all(Base.Fix1(istrans, vi), vns) +function istrans(vi::AbstractVarInfo, vns::AbstractVector{<:VarName}) + return all(Base.Fix1(istrans, vi), vns) +end """ getlogp(vi::VarInfo) From 26d2dbbc106ea6fa0462dd11e48389415dbde8eb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 22 Jun 2022 13:03:19 +0100 Subject: [PATCH 031/101] getindex of varinfo implementations now optionally takes a Distribution argument --- src/simple_varinfo.jl | 16 ++++++++++++++++ src/varinfo.jl | 28 ++++++++++++++++------------ 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index e3a709fdc..d51c32760 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -265,6 +265,16 @@ function Base.show( end # `NamedTuple` +function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution) + return maybe_invlink(vi, vn, dist, Base.getindex(vi, vn)) +end +function Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution) + vals_linked = map(vns) do vn + maybe_invlink(vi, vn, dist, Base.getindex(vi, vn)) + end + return reconstruct(dist, reduce(vcat, vals_linked), length(vns)) +end + Base.getindex(vi::SimpleVarInfo, vn::VarName) = get(vi.values, vn) # `Dict` @@ -309,7 +319,13 @@ Base.getindex(vi::SimpleVarInfo, spl::Sampler) = vi.values # Since we don't perform any transformations in `getindex` for `SimpleVarInfo` # we simply call `getindex` in `getindex_raw`. getindex_raw(vi::SimpleVarInfo, vn::VarName) = vi[vn] +getindex_raw(vi::SimpleVarInfo, vn::VarName, dist::Distribution) = reconstruct(dist, getindex_raw(vi, vn)) getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}) = vi[vns] +function getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution) + vals = getindex_raw(vi, vns) + # `reconstruct` expects a flattened `Vector` regardless of the type of `dist`, so we `vcat` everything. + return reconstruct(dist, reduce(vcat, vals), length(vns)) +end Base.haskey(vi::SimpleVarInfo, vn::VarName) = _haskey(vi.values, vn) function _haskey(nt::NamedTuple, vn::VarName) diff --git a/src/varinfo.jl b/src/varinfo.jl index 0503b3a10..6f3b39fe5 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -935,30 +935,34 @@ distribution(s). If the value(s) is (are) transformed to the Euclidean space, it is (they are) transformed back. """ -function getindex(vi::AbstractVarInfo, vn::VarName) +getindex(vi::AbstractVarInfo, vn::VarName) = getindex(vi, vn, getdist(vi, vn)) +function getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution) @assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" - dist = getdist(vi, vn) + val = getindex_raw(vi, vn, dist) return if istrans(vi, vn) - Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn))) + Bijectors.invlink(dist, val) else - reconstruct(dist, getval(vi, vn)) + val end end -function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) +getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) = getindex(vi, vns, getdist(vi, first(vns))) +function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" - dist = getdist(vi, vns[1]) + val = getindex_raw(vi, vns, dist) return if istrans(vi, vns[1]) - Bijectors.invlink(dist, reconstruct(dist, getval(vi, vns), length(vns))) + Bijectors.invlink(dist, val) else - reconstruct(dist, getval(vi, vns), length(vns)) + val end end -function getindex_raw(vi::AbstractVarInfo, vn::VarName) - return reconstruct(getdist(vi, vn), getval(vi, vn)) +getindex_raw(vi::AbstractVarInfo, vn::VarName) = getindex_raw(vi, vn, getdist(vi, vn)) +function getindex_raw(vi::AbstractVarInfo, vn::VarName, dist::Distribution) + return reconstruct(dist, getval(vi, vn)) end -function getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}) - return reconstruct(getdist(vi, first(vns)), getval(vi, vns), length(vns)) +getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}) = getindex_raw(vi, vns, getdist(vi, first(vns))) +function getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) + return reconstruct(dist, getval(vi, vns), length(vns)) end """ From 3fcba565a4f4d61e32a08df420c19431339df0eb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 22 Jun 2022 13:03:58 +0100 Subject: [PATCH 032/101] use get_index_raw with dist argument --- src/context_implementations.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 9d3ba5aa9..a9bf2db92 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -195,7 +195,7 @@ end # fallback without sampler function assume(dist::Distribution, vn::VarName, vi) # x = vi[vn] - r_raw = getindex_raw(vi, vn) + r_raw = getindex_raw(vi, vn, dist) r = maybe_invlink(vi, vn, dist, r_raw) return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi end @@ -218,7 +218,7 @@ function assume( else # Otherwise we just extract it. # r = vi[vn] - r_raw = getindex_raw(vi, vn) + r_raw = getindex_raw(vi, vn, dist) r = maybe_invlink(vi, vn, dist, r_raw) end else @@ -400,7 +400,7 @@ function dot_assume( # # in which case `var` will have `undef` elements, even if `m` is present in `vi`. # r = vi[vns] - r_raw = getindex_raw(vi, vns) + r_raw = getindex_raw(vi, vns, dist) r = maybe_invlink(vi, vns, dist, r_raw) lp = sum(zip(vns, eachcol(r))) do (vn, ri) return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) @@ -434,7 +434,7 @@ function dot_assume( # m .~ Normal() # # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - r_raw = getindex_raw(vi, vec(vns)) + r_raw = getindex_raw(vi, vec(vns), dists) r = reshape(maybe_invlink.(Ref(vi), vns, dists, r_raw), size(vns)) lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) return r, lp, vi From 83a9448dc22bd0e46dc29f81b552d85019539386 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 22 Jun 2022 13:04:09 +0100 Subject: [PATCH 033/101] added missing assume implementations for SimpleVarInfo --- src/simple_varinfo.jl | 50 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index d51c32760..b07b102ea 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -470,7 +470,55 @@ function dot_assume( # Transform if we're working in transformed space. ist = istrans(vi, first(vns)) - value_raw = ist ? link.(dist, value) : value + value_raw = ist ? link.(dists, value) : value + + # Update `vi` + vi = BangBang.setindex!!(vi, value_raw, vns) + + # Compute logp. + lp = sum(Bijectors.logpdf_with_trans.(dists, value, ist)) + return value, lp, vi +end + +function dot_assume( + rng, + spl::Union{SampleFromPrior,SampleFromUniform}, + dist::MultivariateDistribution, + vns::AbstractVector{<:VarName}, + var::AbstractMatrix, + vi::SimpleOrThreadSafeSimple, +) + @assert length(dist) == size(var, 1) + + # r = get_and_set_val!(rng, vi, vns, dist, spl) + n = length(vns) + value = init(rng, dist, spl, n) + + # Update `vi`. + for (vn, val) in zip(vns, eachcol(value)) + val_linked = maybe_link(vi, vn, dist, val) + vi = BangBang.setindex!!(vi, val_linked, vn) + end + + # Compute logp. + lp = sum(Bijectors.logpdf_with_trans(dist, value, istrans(vi, first(vns)))) + return value, lp, vi +end + +function dot_assume( + rng, + spl::Union{SampleFromPrior,SampleFromUniform}, + dists::Union{Distribution,AbstractArray{<:Distribution}}, + vns::AbstractArray{<:VarName}, + var::AbstractArray, + vi::SimpleOrThreadSafeSimple, +) + f = (vn, dist) -> init(rng, dist, spl) + value = f.(vns, dists) + + # Transform if we're working in transformed space. + ist = istrans(vi, first(vns)) + value_raw = ist ? link.(dists, value) : value # Update `vi` vi = BangBang.setindex!!(vi, value_raw, vns) From 356fa9ceae23dd2208c5b34faf3fc079f4c49cb9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 22 Jun 2022 13:04:27 +0100 Subject: [PATCH 034/101] fixed settrans!! for VarInfo --- src/varinfo.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/varinfo.jl b/src/varinfo.jl index 6f3b39fe5..194d475fc 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -381,6 +381,8 @@ function settrans!!(vi::VarInfo, trans::Bool) for vn in keys(vi) settrans!!(vi, trans, vn) end + + return vi end """ From 13f037ff2a1a036a81bb6efd2453d20b3876bf0c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 24 Jun 2022 16:47:35 +0100 Subject: [PATCH 035/101] formatting --- src/simple_varinfo.jl | 4 +++- src/varinfo.jl | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index b07b102ea..aad65dfc8 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -319,7 +319,9 @@ Base.getindex(vi::SimpleVarInfo, spl::Sampler) = vi.values # Since we don't perform any transformations in `getindex` for `SimpleVarInfo` # we simply call `getindex` in `getindex_raw`. getindex_raw(vi::SimpleVarInfo, vn::VarName) = vi[vn] -getindex_raw(vi::SimpleVarInfo, vn::VarName, dist::Distribution) = reconstruct(dist, getindex_raw(vi, vn)) +function getindex_raw(vi::SimpleVarInfo, vn::VarName, dist::Distribution) + return reconstruct(dist, getindex_raw(vi, vn)) +end getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}) = vi[vns] function getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution) vals = getindex_raw(vi, vns) diff --git a/src/varinfo.jl b/src/varinfo.jl index 194d475fc..13259a12e 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -962,7 +962,9 @@ getindex_raw(vi::AbstractVarInfo, vn::VarName) = getindex_raw(vi, vn, getdist(vi function getindex_raw(vi::AbstractVarInfo, vn::VarName, dist::Distribution) return reconstruct(dist, getval(vi, vn)) end -getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}) = getindex_raw(vi, vns, getdist(vi, first(vns))) +function getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}) + return getindex_raw(vi, vns, getdist(vi, first(vns))) +end function getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) return reconstruct(dist, getval(vi, vns), length(vns)) end From c7544e087dcc078d86cfe4de75f0e032b998c8a9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 24 Jun 2022 16:47:45 +0100 Subject: [PATCH 036/101] fixed bug where constrained/unconstrained wasn't preserved in setindex!! for SimpleVarInfo --- src/simple_varinfo.jl | 29 +++-------------------------- src/varinfo.jl | 4 +++- 2 files changed, 6 insertions(+), 27 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index aad65dfc8..1c6d4043a 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -364,7 +364,7 @@ end function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName) # For `NamedTuple` we treat the symbol in `vn` as the _property_ to set. - return SimpleVarInfo(set!!(vi.values, vn, val), vi.logp) + return Setfield.@set vi.values = set!!(vi.values, vn, val) end # TODO: Specialize to handle certain cases, e.g. a collection of `VarName` with @@ -395,7 +395,7 @@ function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName vn_key = VarName(vn, keylens) BangBang.setindex!!(dict, set!!(dict[vn_key], child, val), vn_key) end - return SimpleVarInfo(dict_new, vi.logp) + return Setfield.@set vi.values = dict_new end # `NamedTuple` @@ -503,30 +503,7 @@ function dot_assume( end # Compute logp. - lp = sum(Bijectors.logpdf_with_trans(dist, value, istrans(vi, first(vns)))) - return value, lp, vi -end - -function dot_assume( - rng, - spl::Union{SampleFromPrior,SampleFromUniform}, - dists::Union{Distribution,AbstractArray{<:Distribution}}, - vns::AbstractArray{<:VarName}, - var::AbstractArray, - vi::SimpleOrThreadSafeSimple, -) - f = (vn, dist) -> init(rng, dist, spl) - value = f.(vns, dists) - - # Transform if we're working in transformed space. - ist = istrans(vi, first(vns)) - value_raw = ist ? link.(dists, value) : value - - # Update `vi` - vi = BangBang.setindex!!(vi, value_raw, vns) - - # Compute logp. - lp = sum(Bijectors.logpdf_with_trans.(dists, value, ist)) + lp = sum(Bijectors.logpdf_with_trans(dist, value, istrans(vi))) return value, lp, vi end diff --git a/src/varinfo.jl b/src/varinfo.jl index 13259a12e..582582b77 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -947,7 +947,9 @@ function getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution) val end end -getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) = getindex(vi, vns, getdist(vi, first(vns))) +function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) + return getindex(vi, vns, getdist(vi, first(vns))) +end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" val = getindex_raw(vi, vns, dist) From d1dccf19b764b8876cf7ba35ae4d3aed0962be14 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 24 Jun 2022 16:48:27 +0100 Subject: [PATCH 037/101] hack to avoid type-instabilities for dot_assume with MultivariateDistribution --- src/utils.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index 821eba38e..2a7838e2f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -417,3 +417,14 @@ function splitlens(condition, lens) return current_parent, current_child, condition(current_parent) end + +# HACK(torfjelde): Avoids type-instability in `dot_assume` for `SimpleVarInfo`. +function BangBang.possible( + ::typeof(BangBang._setindex!), + ::C, + ::T, + ::Colon, + ::Integer +) where {C<:AbstractMatrix,T<:AbstractVector} + return BangBang.implements(setindex!, C) && promote_type(eltype(C), eltype(T)) <: eltype(C) +end From ff7ff4ab6e5da58fe5d700c4e85e8d6be44e90a9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 26 Jun 2022 14:01:10 +0100 Subject: [PATCH 038/101] style --- src/utils.jl | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 2a7838e2f..b200171a7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -420,11 +420,8 @@ end # HACK(torfjelde): Avoids type-instability in `dot_assume` for `SimpleVarInfo`. function BangBang.possible( - ::typeof(BangBang._setindex!), - ::C, - ::T, - ::Colon, - ::Integer + ::typeof(BangBang._setindex!), ::C, ::T, ::Colon, ::Integer ) where {C<:AbstractMatrix,T<:AbstractVector} - return BangBang.implements(setindex!, C) && promote_type(eltype(C), eltype(T)) <: eltype(C) + return BangBang.implements(setindex!, C) && + promote_type(eltype(C), eltype(T)) <: eltype(C) end From 2f1a2ff68d7abf3be166bd5de691ab479f3131bf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 26 Jun 2022 14:01:35 +0100 Subject: [PATCH 039/101] added keys implementations for the models in TestUtils to make testing AbstractVarInfo implementations easier --- src/test_utils.jl | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/test_utils.jl b/src/test_utils.jl index ca0fabc9a..bd4c0bedc 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -66,6 +66,9 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe)}, m) return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) end +function Base.keys(model::Model{typeof(demo_dot_assume_dot_observe)}) + return [@varname(m[1]), @varname(m[2])] +end @model function demo_assume_index_observe( x=[10.0, 10.0], ::Type{TV}=Vector{Float64} @@ -85,6 +88,9 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_index_observe)}, m) return logpdf(MvNormal(m, 0.25 * I), model.args.x) end +function Base.keys(model::Model{typeof(demo_assume_index_observe)}) + return [@varname(m[1]), @varname(m[2])] +end @model function demo_assume_multivariate_observe(x=[10.0, 10.0]) # Multivariate `assume` and `observe` @@ -99,6 +105,9 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_multivariate_observe)}, m) return logpdf(MvNormal(m, 0.25 * I), model.args.x) end +function Base.keys(model::Model{typeof(demo_assume_multivariate_observe)}) + return [@varname(m)] +end @model function demo_dot_assume_observe_index( x=[10.0, 10.0], ::Type{TV}=Vector{Float64} @@ -118,6 +127,9 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index)}, m) return sum(logpdf.(Normal.(m, 0.5), model.args.x)) end +function Base.keys(model::Model{typeof(demo_dot_assume_observe_index)}) + return [@varname(m[1]), @varname(m[2])] +end # Using vector of `length` 1 here so the posterior of `m` is the same # as the others. @@ -134,6 +146,9 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_dot_observe)}, m) return sum(logpdf.(Normal.(m, 0.5), model.args.x)) end +function Base.keys(model::Model{typeof(demo_assume_dot_observe)}) + return [@varname(m)] +end @model function demo_assume_observe_literal() # `assume` and literal `observe` @@ -148,6 +163,9 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, m) return logpdf(MvNormal(m, 0.25 * I), [10.0, 10.0]) end +function Base.keys(model::Model{typeof(demo_assume_observe_literal)}) + return [@varname(m)] +end @model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and literal `observe` with indexing @@ -165,6 +183,9 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, m) return sum(logpdf.(Normal.(m, 0.5), fill(10.0, length(m)))) end +function Base.keys(model::Model{typeof(demo_dot_assume_observe_index_literal)}) + return [@varname(m[1]), @varname(m[2])] +end @model function demo_assume_literal_dot_observe() # `assume` and literal `dot_observe` @@ -179,6 +200,9 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, m) return logpdf(Normal(m, 0.5), 10.0) end +function Base.keys(model::Model{typeof(demo_assume_literal_dot_observe)}) + return [@varname(m)] +end @model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} m = TV(undef, 2) @@ -204,6 +228,9 @@ function loglikelihood_true( ) return sum(logpdf.(Normal.(m, 0.5), 10.0)) end +function Base.keys(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) + return [@varname(m[1]), @varname(m[2])] +end @model function _likelihood_dot_observe(m, x) return x ~ MvNormal(m, 0.25 * I) @@ -226,6 +253,9 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, m) return logpdf(MvNormal(m, 0.25 * I), model.args.x) end +function Base.keys(model::Model{typeof(demo_dot_assume_observe_submodel)}) + return [@varname(m[1]), @varname(m[2])] +end @model function demo_dot_assume_dot_observe_matrix( x=fill(10.0, 2, 1), ::Type{TV}=Vector{Float64} From d6311b74e0d62941361b959ab08d9d4b986b5306 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 26 Jun 2022 14:02:08 +0100 Subject: [PATCH 040/101] added additional test model which uses dot-assume on MultivariateDistribution --- src/test_utils.jl | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/test_utils.jl b/src/test_utils.jl index bd4c0bedc..faa390cea 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -274,6 +274,33 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, m) return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) end +function Base.keys(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) + return [@varname(m[1]), @varname(m[2])] +end + +@model function demo_dot_assume_matrix_dot_observe_matrix( + x=fill(10.0, 2, 1), ::Type{TV}=Array{Float64} +) where {TV} + d = length(x) ÷ 2 + m = TV(undef, d, 2) + m .~ MvNormal(zeros(d), I) + + # Dotted observe for `Matrix`. + x .~ MvNormal(vec(m), 0.25 * I) + + return (; m=m, x=x, logp=getlogp(__varinfo__)) +end +function logprior_true(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, m) + return loglikelihood(Normal(), vec(m)) +end +function loglikelihood_true( + model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, m +) + return loglikelihood(MvNormal(vec(m), 0.25 * I), model.args.x) +end +function Base.keys(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) + return [@varname(m[:, 1]), @varname(m[:, 2])] +end const DEMO_MODELS = ( demo_dot_assume_dot_observe(), @@ -287,6 +314,7 @@ const DEMO_MODELS = ( demo_assume_submodel_observe_index_literal(), demo_dot_assume_observe_submodel(), demo_dot_assume_dot_observe_matrix(), + demo_dot_assume_matrix_dot_observe_matrix(), ) # TODO: Is this really the best/most convenient "default" test method? From ed2fa693d6d988a0cd0f7162000c1ece820cf815 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 26 Jun 2022 14:02:47 +0100 Subject: [PATCH 041/101] updated tests for SimpleVarInfo --- test/simple_varinfo.jl | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 7e9346450..5b7c6ba71 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -85,12 +85,14 @@ isunivariate = !haskey(svi_new, @varname(m[1])) # Realization for `m` should be different wp. 1. - if isunivariate - @test svi_new[@varname(m)] != m - else - @test svi_new[@varname(m[1])] != m[1] - @test svi_new[@varname(m[2])] != m[2] + for vn in keys(model) + # `VarName` functions similarly to `PropertyLens` so + # we just strip this part from `vn` to get a lens we can use + # to extract the corresponding value of `m`. + l = getlens(vn) + @test svi_new[vn] != get(m, l) end + # Logjoint should be non-zero wp. 1. @test getlogp(svi_new) != 0 @@ -103,26 +105,26 @@ end # Update the realizations in `svi_new`. - svi_eval = if isunivariate - DynamicPPL.setindex!!(svi_new, m_eval, @varname(m)) - else - DynamicPPL.setindex!!(svi_new, m_eval, [@varname(m[1]), @varname(m[2])]) + svi_eval = svi_new + for vn in keys(model) + l = getlens(vn) + svi_eval = DynamicPPL.setindex!!(svi_eval, get(m_eval, l), vn) end + # Reset the logp field. svi_eval = DynamicPPL.resetlogp!!(svi_eval) # Compute `logjoint` using the varinfo. logπ = logjoint(model, svi_eval) - # Extract the parameters from `svi_eval`. - m_vi = if isunivariate - svi_eval[@varname(m)] - else - svi_eval[[@varname(m[1]), @varname(m[2])]] + + # Values should not have changed. + for vn in keys(model) + l = getlens(vn) + @test svi_eval[vn] == get(m_eval, l) end - # These should not have changed. - @test m_vi == m_eval + # Compute the true `logjoint` and compare. - logπ_true = DynamicPPL.TestUtils.logjoint_true(model, m_vi) + logπ_true = DynamicPPL.TestUtils.logjoint_true(model, m_eval) @test logπ ≈ logπ_true end end From a82be563fd498907bc0526bda5a4d22b56999002 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 26 Jun 2022 14:05:32 +0100 Subject: [PATCH 042/101] added a no-op reconstruct for UnivariateDistribution --- src/utils.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/utils.jl b/src/utils.jl index b200171a7..ac9222818 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -183,6 +183,7 @@ vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r)) # otherwise we will have error for MatrixDistribution. # Note this is not the case for MultivariateDistribution so I guess this might be lack of # support for some types related to matrices (like PDMat). +reconstruct(d::UnivariateDistribution, val::Real) = val reconstruct(d::Distribution, val::AbstractVector) = reconstruct(size(d), val) reconstruct(::Tuple{}, val::AbstractVector) = val[1] reconstruct(s::NTuple{1}, val::AbstractVector) = copy(val) From 7aacee5e09fdb42de7007811da75974201d4b7d5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 27 Jun 2022 13:20:09 +0100 Subject: [PATCH 043/101] fixed tests for loglikelihoods --- test/loglikelihoods.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index 4d5003f03..0e7f9a3d9 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -1,13 +1,13 @@ @testset "loglikelihoods.jl" begin - for m in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS vi = VarInfo(m) - vns = vi.metadata.m.vns - if length(vns) == 1 && length(vi[vns[1]]) == 1 - # Only have one latent variable. - DynamicPPL.setval!(vi, [1.0], ["m"]) - else - DynamicPPL.setval!(vi, [1.0, 1.0], ["m[1]", "m[2]"]) + for vn in keys(m) + if vi[vn] isa Real + vi = DynamicPPL.setindex!!(vi, 1.0, vn) + else + vi = DynamicPPL.setindex!!(vi, ones(size(vi[vn])), vn) + end end lls = pointwise_loglikelihoods(m, vi) From 96f128fe4d7a281a49f87acf2244f73b2076455e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 27 Jun 2022 13:20:45 +0100 Subject: [PATCH 044/101] fixed dot_tilde_assume for LikelihoodContext --- src/context_implementations.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index a9bf2db92..9263dc441 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -320,13 +320,16 @@ function dot_tilde_assume( dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi) end end + function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi) - return dot_assume(NoDist.(right), left, vn, vi) + nodist = right isa Distribution ? NoDist(right) : NoDist.(right) + return dot_assume(nodist, left, vn, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, vi ) - return dot_assume(rng, sampler, NoDist.(right), vn, left, vi) + nodist = right isa Distribution ? NoDist(right) : NoDist.(right) + return dot_assume(rng, sampler, nodist, vn, left, vi) end # `PriorContext` From 2e88d087eeea267d241667fd495b743603870872 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 27 Jun 2022 13:21:03 +0100 Subject: [PATCH 045/101] removed some now redundant explicit calls to maybe_invlink --- src/context_implementations.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 9263dc441..0ebe4bedd 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -217,9 +217,7 @@ function assume( setorder!(vi, vn, get_num_produce(vi)) else # Otherwise we just extract it. - # r = vi[vn] - r_raw = getindex_raw(vi, vn, dist) - r = maybe_invlink(vi, vn, dist, r_raw) + r = vi[vn, dist] end else r = init(rng, dist, sampler) @@ -481,8 +479,7 @@ function get_and_set_val!( setorder!(vi, vn, get_num_produce(vi)) end else - r_raw = getindex_raw(vi, vns) - r = maybe_invlink(vi, vns, dist, r_raw) + r = vi[vns, dist] end else r = init(rng, dist, spl, n) From 0f9765bda684b27202982cf95d11e8de07304f62 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 27 Jun 2022 13:21:31 +0100 Subject: [PATCH 046/101] added impls of size and length for the wrapper distributions so they work for reconstruct --- src/distribution_wrappers.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index 4045cc089..07dc6f93f 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -13,6 +13,9 @@ end NamedDist(dist::Distribution, name::Symbol) = NamedDist(dist, VarName{name}()) +Base.length(dist::NamedDist) = Base.length(dist.dist) +Base.size(dist::NamedDist) = Base.size(dist.dist) + Distributions.logpdf(dist::NamedDist, x::Real) = Distributions.logpdf(dist.dist, x) function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real}) return Distributions.logpdf(dist.dist, x) @@ -24,12 +27,17 @@ function Distributions.loglikelihood(dist::NamedDist, x::AbstractArray{<:Real}) return Distributions.loglikelihood(dist.dist, x) end +Bijectors.bijector(d::NamedDist) = Bijectors.bijector(d.dist) + struct NoDist{variate,support,Td<:Distribution{variate,support}} <: Distribution{variate,support} dist::Td end NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name) +Base.length(dist::NoDist) = Base.length(dist.dist) +Base.size(dist::NoDist) = Base.size(dist.dist) + Distributions.rand(rng::Random.AbstractRNG, d::NoDist) = rand(rng, d.dist) Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0 Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0 From 116c95c0e695648a24e57b1d9581a9f9b087089a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 28 Jun 2022 17:52:49 +0100 Subject: [PATCH 047/101] bumped version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 305a0a0b3..7adb220b3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.19.2" +version = "0.19.3" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From d797e997399365e56b9d950e9367761ba85aff81 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 28 Jun 2022 17:53:46 +0100 Subject: [PATCH 048/101] removed redunant explict call to maybe_invlink --- src/context_implementations.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 0ebe4bedd..ddc62b639 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -194,9 +194,7 @@ end # fallback without sampler function assume(dist::Distribution, vn::VarName, vi) - # x = vi[vn] - r_raw = getindex_raw(vi, vn, dist) - r = maybe_invlink(vi, vn, dist, r_raw) + r = vi[vn, dist] return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi end From 44b2f66bd170e1e807c08aa7f2225f5b102151bc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 29 Jun 2022 10:54:45 +0100 Subject: [PATCH 049/101] added test model with array on RHS of a .~ statement --- src/test_utils.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/test_utils.jl b/src/test_utils.jl index faa390cea..172df1775 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -302,6 +302,25 @@ function Base.keys(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix return [@varname(m[:, 1]), @varname(m[:, 2])] end +@model function demo_dot_assume_array_dot_observe( + x=[10.0, 10.0], ::Type{TV}=Vector{Float64} +) where {TV} + # `dot_assume` and `observe` + m = TV(undef, length(x)) + m .~ [Normal() for _ in 1:length(x)] + x ~ MvNormal(m, 0.25 * I) + return (; m=m, x=x, logp=getlogp(__varinfo__)) +end +function logprior_true(model::Model{typeof(demo_dot_assume_array_dot_observe)}, m) + return loglikelihood(Normal(), m) +end +function loglikelihood_true(model::Model{typeof(demo_dot_assume_array_dot_observe)}, m) + return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) +end +function Base.keys(model::Model{typeof(demo_dot_assume_array_dot_observe)}) + return [@varname(m[1]), @varname(m[2])] +end + const DEMO_MODELS = ( demo_dot_assume_dot_observe(), demo_assume_index_observe(), @@ -315,6 +334,7 @@ const DEMO_MODELS = ( demo_dot_assume_observe_submodel(), demo_dot_assume_dot_observe_matrix(), demo_dot_assume_matrix_dot_observe_matrix(), + demo_dot_assume_array_dot_observe(), ) # TODO: Is this really the best/most convenient "default" test method? From 81cd88145e4dd1c5b15170d9b13c90e8f73a4865 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 29 Jun 2022 10:55:14 +0100 Subject: [PATCH 050/101] improved some of the default implementations of dot_assume --- src/context_implementations.jl | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index ddc62b639..aead1dde1 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -398,9 +398,7 @@ function dot_assume( # m .~ Normal() # # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - # r = vi[vns] - r_raw = getindex_raw(vi, vns, dist) - r = maybe_invlink(vi, vns, dist, r_raw) + r = vi[vns, dist] lp = sum(zip(vns, eachcol(r))) do (vn, ri) return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) end @@ -422,20 +420,22 @@ function dot_assume( end function dot_assume( - dists::Union{Distribution,AbstractArray{<:Distribution}}, + dist::Distribution, var::AbstractArray, vns::AbstractArray{<:VarName}, vi +) + r = map(vn -> vi[vn, dist], vns) + lp = sum(Bijectors.logpdf_with_trans.(dist, r, map(Base.Fix1(istrans, vi), vns))) + return r, lp, vi +end + +function dot_assume( + dists::AbstractArray{<:Distribution}, var::AbstractArray, vns::AbstractArray{<:VarName}, vi, ) - # NOTE: We cannot work with `var` here because we might have a model of the form - # - # m = Vector{Float64}(undef, n) - # m .~ Normal() - # - # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - r_raw = getindex_raw(vi, vec(vns), dists) - r = reshape(maybe_invlink.(Ref(vi), vns, dists, r_raw), size(vns)) - lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) + @assert length(vns) == length(dists) == length(var) + r = map((vn, dist) -> vi[vn, dist], vns, dists) + lp = sum(Bijectors.logpdf_with_trans.(dists, r, map(Base.Fix1(istrans, vi), vns))) return r, lp, vi end @@ -449,7 +449,7 @@ function dot_assume( ) r = get_and_set_val!(rng, vi, vns, dists, spl) # Make sure `r` is not a matrix for multivariate distributions - lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) + lp = sum(Bijectors.logpdf_with_trans.(dists, r, map(Base.Fix1(istrans, vi), vns))) return r, lp, vi end function dot_assume(rng, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any) From 2e14abd633c544c723d1651425ea45e20132734a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 29 Jun 2022 11:07:58 +0100 Subject: [PATCH 051/101] removed unnecessary code in tests --- test/simple_varinfo.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 5b7c6ba71..092189984 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -79,11 +79,6 @@ # Sample a new varinfo! _, svi_new = DynamicPPL.evaluate!!(model, svi, SamplingContext()) - # If the `m[1]` varname doesn't exist, this is a univariate model. - # TODO: Find a better way of dealing with this that is not dependent - # on knowledge of internals of `model`. - isunivariate = !haskey(svi_new, @varname(m[1])) - # Realization for `m` should be different wp. 1. for vn in keys(model) # `VarName` functions similarly to `PropertyLens` so From 12adc83b793a8329bb091215ae6b6e0c49a5daec Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 29 Jun 2022 11:08:22 +0100 Subject: [PATCH 052/101] improved linking usage in assumes for SimpleVarInfo --- src/simple_varinfo.jl | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 1c6d4043a..c7780b872 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -453,10 +453,9 @@ function assume( ) value = init(rng, dist, sampler) # Transform if we're working in unconstrained space. - ist = istrans(vi, vn) - value_raw = ist ? Bijectors.link(dist, value) : value + value_raw = maybe_link(vi, vn, dist, value) vi = BangBang.push!!(vi, vn, value_raw, dist, sampler) - return value, Bijectors.logpdf_with_trans(dist, value, ist), vi + return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi end function dot_assume( @@ -471,14 +470,19 @@ function dot_assume( value = f.(vns, dists) # Transform if we're working in transformed space. - ist = istrans(vi, first(vns)) - value_raw = ist ? link.(dists, value) : value + value_raw = if dists isa Distribution + @assert length(vns) == length(value) + map((vn, val) -> maybe_link(vi, vn, dists, val), vns, value) + else + @assert length(vns) == length(dists) == length(value) + map((vn, dist, val) -> maybe_link(vi, vn, dist, val), vns, dists, value) + end # Update `vi` vi = BangBang.setindex!!(vi, value_raw, vns) # Compute logp. - lp = sum(Bijectors.logpdf_with_trans.(dists, value, ist)) + lp = sum(Bijectors.logpdf_with_trans.(dists, value, map(Base.Fix1(istrans, vi), vns))) return value, lp, vi end From f7501dfb4f17ddae7618708d539e704c38ad88fc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 30 Jun 2022 09:49:40 +0100 Subject: [PATCH 053/101] added model for testing dynamic constraints --- src/test_utils.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/test_utils.jl b/src/test_utils.jl index 172df1775..a25c25df7 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -49,6 +49,28 @@ function logjoint_true(model::Model, args...) return logprior_true(model, args...) + loglikelihood_true(model, args...) end +""" + demo_dynamic_constraint() + +A model with variables `m` and `x` with `x` having support depending on `m`. +""" +@model function demo_dynamic_constraint() + m ~ Normal() + x ~ truncated(Normal(), m, Inf) + + return (m=m, x=x) +end + +function logprior_true(model::Model{typeof(demo_dynamic_constraint)}, m, x) + return logpdf(Normal(), m) + logpdf(truncated(Normal(), m, Inf)) +end +function loglikelihood_true(model::Model{typeof(demo_dynamic_constraint)}, m, x) + return zero(float(eltype(m))) +end +function Base.keys(model::Model{typeof(demo_dynamic_constraint)}) + return [@varname(m), @varname(x)] +end + # A collection of models for which the mean-of-means for the posterior should # be same. @model function demo_dot_assume_dot_observe( From abcabf49b4f1ef6c468e22a88ce679d3ba8b9b84 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 30 Jun 2022 10:48:18 +0100 Subject: [PATCH 054/101] added logjoint_true_with_logabsdet_jacobian to TestUtils --- src/test_utils.jl | 54 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 50 insertions(+), 4 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index a25c25df7..0b9a5526b 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -6,6 +6,8 @@ using LinearAlgebra using Distributions using Test +using Bijectors: Bijectors + """ logprior_true(model, θ) @@ -29,11 +31,11 @@ See also: [`logjoint_true`](@ref), [`logprior_true`](@ref). function loglikelihood_true end """ - logjoint_true(model, θ) + logjoint_true(model, args...) -Return the `logjoint` of `model` for `θ`. +Return the `logjoint` of `model` for `args...`. -Defaults to `logprior_true(model, θ) + loglikelihood_true(model, θ)`. +Defaults to `logprior_true(model, args...) + loglikelihood_true(model, args..)`. This should generally be implemented by hand for every specific `model` so that the returned value can be used as a ground-truth for testing things like: @@ -49,6 +51,42 @@ function logjoint_true(model::Model, args...) return logprior_true(model, args...) + loglikelihood_true(model, args...) end +""" + logjoint_true_with_logabsdet_jacobian(model::Model, args...) + +Return a tuple `(args_unconstrained, logjoint)` of `model` for `args...`. + +Unlike [`logjoint_true`](@ref), the returned logjoint computation includes the +log-absdet-jacobian adjustment, thus computing logjoint for the unconstrained variables. + +Note that `args` are assumed be in the support of `model`, while `args_unconstrained` +will be unconstrained. + +This should generally not be implemented directly, instead one should implement +[`logprior_true_with_logabsdet_jacobian`](@ref) for a given `model`. + +See also: [`logjoint_true`](@ref), [`logprior_true_with_logabsdet_jacobian`](@ref). +""" +function logjoint_true_with_logabsdet_jacobian(model::Model, args...) + args_unconstrained, lp = logprior_true_with_logabsdet_jacobian(model, args...) + return args_unconstrained, lp + loglikelihood_true(model, args...) +end + +""" + logprior_true_with_logabsdet_jacobian(model::Model, args...) + +Return a tuple `(args_unconstrained, logprior_unconstrained)` of `model` for `args...`. + +Unlike [`logprior_true`](@ref), the returned logprior computation includes the +log-absdet-jacobian adjustment, thus computing logprior for the unconstrained variables. + +Note that `args` are assumed be in the support of `model`, while `args_unconstrained` +will be unconstrained. + +See also: [`logprior_true`](@ref). +""" +function logprior_true_with_logabsdet_jacobian end + """ demo_dynamic_constraint() @@ -62,7 +100,7 @@ A model with variables `m` and `x` with `x` having support depending on `m`. end function logprior_true(model::Model{typeof(demo_dynamic_constraint)}, m, x) - return logpdf(Normal(), m) + logpdf(truncated(Normal(), m, Inf)) + return logpdf(Normal(), m) + logpdf(truncated(Normal(), m, Inf), x) end function loglikelihood_true(model::Model{typeof(demo_dynamic_constraint)}, m, x) return zero(float(eltype(m))) @@ -71,6 +109,14 @@ function Base.keys(model::Model{typeof(demo_dynamic_constraint)}) return [@varname(m), @varname(x)] end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dynamic_constraint)}, m, x +) + b_x = Bijectors.bijector(truncated(Normal(), m, Inf)) + x_unconstrained, Δlogp = Bijectors.with_logabsdet_jacobian(b_x, x) + return (m=m, x=x_unconstrained), logprior_true(model, m, x) - Δlogp +end + # A collection of models for which the mean-of-means for the posterior should # be same. @model function demo_dot_assume_dot_observe( From fdee509a0dd605fc850863d68281aa55734ff079 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 30 Jun 2022 11:00:32 +0100 Subject: [PATCH 055/101] added test for dynamic constraints for SimpleVarInfo --- test/model.jl | 7 +------ test/simple_varinfo.jl | 31 +++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/test/model.jl b/test/model.jl index efe07b362..8eabeeba9 100644 --- a/test/model.jl +++ b/test/model.jl @@ -113,12 +113,7 @@ end end @testset "Dynamic constraints" begin - @model function dynamic_constraints() - m ~ Normal() - return x ~ truncated(Normal(), m, Inf) - end - - model = dynamic_constraints() + model = DynamicPPL.TestUtils.demo_dynamic_constraint() vi = VarInfo(model) spl = SampleFromPrior() link!(vi, spl) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 092189984..9494ae6c1 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -123,4 +123,35 @@ @test logπ ≈ logπ_true end end + + @testset "Dynamic constraints" begin + model = DynamicPPL.TestUtils.demo_dynamic_constraint() + + # Initialize. + svi = DynamicPPL.settrans!!(SimpleVarInfo(), true) + svi = last(DynamicPPL.evaluate!!(model, svi, SamplingContext())) + + # Sample with large variations in unconstrained space. + for i in 1:10 + for vn in keys(svi) + svi = DynamicPPL.setindex!!(svi, 10 * randn(), vn) + end + retval, svi = DynamicPPL.evaluate!!(model, svi, DefaultContext()) + @test retval.m == svi[@varname(m)] # `m` is unconstrained + @test retval.x ≠ svi[@varname(x)] # `x` is constrained depending on `m` + + retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + model, retval.m, retval.x + ) + + # Realizations from model should all be equal to the unconstrained realization. + for vn in keys(model) + @test get(retval_unconstrained, vn) ≈ svi[vn] rtol = 1e-6 + end + + # `getlogp` should be equal to the logjoint with log-absdet-jac correction. + lp = getlogp(svi) + @test lp ≈ lp_true + end + end end From e974c8335bcd359490ff4e2eafffc343fd7cb96a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 30 Jun 2022 11:00:50 +0100 Subject: [PATCH 056/101] fixed keys implementation of SimpleVarInfo --- src/simple_varinfo.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index c7780b872..02565b64a 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -241,6 +241,7 @@ acclogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = getlogp(vi) + logp Return an iterator of keys present in `vi`. """ Base.keys(vi::SimpleVarInfo) = keys(vi.values) +Base.keys(vi::SimpleVarInfo{<:NamedTuple}) = map(k -> VarName{k}(), keys(vi.values)) function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) vi.logp[] = logp @@ -311,8 +312,10 @@ end # HACK: Needed to disambiguiate. Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(Base.Fix1(getindex, vi), vns) -Base.getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.values -Base.getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.values +Base.getindex(vi::SimpleVarInfo, ::Colon) = vi.values +Base.getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi[:] +Base.getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi[:] + # TODO: Should we do better? Base.getindex(vi::SimpleVarInfo, spl::Sampler) = vi.values From 6c6d5f57fda1922cb017999151d4bb90e582340a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 30 Jun 2022 11:03:18 +0100 Subject: [PATCH 057/101] reverted unintended change --- src/simple_varinfo.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 02565b64a..92abc327f 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -312,9 +312,8 @@ end # HACK: Needed to disambiguiate. Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(Base.Fix1(getindex, vi), vns) -Base.getindex(vi::SimpleVarInfo, ::Colon) = vi.values -Base.getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi[:] -Base.getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi[:] +Base.getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.values +Base.getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.values # TODO: Should we do better? Base.getindex(vi::SimpleVarInfo, spl::Sampler) = vi.values From 801bd4caf110367b6f92f81e7fb226cf64e354b2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 11:22:50 +0100 Subject: [PATCH 058/101] renamed Base.keys(model) to varnames(model) in TestUtils --- src/test_utils.jl | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 0b9a5526b..f9130bd61 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -105,7 +105,7 @@ end function loglikelihood_true(model::Model{typeof(demo_dynamic_constraint)}, m, x) return zero(float(eltype(m))) end -function Base.keys(model::Model{typeof(demo_dynamic_constraint)}) +function varnames(model::Model{typeof(demo_dynamic_constraint)}) return [@varname(m), @varname(x)] end @@ -134,7 +134,7 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe)}, m) return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) end -function Base.keys(model::Model{typeof(demo_dot_assume_dot_observe)}) +function varnames(model::Model{typeof(demo_dot_assume_dot_observe)}) return [@varname(m[1]), @varname(m[2])] end @@ -156,7 +156,7 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_index_observe)}, m) return logpdf(MvNormal(m, 0.25 * I), model.args.x) end -function Base.keys(model::Model{typeof(demo_assume_index_observe)}) +function varnames(model::Model{typeof(demo_assume_index_observe)}) return [@varname(m[1]), @varname(m[2])] end @@ -173,7 +173,7 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_multivariate_observe)}, m) return logpdf(MvNormal(m, 0.25 * I), model.args.x) end -function Base.keys(model::Model{typeof(demo_assume_multivariate_observe)}) +function varnames(model::Model{typeof(demo_assume_multivariate_observe)}) return [@varname(m)] end @@ -195,7 +195,7 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index)}, m) return sum(logpdf.(Normal.(m, 0.5), model.args.x)) end -function Base.keys(model::Model{typeof(demo_dot_assume_observe_index)}) +function varnames(model::Model{typeof(demo_dot_assume_observe_index)}) return [@varname(m[1]), @varname(m[2])] end @@ -214,7 +214,7 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_dot_observe)}, m) return sum(logpdf.(Normal.(m, 0.5), model.args.x)) end -function Base.keys(model::Model{typeof(demo_assume_dot_observe)}) +function varnames(model::Model{typeof(demo_assume_dot_observe)}) return [@varname(m)] end @@ -231,7 +231,7 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, m) return logpdf(MvNormal(m, 0.25 * I), [10.0, 10.0]) end -function Base.keys(model::Model{typeof(demo_assume_observe_literal)}) +function varnames(model::Model{typeof(demo_assume_observe_literal)}) return [@varname(m)] end @@ -251,7 +251,7 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, m) return sum(logpdf.(Normal.(m, 0.5), fill(10.0, length(m)))) end -function Base.keys(model::Model{typeof(demo_dot_assume_observe_index_literal)}) +function varnames(model::Model{typeof(demo_dot_assume_observe_index_literal)}) return [@varname(m[1]), @varname(m[2])] end @@ -268,7 +268,7 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, m) return logpdf(Normal(m, 0.5), 10.0) end -function Base.keys(model::Model{typeof(demo_assume_literal_dot_observe)}) +function varnames(model::Model{typeof(demo_assume_literal_dot_observe)}) return [@varname(m)] end @@ -296,7 +296,7 @@ function loglikelihood_true( ) return sum(logpdf.(Normal.(m, 0.5), 10.0)) end -function Base.keys(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) +function varnames(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) return [@varname(m[1]), @varname(m[2])] end @@ -321,7 +321,7 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, m) return logpdf(MvNormal(m, 0.25 * I), model.args.x) end -function Base.keys(model::Model{typeof(demo_dot_assume_observe_submodel)}) +function varnames(model::Model{typeof(demo_dot_assume_observe_submodel)}) return [@varname(m[1]), @varname(m[2])] end @@ -342,7 +342,7 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, m) return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) end -function Base.keys(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) +function varnames(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) return [@varname(m[1]), @varname(m[2])] end @@ -366,7 +366,7 @@ function loglikelihood_true( ) return loglikelihood(MvNormal(vec(m), 0.25 * I), model.args.x) end -function Base.keys(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) +function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) return [@varname(m[:, 1]), @varname(m[:, 2])] end @@ -385,7 +385,7 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_array_dot_observe)}, m) return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) end -function Base.keys(model::Model{typeof(demo_dot_assume_array_dot_observe)}) +function varnames(model::Model{typeof(demo_dot_assume_array_dot_observe)}) return [@varname(m[1]), @varname(m[2])] end From 46f6f4c835e22976ab6333f8edfadb831cc8ff1b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 11:24:32 +0100 Subject: [PATCH 059/101] added default implementation and docstring for TestUtils.varnames --- src/test_utils.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/test_utils.jl b/src/test_utils.jl index f9130bd61..ea509e2de 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -87,6 +87,20 @@ See also: [`logprior_true`](@ref). """ function logprior_true_with_logabsdet_jacobian end +""" + varnames(model::Model) + +Return a collection of `VarName` as they are expected to appear in the model. + +Even though it is recommended to implement this by hand for a particular `Model`, +a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. +""" +function varnames(model::Model) + return collect( + keys(last(DynamicPPL.evaluate!!(model, SimpleVarInfo(Dict()), SamplingContext()))) + ) +end + """ demo_dynamic_constraint() From bcb767b34666a75db34b3676e214738f1015afb5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 11:27:14 +0100 Subject: [PATCH 060/101] replace handwritten by DocStringExtensions --- src/simple_varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 92abc327f..e82b35486 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -4,7 +4,7 @@ struct Constrained <: AbstractConstraint end struct Unconstrained <: AbstractConstraint end """ - SimpleVarInfo{NT,T,C} <: AbstractVarInfo + $(TYPEDEF) A simple wrapper of the parameters with a `logp` field for accumulation of the logdensity. From c5be1c2f8e2856f2fe509c6be6850bab82e1ab07 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 11:32:53 +0100 Subject: [PATCH 061/101] Apply suggestions from @devmotion Co-authored-by: David Widmann --- src/context_implementations.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index aead1dde1..855717eea 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -423,7 +423,7 @@ function dot_assume( dist::Distribution, var::AbstractArray, vns::AbstractArray{<:VarName}, vi ) r = map(vn -> vi[vn, dist], vns) - lp = sum(Bijectors.logpdf_with_trans.(dist, r, map(Base.Fix1(istrans, vi), vns))) + lp = sum(Bijectors.logpdf_with_trans.(dist, r, istrans.((vi,), vns))) return r, lp, vi end @@ -435,7 +435,7 @@ function dot_assume( ) @assert length(vns) == length(dists) == length(var) r = map((vn, dist) -> vi[vn, dist], vns, dists) - lp = sum(Bijectors.logpdf_with_trans.(dists, r, map(Base.Fix1(istrans, vi), vns))) + lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi, ), vns))) return r, lp, vi end @@ -449,7 +449,7 @@ function dot_assume( ) r = get_and_set_val!(rng, vi, vns, dists, spl) # Make sure `r` is not a matrix for multivariate distributions - lp = sum(Bijectors.logpdf_with_trans.(dists, r, map(Base.Fix1(istrans, vi), vns))) + lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))) return r, lp, vi end function dot_assume(rng, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any) From f266929153ba5e4b2e74ba41f90180e01eae7710 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 11:37:00 +0100 Subject: [PATCH 062/101] Update src/context_implementations.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 855717eea..f81809664 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -435,7 +435,7 @@ function dot_assume( ) @assert length(vns) == length(dists) == length(var) r = map((vn, dist) -> vi[vn, dist], vns, dists) - lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi, ), vns))) + lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))) return r, lp, vi end From c2dbbafbdd2f47bd9b1131bff52204f842e980fa Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 11:41:20 +0100 Subject: [PATCH 063/101] removed some asserts and use broadcast instead of map --- src/context_implementations.jl | 1 - src/simple_varinfo.jl | 6 ++---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index f81809664..5a8136a13 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -433,7 +433,6 @@ function dot_assume( vns::AbstractArray{<:VarName}, vi, ) - @assert length(vns) == length(dists) == length(var) r = map((vn, dist) -> vi[vn, dist], vns, dists) lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))) return r, lp, vi diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index e82b35486..70a6a4770 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -473,11 +473,9 @@ function dot_assume( # Transform if we're working in transformed space. value_raw = if dists isa Distribution - @assert length(vns) == length(value) - map((vn, val) -> maybe_link(vi, vn, dists, val), vns, value) + maybe_link.((vi,), vns, (dists, ), value) else - @assert length(vns) == length(dists) == length(value) - map((vn, dist, val) -> maybe_link(vi, vn, dist, val), vns, dists, value) + maybe_link.((vi,), vns, dists, value) end # Update `vi` From 1abb46c92e8dd4d9682d7116927e836f0a93bef7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 11:43:12 +0100 Subject: [PATCH 064/101] replace map with broadcasting to ensure consistent behavior --- src/context_implementations.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 5a8136a13..24ee74439 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -422,8 +422,8 @@ end function dot_assume( dist::Distribution, var::AbstractArray, vns::AbstractArray{<:VarName}, vi ) - r = map(vn -> vi[vn, dist], vns) - lp = sum(Bijectors.logpdf_with_trans.(dist, r, istrans.((vi,), vns))) + r = getindex.((vi,), vns, (dist,)) + lp = sum(Bijectors.logpdf_with_trans.((dist,), r, istrans.((vi,), vns))) return r, lp, vi end @@ -433,7 +433,7 @@ function dot_assume( vns::AbstractArray{<:VarName}, vi, ) - r = map((vn, dist) -> vi[vn, dist], vns, dists) + r = getindex.((vi,), vns, dists) lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))) return r, lp, vi end From 1086c6c8fd8fd4ffa65dcc5266e3eec1bffc1508 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 11:45:40 +0100 Subject: [PATCH 065/101] Update src/simple_varinfo.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/simple_varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 70a6a4770..4ca25b2f2 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -473,7 +473,7 @@ function dot_assume( # Transform if we're working in transformed space. value_raw = if dists isa Distribution - maybe_link.((vi,), vns, (dists, ), value) + maybe_link.((vi,), vns, (dists,), value) else maybe_link.((vi,), vns, dists, value) end From f2fb4a5c53c8fbbeafbb9d8395810530a1982bf7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 11:47:07 +0100 Subject: [PATCH 066/101] added a method nodist to allow broadcasting NoDist constructor --- src/context_implementations.jl | 4 ++-- src/distribution_wrappers.jl | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 24ee74439..a255eb02f 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -318,13 +318,13 @@ function dot_tilde_assume( end function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi) - nodist = right isa Distribution ? NoDist(right) : NoDist.(right) + nodist = nodist.(right) return dot_assume(nodist, left, vn, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, vi ) - nodist = right isa Distribution ? NoDist(right) : NoDist.(right) + nodist = nodist.(right) return dot_assume(rng, sampler, nodist, vn, left, vi) end diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index 07dc6f93f..fb4ecab64 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -35,6 +35,9 @@ struct NoDist{variate,support,Td<:Distribution{variate,support}} <: end NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name) +# HACK(torfjelde): Useful to have a constructor we can use in broadcasting. +nodist(dist::Distribution) = NoDist(dist) + Base.length(dist::NoDist) = Base.length(dist.dist) Base.size(dist::NoDist) = Base.size(dist.dist) From 490d24eea599b8bc113d8be5a8af92e1be4d0edc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 11:52:50 +0100 Subject: [PATCH 067/101] updated some tests --- test/simple_varinfo.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 9494ae6c1..9847d6959 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -80,7 +80,7 @@ _, svi_new = DynamicPPL.evaluate!!(model, svi, SamplingContext()) # Realization for `m` should be different wp. 1. - for vn in keys(model) + for vn in DynamicPPL.TestUtils.varnames(model) # `VarName` functions similarly to `PropertyLens` so # we just strip this part from `vn` to get a lens we can use # to extract the corresponding value of `m`. @@ -101,7 +101,7 @@ # Update the realizations in `svi_new`. svi_eval = svi_new - for vn in keys(model) + for vn in DynamicPPL.TestUtils.varnames(model) l = getlens(vn) svi_eval = DynamicPPL.setindex!!(svi_eval, get(m_eval, l), vn) end @@ -113,7 +113,7 @@ logπ = logjoint(model, svi_eval) # Values should not have changed. - for vn in keys(model) + for vn in DynamicPPL.TestUtils.varnames(model) l = getlens(vn) @test svi_eval[vn] == get(m_eval, l) end @@ -145,7 +145,7 @@ ) # Realizations from model should all be equal to the unconstrained realization. - for vn in keys(model) + for vn in DynamicPPL.TestUtils.varnames(model) @test get(retval_unconstrained, vn) ≈ svi[vn] rtol = 1e-6 end From 6350ccdc1be1958cc97e71847a2e44bf04b513c4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 11:52:58 +0100 Subject: [PATCH 068/101] renamed AbstractConstraint to AbstractTransformation and its subtypes --- src/simple_varinfo.jl | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 4ca25b2f2..20f86222a 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -1,7 +1,7 @@ -abstract type AbstractConstraint end +abstract type AbstractTransformation end -struct Constrained <: AbstractConstraint end -struct Unconstrained <: AbstractConstraint end +struct NoTransformation <: AbstractTransformation end +struct DefaultTransformation <: AbstractTransformation end """ $(TYPEDEF) @@ -86,7 +86,7 @@ ERROR: KeyError: key x[1:2] not found [...] ``` -You can also sample in _unconstrained_ space: +You can also sample in _transformed_ space: ```jldoctest simplevarinfo-general julia> @model demo_constrained() = x ~ Exponential() @@ -121,19 +121,19 @@ julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true ``` -Evaluation in unconstrained space of course also works: +Evaluation in transformed space of course also works: ```jldoctest simplevarinfo-general julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true) -Unconstrained SimpleVarInfo((x = -1.0,), 0.0) +Transformed SimpleVarInfo((x = -1.0,), 0.0) julia> # (✓) Positive probability mass on negative numbers! getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) -1.3678794411714423 -julia> # While if we forget to make indicate that it's unconstrained/transformed: +julia> # While if we forget to make indicate that it's transformed: vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false) -Constrained SimpleVarInfo((x = -1.0,), 0.0) +SimpleVarInfo((x = -1.0,), 0.0) julia> # (✓) No probability mass on negative numbers! getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) @@ -188,16 +188,16 @@ ERROR: type NamedTuple has no field b [...] ``` """ -struct SimpleVarInfo{NT,T,C<:AbstractConstraint} <: AbstractVarInfo +struct SimpleVarInfo{NT,T,C<:AbstractTransformation} <: AbstractVarInfo "underlying representation of the realization represented" values::NT "holds the accumulated log-probability" logp::T - "represents whether it assumes variables to be constrained or unconstrained" - constraint::C + "represents whether it assumes variables to be transformed" + transformation::C end -SimpleVarInfo(values, logp) = SimpleVarInfo(values, logp, Constrained()) +SimpleVarInfo(values, logp) = SimpleVarInfo(values, logp, NoTransformation()) function SimpleVarInfo{T}(θ) where {T<:Real} return SimpleVarInfo(θ, zero(T)) @@ -254,15 +254,15 @@ function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) end function Base.show( - io::IO, ::MIME"text/plain", svi::SimpleVarInfo{<:Any,<:Any,<:Constrained} + io::IO, ::MIME"text/plain", svi::SimpleVarInfo{<:Any,<:Any,<:NoTransformation} ) - return print(io, "Constrained SimpleVarInfo(", svi.values, ", ", svi.logp, ")") + return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ")") end function Base.show( - io::IO, ::MIME"text/plain", svi::SimpleVarInfo{<:Any,<:Any,<:Unconstrained} + io::IO, ::MIME"text/plain", svi::SimpleVarInfo{<:Any,<:Any,<:DefaultTransformation} ) - return print(io, "Unconstrained SimpleVarInfo(", svi.values, ", ", svi.logp, ")") + return print(io, "Transformed SimpleVarInfo(", svi.values, ", ", svi.logp, ")") end # `NamedTuple` @@ -516,14 +516,14 @@ increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing # NOTE: We don't implement `settrans!!(vi, trans, vn)`. function settrans!!(vi::SimpleVarInfo, trans) - return SimpleVarInfo(vi.values, vi.logp, trans ? Unconstrained() : Constrained()) + return SimpleVarInfo(vi.values, vi.logp, trans ? DefaultTransformation() : NoTransformation()) end function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) return Setfield.@set vi.varinfo = settrans!!(vi, trans) end -istrans(vi::SimpleVarInfo{<:Any,<:Any,<:Constrained}) = false -istrans(vi::SimpleVarInfo{<:Any,<:Any,<:Unconstrained}) = true +istrans(vi::SimpleVarInfo{<:Any,<:Any,<:NoTransformation}) = false +istrans(vi::SimpleVarInfo{<:Any,<:Any,<:DefaultTransformation}) = true istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi) istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) From 951e4c36ba67f518671744282d59f5dc7d044b68 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 11:58:49 +0100 Subject: [PATCH 069/101] updated tests --- test/loglikelihoods.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index 0e7f9a3d9..eaf1e00bd 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -2,7 +2,7 @@ @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS vi = VarInfo(m) - for vn in keys(m) + for vn in DynamicPPL.TestUtils.varnames(m) if vi[vn] isa Real vi = DynamicPPL.setindex!!(vi, 1.0, vn) else From dcd92c926f28cff980fc5a91623037269d8af36c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 13:51:36 +0100 Subject: [PATCH 070/101] fixed nodist usage --- src/context_implementations.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index a255eb02f..2bef931a4 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -318,14 +318,12 @@ function dot_tilde_assume( end function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi) - nodist = nodist.(right) - return dot_assume(nodist, left, vn, vi) + return dot_assume(nodist.(right), left, vn, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, vi ) - nodist = nodist.(right) - return dot_assume(rng, sampler, nodist, vn, left, vi) + return dot_assume(rng, sampler, nodist.(right), vn, left, vi) end # `PriorContext` From 2922ffa352a8a928b8898b0bcf788bdbb3fab82a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 14:09:24 +0100 Subject: [PATCH 071/101] fixed implementation of nodist --- src/context_implementations.jl | 4 ++-- src/distribution_wrappers.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 2bef931a4..0068546d5 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -318,12 +318,12 @@ function dot_tilde_assume( end function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi) - return dot_assume(nodist.(right), left, vn, vi) + return dot_assume(nodist(right), left, vn, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, vi ) - return dot_assume(rng, sampler, nodist.(right), vn, left, vi) + return dot_assume(rng, sampler, nodist(right), vn, left, vi) end # `PriorContext` diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index fb4ecab64..9761123c0 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -35,8 +35,8 @@ struct NoDist{variate,support,Td<:Distribution{variate,support}} <: end NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name) -# HACK(torfjelde): Useful to have a constructor we can use in broadcasting. nodist(dist::Distribution) = NoDist(dist) +nodist(dists::AbstractArray) = nodist.(dist) Base.length(dist::NoDist) = Base.length(dist.dist) Base.size(dist::NoDist) = Base.size(dist.dist) From 5266a4ba6fcc233887903a519f076e24c86ff35c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 14:16:04 +0100 Subject: [PATCH 072/101] fixed typo --- src/distribution_wrappers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index 9761123c0..65479f035 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -36,7 +36,7 @@ end NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name) nodist(dist::Distribution) = NoDist(dist) -nodist(dists::AbstractArray) = nodist.(dist) +nodist(dists::AbstractArray) = nodist.(dists) Base.length(dist::NoDist) = Base.length(dist.dist) Base.size(dist::NoDist) = Base.size(dist.dist) From 3c38710e138a9bf1d9f35e0500496df4b4f4c581 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 14:16:09 +0100 Subject: [PATCH 073/101] formatting --- src/simple_varinfo.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 20f86222a..b6f01fd08 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -516,7 +516,9 @@ increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing # NOTE: We don't implement `settrans!!(vi, trans, vn)`. function settrans!!(vi::SimpleVarInfo, trans) - return SimpleVarInfo(vi.values, vi.logp, trans ? DefaultTransformation() : NoTransformation()) + return SimpleVarInfo( + vi.values, vi.logp, trans ? DefaultTransformation() : NoTransformation() + ) end function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) return Setfield.@set vi.varinfo = settrans!!(vi, trans) From ba92f3f1b47ba414ea32736c7901d9e1bb368684 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 14:23:24 +0100 Subject: [PATCH 074/101] bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7adb220b3..bfa13d956 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.19.3" +version = "0.19.4" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From 70c864c1b910e6008d1e02f31f36f5cbd7d7fe6d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 14:37:58 +0100 Subject: [PATCH 075/101] fixed ThreadsafeVarInfo --- src/threadsafe.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index b42cf82a5..7c7dd13ac 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -62,11 +62,24 @@ islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl getindex(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex(vi.varinfo, spl) getindex(vi::ThreadSafeVarInfo, spl::SampleFromPrior) = getindex(vi.varinfo, spl) getindex(vi::ThreadSafeVarInfo, spl::SampleFromUniform) = getindex(vi.varinfo, spl) + getindex(vi::ThreadSafeVarInfo, vn::VarName) = getindex(vi.varinfo, vn) +function getindex(vi::ThreadSafeVarInfo, vn::VarName, dist::Distribution) + return getindex(vi.varinfo, vn, dist) +end getindex(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) = getindex(vi.varinfo, vns) +function getindex(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}, dist::Distribution) + return getindex(vi.varinfo, vns, dist) +end getindex_raw(vi::ThreadSafeVarInfo, vn::VarName) = getindex_raw(vi.varinfo, vn) +function getindex_raw(vi::ThreadSafeVarInfo, vn::VarName, dist::Distribution) + return getindex_raw(vi.varinfo, vn, dist) +end getindex_raw(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) = getindex_raw(vi.varinfo, vns) +function getindex_raw(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}, dist::Distribution) + return getindex_raw(vi.varinfo, vns, dist) +end function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::AbstractSampler) return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) From 66f41a936a3d3172fee9d0320d4eaeb6e78f5c84 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 17:08:15 +0100 Subject: [PATCH 076/101] Apply suggestions from code review Co-authored-by: David Widmann --- src/context_implementations.jl | 14 +++++++------- src/simple_varinfo.jl | 14 +++++++------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 0068546d5..742082346 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -291,7 +291,7 @@ function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left, var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!!.(Ref(vi), false, _vns) + settrans!!.((vi,), false, _vns) dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi) else dot_tilde_assume(LikelihoodContext(), right, left, vn, vi) @@ -310,7 +310,7 @@ function dot_tilde_assume( var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!!.(Ref(vi), false, _vns) + settrans!!.((vi,), false, _vns) dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi) else dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi) @@ -332,7 +332,7 @@ function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!!.(Ref(vi), false, _vns) + settrans!!.((vi,), false, _vns) dot_tilde_assume(PriorContext(), _right, _left, _vns, vi) else dot_tilde_assume(PriorContext(), right, left, vn, vi) @@ -351,7 +351,7 @@ function dot_tilde_assume( var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!!.(Ref(vi), false, _vns) + settrans!!.((vi,), false, _vns) dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi) else dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi) @@ -525,11 +525,11 @@ function get_and_set_val!( # 2. Define an anonymous function which returns `nothing`, which # we then broadcast. This will allocate a vector of `nothing` though. if istrans(vi) - push!!.(Ref(vi), vns, link.(Ref(vi), vns, dists, r), dists, Ref(spl)) + push!!.((vi,), vns, link.((vi,), vns, dists, r), dists, (spl,)) # `push!!` sets the trans-flag to `false` by default. - settrans!!.(Ref(vi), true, vns) + settrans!!.((vi,), true, vns) else - push!!.(Ref(vi), vns, r, dists, Ref(spl)) + push!!.((vi,), vns, r, dists, (spl,)) end end return r diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index b6f01fd08..74a0bac20 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -267,13 +267,13 @@ end # `NamedTuple` function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution) - return maybe_invlink(vi, vn, dist, Base.getindex(vi, vn)) + return maybe_invlink(vi, vn, dist, getindex(vi, vn)) end function Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution) - vals_linked = map(vns) do vn - maybe_invlink(vi, vn, dist, Base.getindex(vi, vn)) + vals_linked = mapreduce(vcat, vns) do vn + getindex(vi, vn, dist) end - return reconstruct(dist, reduce(vcat, vals_linked), length(vns)) + return reconstruct(dist, vals_linked, length(vns)) end Base.getindex(vi::SimpleVarInfo, vn::VarName) = get(vi.values, vn) @@ -326,9 +326,9 @@ function getindex_raw(vi::SimpleVarInfo, vn::VarName, dist::Distribution) end getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}) = vi[vns] function getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution) - vals = getindex_raw(vi, vns) # `reconstruct` expects a flattened `Vector` regardless of the type of `dist`, so we `vcat` everything. - return reconstruct(dist, reduce(vcat, vals), length(vns)) + vals = mapreduce(Base.Fix1(getindex_raw, vi), vcat, vns) + return reconstruct(dist, vals, length(vns)) end Base.haskey(vi::SimpleVarInfo, vn::VarName) = _haskey(vi.values, vn) @@ -482,7 +482,7 @@ function dot_assume( vi = BangBang.setindex!!(vi, value_raw, vns) # Compute logp. - lp = sum(Bijectors.logpdf_with_trans.(dists, value, map(Base.Fix1(istrans, vi), vns))) + lp = sum(Bijectors.logpdf_with_trans.(dists, value, istrans.((vi,), vns))) return value, lp, vi end From eb2d6b5ae2b4350f5df88c527ff95abaf62103f8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 17:37:32 +0100 Subject: [PATCH 077/101] allow type-stable settrans!! for SimpleVarInfo --- src/simple_varinfo.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 74a0bac20..dd546fb3f 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -516,16 +516,16 @@ increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing # NOTE: We don't implement `settrans!!(vi, trans, vn)`. function settrans!!(vi::SimpleVarInfo, trans) - return SimpleVarInfo( - vi.values, vi.logp, trans ? DefaultTransformation() : NoTransformation() - ) + return settrans!!(vi, trans ? DefaultTransformation() : NoTransformation()) +end +function settrans!!(vi::SimpleVarInfo, transformation::AbstractTransformation) + return Setfield.@set vi.transformation = transformation end function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) - return Setfield.@set vi.varinfo = settrans!!(vi, trans) + return Setfield.@set vi.varinfo = settrans!!(vi.varinfo, trans) end -istrans(vi::SimpleVarInfo{<:Any,<:Any,<:NoTransformation}) = false -istrans(vi::SimpleVarInfo{<:Any,<:Any,<:DefaultTransformation}) = true +istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi) istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) From e8cdb91e7096e1570ee447fcb98e089f220d136f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 17:45:01 +0100 Subject: [PATCH 078/101] use maybe_invlink in getindex for VarInfo --- src/varinfo.jl | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 582582b77..b7048b64b 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -941,11 +941,7 @@ getindex(vi::AbstractVarInfo, vn::VarName) = getindex(vi, vn, getdist(vi, vn)) function getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution) @assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" val = getindex_raw(vi, vn, dist) - return if istrans(vi, vn) - Bijectors.invlink(dist, val) - else - val - end + return maybe_invlink(vi, vn, dist, val) end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) return getindex(vi, vns, getdist(vi, first(vns))) @@ -953,11 +949,7 @@ end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" val = getindex_raw(vi, vns, dist) - return if istrans(vi, vns[1]) - Bijectors.invlink(dist, val) - else - val - end + return maybe_invlink.((vi,), vns, (dist,), val) end getindex_raw(vi::AbstractVarInfo, vn::VarName) = getindex_raw(vi, vn, getdist(vi, vn)) From 359d384eae0847e8271ab60493fdb0684137295b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 17:45:11 +0100 Subject: [PATCH 079/101] added comment to warn about buggy behavior --- src/varinfo.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/varinfo.jl b/src/varinfo.jl index b7048b64b..938fa6285 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -944,6 +944,8 @@ function getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution) return maybe_invlink(vi, vn, dist, val) end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) + # NOTE(torfjelde): Using `getdist(vi, first(vns))` won't be correct in cases + # such as `x .~ [Normal(), Exponential()]`. return getindex(vi, vns, getdist(vi, first(vns))) end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) From ab0a99bf8e5b86cce12a407b376b6ea197d98a22 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 17:46:51 +0100 Subject: [PATCH 080/101] Update src/context_implementations.jl Co-authored-by: David Widmann --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 742082346..4454701b8 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -514,7 +514,7 @@ function get_and_set_val!( else # r = reshape(vi[vec(vns)], size(vns)) r_raw = getindex_raw(vi, vec(vns)) - r = maybe_invlink.(Ref(vi), vns, dists, reshape(r_raw, size(vns))) + r = maybe_invlink.((vi,), vns, dists, reshape(r_raw, size(vns))) end else f = (vn, dist) -> init(rng, dist, spl) From dd109134bda4574ba1a9f4ec2cda8a44f40c2947 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 17:50:16 +0100 Subject: [PATCH 081/101] just fix potential bug in getindex for VarInfo --- src/varinfo.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 938fa6285..58f737785 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -944,9 +944,7 @@ function getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution) return maybe_invlink(vi, vn, dist, val) end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) - # NOTE(torfjelde): Using `getdist(vi, first(vns))` won't be correct in cases - # such as `x .~ [Normal(), Exponential()]`. - return getindex(vi, vns, getdist(vi, first(vns))) + return getindex.((vi,), vns, getdist.((vi,), vns)) end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" From 18d28ccc8647720551d6d15cf8e742474d5deeeb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 17:51:48 +0100 Subject: [PATCH 082/101] revert previous change because it likely introduces bugs --- src/varinfo.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 58f737785..938fa6285 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -944,7 +944,9 @@ function getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution) return maybe_invlink(vi, vn, dist, val) end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) - return getindex.((vi,), vns, getdist.((vi,), vns)) + # NOTE(torfjelde): Using `getdist(vi, first(vns))` won't be correct in cases + # such as `x .~ [Normal(), Exponential()]`. + return getindex(vi, vns, getdist(vi, first(vns))) end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" From 32b7aab6744c52af7ea4b01d9fa7916e38442a5c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 17:53:51 +0100 Subject: [PATCH 083/101] elaborate in comment regarding potential bug --- src/varinfo.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/varinfo.jl b/src/varinfo.jl index 938fa6285..4748643f3 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -946,6 +946,9 @@ end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) # NOTE(torfjelde): Using `getdist(vi, first(vns))` won't be correct in cases # such as `x .~ [Normal(), Exponential()]`. + # BUT we also can't fix this here because this will lead to "incorrect" + # behavior if `vns` arose from something like `x .~ MvNormal(zeros(2), I)`, + # where by "incorrect" we mean there exists pieces of code expecting this behavior. return getindex(vi, vns, getdist(vi, first(vns))) end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) From f782fe25b1b992e496e823a707688803f9c6eb99 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 17:55:31 +0100 Subject: [PATCH 084/101] added error message to dot_assume --- src/simple_varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index dd546fb3f..e7a389ee8 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -494,7 +494,7 @@ function dot_assume( var::AbstractMatrix, vi::SimpleOrThreadSafeSimple, ) - @assert length(dist) == size(var, 1) + @assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))" # r = get_and_set_val!(rng, vi, vns, dist, spl) n = length(vns) From 7d3493dc53c9f27adf93f498df5dc376f74c364d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 17:56:06 +0100 Subject: [PATCH 085/101] added error message to dot_assume again --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 4454701b8..6271b5d8c 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -389,7 +389,7 @@ function dot_assume( vns::AbstractVector{<:VarName}, vi::AbstractVarInfo, ) - @assert length(dist) == size(var, 1) + @assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))" # NOTE: We cannot work with `var` here because we might have a model of the form # # m = Vector{Float64}(undef, n) From f0f981b77c329b792bfc9d2f3786f454cc3fec23 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 00:56:21 +0100 Subject: [PATCH 086/101] added _protect_dists method to help with broadcasting of NoDist --- src/context_implementations.jl | 4 ++-- src/distribution_wrappers.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 6271b5d8c..eaf77f07f 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -318,12 +318,12 @@ function dot_tilde_assume( end function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi) - return dot_assume(nodist(right), left, vn, vi) + return dot_assume(NoDist.(_protect_dists(right)), left, vn, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, vi ) - return dot_assume(rng, sampler, nodist(right), vn, left, vi) + return dot_assume(rng, sampler, NoDist.(_protect_dists(right)), vn, left, vi) end # `PriorContext` diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index 65479f035..ee58b107a 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -35,8 +35,8 @@ struct NoDist{variate,support,Td<:Distribution{variate,support}} <: end NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name) -nodist(dist::Distribution) = NoDist(dist) -nodist(dists::AbstractArray) = nodist.(dists) +_protect_dists(x) = x +_protect_dists(x::Distribution) = tuple(x) Base.length(dist::NoDist) = Base.length(dist.dist) Base.size(dist::NoDist) = Base.size(dist.dist) From 1e0b946f0d240ca45db22e618118a3cbece5f080 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 00:58:59 +0100 Subject: [PATCH 087/101] simplified show for SimpleVarInfo --- src/simple_varinfo.jl | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index e7a389ee8..dc6eeae3e 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -254,15 +254,13 @@ function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) end function Base.show( - io::IO, ::MIME"text/plain", svi::SimpleVarInfo{<:Any,<:Any,<:NoTransformation} + io::IO, ::MIME"text/plain", svi::SimpleVarInfo ) - return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ")") -end + if !(svi.transform isa NoTransformation) + print(io, "Transformed ") + end -function Base.show( - io::IO, ::MIME"text/plain", svi::SimpleVarInfo{<:Any,<:Any,<:DefaultTransformation} -) - return print(io, "Transformed SimpleVarInfo(", svi.values, ", ", svi.logp, ")") + return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ")") end # `NamedTuple` From faa0e4210f2e5c060841ca67d469d76caf28e2b4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 00:59:22 +0100 Subject: [PATCH 088/101] styling --- src/simple_varinfo.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index dc6eeae3e..4fb0208c9 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -253,9 +253,7 @@ function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) return vi end -function Base.show( - io::IO, ::MIME"text/plain", svi::SimpleVarInfo -) +function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo) if !(svi.transform isa NoTransformation) print(io, "Transformed ") end From 9e7f493be1315b5faaef93f84b1adde23105357d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 01:06:16 +0100 Subject: [PATCH 089/101] fixed bug in show for SimpleVarInfo --- src/simple_varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 4fb0208c9..5b9edefdf 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -254,7 +254,7 @@ function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) end function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo) - if !(svi.transform isa NoTransformation) + if !(svi.transformation isa NoTransformation) print(io, "Transformed ") end From 0a9383b1c3609ac5f310e798e1d7b4119cb6b74a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 01:14:45 +0100 Subject: [PATCH 090/101] Revert "added _protect_dists method to help with broadcasting of NoDist" This reverts commit f0f981b77c329b792bfc9d2f3786f454cc3fec23. --- src/context_implementations.jl | 4 ++-- src/distribution_wrappers.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index eaf77f07f..6271b5d8c 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -318,12 +318,12 @@ function dot_tilde_assume( end function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi) - return dot_assume(NoDist.(_protect_dists(right)), left, vn, vi) + return dot_assume(nodist(right), left, vn, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, vi ) - return dot_assume(rng, sampler, NoDist.(_protect_dists(right)), vn, left, vi) + return dot_assume(rng, sampler, nodist(right), vn, left, vi) end # `PriorContext` diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index ee58b107a..65479f035 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -35,8 +35,8 @@ struct NoDist{variate,support,Td<:Distribution{variate,support}} <: end NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name) -_protect_dists(x) = x -_protect_dists(x::Distribution) = tuple(x) +nodist(dist::Distribution) = NoDist(dist) +nodist(dists::AbstractArray) = nodist.(dists) Base.length(dist::NoDist) = Base.length(dist.dist) Base.size(dist::NoDist) = Base.size(dist.dist) From d8b0a75a7d56baa75589327db40e2456c5be1dd4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 01:33:16 +0100 Subject: [PATCH 091/101] fixed getindex with vector of varnames for AbstractVarInfo --- src/varinfo.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 4748643f3..22728ba9a 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -944,7 +944,7 @@ function getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution) return maybe_invlink(vi, vn, dist, val) end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) - # NOTE(torfjelde): Using `getdist(vi, first(vns))` won't be correct in cases + # FIXME(torfjelde): Using `getdist(vi, first(vns))` won't be correct in cases # such as `x .~ [Normal(), Exponential()]`. # BUT we also can't fix this here because this will lead to "incorrect" # behavior if `vns` arose from something like `x .~ MvNormal(zeros(2), I)`, @@ -953,8 +953,10 @@ function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" - val = getindex_raw(vi, vns, dist) - return maybe_invlink.((vi,), vns, (dist,), val) + vals_linked = mapreduce(vcat, vns) do vn + getindex(vi, vn, dist) + end + return reconstruct(dist, vals_linked, length(vns)) end getindex_raw(vi::AbstractVarInfo, vn::VarName) = getindex_raw(vi, vn, getdist(vi, vn)) From 400f90fb1ffee7af54639371d8f5c3c3778e2ebf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 20 Jul 2022 09:39:19 +0200 Subject: [PATCH 092/101] Improvements to TestUtils (follow-up from #360) (#415) * added example_values and posterior_mean_values methods to models in TestUtils * demo models in TestUtils are now a bit more complex, including constrained variables * added logprior_true_with_logabsdet_jacobian for demo models * fixed mistakes in a couple of models in TestUtils * moved varnames method which creates iterator of leaf varnames into TestUtils and starting using this in test_continuous_models * updated docstring for test_sampler_demo_models * renamed varnames to varname_leaves and renamed keys(model) to varnames(model) * added test_sampler_on_models as a generalization of test_sampler_demo_models * updated docs * added docs for TestUtils.DEMO_MODELS * updated some tests * fixed docstrings * fixed docstrings * imprvoed docstring * improved docstrings * fixed tests of pointwise_loglikelihoods * Apply suggestions from code review Co-authored-by: David Widmann * renamed posterior_mean_values to posterior_mean * made demo models a bit more complex, now including different observations * Update docs/src/api.md Co-authored-by: David Widmann * reduce number of method definitions by defining some useful type unions in TestUtils * removed unnecessary method * fixed a couple of loglikelihood_true definitions * style * added tests for logprior and loglikelihood computation for SimpleVarInfo * fixed implementation of logpdf_with_trans for NoDist * removed unused variable * added test for transformed values for the logprior_true and loglikelihood_true methods * renamed test_sampler_on_models to test_sampler * updated docs * share implementation of example_values * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added marginal_mean_of_samples according to suggestions * removed example_values in favour of rand with NamedTuple * updated docs Co-authored-by: David Widmann Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/api.md | 22 +- src/distribution_wrappers.jl | 20 +- src/test_utils.jl | 581 ++++++++++++++++++++++++----------- test/contexts.jl | 26 +- test/loglikelihoods.jl | 22 +- test/simple_varinfo.jl | 62 ++-- 6 files changed, 498 insertions(+), 235 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 133b86e9b..809e6c49e 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -103,8 +103,14 @@ NamedDist DynamicPPL provides several demo models and helpers for testing samplers in the `DynamicPPL.TestUtils` submodule. ```@docs -DynamicPPL.TestUtils.test_sampler_demo_models +DynamicPPL.TestUtils.test_sampler +DynamicPPL.TestUtils.test_sampler_on_demo_models DynamicPPL.TestUtils.test_sampler_continuous +DynamicPPL.TestUtils.marginal_mean_of_samples +``` + +```@docs +DynamicPPL.TestUtils.DEMO_MODELS ``` For every demo model, one can define the true log prior, log likelihood, and log joint probabilities. @@ -115,6 +121,20 @@ DynamicPPL.TestUtils.loglikelihood_true DynamicPPL.TestUtils.logjoint_true ``` +And in the case where the model includes constrained variables, it can also be useful to define + +```@docs +DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian +DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian +``` + +Finally, the following methods can also be of use: + +```@docs +DynamicPPL.TestUtils.varnames +DynamicPPL.TestUtils.posterior_mean +``` + ## Advanced ### Variable names diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index 65479f035..d8968a68e 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -51,9 +51,21 @@ Distributions.logpdf(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0 Distributions.minimum(d::NoDist) = minimum(d.dist) Distributions.maximum(d::NoDist) = maximum(d.dist) -Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real) = 0 -Bijectors.logpdf_with_trans(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0 -function Bijectors.logpdf_with_trans(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}) +Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real, ::Bool) = 0 +function Bijectors.logpdf_with_trans( + d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}, ::Bool +) + return 0 +end +function Bijectors.logpdf_with_trans( + d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}, ::Bool +) return zeros(Int, size(x, 2)) end -Bijectors.logpdf_with_trans(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0 +function Bijectors.logpdf_with_trans( + d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}, ::Bool +) + return 0 +end + +Bijectors.bijector(d::NoDist) = Bijectors.bijector(d.dist) diff --git a/src/test_utils.jl b/src/test_utils.jl index ea509e2de..ef314fa91 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -6,12 +6,35 @@ using LinearAlgebra using Distributions using Test +using Random: Random using Bijectors: Bijectors +using Setfield: Setfield """ - logprior_true(model, θ) + varname_leaves(vn::VarName, val) -Return the `logprior` of `model` for `θ`. +Return iterator over all varnames that are represented by `vn` on `val`, +e.g. `varname_leaves(@varname(x), rand(2))` results in an iterator over `[@varname(x[1]), @varname(x[2])]`. +""" +varname_leaves(vn::VarName, val::Real) = [vn] +function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) + return ( + VarName(vn, DynamicPPL.getlens(vn) ∘ Setfield.IndexLens(Tuple(I))) for + I in CartesianIndices(val) + ) +end +function varname_leaves(vn::VarName, val::AbstractArray) + return Iterators.flatten( + varname_leaves( + VarName(vn, DynamicPPL.getlens(vn) ∘ Setfield.IndexLens(Tuple(I))), val[I] + ) for I in CartesianIndices(val) + ) +end + +""" + logprior_true(model, args...) + +Return the `logprior` of `model` for `args`. This should generally be implemented by hand for every specific `model`. @@ -20,9 +43,9 @@ See also: [`logjoint_true`](@ref), [`loglikelihood_true`](@ref). function logprior_true end """ - loglikelihood_true(model, θ) + loglikelihood_true(model, args...) -Return the `loglikelihood` of `model` for `θ`. +Return the `loglikelihood` of `model` for `args`. This should generally be implemented by hand for every specific `model`. @@ -33,7 +56,7 @@ function loglikelihood_true end """ logjoint_true(model, args...) -Return the `logjoint` of `model` for `args...`. +Return the `logjoint` of `model` for `args`. Defaults to `logprior_true(model, args...) + loglikelihood_true(model, args..)`. @@ -54,7 +77,7 @@ end """ logjoint_true_with_logabsdet_jacobian(model::Model, args...) -Return a tuple `(args_unconstrained, logjoint)` of `model` for `args...`. +Return a tuple `(args_unconstrained, logjoint)` of `model` for `args`. Unlike [`logjoint_true`](@ref), the returned logjoint computation includes the log-absdet-jacobian adjustment, thus computing logjoint for the unconstrained variables. @@ -101,6 +124,17 @@ function varnames(model::Model) ) end +""" + posterior_mean(model::Model) + +Return a `NamedTuple` compatible with `varnames(model)` where the values represent +the posterior mean under `model`. + +"Compatible" means that a `varname` from `varnames(model)` can be used to extract the +corresponding value using `get`, e.g. `get(posterior_mean(model), varname)`. +""" +function posterior_mean end + """ demo_dynamic_constraint() @@ -122,7 +156,6 @@ end function varnames(model::Model{typeof(demo_dynamic_constraint)}) return [@varname(m), @varname(x)] end - function logprior_true_with_logabsdet_jacobian( model::Model{typeof(demo_dynamic_constraint)}, m, x ) @@ -131,278 +164,471 @@ function logprior_true_with_logabsdet_jacobian( return (m=m, x=x_unconstrained), logprior_true(model, m, x) - Δlogp end -# A collection of models for which the mean-of-means for the posterior should -# be same. +# A collection of models for which the posterior should be "similar". +# Some utility methods for these. +function _demo_logprior_true_with_logabsdet_jacobian(model, s, m) + b = Bijectors.bijector(InverseGamma(2, 3)) + s_unconstrained = b.(s) + Δlogp = sum(Base.Fix1(Bijectors.logabsdetjac, b), s) + return (s=s_unconstrained, m=m), logprior_true(model, s, m) - Δlogp +end + @model function demo_dot_assume_dot_observe( - x=[10.0, 10.0], ::Type{TV}=Vector{Float64} + x=[1.5, 2.0], ::Type{TV}=Vector{Float64} ) where {TV} # `dot_assume` and `observe` + s = TV(undef, length(x)) m = TV(undef, length(x)) - m .~ Normal() - x ~ MvNormal(m, 0.25 * I) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + s .~ InverseGamma(2, 3) + m .~ Normal.(0, sqrt.(s)) + + x ~ MvNormal(m, Diagonal(s)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) +end +function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe)}, s, m) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe)}, m) - return loglikelihood(Normal(), m) +function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe)}, s, m) + return loglikelihood(MvNormal(m, Diagonal(s)), model.args.x) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe)}, m) - return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_dot_observe)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_dot_observe)}) - return [@varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end @model function demo_assume_index_observe( - x=[10.0, 10.0], ::Type{TV}=Vector{Float64} + x=[1.5, 2.0], ::Type{TV}=Vector{Float64} ) where {TV} # `assume` with indexing and `observe` + s = TV(undef, length(x)) + for i in eachindex(s) + s[i] ~ InverseGamma(2, 3) + end m = TV(undef, length(x)) for i in eachindex(m) - m[i] ~ Normal() + m[i] ~ Normal(0, sqrt(s[i])) end - x ~ MvNormal(m, 0.25 * I) + x ~ MvNormal(m, Diagonal(s)) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) +end +function logprior_true(model::Model{typeof(demo_assume_index_observe)}, s, m) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function logprior_true(model::Model{typeof(demo_assume_index_observe)}, m) - return loglikelihood(Normal(), m) +function loglikelihood_true(model::Model{typeof(demo_assume_index_observe)}, s, m) + return logpdf(MvNormal(m, Diagonal(s)), model.args.x) end -function loglikelihood_true(model::Model{typeof(demo_assume_index_observe)}, m) - return logpdf(MvNormal(m, 0.25 * I), model.args.x) +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_index_observe)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_assume_index_observe)}) - return [@varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end -@model function demo_assume_multivariate_observe(x=[10.0, 10.0]) +@model function demo_assume_multivariate_observe(x=[1.5, 2.0]) # Multivariate `assume` and `observe` - m ~ MvNormal(zero(x), I) - x ~ MvNormal(m, 0.25 * I) + s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) + m ~ MvNormal(zero(x), Diagonal(s)) + x ~ MvNormal(m, Diagonal(s)) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) +end +function logprior_true(model::Model{typeof(demo_assume_multivariate_observe)}, s, m) + s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) + m_dist = MvNormal(zero(model.args.x), Diagonal(s)) + return logpdf(s_dist, s) + logpdf(m_dist, m) end -function logprior_true(model::Model{typeof(demo_assume_multivariate_observe)}, m) - return logpdf(MvNormal(zero(model.args.x), I), m) +function loglikelihood_true(model::Model{typeof(demo_assume_multivariate_observe)}, s, m) + return logpdf(MvNormal(m, Diagonal(s)), model.args.x) end -function loglikelihood_true(model::Model{typeof(demo_assume_multivariate_observe)}, m) - return logpdf(MvNormal(m, 0.25 * I), model.args.x) +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_multivariate_observe)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_assume_multivariate_observe)}) - return [@varname(m)] + return [@varname(s), @varname(m)] end @model function demo_dot_assume_observe_index( - x=[10.0, 10.0], ::Type{TV}=Vector{Float64} + x=[1.5, 2.0], ::Type{TV}=Vector{Float64} ) where {TV} # `dot_assume` and `observe` with indexing + s = TV(undef, length(x)) + s .~ InverseGamma(2, 3) m = TV(undef, length(x)) - m .~ Normal() + m .~ Normal.(0, sqrt.(s)) for i in eachindex(x) - x[i] ~ Normal(m[i], 0.5) + x[i] ~ Normal(m[i], sqrt(s[i])) end - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_observe_index)}, m) - return loglikelihood(Normal(), m) +function logprior_true(model::Model{typeof(demo_dot_assume_observe_index)}, s, m) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index)}, m) - return sum(logpdf.(Normal.(m, 0.5), model.args.x)) +function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index)}, s, m) + return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_observe_index)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_observe_index)}) - return [@varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end # Using vector of `length` 1 here so the posterior of `m` is the same # as the others. -@model function demo_assume_dot_observe(x=[10.0]) +@model function demo_assume_dot_observe(x=[1.5, 2.0]) # `assume` and `dot_observe` - m ~ Normal() - x .~ Normal(m, 0.5) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x .~ Normal(m, sqrt(s)) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_dot_observe)}, m) - return logpdf(Normal(), m) +function logprior_true(model::Model{typeof(demo_assume_dot_observe)}, s, m) + return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) end -function loglikelihood_true(model::Model{typeof(demo_assume_dot_observe)}, m) - return sum(logpdf.(Normal.(m, 0.5), model.args.x)) +function loglikelihood_true(model::Model{typeof(demo_assume_dot_observe)}, s, m) + return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_dot_observe)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_assume_dot_observe)}) - return [@varname(m)] + return [@varname(s), @varname(m)] end @model function demo_assume_observe_literal() # `assume` and literal `observe` - m ~ MvNormal(zeros(2), I) - [10.0, 10.0] ~ MvNormal(m, 0.25 * I) + s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) + m ~ MvNormal(zeros(2), Diagonal(s)) + [1.5, 2.0] ~ MvNormal(m, Diagonal(s)) - return (; m=m, x=[10.0, 10.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) +end +function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) + s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) + m_dist = MvNormal(zeros(2), Diagonal(s)) + return logpdf(s_dist, s) + logpdf(m_dist, m) end -function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, m) - return logpdf(MvNormal(zeros(2), I), m) +function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) + return logpdf(MvNormal(m, Diagonal(s)), [1.5, 2.0]) end -function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, m) - return logpdf(MvNormal(m, 0.25 * I), [10.0, 10.0]) +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_observe_literal)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_assume_observe_literal)}) - return [@varname(m)] + return [@varname(s), @varname(m)] end @model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and literal `observe` with indexing + s = TV(undef, 2) m = TV(undef, 2) - m .~ Normal() - for i in eachindex(m) - 10.0 ~ Normal(m[i], 0.5) - end + s .~ InverseGamma(2, 3) + m .~ Normal.(0, sqrt.(s)) + + 1.5 ~ Normal(m[1], sqrt(s[1])) + 2.0 ~ Normal(m[2], sqrt(s[2])) - return (; m=m, x=fill(10.0, length(m)), logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) +end +function logprior_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, s, m) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function logprior_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, m) - return loglikelihood(Normal(), m) +function loglikelihood_true( + model::Model{typeof(demo_dot_assume_observe_index_literal)}, s, m +) + return sum(logpdf.(Normal.(m, sqrt.(s)), [1.5, 2.0])) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, m) - return sum(logpdf.(Normal.(m, 0.5), fill(10.0, length(m)))) +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_observe_index_literal)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_observe_index_literal)}) - return [@varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end @model function demo_assume_literal_dot_observe() # `assume` and literal `dot_observe` - m ~ Normal() - [10.0] .~ Normal(m, 0.5) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + [1.5, 2.0] .~ Normal(m, sqrt(s)) - return (; m=m, x=[10.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) +end +function logprior_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m) + return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) end -function logprior_true(model::Model{typeof(demo_assume_literal_dot_observe)}, m) - return logpdf(Normal(), m) +function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m) + return loglikelihood(Normal(m, sqrt(s)), [1.5, 2.0]) end -function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, m) - return logpdf(Normal(m, 0.5), 10.0) +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_literal_dot_observe)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_assume_literal_dot_observe)}) - return [@varname(m)] + return [@varname(s), @varname(m)] end @model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} + s = TV(undef, 2) + s .~ InverseGamma(2, 3) m = TV(undef, 2) - m .~ Normal() + m .~ Normal.(0, sqrt.(s)) - return m + return s, m end @model function demo_assume_submodel_observe_index_literal() # Submodel prior - @submodel m = _prior_dot_assume() - for i in eachindex(m) - 10.0 ~ Normal(m[i], 0.5) - end + @submodel s, m = _prior_dot_assume() + 1.5 ~ Normal(m[1], sqrt(s[1])) + 2.0 ~ Normal(m[2], sqrt(s[2])) - return (; m=m, x=[10.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_submodel_observe_index_literal)}, m) - return loglikelihood(Normal(), m) +function logprior_true( + model::Model{typeof(demo_assume_submodel_observe_index_literal)}, s, m +) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end function loglikelihood_true( - model::Model{typeof(demo_assume_submodel_observe_index_literal)}, m + model::Model{typeof(demo_assume_submodel_observe_index_literal)}, s, m ) - return sum(logpdf.(Normal.(m, 0.5), 10.0)) + return sum(logpdf.(Normal.(m, sqrt.(s)), [1.5, 2.0])) +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_submodel_observe_index_literal)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) - return [@varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end -@model function _likelihood_dot_observe(m, x) - return x ~ MvNormal(m, 0.25 * I) +@model function _likelihood_mltivariate_observe(s, m, x) + return x ~ MvNormal(m, Diagonal(s)) end @model function demo_dot_assume_observe_submodel( - x=[10.0, 10.0], ::Type{TV}=Vector{Float64} + x=[1.5, 2.0], ::Type{TV}=Vector{Float64} ) where {TV} + s = TV(undef, length(x)) + s .~ InverseGamma(2, 3) m = TV(undef, length(x)) - m .~ Normal() + m .~ Normal.(0, sqrt.(s)) # Submodel likelihood - @submodel _likelihood_dot_observe(m, x) + @submodel _likelihood_mltivariate_observe(s, m, x) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, m) - return loglikelihood(Normal(), m) +function logprior_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, s, m) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, m) - return logpdf(MvNormal(m, 0.25 * I), model.args.x) +function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, s, m) + return logpdf(MvNormal(m, Diagonal(s)), model.args.x) +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_observe_submodel)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_observe_submodel)}) - return [@varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end @model function demo_dot_assume_dot_observe_matrix( - x=fill(10.0, 2, 1), ::Type{TV}=Vector{Float64} + x=transpose([1.5 2.0;]), ::Type{TV}=Vector{Float64} ) where {TV} + s = TV(undef, length(x)) + s .~ InverseGamma(2, 3) m = TV(undef, length(x)) - m .~ Normal() + m .~ Normal.(0, sqrt.(s)) # Dotted observe for `Matrix`. - x .~ MvNormal(m, 0.25 * I) + x .~ MvNormal(m, Diagonal(s)) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) +end +function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, m) - return loglikelihood(Normal(), m) +function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m) + return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, m) - return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) - return [@varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end @model function demo_dot_assume_matrix_dot_observe_matrix( - x=fill(10.0, 2, 1), ::Type{TV}=Array{Float64} + x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64} ) where {TV} + n = length(x) d = length(x) ÷ 2 - m = TV(undef, d, 2) - m .~ MvNormal(zeros(d), I) + s = TV(undef, d, 2) + s .~ product_distribution([InverseGamma(2, 3) for _ in 1:d]) + s_vec = vec(s) + m ~ MvNormal(zeros(n), Diagonal(s_vec)) # Dotted observe for `Matrix`. - x .~ MvNormal(vec(m), 0.25 * I) + x .~ MvNormal(m, Diagonal(s_vec)) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, m) - return loglikelihood(Normal(), vec(m)) +function logprior_true( + model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m +) + n = length(model.args.x) + s_vec = vec(s) + return loglikelihood(InverseGamma(2, 3), s_vec) + + logpdf(MvNormal(zeros(n), Diagonal(s_vec)), m) end function loglikelihood_true( - model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, m + model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m ) - return loglikelihood(MvNormal(vec(m), 0.25 * I), model.args.x) + return loglikelihood(MvNormal(m, Diagonal(vec(s))), model.args.x) end -function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) - return [@varname(m[:, 1]), @varname(m[:, 2])] +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end +function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) + return [@varname(s[:, 1]), @varname(s[:, 2]), @varname(m)] +end + +const DemoModels = Union{ + Model{typeof(demo_dot_assume_dot_observe)}, + Model{typeof(demo_assume_index_observe)}, + Model{typeof(demo_assume_multivariate_observe)}, + Model{typeof(demo_dot_assume_observe_index)}, + Model{typeof(demo_assume_dot_observe)}, + Model{typeof(demo_assume_literal_dot_observe)}, + Model{typeof(demo_assume_observe_literal)}, + Model{typeof(demo_dot_assume_observe_index_literal)}, + Model{typeof(demo_assume_submodel_observe_index_literal)}, + Model{typeof(demo_dot_assume_observe_submodel)}, + Model{typeof(demo_dot_assume_dot_observe_matrix)}, + Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, +} + +# We require demo models to have explict impleentations of `rand` since we want +# these to be considered as ground truth. +function Random.rand(rng::Random.AbstractRNG, ::Type{NamedTuple}, model::DemoModels) + return error("demo models requires explicit implementation of rand") +end + +const UnivariateAssumeDemoModels = Union{ + Model{typeof(demo_assume_dot_observe)},Model{typeof(demo_assume_literal_dot_observe)} +} +function posterior_mean(model::UnivariateAssumeDemoModels) + return (s=49 / 24, m=7 / 6) +end +function Random.rand( + rng::Random.AbstractRNG, ::Type{NamedTuple}, model::UnivariateAssumeDemoModels +) + s = rand(rng, InverseGamma(2, 3)) + m = rand(rng, Normal(0, sqrt(s))) + + return (s=s, m=m) +end + +const MultivariateAssumeDemoModels = Union{ + Model{typeof(demo_dot_assume_dot_observe)}, + Model{typeof(demo_assume_index_observe)}, + Model{typeof(demo_assume_multivariate_observe)}, + Model{typeof(demo_dot_assume_observe_index)}, + Model{typeof(demo_assume_observe_literal)}, + Model{typeof(demo_dot_assume_observe_index_literal)}, + Model{typeof(demo_assume_submodel_observe_index_literal)}, + Model{typeof(demo_dot_assume_observe_submodel)}, + Model{typeof(demo_dot_assume_dot_observe_matrix)}, + Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, +} +function posterior_mean(model::MultivariateAssumeDemoModels) + # Get some containers to fill. + vals = Random.rand(model) + + vals.s[1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s[2] = 8 / 3 + vals.m[2] = 1 + + return vals +end +function Random.rand( + rng::Random.AbstractRNG, ::Type{NamedTuple}, model::MultivariateAssumeDemoModels +) + # Get template values from `model`. + retval = model(rng) + vals = (s=retval.s, m=retval.m) + # Fill containers with realizations from prior. + for i in LinearIndices(vals.s) + vals.s[i] = rand(rng, InverseGamma(2, 3)) + vals.m[i] = rand(rng, Normal(0, sqrt(vals.s[i]))) + end -@model function demo_dot_assume_array_dot_observe( - x=[10.0, 10.0], ::Type{TV}=Vector{Float64} -) where {TV} - # `dot_assume` and `observe` - m = TV(undef, length(x)) - m .~ [Normal() for _ in 1:length(x)] - x ~ MvNormal(m, 0.25 * I) - return (; m=m, x=x, logp=getlogp(__varinfo__)) -end -function logprior_true(model::Model{typeof(demo_dot_assume_array_dot_observe)}, m) - return loglikelihood(Normal(), m) -end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_array_dot_observe)}, m) - return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) -end -function varnames(model::Model{typeof(demo_dot_assume_array_dot_observe)}) - return [@varname(m[1]), @varname(m[2])] + return vals end +""" +A collection of models corresponding to the posterior distribution defined by +the generative process + + s ~ InverseGamma(2, 3) + m ~ Normal(0, √s) + 1.5 ~ Normal(m, √s) + 2.0 ~ Normal(m, √s) + +or by + + s[1] ~ InverseGamma(2, 3) + s[2] ~ InverseGamma(2, 3) + m[1] ~ Normal(0, √s) + m[2] ~ Normal(0, √s) + 1.5 ~ Normal(m[1], √s[1]) + 2.0 ~ Normal(m[2], √s[2]) + +These are examples of a Normal-InverseGamma conjugate prior with Normal likelihood, +for which the posterior is known in closed form. + +In particular, for the univariate model (the former one): + + mean(s) == 49 / 24 + mean(m) == 7 / 6 + +And for the multivariate one (the latter one): + + mean(s[1]) == 19 / 8 + mean(m[1]) == 3 / 4 + mean(s[2]) == 8 / 3 + mean(m[2]) == 1 + +""" const DEMO_MODELS = ( demo_dot_assume_dot_observe(), demo_assume_index_observe(), @@ -416,65 +642,78 @@ const DEMO_MODELS = ( demo_dot_assume_observe_submodel(), demo_dot_assume_dot_observe_matrix(), demo_dot_assume_matrix_dot_observe_matrix(), - demo_dot_assume_array_dot_observe(), ) -# TODO: Is this really the best/most convenient "default" test method? """ - test_sampler_demo_models(meanfunction, sampler, args...; kwargs...) + marginal_mean_of_samples(chain, varname) -Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`. +Return the mean of variable represented by `varname` in `chain`. +""" +marginal_mean_of_samples(chain, varname) = mean(Array(chain[Symbol(varname)])) + +""" + test_sampler(models, sampler, args...; kwargs...) -In short, this method iterators through `demo_models`, calls `AbstractMCMC.sample` on the -`model` and `sampler` to produce a `chain`, and then checks `meanfunction(chain)` against `target` -provided in `kwargs...`. +Test that `sampler` produces correct marginal posterior means on each model in `models`. + +In short, this method iterates through `models`, calls `AbstractMCMC.sample` on the +`model` and `sampler` to produce a `chain`, and then checks `marginal_mean_of_samples(chain, vn)` +for every (leaf) varname `vn` against the corresponding value returned by +[`posterior_mean`](@ref) for each model. + +To change how comparison is done for a particular `chain` type, one can overload +[`marginal_mean_of_samples`](@ref) for the corresponding type. # Arguments -- `meanfunction`: A callable which computes the mean of the marginal means from the - chain resulting from the `sample` call. +- `models`: A collection of instaces of [`DynamicPPL.Model`](@ref) to test on. - `sampler`: The `AbstractMCMC.AbstractSampler` to test. - `args...`: Arguments forwarded to `sample`. # Keyword arguments -- `target`: Value to compare result of `meanfunction(chain)` to. - `atol=1e-1`: Absolute tolerance used in `@test`. - `rtol=1e-3`: Relative tolerance used in `@test`. - `kwargs...`: Keyword arguments forwarded to `sample`. """ -function test_sampler_demo_models( - meanfunction, - sampler::AbstractMCMC.AbstractSampler, - args...; - target=8.0, - atol=1e-1, - rtol=1e-3, - kwargs..., +function test_sampler( + models, sampler::AbstractMCMC.AbstractSampler, args...; atol=1e-1, rtol=1e-3, kwargs... ) - @testset "$(nameof(typeof(sampler))) on $(nameof(m))" for model in DEMO_MODELS + @testset "$(typeof(sampler)) on $(nameof(model))" for model in models chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) - μ = meanfunction(chain) - @test μ ≈ target atol = atol rtol = rtol + target_values = posterior_mean(model) + for vn in varnames(model) + # We want to compare elementwise which can be achieved by + # extracting the leaves of the `VarName` and the corresponding value. + for vn_leaf in varname_leaves(vn, get(target_values, vn)) + target_value = get(target_values, vn_leaf) + chain_mean_value = marginal_mean_of_samples(chain, vn_leaf) + @test chain_mean_value ≈ target_value atol = atol rtol = rtol + end + end end end """ - test_sampler_continuous([meanfunction, ]sampler, args...; kwargs...) + test_sampler_on_demo_models(meanfunction, sampler, args...; kwargs...) -Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`. +Test `sampler` on every model in [`DEMO_MODELS`](@ref). -As of right now, this is just an alias for [`test_sampler_demo_models`](@ref). +This is just a proxy for `test_sampler(meanfunction, DEMO_MODELS, sampler, args...; kwargs...)`. """ -function test_sampler_continuous( - meanfunction, sampler::AbstractMCMC.AbstractSampler, args...; kwargs... +function test_sampler_on_demo_models( + sampler::AbstractMCMC.AbstractSampler, args...; kwargs... ) - return test_sampler_demo_models(meanfunction, sampler, args...; kwargs...) + return test_sampler(DEMO_MODELS, sampler, args...; kwargs...) end +""" + test_sampler_continuous(sampler, args...; kwargs...) + +Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`. + +As of right now, this is just an alias for [`test_sampler_on_demo_models`](@ref). +""" function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...) - # Default for `MCMCChains.Chains`. - return test_sampler_continuous(sampler, args...; kwargs...) do chain - mean(Array(chain)) - end + return test_sampler_on_demo_models(sampler, args...; kwargs...) end end diff --git a/test/contexts.jl b/test/contexts.jl index 65629afec..edcf5d0f3 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -57,26 +57,6 @@ function remove_prefix(vn::VarName) ) end -""" - varnames(vn::VarName, val) - -Return iterator over all varnames that are represented by `vn` on `val`, -e.g. `varnames(@varname(x), rand(2))` results in an iterator over `[@varname(x[1]), @varname(x[2])]`. -""" -varnames(vn::VarName, val::Real) = [vn] -function varnames(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) - return ( - VarName(vn, getlens(vn) ∘ Setfield.IndexLens(Tuple(I))) for - I in CartesianIndices(val) - ) -end -function varnames(vn::VarName, val::AbstractArray) - return Iterators.flatten( - varnames(VarName(vn, getlens(vn) ∘ Setfield.IndexLens(Tuple(I))), val[I]) for - I in CartesianIndices(val) - ) -end - @testset "contexts.jl" begin child_contexts = [DefaultContext(), PriorContext(), LikelihoodContext()] @@ -185,7 +165,8 @@ end vn_without_prefix = remove_prefix(vn) # Let's check elementwise. - for vn_child in varnames(vn_without_prefix, val) + for vn_child in + DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) if get(val, getlens(vn_child)) === missing @test contextual_isassumption(context, vn_child) else @@ -217,7 +198,8 @@ end # `ConditionContext` with the conditioned variable. vn_without_prefix = remove_prefix(vn) - for vn_child in varnames(vn_without_prefix, val) + for vn_child in + DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) # `vn_child` should be in `context`. @test hasvalue_nested(context, vn_child) # Value should be the same as extracted above. diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index eaf1e00bd..b390997af 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -1,15 +1,14 @@ @testset "loglikelihoods.jl" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - vi = VarInfo(m) + example_values = rand(NamedTuple, m) + # Instantiate a `VarInfo` with the example values. + vi = VarInfo(m) for vn in DynamicPPL.TestUtils.varnames(m) - if vi[vn] isa Real - vi = DynamicPPL.setindex!!(vi, 1.0, vn) - else - vi = DynamicPPL.setindex!!(vi, ones(size(vi[vn])), vn) - end + vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) end + # Compute the pointwise loglikelihoods. lls = pointwise_loglikelihoods(m, vi) if isempty(lls) @@ -17,14 +16,9 @@ continue end - loglikelihood = if length(keys(lls)) == 1 && length(m.args.x) == 1 - # Only have one observation, so we need to double it - # for comparison with other models. - 2 * sum(lls[first(keys(lls))]) - else - sum(sum, values(lls)) - end + loglikelihood = sum(sum, values(lls)) + loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(m, example_values...) - @test loglikelihood ≈ -324.45158270528947 + @test loglikelihood ≈ loglikelihood_true end end diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 9847d6959..6a8c545ca 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -62,18 +62,18 @@ DynamicPPL.TestUtils.DEMO_MODELS # We might need to pre-allocate for the variable `m`, so we need # to see whether this is the case. - m = model().m - svi_nt = if m isa AbstractArray - SimpleVarInfo((m=similar(m),)) - else - SimpleVarInfo() - end + svi_nt = SimpleVarInfo(rand(NamedTuple, model)) svi_dict = SimpleVarInfo(VarInfo(model), Dict) - @testset "$(nameof(typeof(svi.values)))" for svi in (svi_nt, svi_dict) + @testset "$(nameof(typeof(DynamicPPL.values_as(svi))))" for svi in ( + svi_nt, + svi_dict, + DynamicPPL.settrans!!(svi_nt, true), + DynamicPPL.settrans!!(svi_dict, true), + ) # Random seed is set in each `@testset`, so we need to sample # a new realization for `m` here. - m = model().m + retval = model() ### Sampling ### # Sample a new varinfo! @@ -81,29 +81,43 @@ # Realization for `m` should be different wp. 1. for vn in DynamicPPL.TestUtils.varnames(model) - # `VarName` functions similarly to `PropertyLens` so - # we just strip this part from `vn` to get a lens we can use - # to extract the corresponding value of `m`. - l = getlens(vn) - @test svi_new[vn] != get(m, l) + @test svi_new[vn] != get(retval, vn) end # Logjoint should be non-zero wp. 1. @test getlogp(svi_new) != 0 ### Evaluation ### - # Sample some random testing values. - m_eval = if m isa AbstractArray - randn!(similar(m)) + values_eval_constrained = rand(NamedTuple, model) + if DynamicPPL.istrans(svi) + _values_prior, logpri_true = DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian( + model, values_eval_constrained... + ) + values_eval, logπ_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + model, values_eval_constrained... + ) + # Make sure that these two computation paths provide the same + # transformed values. + @test values_eval == _values_prior else - randn(eltype(m)) + logpri_true = DynamicPPL.TestUtils.logprior_true( + model, values_eval_constrained... + ) + logπ_true = DynamicPPL.TestUtils.logjoint_true( + model, values_eval_constrained... + ) + values_eval = values_eval_constrained end + # No logabsdet-jacobian correction needed for the likelihood. + loglik_true = DynamicPPL.TestUtils.loglikelihood_true( + model, values_eval_constrained... + ) + # Update the realizations in `svi_new`. svi_eval = svi_new for vn in DynamicPPL.TestUtils.varnames(model) - l = getlens(vn) - svi_eval = DynamicPPL.setindex!!(svi_eval, get(m_eval, l), vn) + svi_eval = DynamicPPL.setindex!!(svi_eval, get(values_eval, vn), vn) end # Reset the logp field. @@ -111,15 +125,17 @@ # Compute `logjoint` using the varinfo. logπ = logjoint(model, svi_eval) + logpri = logprior(model, svi_eval) + loglik = loglikelihood(model, svi_eval) # Values should not have changed. for vn in DynamicPPL.TestUtils.varnames(model) - l = getlens(vn) - @test svi_eval[vn] == get(m_eval, l) + @test svi_eval[vn] == get(values_eval, vn) end - # Compute the true `logjoint` and compare. - logπ_true = DynamicPPL.TestUtils.logjoint_true(model, m_eval) + # Compare log-probability computations. + @test logpri ≈ logpri_true + @test loglik ≈ loglik_true @test logπ ≈ logπ_true end end From 9241acd1246883a99e8f33d9c0f4c93aef0a07bf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 21 Jul 2022 09:09:58 +0100 Subject: [PATCH 093/101] fixed tests for distribution_wrappers --- test/distribution_wrappers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/distribution_wrappers.jl b/test/distribution_wrappers.jl index 350ce6014..8bb692783 100644 --- a/test/distribution_wrappers.jl +++ b/test/distribution_wrappers.jl @@ -9,5 +9,5 @@ @test minimum(nd) == -Inf @test maximum(nd) == Inf @test logpdf(nd, 15.0) == 0 - @test Bijectors.logpdf_with_trans(nd, 0) == 0 + @test Bijectors.logpdf_with_trans(nd, 30, true) == 0 end From 947e5c6ca4ff19ad09ccfe9153b680bffd9f3ee2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 21 Jul 2022 13:18:47 +0100 Subject: [PATCH 094/101] upper bound Distributions because tests are sooooo slow due to deprecations --- test/Project.toml | 2 +- test/turing/Project.toml | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 6a3726474..edba29e5a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -23,7 +23,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" AbstractMCMC = "2.1, 3.0, 4" AbstractPPL = "0.5.1" Bijectors = "0.9.5, 0.10" -Distributions = "0.25" +Distributions = "<0.25.65" DistributionsAD = "0.6.3" Documenter = "0.26.1, 0.27" ForwardDiff = "0.10.12" diff --git a/test/turing/Project.toml b/test/turing/Project.toml index 41a873172..e8713dae2 100644 --- a/test/turing/Project.toml +++ b/test/turing/Project.toml @@ -1,10 +1,12 @@ [deps] +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] +Distributions = "<0.25.65" DynamicPPL = "0.19" Turing = "0.21" julia = "1.3" From 6f9be0d742a798ac5c7e2daf73618f00665d7a55 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Thu, 21 Jul 2022 17:33:55 +0100 Subject: [PATCH 095/101] Update bors.toml --- bors.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bors.toml b/bors.toml index 302d9e1d5..273e80ce6 100644 --- a/bors.toml +++ b/bors.toml @@ -17,4 +17,5 @@ required_approvals = 1 use_squash_merge = true # Uncomment this to use a four hour timeout. # The default is one hour. -timeout_sec = 14400 +# timeout_sec = 14400 +timeout_sec = 43200 From 5c5b9cef9fcfeef7c1e9d67a9d052c311c366ae4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 22 Jul 2022 10:23:16 +0100 Subject: [PATCH 096/101] Revert "Update bors.toml" This reverts commit 6f9be0d742a798ac5c7e2daf73618f00665d7a55. --- bors.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/bors.toml b/bors.toml index 273e80ce6..302d9e1d5 100644 --- a/bors.toml +++ b/bors.toml @@ -17,5 +17,4 @@ required_approvals = 1 use_squash_merge = true # Uncomment this to use a four hour timeout. # The default is one hour. -# timeout_sec = 14400 -timeout_sec = 43200 +timeout_sec = 14400 From e0797cc883ddcafb834f6cb80854ee42109f9662 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 22 Jul 2022 10:23:33 +0100 Subject: [PATCH 097/101] Revert "upper bound Distributions because tests are sooooo slow due to deprecations" This reverts commit 947e5c6ca4ff19ad09ccfe9153b680bffd9f3ee2. --- test/Project.toml | 2 +- test/turing/Project.toml | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index edba29e5a..6a3726474 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -23,7 +23,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" AbstractMCMC = "2.1, 3.0, 4" AbstractPPL = "0.5.1" Bijectors = "0.9.5, 0.10" -Distributions = "<0.25.65" +Distributions = "0.25" DistributionsAD = "0.6.3" Documenter = "0.26.1, 0.27" ForwardDiff = "0.10.12" diff --git a/test/turing/Project.toml b/test/turing/Project.toml index e8713dae2..41a873172 100644 --- a/test/turing/Project.toml +++ b/test/turing/Project.toml @@ -1,12 +1,10 @@ [deps] -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] -Distributions = "<0.25.65" DynamicPPL = "0.19" Turing = "0.21" julia = "1.3" From 4b0e0e1e8795699299a30ef75bcfe7e7b46ffc4d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 22 Jul 2022 10:28:46 +0100 Subject: [PATCH 098/101] switch of deprecation warnings from integration tests for now --- .github/workflows/IntegrationTest.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/IntegrationTest.yml b/.github/workflows/IntegrationTest.yml index a15cf6e8e..aa4ac1d0b 100644 --- a/.github/workflows/IntegrationTest.yml +++ b/.github/workflows/IntegrationTest.yml @@ -40,7 +40,7 @@ jobs: # force it to use this PR's version of the package Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps Pkg.update() - Pkg.test() # resolver may fail with test time deps + Pkg.test(julia_args=["--depwarn=no"]) # resolver may fail with test time deps catch err err isa Pkg.Resolve.ResolverError || rethrow() # If we can't resolve that means this is incompatible by SemVer and this is fine From 5a73c87a58a595767cf7ee09d9066c19267f6332 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 23 Jul 2022 10:55:34 +0100 Subject: [PATCH 099/101] bump supported Julia version to 1.6 --- .github/workflows/CI.yml | 2 +- Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 2cdde32d0..f1da30a13 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -16,7 +16,7 @@ jobs: strategy: matrix: version: - - '1.3' # minimum supported version + - '1.6' # minimum supported version - '1' # current stable version os: - ubuntu-latest diff --git a/Project.toml b/Project.toml index bfa13d956..fc2cb6f41 100644 --- a/Project.toml +++ b/Project.toml @@ -30,4 +30,4 @@ DocStringExtensions = "0.8" MacroTools = "0.5.6" Setfield = "0.7.1, 0.8" ZygoteRules = "0.2" -julia = "1.3" +julia = "1.6" From bb43021bee11d2f4349863a6dee053248b77ad0c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 23 Jul 2022 10:55:53 +0100 Subject: [PATCH 100/101] added ability to filter varnames to check in TestUtils.test_sampler --- src/test_utils.jl | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index ef314fa91..bcc649675 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -670,17 +670,25 @@ To change how comparison is done for a particular `chain` type, one can overload - `args...`: Arguments forwarded to `sample`. # Keyword arguments +- `varnames_filter`: A filter to apply to `varnames(model)`, allowing comparison for only + a subset of the varnames. - `atol=1e-1`: Absolute tolerance used in `@test`. - `rtol=1e-3`: Relative tolerance used in `@test`. - `kwargs...`: Keyword arguments forwarded to `sample`. """ function test_sampler( - models, sampler::AbstractMCMC.AbstractSampler, args...; atol=1e-1, rtol=1e-3, kwargs... + models, + sampler::AbstractMCMC.AbstractSampler, + args...; + varnames_filter=Returns(true), + atol=1e-1, + rtol=1e-3, + kwargs..., ) @testset "$(typeof(sampler)) on $(nameof(model))" for model in models chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) target_values = posterior_mean(model) - for vn in varnames(model) + for vn in filter(varnames_filter, varnames(model)) # We want to compare elementwise which can be achieved by # extracting the leaves of the `VarName` and the corresponding value. for vn_leaf in varname_leaves(vn, get(target_values, vn)) From d5a48f8514157ecb06ce1e473513cde21d58eecb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 23 Jul 2022 11:06:07 +0100 Subject: [PATCH 101/101] bump minor version --- Project.toml | 2 +- test/Project.toml | 2 +- test/turing/Project.toml | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index fc2cb6f41..89aa66be2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.19.4" +version = "0.20.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/test/Project.toml b/test/Project.toml index 6a3726474..9408cf14d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -33,4 +33,4 @@ Setfield = "0.7.1, 0.8" StableRNGs = "1" Tracker = "0.2.11" Zygote = "0.5.4, 0.6" -julia = "1.3" +julia = "1.6" diff --git a/test/turing/Project.toml b/test/turing/Project.toml index 41a873172..39223b95e 100644 --- a/test/turing/Project.toml +++ b/test/turing/Project.toml @@ -5,6 +5,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] -DynamicPPL = "0.19" +DynamicPPL = "0.20" Turing = "0.21" -julia = "1.3" +julia = "1.6"