From 14594d655a2e719ffc001afb2907b1f26c471e8f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 7 Jan 2022 20:54:56 +0000 Subject: [PATCH 001/221] performing linking in assume rather than implicitly in getindex --- src/context_implementations.jl | 18 +++++++++++++----- src/varinfo.jl | 6 ++++++ test/model.jl | 21 +++++++++++++++++++++ 3 files changed, 40 insertions(+), 5 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 20c4af446..bcb3ad463 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -194,7 +194,9 @@ end # fallback without sampler function assume(dist::Distribution, vn::VarName, vi) - r = vi[vn] + # x = vi[vn] + r_raw = getindex_raw(vi, vn) + r = maybe_invlink(vi, vn, dist, r_raw) return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi end @@ -215,7 +217,9 @@ function assume( settrans!(vi, false, vn) setorder!(vi, vn, get_num_produce(vi)) else - r = vi[vn] + # r = vi[vn] + r_raw = getindex_raw(vi, vn) + r = maybe_invlink(vi, vn, dist, r_raw) end else r = init(rng, dist, sampler) @@ -390,7 +394,9 @@ function dot_assume( # m .~ Normal() # # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - r = vi[vns] + # r = vi[vns] + r_raw = getindex_raw(vi, vns) + r = maybe_invlink(vi, vn, dist, r_raw) lp = sum(zip(vns, eachcol(r))) do (vn, ri) return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) end @@ -423,7 +429,8 @@ function dot_assume( # m .~ Normal() # # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - r = reshape(vi[vec(vns)], size(vns)) + r_raw = getindex_raw(vi, vec(vns)) + r = reshape(maybe_invlink.(Ref(vi), vns, dists, r_raw), size(vns)) lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) return r, lp, vi end @@ -467,7 +474,8 @@ function get_and_set_val!( setorder!(vi, vn, get_num_produce(vi)) end else - r = vi[vns] + r_raw = getindex_raw(vi, vns) + r = maybe_invlink(vi, vns, dist, r_raw) end else r = init(rng, dist, spl, n) diff --git a/src/varinfo.jl b/src/varinfo.jl index 9ce0414d6..c8ffca5df 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -866,6 +866,9 @@ end return expr end +maybe_invlink(vi, vn, dist, val) = istrans(vi, vn) ? Bijectors.invlink(dist, val) : val + + """ islinked(vi::VarInfo, spl::Union{Sampler, SampleFromPrior}) @@ -923,6 +926,9 @@ function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) end end +getindex_raw(vi::AbstractVarInfo, vn::VarName) = reconstruct(getdist(vi, vn), getval(vi, vn)) +getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}) = reconstruct(getdist(vi, first(vns)), getval(vi, vns)) + """ getindex(vi::VarInfo, spl::Union{SampleFromPrior, Sampler}) diff --git a/test/model.jl b/test/model.jl index 466a7d1f4..7e417eaa7 100644 --- a/test/model.jl +++ b/test/model.jl @@ -81,4 +81,25 @@ call_retval = model() @test !any(map(x -> x isa DynamicPPL.AbstractVarInfo, call_retval)) end + + @testset "Dynamic constraints" begin + @model function dynamic_constraints() + m ~ Normal() + x ~ truncated(Normal(), m, Inf) + end + + model = dynamic_constraints() + vi = VarInfo(model) + spl = SampleFromPrior() + link!(vi, spl) + + for i = 1:10 + # Sample with large variations. + r_raw = randn(length(vi[spl])) * 10 + vi[spl] = r_raw + @test vi[@varname(m)] == r_raw[1] + @test vi[@varname(x)] != r_raw[2] + model(vi) + end + end end From 0bc279f5848eb7c42794df9a0df6e490f74303d0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 7 Jan 2022 22:20:58 +0000 Subject: [PATCH 002/221] added istrans to SimpleVarInfo --- src/simple_varinfo.jl | 87 ++++++++++++++++--------------------------- 1 file changed, 32 insertions(+), 55 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 5cecda4b2..85896792b 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -129,12 +129,13 @@ ERROR: type NamedTuple has no field b struct SimpleVarInfo{NT,T} <: AbstractVarInfo values::NT logp::T + istrans::Bool end -SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) -SimpleVarInfo{T}(; kwargs...) where {T<:Real} = SimpleVarInfo{T}(NamedTuple(kwargs)) -SimpleVarInfo(; kwargs...) = SimpleVarInfo{Float64}(NamedTuple(kwargs)) -SimpleVarInfo(θ) = SimpleVarInfo{Float64}(θ) +SimpleVarInfo{T}(θ, istrans::Bool=false) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T), istrans) +SimpleVarInfo{T}(istrans::Bool=false; kwargs...) where {T<:Real} = SimpleVarInfo{T}(NamedTuple(kwargs), istrans) +SimpleVarInfo(istrans::Bool=false; kwargs...) = SimpleVarInfo{Float64}(NamedTuple(kwargs), istrans) +SimpleVarInfo(θ, istrans::Bool=false) = SimpleVarInfo{Float64}(θ, istrans) # Constructor from `Model`. SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...) @@ -158,8 +159,8 @@ function BangBang.empty!!(vi::SimpleVarInfo) end getlogp(vi::SimpleVarInfo) = vi.logp -setlogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.values, logp) -acclogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.values, getlogp(vi) + logp) +setlogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = logp +acclogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = getlogp(vi) + logp """ keys(vi::SimpleVarInfo) @@ -179,7 +180,7 @@ function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) end function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo) - return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ")") + return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ", ", svi.istrans, ")") end # `NamedTuple` @@ -224,6 +225,11 @@ Base.getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.values # TODO: Should we do better? Base.getindex(vi::SimpleVarInfo, spl::Sampler) = vi.values +# Since we don't perform any transformations in `getindex` for `SimpleVarInfo` +# we simply call `getindex` in `getindex_raw`. +getindex_raw(vi::SimpleVarInfo, vn::VarName) = vi[vn] +getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}) = vi[vns] + Base.haskey(vi::SimpleVarInfo, vn::VarName) = _haskey(vi.values, vn) function _haskey(nt::NamedTuple, vn::VarName) # LHS: Ensure that `nt` indeed has the property we want. @@ -337,58 +343,21 @@ function Base.eltype( end # Context implementations -function assume(dist::Distribution, vn::VarName, vi::SimpleOrThreadSafeSimple) - left = vi[vn] - return left, Distributions.loglikelihood(dist, left), vi -end - +# NOTE: Evaluations, i.e. those without `rng` are shared with other +# implementations of `AbstractVarInfo`. function assume( rng::Random.AbstractRNG, - sampler::SampleFromPrior, + sampler::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, vi::SimpleOrThreadSafeSimple, ) value = init(rng, dist, sampler) - vi = BangBang.push!!(vi, vn, value, dist, sampler) - return value, Distributions.loglikelihood(dist, value), vi -end - -function dot_assume( - dist::MultivariateDistribution, - var::AbstractMatrix, - vns::AbstractVector{<:VarName}, - vi::SimpleOrThreadSafeSimple, -) - @assert length(dist) == size(var, 1) - # NOTE: We cannot work with `var` here because we might have a model of the form - # - # m = Vector{Float64}(undef, n) - # m .~ Normal() - # - # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - value = vi[vns] - lp = sum(zip(vns, eachcol(value))) do (vn, val) - return Distributions.logpdf(dist, val) - end - return value, lp, vi -end - -function dot_assume( - dists::Union{Distribution,AbstractArray{<:Distribution}}, - var::AbstractArray, - vns::AbstractArray{<:VarName}, - vi::SimpleOrThreadSafeSimple, -) - # NOTE: We cannot work with `var` here because we might have a model of the form - # - # m = Vector{Float64}(undef, n) - # m .~ Normal() - # - # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - value = vi[vns] - lp = sum(Distributions.logpdf.(dists, value)) - return value, lp, vi + # Transform if we're working in unconstrained space. + ist = istrans(vi, vn) + value_raw = ist ? Bijectors.link(dist, value) : value + vi = BangBang.push!!(vi, vn, value_raw, dist, sampler) + return value, Bijectors.logpdf_with_trans(dist, value, ist), vi end function dot_assume( @@ -401,15 +370,23 @@ function dot_assume( ) f = (vn, dist) -> init(rng, dist, spl) value = f.(vns, dists) - vi = BangBang.setindex!!(vi, value, vns) - lp = sum(Distributions.logpdf.(dists, value)) + + # Transform if we're working in transformed space. + ist = istrans(vi, first(vns)) + value_raw = ist ? link.(dist, value) : value + + # Update `vi` + vi = BangBang.setindex!!(vi, value_raw, vns) + + # Compute logp. + lp = sum(Bijectors.logpdf_with_trans.(dists, value, ist)) return value, lp, vi end # HACK: Allows us to re-use the implementation of `dot_tilde`, etc. for literals. increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing settrans!(vi::SimpleOrThreadSafeSimple, trans::Bool, vn::VarName) = nothing -istrans(::SimpleVarInfo, vn::VarName) = false +istrans(svi::SimpleVarInfo, vn::VarName) = svi.istrans """ values_as(varinfo[, Type]) From 81ee12e44bc2382936094b0dd2f344e2ca7a3630 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 7 Jan 2022 23:46:13 +0100 Subject: [PATCH 003/221] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/simple_varinfo.jl | 12 +++++++++--- src/varinfo.jl | 9 ++++++--- test/model.jl | 4 ++-- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 85896792b..f14642b02 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -132,9 +132,15 @@ struct SimpleVarInfo{NT,T} <: AbstractVarInfo istrans::Bool end -SimpleVarInfo{T}(θ, istrans::Bool=false) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T), istrans) -SimpleVarInfo{T}(istrans::Bool=false; kwargs...) where {T<:Real} = SimpleVarInfo{T}(NamedTuple(kwargs), istrans) -SimpleVarInfo(istrans::Bool=false; kwargs...) = SimpleVarInfo{Float64}(NamedTuple(kwargs), istrans) +function SimpleVarInfo{T}(θ, istrans::Bool=false) where {T<:Real} + return SimpleVarInfo{typeof(θ),T}(θ, zero(T), istrans) +end +function SimpleVarInfo{T}(istrans::Bool=false; kwargs...) where {T<:Real} + return SimpleVarInfo{T}(NamedTuple(kwargs), istrans) +end +function SimpleVarInfo(istrans::Bool=false; kwargs...) + return SimpleVarInfo{Float64}(NamedTuple(kwargs), istrans) +end SimpleVarInfo(θ, istrans::Bool=false) = SimpleVarInfo{Float64}(θ, istrans) # Constructor from `Model`. diff --git a/src/varinfo.jl b/src/varinfo.jl index c8ffca5df..78d3b486b 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -868,7 +868,6 @@ end maybe_invlink(vi, vn, dist, val) = istrans(vi, vn) ? Bijectors.invlink(dist, val) : val - """ islinked(vi::VarInfo, spl::Union{Sampler, SampleFromPrior}) @@ -926,8 +925,12 @@ function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) end end -getindex_raw(vi::AbstractVarInfo, vn::VarName) = reconstruct(getdist(vi, vn), getval(vi, vn)) -getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}) = reconstruct(getdist(vi, first(vns)), getval(vi, vns)) +function getindex_raw(vi::AbstractVarInfo, vn::VarName) + return reconstruct(getdist(vi, vn), getval(vi, vn)) +end +function getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}) + return reconstruct(getdist(vi, first(vns)), getval(vi, vns)) +end """ getindex(vi::VarInfo, spl::Union{SampleFromPrior, Sampler}) diff --git a/test/model.jl b/test/model.jl index 7e417eaa7..c2ab0d724 100644 --- a/test/model.jl +++ b/test/model.jl @@ -85,7 +85,7 @@ @testset "Dynamic constraints" begin @model function dynamic_constraints() m ~ Normal() - x ~ truncated(Normal(), m, Inf) + return x ~ truncated(Normal(), m, Inf) end model = dynamic_constraints() @@ -93,7 +93,7 @@ spl = SampleFromPrior() link!(vi, spl) - for i = 1:10 + for i in 1:10 # Sample with large variations. r_raw = randn(length(vi[spl])) * 10 vi[spl] = r_raw From d39f87dc494261b75e26bdf2e7b953ea185e774e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 7 Jan 2022 22:57:59 +0000 Subject: [PATCH 004/221] added a comment --- src/simple_varinfo.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index f14642b02..c450145f8 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -129,6 +129,7 @@ ERROR: type NamedTuple has no field b struct SimpleVarInfo{NT,T} <: AbstractVarInfo values::NT logp::T + # TODO: Should we put this in the type instead? istrans::Bool end From 23f34cc2f7272094cf8d91150b9b8e999cc36ed4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 7 Jan 2022 22:58:10 +0000 Subject: [PATCH 005/221] bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b21c4a91a..cd40c074a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.17.3" +version = "0.17.4" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From d3ec10822b5c36ef5e41133170798d134063f0f7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 9 Jan 2022 23:17:10 +0000 Subject: [PATCH 006/221] introduced settrans!! --- src/context_implementations.jl | 18 +++++++++--------- src/simple_varinfo.jl | 9 ++++++++- src/varinfo.jl | 24 +++++++++++++++--------- 3 files changed, 32 insertions(+), 19 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index bcb3ad463..f66c4d83f 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -55,7 +55,7 @@ end function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi) if haskey(context.vars, getsym(vn)) vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return tilde_assume(PriorContext(), right, vn, vi) end @@ -64,7 +64,7 @@ function tilde_assume( ) if haskey(context.vars, getsym(vn)) vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return tilde_assume(rng, PriorContext(), sampler, right, vn, vi) end @@ -72,7 +72,7 @@ end function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi) if haskey(context.vars, getsym(vn)) vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return tilde_assume(LikelihoodContext(), right, vn, vi) end @@ -86,7 +86,7 @@ function tilde_assume( ) if haskey(context.vars, getsym(vn)) vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi) end @@ -214,7 +214,7 @@ function assume( unset_flag!(vi, vn, "del") r = init(rng, dist, sampler) vi[vn] = vectorize(dist, r) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) setorder!(vi, vn, get_num_produce(vi)) else # r = vi[vn] @@ -224,7 +224,7 @@ function assume( else r = init(rng, dist, sampler) push!!(vi, vn, r, dist, sampler) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi @@ -470,7 +470,7 @@ function get_and_set_val!( for i in 1:n vn = vns[i] vi[vn] = vectorize(dist, r[:, i]) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) setorder!(vi, vn, get_num_produce(vi)) end else @@ -482,7 +482,7 @@ function get_and_set_val!( for i in 1:n vn = vns[i] push!!(vi, vn, r[:, i], dist, spl) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end end return r @@ -505,7 +505,7 @@ function get_and_set_val!( vn = vns[i] dist = dists isa AbstractArray ? dists[i] : dists vi[vn] = vectorize(dist, r[i]) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) setorder!(vi, vn, get_num_produce(vi)) end else diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index c450145f8..be9fa7455 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -392,8 +392,15 @@ end # HACK: Allows us to re-use the implementation of `dot_tilde`, etc. for literals. increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing -settrans!(vi::SimpleOrThreadSafeSimple, trans::Bool, vn::VarName) = nothing + +# NOTE: We don't implement `settrans!!(vi, trans, vn)`. +settrans!!(vi::SimpleVarInfo, trans::Bool) = Setfield.@set vi.istrans = trans +function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans::Bool) + return Setfield.@set vi.varinfo = settrans!!(vi, trans) +end + istrans(svi::SimpleVarInfo, vn::VarName) = svi.istrans +istrans(svi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = svi.istrans """ values_as(varinfo[, Type]) diff --git a/src/varinfo.jl b/src/varinfo.jl index 78d3b486b..49f7b553d 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -358,12 +358,18 @@ Return the set of sampler selectors associated with `vn` in `vi`. getgid(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).gids[getidx(vi, vn)] """ - settrans!(vi::VarInfo, trans::Bool, vn::VarName) + settrans!!(vi::VarInfo, trans::Bool, vn::VarName) Set the `trans` flag value of `vn` in `vi`. """ -function settrans!(vi::AbstractVarInfo, trans::Bool, vn::VarName) - return trans ? set_flag!(vi, vn, "trans") : unset_flag!(vi, vn, "trans") +function settrans!!(vi::AbstractVarInfo, trans::Bool, vn::VarName) + if trans + set_flag!(vi, vn, "trans") + else + unset_flag!(vi, vn, "trans") + end + + return vi end """ @@ -749,7 +755,7 @@ function link!(vi::UntypedVarInfo, spl::Sampler) vectorize(dist, Bijectors.link(dist, reconstruct(dist, getval(vi, vn)))), vn, ) - settrans!(vi, true, vn) + settrans!!(vi, true, vn) end else @warn("[DynamicPPL] attempt to link a linked vi") @@ -785,7 +791,7 @@ end ), vn, ) - settrans!(vi, true, vn) + settrans!!(vi, true, vn) end else @warn("[DynamicPPL] attempt to link a linked vi") @@ -816,7 +822,7 @@ function invlink!(vi::UntypedVarInfo, spl::AbstractSampler) vectorize(dist, Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn)))), vn, ) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end else @warn("[DynamicPPL] attempt to invlink an invlinked vi") @@ -854,7 +860,7 @@ end ), vn, ) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end else @warn("[DynamicPPL] attempt to invlink an invlinked vi") @@ -1420,7 +1426,7 @@ function _setval_kernel!(vi::VarInfo, vn::VarName, values, keys) if !isempty(indices) val = reduce(vcat, values[indices]) setval!(vi, val, vn) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return indices @@ -1501,7 +1507,7 @@ function _setval_and_resample_kernel!(vi::VarInfo, vn::VarName, values, keys) if !isempty(indices) val = reduce(vcat, values[indices]) setval!(vi, val, vn) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) else # Ensures that we'll resample the variable corresponding to `vn` if we run # the model on `vi` again. From 81782c9c54c99e86768ea48cbe7f11e6e4139711 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 9 Jan 2022 23:50:33 +0000 Subject: [PATCH 007/221] added istrans(vi) and renamed all occurences of trans! to trans!! --- src/context_implementations.jl | 49 ++++++++++++++++++++++------------ src/simple_varinfo.jl | 5 ++-- src/varinfo.jl | 9 +++++++ 3 files changed, 44 insertions(+), 19 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index f66c4d83f..8509e4a09 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -213,18 +213,23 @@ function assume( if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") unset_flag!(vi, vn, "del") r = init(rng, dist, sampler) - vi[vn] = vectorize(dist, r) - settrans!!(vi, false, vn) + vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r)) setorder!(vi, vn, get_num_produce(vi)) else + # Otherwise we just extract it. # r = vi[vn] r_raw = getindex_raw(vi, vn) r = maybe_invlink(vi, vn, dist, r_raw) end else r = init(rng, dist, sampler) - push!!(vi, vn, r, dist, sampler) - settrans!!(vi, false, vn) + if istrans(vi) + push!!(vi, vn, link(dist, r), dist, sampler) + # By default `push!!` sets the transformed flag to `false`. + settrans!!(vi, true, vn) + else + push!!(vi, vn, r, dist, sampler) + end end return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi @@ -290,7 +295,7 @@ function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left, var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) + settrans!!.(Ref(vi), false, _vns) dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi) else dot_tilde_assume(LikelihoodContext(), right, left, vn, vi) @@ -309,7 +314,7 @@ function dot_tilde_assume( var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) + settrans!!.(Ref(vi), false, _vns) dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi) else dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi) @@ -330,7 +335,7 @@ function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) + settrans!!.(Ref(vi), false, _vns) dot_tilde_assume(PriorContext(), _right, _left, _vns, vi) else dot_tilde_assume(PriorContext(), right, left, vn, vi) @@ -349,7 +354,7 @@ function dot_tilde_assume( var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) + settrans!!.(Ref(vi), false, _vns) dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi) else dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi) @@ -469,8 +474,7 @@ function get_and_set_val!( r = init(rng, dist, spl, n) for i in 1:n vn = vns[i] - vi[vn] = vectorize(dist, r[:, i]) - settrans!!(vi, false, vn) + vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r[:, i])) setorder!(vi, vn, get_num_produce(vi)) end else @@ -481,8 +485,13 @@ function get_and_set_val!( r = init(rng, dist, spl, n) for i in 1:n vn = vns[i] - push!!(vi, vn, r[:, i], dist, spl) - settrans!!(vi, false, vn) + if istrans(vi) + push!!(vi, vn, maybe_link(vi, vn, dist, r[:, i]), dist, spl) + # `push!!` sets the trans-flag to `false` by default. + setttrans!!(vi, true, vn) + else + push!!(vi, vn, r[:, i], dist, spl) + end end end return r @@ -504,12 +513,13 @@ function get_and_set_val!( for i in eachindex(vns) vn = vns[i] dist = dists isa AbstractArray ? dists[i] : dists - vi[vn] = vectorize(dist, r[i]) - settrans!!(vi, false, vn) + vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r[i])) setorder!(vi, vn, get_num_produce(vi)) end else - r = reshape(vi[vec(vns)], size(vns)) + # r = reshape(vi[vec(vns)], size(vns)) + r_raw = getindex_raw(vi, vec(vns)) + r = maybe_invlink.(Ref(vi), vns, dists, reshape(r_raw, size(vns))) end else f = (vn, dist) -> init(rng, dist, spl) @@ -519,8 +529,13 @@ function get_and_set_val!( # 1. Figure out the broadcast size and use a `foreach`. # 2. Define an anonymous function which returns `nothing`, which # we then broadcast. This will allocate a vector of `nothing` though. - push!!.(Ref(vi), vns, r, dists, Ref(spl)) - settrans!.(Ref(vi), false, vns) + if istrans(vi) + push!!.(Ref(vi), vns, link.(Ref(vi), vns, dists, r), dists, Ref(spl)) + # `push!!` sets the trans-flag to `false` by default. + settrans!!.(Ref(vi), true, vns) + else + push!!.(Ref(vi), vns, r, dists, Ref(spl)) + end end return r end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index be9fa7455..fc7d9e80f 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -399,8 +399,9 @@ function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans::Bool) return Setfield.@set vi.varinfo = settrans!!(vi, trans) end -istrans(svi::SimpleVarInfo, vn::VarName) = svi.istrans -istrans(svi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = svi.istrans +istrans(vi::SimpleVarInfo) = vi.istrans +istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi) +istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) """ values_as(varinfo[, Type]) diff --git a/src/varinfo.jl b/src/varinfo.jl index 49f7b553d..d2b0cd6c4 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -644,6 +644,14 @@ function setgid!(vi::VarInfo, gid::Selector, vn::VarName) return push!(getmetadata(vi, vn).gids[getidx(vi, vn)], gid) end +""" + istrans(vi::AbstractVarInfo) + +Return `true` if `vi` is working in unconstrained space, and `false` +if `vi` is assuming realizations to be in support of the corresponding distributions. +""" +istrans(vi::AbstractVarInfo) = false + """ istrans(vi::VarInfo, vn::VarName) @@ -872,6 +880,7 @@ end return expr end +maybe_link(vi, vn, dist, val) = istrans(vi, vn) ? Bijectors.link(dist, val) : val maybe_invlink(vi, vn, dist, val) = istrans(vi, vn) ? Bijectors.invlink(dist, val) : val """ From 12bfb420a1fa9bc2e0572d2dd0d7a385c0fac356 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 10 Jan 2022 00:22:04 +0000 Subject: [PATCH 008/221] exclusively use settrans!! to set the istrans for SimpleVarInfo --- src/simple_varinfo.jl | 15 ++++++++------- src/varinfo.jl | 2 +- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index fc7d9e80f..9805e28ab 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -133,16 +133,17 @@ struct SimpleVarInfo{NT,T} <: AbstractVarInfo istrans::Bool end -function SimpleVarInfo{T}(θ, istrans::Bool=false) where {T<:Real} - return SimpleVarInfo{typeof(θ),T}(θ, zero(T), istrans) +SimpleVarInfo(values, logp) = SimpleVarInfo(values, logp, false) +function SimpleVarInfo{T}(θ) where {T<:Real} + return SimpleVarInfo{typeof(θ),T}(θ, zero(T), false) end -function SimpleVarInfo{T}(istrans::Bool=false; kwargs...) where {T<:Real} - return SimpleVarInfo{T}(NamedTuple(kwargs), istrans) +function SimpleVarInfo{T}(; kwargs...) where {T<:Real} + return SimpleVarInfo{T}(NamedTuple(kwargs)) end -function SimpleVarInfo(istrans::Bool=false; kwargs...) - return SimpleVarInfo{Float64}(NamedTuple(kwargs), istrans) +function SimpleVarInfo(; kwargs...) + return SimpleVarInfo{Float64}(NamedTuple(kwargs)) end -SimpleVarInfo(θ, istrans::Bool=false) = SimpleVarInfo{Float64}(θ, istrans) +SimpleVarInfo(θ) = SimpleVarInfo{Float64}(θ) # Constructor from `Model`. SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...) diff --git a/src/varinfo.jl b/src/varinfo.jl index d2b0cd6c4..67ded7b82 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -650,7 +650,7 @@ end Return `true` if `vi` is working in unconstrained space, and `false` if `vi` is assuming realizations to be in support of the corresponding distributions. """ -istrans(vi::AbstractVarInfo) = false +istrans(vi::VarInfo) = false # `VarInfo` works in constrained space by default. """ istrans(vi::VarInfo, vn::VarName) From c2c2417f13df1d4591965051cd68f42b5e300c78 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 10 Jan 2022 13:51:05 +0000 Subject: [PATCH 009/221] removed usage of deprecated method in turing tests --- test/turing/varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/turing/varinfo.jl b/test/turing/varinfo.jl index 892433779..0db560956 100644 --- a/test/turing/varinfo.jl +++ b/test/turing/varinfo.jl @@ -9,7 +9,7 @@ ) if !haskey(vi, vn) r = rand(dist) - push!(vi, vn, r, dist, spl) + push!!(vi, vn, r, dist, spl) r elseif is_flagged(vi, vn, "del") unset_flag!(vi, vn, "del") From 2e2cb5cd42b9df58e8ff8ff21d7a9e390bf4a2f0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 13 Jan 2022 19:03:49 +0000 Subject: [PATCH 010/221] added docstring to settrans!! --- src/simple_varinfo.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 9805e28ab..7f185780d 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -395,6 +395,11 @@ end increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing # NOTE: We don't implement `settrans!!(vi, trans, vn)`. +""" + settrans!!(vi::AbstractVarInfo, trans::Bool) + +Return new instance of `vi` but with `istrans(vi, trans)` now evaluating to `true`. +""" settrans!!(vi::SimpleVarInfo, trans::Bool) = Setfield.@set vi.istrans = trans function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans::Bool) return Setfield.@set vi.varinfo = settrans!!(vi, trans) From f6c3fc42afd427b9668b151aa31a81d1297eb630 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Jan 2022 13:22:37 +0000 Subject: [PATCH 011/221] include istrans flag in type of SimpleVarInfo instead --- src/simple_varinfo.jl | 38 +++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 7f185780d..c9c9b5ff6 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -126,16 +126,26 @@ ERROR: type NamedTuple has no field b [...] ``` """ -struct SimpleVarInfo{NT,T} <: AbstractVarInfo +struct SimpleVarInfo{NT,T,IsTrans} <: AbstractVarInfo values::NT logp::T - # TODO: Should we put this in the type instead? - istrans::Bool end -SimpleVarInfo(values, logp) = SimpleVarInfo(values, logp, false) +function Setfield.ConstructionBase.constructorof( + ::Type{<:SimpleVarInfo{<:Any,<:Any,IsTrans}} +) where {IsTrans} + return function SimpleVarInfo_constructor(values, logp) + return SimpleVarInfo{typeof(values),typeof(logp),IsTrans}(values, logp) + end +end + +SimpleVarInfo(values, logp, istrans::Bool) = SimpleVarInfo(values, logp, Val{istrans}()) +function SimpleVarInfo(values, logp, ::Val{IsTrans}) where {IsTrans} + return SimpleVarInfo{typeof(values),typeof(logp),IsTrans}(values, logp) +end + function SimpleVarInfo{T}(θ) where {T<:Real} - return SimpleVarInfo{typeof(θ),T}(θ, zero(T), false) + return SimpleVarInfo{typeof(θ),T,false}(θ, zero(T)) end function SimpleVarInfo{T}(; kwargs...) where {T<:Real} return SimpleVarInfo{T}(NamedTuple(kwargs)) @@ -187,8 +197,10 @@ function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) return vi end -function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo) - return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ", ", svi.istrans, ")") +function Base.show( + io::IO, ::MIME"text/plain", svi::SimpleVarInfo{<:Any,<:Any,IsTrans} +) where {IsTrans} + return print(io, "SimpleVarInfo{IsTrans=$(IsTrans)}(", svi.values, ", ", svi.logp, ")") end # `NamedTuple` @@ -339,8 +351,8 @@ function BangBang.push!!( return vi end -const SimpleOrThreadSafeSimple{T,V} = Union{ - SimpleVarInfo{T,V},ThreadSafeVarInfo{<:SimpleVarInfo{T,V}} +const SimpleOrThreadSafeSimple{T,V,IsTrans} = Union{ + SimpleVarInfo{T,V,IsTrans},ThreadSafeVarInfo{<:SimpleVarInfo{T,V,IsTrans}} } # Necessary for `matchingvalue` to work properly. @@ -396,16 +408,16 @@ increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing # NOTE: We don't implement `settrans!!(vi, trans, vn)`. """ - settrans!!(vi::AbstractVarInfo, trans::Bool) + settrans!!(vi::AbstractVarInfo, trans) Return new instance of `vi` but with `istrans(vi, trans)` now evaluating to `true`. """ -settrans!!(vi::SimpleVarInfo, trans::Bool) = Setfield.@set vi.istrans = trans -function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans::Bool) +settrans!!(vi::SimpleVarInfo, trans) = SimpleVarInfo(vi.values, vi.logp, trans) +function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) return Setfield.@set vi.varinfo = settrans!!(vi, trans) end -istrans(vi::SimpleVarInfo) = vi.istrans +istrans(vi::SimpleVarInfo{<:Any,<:Any,IsTrans}) where {IsTrans} = IsTrans istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi) istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) From d643a783b5f6a9dc316c69b5064b1e5cc4e249bf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Jan 2022 13:41:29 +0000 Subject: [PATCH 012/221] deprecated settrans! in favour of settrans!! --- src/DynamicPPL.jl | 2 ++ src/simple_varinfo.jl | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 82df0f008..21efbd72d 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -176,4 +176,6 @@ include("test_utils.jl") @deprecate acclogp!(vi, logp) acclogp!!(vi, logp) @deprecate resetlogp!(vi) resetlogp!!(vi) +@deprecate settrans!(vi::AbstractVarInfo, trans::Bool, vn::VarName) settrans!!(vi, trans, vn) + end # module diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index c9c9b5ff6..93c91e21e 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -139,7 +139,7 @@ function Setfield.ConstructionBase.constructorof( end end -SimpleVarInfo(values, logp, istrans::Bool) = SimpleVarInfo(values, logp, Val{istrans}()) +SimpleVarInfo(values, logp, istrans::Bool=false) = SimpleVarInfo(values, logp, Val{istrans}()) function SimpleVarInfo(values, logp, ::Val{IsTrans}) where {IsTrans} return SimpleVarInfo{typeof(values),typeof(logp),IsTrans}(values, logp) end From 3cab7d9ac7f6b64e337bed0b9711df359f05be93 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Jan 2022 13:48:49 +0000 Subject: [PATCH 013/221] added some tests specifically for istrans --- test/varinfo.jl | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/test/varinfo.jl b/test/varinfo.jl index 2f7816024..32c90bf47 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -271,4 +271,47 @@ DynamicPPL.setval_and_resample!(vi, vi.metadata.x.vals, ks) @test vals_prev == vi.metadata.x.vals end + + @testset "istrans" begin + @model demo_constrained() = x ~ truncated(Normal(), 0, Inf) + model = demo_constrained() + vn = @varname(x) + dist = truncated(Normal(), 0, Inf) + + ### `VarInfo` + # Need to run once since we can't specify that we want to _sample_ + # in the unconstrained space for `VarInfo` without having `vn` + # present in the `varinfo`. + ## `UntypedVarInfo` + vi = VarInfo() + vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + vi = DynamicPPL.settrans!!(vi, true, vn) + # Sample in unconstrained space. + vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + x = Bijectors.invlink(dist, DynamicPPL.getindex_raw(vi, vn)) + @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + + ## `TypedVarInfo` + vi = VarInfo(model) + vi = DynamicPPL.settrans!!(vi, true, vn) + # Sample in unconstrained space. + vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + x = Bijectors.invlink(dist, DynamicPPL.getindex_raw(vi, vn)) + @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + + ### `SimpleVarInfo` + ## `SimpleVarInfo{<:NamedTuple}` + vi = DynamicPPL.settrans!!(SimpleVarInfo(), true) + # Sample in unconstrained space. + vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + x = Bijectors.invlink(dist, DynamicPPL.getindex_raw(vi, vn)) + @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + + ## `SimpleVarInfo{<:Dict}` + vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true) + # Sample in unconstrained space. + vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + x = Bijectors.invlink(dist, DynamicPPL.getindex_raw(vi, vn)) + @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + end end From 8b870dcff71f69dbe4c32d81a8ca1e9269e62364 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Jan 2022 13:49:49 +0000 Subject: [PATCH 014/221] formatting --- src/DynamicPPL.jl | 4 +++- src/simple_varinfo.jl | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 21efbd72d..f0ede3f67 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -176,6 +176,8 @@ include("test_utils.jl") @deprecate acclogp!(vi, logp) acclogp!!(vi, logp) @deprecate resetlogp!(vi) resetlogp!!(vi) -@deprecate settrans!(vi::AbstractVarInfo, trans::Bool, vn::VarName) settrans!!(vi, trans, vn) +@deprecate settrans!(vi::AbstractVarInfo, trans::Bool, vn::VarName) settrans!!( + vi, trans, vn +) end # module diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 93c91e21e..ee1fa9999 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -139,7 +139,9 @@ function Setfield.ConstructionBase.constructorof( end end -SimpleVarInfo(values, logp, istrans::Bool=false) = SimpleVarInfo(values, logp, Val{istrans}()) +function SimpleVarInfo(values, logp, istrans::Bool=false) + return SimpleVarInfo(values, logp, Val{istrans}()) +end function SimpleVarInfo(values, logp, ::Val{IsTrans}) where {IsTrans} return SimpleVarInfo{typeof(values),typeof(logp),IsTrans}(values, logp) end From 0b304db2b61564b3e185988559a621e7966bb28a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Jan 2022 14:32:56 +0000 Subject: [PATCH 015/221] fixed bugs for ThreadSafeVarInfo --- src/threadsafe.jl | 3 +++ src/varinfo.jl | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 6f020a352..b42cf82a5 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -65,6 +65,9 @@ getindex(vi::ThreadSafeVarInfo, spl::SampleFromUniform) = getindex(vi.varinfo, s getindex(vi::ThreadSafeVarInfo, vn::VarName) = getindex(vi.varinfo, vn) getindex(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) = getindex(vi.varinfo, vns) +getindex_raw(vi::ThreadSafeVarInfo, vn::VarName) = getindex_raw(vi.varinfo, vn) +getindex_raw(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) = getindex_raw(vi.varinfo, vns) + function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::AbstractSampler) return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) end diff --git a/src/varinfo.jl b/src/varinfo.jl index 67ded7b82..c3aaf8255 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -650,7 +650,7 @@ end Return `true` if `vi` is working in unconstrained space, and `false` if `vi` is assuming realizations to be in support of the corresponding distributions. """ -istrans(vi::VarInfo) = false # `VarInfo` works in constrained space by default. +istrans(vi::AbstractVarInfo) = false # `VarInfo` works in constrained space by default. """ istrans(vi::VarInfo, vn::VarName) From b146b116d5877215849d457e7d9a3e2771a00ee6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Jan 2022 21:01:46 +0000 Subject: [PATCH 016/221] additional constructor for SimpleVarInfo --- src/simple_varinfo.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index ee1fa9999..153e43918 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -139,7 +139,8 @@ function Setfield.ConstructionBase.constructorof( end end -function SimpleVarInfo(values, logp, istrans::Bool=false) +SimpleVarInfo(values, logp) = SimpleVarInfo(values, logp, Val{false}()) +function SimpleVarInfo(values, logp, istrans::Bool) return SimpleVarInfo(values, logp, Val{istrans}()) end function SimpleVarInfo(values, logp, ::Val{IsTrans}) where {IsTrans} From d170d92d98494d2078afd39c6ce5f2b88d158299 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Feb 2022 13:28:41 +0000 Subject: [PATCH 017/221] Update src/DynamicPPL.jl Co-authored-by: David Widmann --- src/DynamicPPL.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index f0ede3f67..fc7a2fd0c 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -178,6 +178,6 @@ include("test_utils.jl") @deprecate settrans!(vi::AbstractVarInfo, trans::Bool, vn::VarName) settrans!!( vi, trans, vn -) +) false end # module From f46183bcb3cbfe4a29f38639bee769e81336502d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Feb 2022 13:30:13 +0000 Subject: [PATCH 018/221] added ConstructionBase.jl as dep --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index cd40c074a..7412279b2 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" @@ -22,6 +23,7 @@ AbstractPPL = "0.3" BangBang = "0.3" Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9, 0.10" ChainRulesCore = "0.9.7, 0.10, 1" +ConstructionBase = "1" Distributions = "0.23.8, 0.24, 0.25" MacroTools = "0.5.6" Setfield = "0.7.1, 0.8" From 7a78eec52a5e700cc1570618e78630bd850fb924 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Feb 2022 15:58:32 +0000 Subject: [PATCH 019/221] added constraint types and doctests --- src/simple_varinfo.jl | 108 ++++++++++++++++++++++++++++++++---------- 1 file changed, 82 insertions(+), 26 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 153e43918..fb5e916f1 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -1,5 +1,10 @@ +abstract type AbstractConstraint end + +struct Constrained <: AbstractConstraint end +struct Unconstrained <: AbstractConstraint end + """ - SimpleVarInfo{NT,T} <: AbstractVarInfo + SimpleVarInfo{NT,T,C} <: AbstractVarInfo A simple wrapper of the parameters with a `logp` field for accumulation of the logdensity. @@ -16,7 +21,7 @@ The major differences between this and `TypedVarInfo` are: # Examples ## General usage -```jldoctest; setup=:(using Distributions) +```jldoctest simplevarinfo-general; setup=:(using Distributions) julia> using StableRNGs julia> @model function demo() @@ -78,6 +83,60 @@ ERROR: KeyError: key x[1:2] not found [...] ``` +You can also sample in _unconstrained_ space: + +```jldoctest simplevarinfo-general +julia> @model demo_constrained() = x ~ Exponential() +demo_constrained (generic function with 2 methods) + +julia> m = demo_constrained(); + +julia> _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo(), ctx); + +julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞ +1.8632965762164932 + +julia> _, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(), true), ctx); + +julia> vi[@varname(x)] # (✓) -∞ < x < ∞ +-0.21080155351918753 + +julia> xs = [last(DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(), true), ctx))[@varname(x)] for i = 1:10]; + +julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! +true + +julia> # And with `Dict` of course! + _, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true), ctx); + +julia> vi[@varname(x)] # (✓) -∞ < x < ∞ +0.6225185067787314 + +julia> xs = [last(DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(), true), ctx))[@varname(x)] for i = 1:10]; + +julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! +true +``` + +Evaluation in unconstrained space of course also works: + +```jldoctest simplevarinfo-general +julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true) +Unconstrained SimpleVarInfo((x = -1.0,), 0.0) + +julia> # (✓) Positive probability mass on negative numbers! + getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) +-1.3678794411714423 + +julia> # While if we forget to make indicate that it's unconstrained/transformed: + vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false) +Constrained SimpleVarInfo((x = -1.0,), 0.0) + +julia> # (✓) No probability mass on negative numbers! + getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) +-Inf +``` + ## Indexing Using `NamedTuple` as underlying storage. @@ -126,29 +185,19 @@ ERROR: type NamedTuple has no field b [...] ``` """ -struct SimpleVarInfo{NT,T,IsTrans} <: AbstractVarInfo +struct SimpleVarInfo{NT,T,C<:AbstractConstraint} <: AbstractVarInfo + "underlying representation of the realization represented" values::NT + "holds the accumulated log-probability" logp::T + "represents whether it assumes variables to be constrained or unconstrained" + constraint::C end -function Setfield.ConstructionBase.constructorof( - ::Type{<:SimpleVarInfo{<:Any,<:Any,IsTrans}} -) where {IsTrans} - return function SimpleVarInfo_constructor(values, logp) - return SimpleVarInfo{typeof(values),typeof(logp),IsTrans}(values, logp) - end -end - -SimpleVarInfo(values, logp) = SimpleVarInfo(values, logp, Val{false}()) -function SimpleVarInfo(values, logp, istrans::Bool) - return SimpleVarInfo(values, logp, Val{istrans}()) -end -function SimpleVarInfo(values, logp, ::Val{IsTrans}) where {IsTrans} - return SimpleVarInfo{typeof(values),typeof(logp),IsTrans}(values, logp) -end +SimpleVarInfo(values, logp) = SimpleVarInfo(values, logp, Constrained()) function SimpleVarInfo{T}(θ) where {T<:Real} - return SimpleVarInfo{typeof(θ),T,false}(θ, zero(T)) + return SimpleVarInfo(θ, zero(T)) end function SimpleVarInfo{T}(; kwargs...) where {T<:Real} return SimpleVarInfo{T}(NamedTuple(kwargs)) @@ -201,9 +250,15 @@ function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) end function Base.show( - io::IO, ::MIME"text/plain", svi::SimpleVarInfo{<:Any,<:Any,IsTrans} -) where {IsTrans} - return print(io, "SimpleVarInfo{IsTrans=$(IsTrans)}(", svi.values, ", ", svi.logp, ")") + io::IO, ::MIME"text/plain", svi::SimpleVarInfo{<:Any,<:Any,<:Constrained} +) + return print(io, "Constrained SimpleVarInfo(", svi.values, ", ", svi.logp, ")") +end + +function Base.show( + io::IO, ::MIME"text/plain", svi::SimpleVarInfo{<:Any,<:Any,<:Unconstrained} +) + return print(io, "Unconstrained SimpleVarInfo(", svi.values, ", ", svi.logp, ")") end # `NamedTuple` @@ -354,8 +409,8 @@ function BangBang.push!!( return vi end -const SimpleOrThreadSafeSimple{T,V,IsTrans} = Union{ - SimpleVarInfo{T,V,IsTrans},ThreadSafeVarInfo{<:SimpleVarInfo{T,V,IsTrans}} +const SimpleOrThreadSafeSimple{T,V,C} = Union{ + SimpleVarInfo{T,V,C},ThreadSafeVarInfo{<:SimpleVarInfo{T,V,C}} } # Necessary for `matchingvalue` to work properly. @@ -415,12 +470,13 @@ increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing Return new instance of `vi` but with `istrans(vi, trans)` now evaluating to `true`. """ -settrans!!(vi::SimpleVarInfo, trans) = SimpleVarInfo(vi.values, vi.logp, trans) +settrans!!(vi::SimpleVarInfo, trans) = SimpleVarInfo(vi.values, vi.logp, trans ? Unconstrained() : Constrained()) function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) return Setfield.@set vi.varinfo = settrans!!(vi, trans) end -istrans(vi::SimpleVarInfo{<:Any,<:Any,IsTrans}) where {IsTrans} = IsTrans +istrans(vi::SimpleVarInfo{<:Any,<:Any,<:Constrained}) = false +istrans(vi::SimpleVarInfo{<:Any,<:Any,<:Unconstrained}) = true istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi) istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) From 70b3b70e7c29d230822a17e4d9c98fb46720447a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Feb 2022 15:59:04 +0000 Subject: [PATCH 020/221] added DocStringExtensions as a dep --- Project.toml | 2 ++ src/DynamicPPL.jl | 2 ++ src/simple_varinfo.jl | 3 +++ 3 files changed, 7 insertions(+) diff --git a/Project.toml b/Project.toml index 7412279b2..8f4cfa184 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -25,6 +26,7 @@ Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9, 0.10" ChainRulesCore = "0.9.7, 0.10, 1" ConstructionBase = "1" Distributions = "0.23.8, 0.24, 0.25" +DocStringExtensions = "0.8" MacroTools = "0.5.6" Setfield = "0.7.1, 0.8" ZygoteRules = "0.2" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index fc7a2fd0c..9a66bb57e 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -15,6 +15,8 @@ using Setfield: Setfield using Setfield: Setfield using BangBang: BangBang +using DocStringExtensions + using Random: Random import Base: diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index fb5e916f1..c0dcacd17 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -11,6 +11,9 @@ accumulation of the logdensity. Currently only implemented for `NT<:NamedTuple` and `NT<:Dict`. +# Fields +$(FIELDS) + # Notes The major differences between this and `TypedVarInfo` are: 1. `SimpleVarInfo` does not require linearization. From a03e8cf7aac44e7bd75d32a963f7d65eab974d14 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Feb 2022 16:00:50 +0000 Subject: [PATCH 021/221] formatting --- src/simple_varinfo.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index c0dcacd17..2d1f807fc 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -473,7 +473,9 @@ increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing Return new instance of `vi` but with `istrans(vi, trans)` now evaluating to `true`. """ -settrans!!(vi::SimpleVarInfo, trans) = SimpleVarInfo(vi.values, vi.logp, trans ? Unconstrained() : Constrained()) +function settrans!!(vi::SimpleVarInfo, trans) + return SimpleVarInfo(vi.values, vi.logp, trans ? Unconstrained() : Constrained()) +end function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) return Setfield.@set vi.varinfo = settrans!!(vi, trans) end From 793c931a0ec3b0247a0e9628c8ade314d8bfffc4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Feb 2022 16:07:31 +0000 Subject: [PATCH 022/221] remove redundant maybe_link --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 8509e4a09..3cc1c9def 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -486,7 +486,7 @@ function get_and_set_val!( for i in 1:n vn = vns[i] if istrans(vi) - push!!(vi, vn, maybe_link(vi, vn, dist, r[:, i]), dist, spl) + push!!(vi, vn, Bijectors.link(dist, r[:, i]), dist, spl) # `push!!` sets the trans-flag to `false` by default. setttrans!!(vi, true, vn) else From a9b12fdd82ad0f1089bc5ca705166b8139d909b7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Feb 2022 16:08:08 +0000 Subject: [PATCH 023/221] fixed typo --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 3cc1c9def..53db489f1 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -488,7 +488,7 @@ function get_and_set_val!( if istrans(vi) push!!(vi, vn, Bijectors.link(dist, r[:, i]), dist, spl) # `push!!` sets the trans-flag to `false` by default. - setttrans!!(vi, true, vn) + settrans!!(vi, true, vn) else push!!(vi, vn, r[:, i], dist, spl) end From be989612e1d9b5c3e025215a91d58dec90a65511 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 11 Feb 2022 16:35:27 +0000 Subject: [PATCH 024/221] moved a docstring --- src/simple_varinfo.jl | 5 ----- src/varinfo.jl | 11 +++++++++++ 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 2d1f807fc..e3a709fdc 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -468,11 +468,6 @@ end increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing # NOTE: We don't implement `settrans!!(vi, trans, vn)`. -""" - settrans!!(vi::AbstractVarInfo, trans) - -Return new instance of `vi` but with `istrans(vi, trans)` now evaluating to `true`. -""" function settrans!!(vi::SimpleVarInfo, trans) return SimpleVarInfo(vi.values, vi.logp, trans ? Unconstrained() : Constrained()) end diff --git a/src/varinfo.jl b/src/varinfo.jl index c3aaf8255..d849b0fd7 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -372,6 +372,17 @@ function settrans!!(vi::AbstractVarInfo, trans::Bool, vn::VarName) return vi end +""" + settrans!!(vi::AbstractVarInfo, trans) + +Return new instance of `vi` but with `istrans(vi, trans)` now evaluating to `true`. +""" +function settrans!!(vi::VarInfo, trans::Bool) + for vn in keys(vi) + settrans!!(vi, trans, vn) + end +end + """ syms(vi::VarInfo) From 3e1588bc01183fcef00cfc3aa35bb9fd058910cd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 11 Feb 2022 16:35:55 +0000 Subject: [PATCH 025/221] fixed bug in tets --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 53db489f1..9d3ba5aa9 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -401,7 +401,7 @@ function dot_assume( # in which case `var` will have `undef` elements, even if `m` is present in `vi`. # r = vi[vns] r_raw = getindex_raw(vi, vns) - r = maybe_invlink(vi, vn, dist, r_raw) + r = maybe_invlink(vi, vns, dist, r_raw) lp = sum(zip(vns, eachcol(r))) do (vn, ri) return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) end From 3139c62f41ccde2d0ad82d66b3f392fe276d6c56 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 11 Feb 2022 16:36:57 +0000 Subject: [PATCH 026/221] version bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c07179202..0c243767d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.17.4" +version = "0.17.5" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From 3610658285eea9c44a1136903a11f782ce33f9b6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 12 Feb 2022 18:31:37 +0000 Subject: [PATCH 027/221] added missing istrans impl --- src/varinfo.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/varinfo.jl b/src/varinfo.jl index d849b0fd7..e430a9967 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -670,6 +670,7 @@ Return true if `vn`'s values in `vi` are transformed to Euclidean space, and fal they are in the support of `vn`'s distribution. """ istrans(vi::AbstractVarInfo, vn::VarName) = is_flagged(vi, vn, "trans") +istrans(vi::AbstractVarInfo, vns::AbstractVector{<:VarName}) = all(istrans, vns) """ getlogp(vi::VarInfo) From 27171ad7d5427b977c530355a201ece9e8657661 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 13 Feb 2022 16:22:28 +0000 Subject: [PATCH 028/221] fixed bug with istrans --- src/varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index e430a9967..482dca6cb 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -670,7 +670,7 @@ Return true if `vn`'s values in `vi` are transformed to Euclidean space, and fal they are in the support of `vn`'s distribution. """ istrans(vi::AbstractVarInfo, vn::VarName) = is_flagged(vi, vn, "trans") -istrans(vi::AbstractVarInfo, vns::AbstractVector{<:VarName}) = all(istrans, vns) +istrans(vi::AbstractVarInfo, vns::AbstractVector{<:VarName}) = all(Base.Fix1(istrans, vi), vns) """ getlogp(vi::VarInfo) From cd2d9d69a18c0e18d7b4c92dcb1be2a0bb7ed297 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 13 Feb 2022 17:54:16 +0000 Subject: [PATCH 029/221] fixed issue with getindex_raw for VarInfo --- src/varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 482dca6cb..f8923856a 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -956,7 +956,7 @@ function getindex_raw(vi::AbstractVarInfo, vn::VarName) return reconstruct(getdist(vi, vn), getval(vi, vn)) end function getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}) - return reconstruct(getdist(vi, first(vns)), getval(vi, vns)) + return reconstruct(getdist(vi, first(vns)), getval(vi, vns), length(vns)) end """ From d948cb91c1c769afbdcbcd9501c5dbce5c5c064f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 13 Feb 2022 17:54:43 +0000 Subject: [PATCH 030/221] Update src/varinfo.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/varinfo.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index f8923856a..0503b3a10 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -670,7 +670,9 @@ Return true if `vn`'s values in `vi` are transformed to Euclidean space, and fal they are in the support of `vn`'s distribution. """ istrans(vi::AbstractVarInfo, vn::VarName) = is_flagged(vi, vn, "trans") -istrans(vi::AbstractVarInfo, vns::AbstractVector{<:VarName}) = all(Base.Fix1(istrans, vi), vns) +function istrans(vi::AbstractVarInfo, vns::AbstractVector{<:VarName}) + return all(Base.Fix1(istrans, vi), vns) +end """ getlogp(vi::VarInfo) From 26d2dbbc106ea6fa0462dd11e48389415dbde8eb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 22 Jun 2022 13:03:19 +0100 Subject: [PATCH 031/221] getindex of varinfo implementations now optionally takes a Distribution argument --- src/simple_varinfo.jl | 16 ++++++++++++++++ src/varinfo.jl | 28 ++++++++++++++++------------ 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index e3a709fdc..d51c32760 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -265,6 +265,16 @@ function Base.show( end # `NamedTuple` +function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution) + return maybe_invlink(vi, vn, dist, Base.getindex(vi, vn)) +end +function Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution) + vals_linked = map(vns) do vn + maybe_invlink(vi, vn, dist, Base.getindex(vi, vn)) + end + return reconstruct(dist, reduce(vcat, vals_linked), length(vns)) +end + Base.getindex(vi::SimpleVarInfo, vn::VarName) = get(vi.values, vn) # `Dict` @@ -309,7 +319,13 @@ Base.getindex(vi::SimpleVarInfo, spl::Sampler) = vi.values # Since we don't perform any transformations in `getindex` for `SimpleVarInfo` # we simply call `getindex` in `getindex_raw`. getindex_raw(vi::SimpleVarInfo, vn::VarName) = vi[vn] +getindex_raw(vi::SimpleVarInfo, vn::VarName, dist::Distribution) = reconstruct(dist, getindex_raw(vi, vn)) getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}) = vi[vns] +function getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution) + vals = getindex_raw(vi, vns) + # `reconstruct` expects a flattened `Vector` regardless of the type of `dist`, so we `vcat` everything. + return reconstruct(dist, reduce(vcat, vals), length(vns)) +end Base.haskey(vi::SimpleVarInfo, vn::VarName) = _haskey(vi.values, vn) function _haskey(nt::NamedTuple, vn::VarName) diff --git a/src/varinfo.jl b/src/varinfo.jl index 0503b3a10..6f3b39fe5 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -935,30 +935,34 @@ distribution(s). If the value(s) is (are) transformed to the Euclidean space, it is (they are) transformed back. """ -function getindex(vi::AbstractVarInfo, vn::VarName) +getindex(vi::AbstractVarInfo, vn::VarName) = getindex(vi, vn, getdist(vi, vn)) +function getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution) @assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" - dist = getdist(vi, vn) + val = getindex_raw(vi, vn, dist) return if istrans(vi, vn) - Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn))) + Bijectors.invlink(dist, val) else - reconstruct(dist, getval(vi, vn)) + val end end -function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) +getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) = getindex(vi, vns, getdist(vi, first(vns))) +function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" - dist = getdist(vi, vns[1]) + val = getindex_raw(vi, vns, dist) return if istrans(vi, vns[1]) - Bijectors.invlink(dist, reconstruct(dist, getval(vi, vns), length(vns))) + Bijectors.invlink(dist, val) else - reconstruct(dist, getval(vi, vns), length(vns)) + val end end -function getindex_raw(vi::AbstractVarInfo, vn::VarName) - return reconstruct(getdist(vi, vn), getval(vi, vn)) +getindex_raw(vi::AbstractVarInfo, vn::VarName) = getindex_raw(vi, vn, getdist(vi, vn)) +function getindex_raw(vi::AbstractVarInfo, vn::VarName, dist::Distribution) + return reconstruct(dist, getval(vi, vn)) end -function getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}) - return reconstruct(getdist(vi, first(vns)), getval(vi, vns), length(vns)) +getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}) = getindex_raw(vi, vns, getdist(vi, first(vns))) +function getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) + return reconstruct(dist, getval(vi, vns), length(vns)) end """ From 3fcba565a4f4d61e32a08df420c19431339df0eb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 22 Jun 2022 13:03:58 +0100 Subject: [PATCH 032/221] use get_index_raw with dist argument --- src/context_implementations.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 9d3ba5aa9..a9bf2db92 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -195,7 +195,7 @@ end # fallback without sampler function assume(dist::Distribution, vn::VarName, vi) # x = vi[vn] - r_raw = getindex_raw(vi, vn) + r_raw = getindex_raw(vi, vn, dist) r = maybe_invlink(vi, vn, dist, r_raw) return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi end @@ -218,7 +218,7 @@ function assume( else # Otherwise we just extract it. # r = vi[vn] - r_raw = getindex_raw(vi, vn) + r_raw = getindex_raw(vi, vn, dist) r = maybe_invlink(vi, vn, dist, r_raw) end else @@ -400,7 +400,7 @@ function dot_assume( # # in which case `var` will have `undef` elements, even if `m` is present in `vi`. # r = vi[vns] - r_raw = getindex_raw(vi, vns) + r_raw = getindex_raw(vi, vns, dist) r = maybe_invlink(vi, vns, dist, r_raw) lp = sum(zip(vns, eachcol(r))) do (vn, ri) return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) @@ -434,7 +434,7 @@ function dot_assume( # m .~ Normal() # # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - r_raw = getindex_raw(vi, vec(vns)) + r_raw = getindex_raw(vi, vec(vns), dists) r = reshape(maybe_invlink.(Ref(vi), vns, dists, r_raw), size(vns)) lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) return r, lp, vi From 83a9448dc22bd0e46dc29f81b552d85019539386 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 22 Jun 2022 13:04:09 +0100 Subject: [PATCH 033/221] added missing assume implementations for SimpleVarInfo --- src/simple_varinfo.jl | 50 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index d51c32760..b07b102ea 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -470,7 +470,55 @@ function dot_assume( # Transform if we're working in transformed space. ist = istrans(vi, first(vns)) - value_raw = ist ? link.(dist, value) : value + value_raw = ist ? link.(dists, value) : value + + # Update `vi` + vi = BangBang.setindex!!(vi, value_raw, vns) + + # Compute logp. + lp = sum(Bijectors.logpdf_with_trans.(dists, value, ist)) + return value, lp, vi +end + +function dot_assume( + rng, + spl::Union{SampleFromPrior,SampleFromUniform}, + dist::MultivariateDistribution, + vns::AbstractVector{<:VarName}, + var::AbstractMatrix, + vi::SimpleOrThreadSafeSimple, +) + @assert length(dist) == size(var, 1) + + # r = get_and_set_val!(rng, vi, vns, dist, spl) + n = length(vns) + value = init(rng, dist, spl, n) + + # Update `vi`. + for (vn, val) in zip(vns, eachcol(value)) + val_linked = maybe_link(vi, vn, dist, val) + vi = BangBang.setindex!!(vi, val_linked, vn) + end + + # Compute logp. + lp = sum(Bijectors.logpdf_with_trans(dist, value, istrans(vi, first(vns)))) + return value, lp, vi +end + +function dot_assume( + rng, + spl::Union{SampleFromPrior,SampleFromUniform}, + dists::Union{Distribution,AbstractArray{<:Distribution}}, + vns::AbstractArray{<:VarName}, + var::AbstractArray, + vi::SimpleOrThreadSafeSimple, +) + f = (vn, dist) -> init(rng, dist, spl) + value = f.(vns, dists) + + # Transform if we're working in transformed space. + ist = istrans(vi, first(vns)) + value_raw = ist ? link.(dists, value) : value # Update `vi` vi = BangBang.setindex!!(vi, value_raw, vns) From 356fa9ceae23dd2208c5b34faf3fc079f4c49cb9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 22 Jun 2022 13:04:27 +0100 Subject: [PATCH 034/221] fixed settrans!! for VarInfo --- src/varinfo.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/varinfo.jl b/src/varinfo.jl index 6f3b39fe5..194d475fc 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -381,6 +381,8 @@ function settrans!!(vi::VarInfo, trans::Bool) for vn in keys(vi) settrans!!(vi, trans, vn) end + + return vi end """ From 13f037ff2a1a036a81bb6efd2453d20b3876bf0c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 24 Jun 2022 16:47:35 +0100 Subject: [PATCH 035/221] formatting --- src/simple_varinfo.jl | 4 +++- src/varinfo.jl | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index b07b102ea..aad65dfc8 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -319,7 +319,9 @@ Base.getindex(vi::SimpleVarInfo, spl::Sampler) = vi.values # Since we don't perform any transformations in `getindex` for `SimpleVarInfo` # we simply call `getindex` in `getindex_raw`. getindex_raw(vi::SimpleVarInfo, vn::VarName) = vi[vn] -getindex_raw(vi::SimpleVarInfo, vn::VarName, dist::Distribution) = reconstruct(dist, getindex_raw(vi, vn)) +function getindex_raw(vi::SimpleVarInfo, vn::VarName, dist::Distribution) + return reconstruct(dist, getindex_raw(vi, vn)) +end getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}) = vi[vns] function getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution) vals = getindex_raw(vi, vns) diff --git a/src/varinfo.jl b/src/varinfo.jl index 194d475fc..13259a12e 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -962,7 +962,9 @@ getindex_raw(vi::AbstractVarInfo, vn::VarName) = getindex_raw(vi, vn, getdist(vi function getindex_raw(vi::AbstractVarInfo, vn::VarName, dist::Distribution) return reconstruct(dist, getval(vi, vn)) end -getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}) = getindex_raw(vi, vns, getdist(vi, first(vns))) +function getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}) + return getindex_raw(vi, vns, getdist(vi, first(vns))) +end function getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) return reconstruct(dist, getval(vi, vns), length(vns)) end From c7544e087dcc078d86cfe4de75f0e032b998c8a9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 24 Jun 2022 16:47:45 +0100 Subject: [PATCH 036/221] fixed bug where constrained/unconstrained wasn't preserved in setindex!! for SimpleVarInfo --- src/simple_varinfo.jl | 29 +++-------------------------- src/varinfo.jl | 4 +++- 2 files changed, 6 insertions(+), 27 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index aad65dfc8..1c6d4043a 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -364,7 +364,7 @@ end function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName) # For `NamedTuple` we treat the symbol in `vn` as the _property_ to set. - return SimpleVarInfo(set!!(vi.values, vn, val), vi.logp) + return Setfield.@set vi.values = set!!(vi.values, vn, val) end # TODO: Specialize to handle certain cases, e.g. a collection of `VarName` with @@ -395,7 +395,7 @@ function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName vn_key = VarName(vn, keylens) BangBang.setindex!!(dict, set!!(dict[vn_key], child, val), vn_key) end - return SimpleVarInfo(dict_new, vi.logp) + return Setfield.@set vi.values = dict_new end # `NamedTuple` @@ -503,30 +503,7 @@ function dot_assume( end # Compute logp. - lp = sum(Bijectors.logpdf_with_trans(dist, value, istrans(vi, first(vns)))) - return value, lp, vi -end - -function dot_assume( - rng, - spl::Union{SampleFromPrior,SampleFromUniform}, - dists::Union{Distribution,AbstractArray{<:Distribution}}, - vns::AbstractArray{<:VarName}, - var::AbstractArray, - vi::SimpleOrThreadSafeSimple, -) - f = (vn, dist) -> init(rng, dist, spl) - value = f.(vns, dists) - - # Transform if we're working in transformed space. - ist = istrans(vi, first(vns)) - value_raw = ist ? link.(dists, value) : value - - # Update `vi` - vi = BangBang.setindex!!(vi, value_raw, vns) - - # Compute logp. - lp = sum(Bijectors.logpdf_with_trans.(dists, value, ist)) + lp = sum(Bijectors.logpdf_with_trans(dist, value, istrans(vi))) return value, lp, vi end diff --git a/src/varinfo.jl b/src/varinfo.jl index 13259a12e..582582b77 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -947,7 +947,9 @@ function getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution) val end end -getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) = getindex(vi, vns, getdist(vi, first(vns))) +function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) + return getindex(vi, vns, getdist(vi, first(vns))) +end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" val = getindex_raw(vi, vns, dist) From d1dccf19b764b8876cf7ba35ae4d3aed0962be14 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 24 Jun 2022 16:48:27 +0100 Subject: [PATCH 037/221] hack to avoid type-instabilities for dot_assume with MultivariateDistribution --- src/utils.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index 821eba38e..2a7838e2f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -417,3 +417,14 @@ function splitlens(condition, lens) return current_parent, current_child, condition(current_parent) end + +# HACK(torfjelde): Avoids type-instability in `dot_assume` for `SimpleVarInfo`. +function BangBang.possible( + ::typeof(BangBang._setindex!), + ::C, + ::T, + ::Colon, + ::Integer +) where {C<:AbstractMatrix,T<:AbstractVector} + return BangBang.implements(setindex!, C) && promote_type(eltype(C), eltype(T)) <: eltype(C) +end From ff7ff4ab6e5da58fe5d700c4e85e8d6be44e90a9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 26 Jun 2022 14:01:10 +0100 Subject: [PATCH 038/221] style --- src/utils.jl | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 2a7838e2f..b200171a7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -420,11 +420,8 @@ end # HACK(torfjelde): Avoids type-instability in `dot_assume` for `SimpleVarInfo`. function BangBang.possible( - ::typeof(BangBang._setindex!), - ::C, - ::T, - ::Colon, - ::Integer + ::typeof(BangBang._setindex!), ::C, ::T, ::Colon, ::Integer ) where {C<:AbstractMatrix,T<:AbstractVector} - return BangBang.implements(setindex!, C) && promote_type(eltype(C), eltype(T)) <: eltype(C) + return BangBang.implements(setindex!, C) && + promote_type(eltype(C), eltype(T)) <: eltype(C) end From 2f1a2ff68d7abf3be166bd5de691ab479f3131bf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 26 Jun 2022 14:01:35 +0100 Subject: [PATCH 039/221] added keys implementations for the models in TestUtils to make testing AbstractVarInfo implementations easier --- src/test_utils.jl | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/test_utils.jl b/src/test_utils.jl index ca0fabc9a..bd4c0bedc 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -66,6 +66,9 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe)}, m) return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) end +function Base.keys(model::Model{typeof(demo_dot_assume_dot_observe)}) + return [@varname(m[1]), @varname(m[2])] +end @model function demo_assume_index_observe( x=[10.0, 10.0], ::Type{TV}=Vector{Float64} @@ -85,6 +88,9 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_index_observe)}, m) return logpdf(MvNormal(m, 0.25 * I), model.args.x) end +function Base.keys(model::Model{typeof(demo_assume_index_observe)}) + return [@varname(m[1]), @varname(m[2])] +end @model function demo_assume_multivariate_observe(x=[10.0, 10.0]) # Multivariate `assume` and `observe` @@ -99,6 +105,9 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_multivariate_observe)}, m) return logpdf(MvNormal(m, 0.25 * I), model.args.x) end +function Base.keys(model::Model{typeof(demo_assume_multivariate_observe)}) + return [@varname(m)] +end @model function demo_dot_assume_observe_index( x=[10.0, 10.0], ::Type{TV}=Vector{Float64} @@ -118,6 +127,9 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index)}, m) return sum(logpdf.(Normal.(m, 0.5), model.args.x)) end +function Base.keys(model::Model{typeof(demo_dot_assume_observe_index)}) + return [@varname(m[1]), @varname(m[2])] +end # Using vector of `length` 1 here so the posterior of `m` is the same # as the others. @@ -134,6 +146,9 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_dot_observe)}, m) return sum(logpdf.(Normal.(m, 0.5), model.args.x)) end +function Base.keys(model::Model{typeof(demo_assume_dot_observe)}) + return [@varname(m)] +end @model function demo_assume_observe_literal() # `assume` and literal `observe` @@ -148,6 +163,9 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, m) return logpdf(MvNormal(m, 0.25 * I), [10.0, 10.0]) end +function Base.keys(model::Model{typeof(demo_assume_observe_literal)}) + return [@varname(m)] +end @model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and literal `observe` with indexing @@ -165,6 +183,9 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, m) return sum(logpdf.(Normal.(m, 0.5), fill(10.0, length(m)))) end +function Base.keys(model::Model{typeof(demo_dot_assume_observe_index_literal)}) + return [@varname(m[1]), @varname(m[2])] +end @model function demo_assume_literal_dot_observe() # `assume` and literal `dot_observe` @@ -179,6 +200,9 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, m) return logpdf(Normal(m, 0.5), 10.0) end +function Base.keys(model::Model{typeof(demo_assume_literal_dot_observe)}) + return [@varname(m)] +end @model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} m = TV(undef, 2) @@ -204,6 +228,9 @@ function loglikelihood_true( ) return sum(logpdf.(Normal.(m, 0.5), 10.0)) end +function Base.keys(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) + return [@varname(m[1]), @varname(m[2])] +end @model function _likelihood_dot_observe(m, x) return x ~ MvNormal(m, 0.25 * I) @@ -226,6 +253,9 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, m) return logpdf(MvNormal(m, 0.25 * I), model.args.x) end +function Base.keys(model::Model{typeof(demo_dot_assume_observe_submodel)}) + return [@varname(m[1]), @varname(m[2])] +end @model function demo_dot_assume_dot_observe_matrix( x=fill(10.0, 2, 1), ::Type{TV}=Vector{Float64} From d6311b74e0d62941361b959ab08d9d4b986b5306 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 26 Jun 2022 14:02:08 +0100 Subject: [PATCH 040/221] added additional test model which uses dot-assume on MultivariateDistribution --- src/test_utils.jl | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/test_utils.jl b/src/test_utils.jl index bd4c0bedc..faa390cea 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -274,6 +274,33 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, m) return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) end +function Base.keys(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) + return [@varname(m[1]), @varname(m[2])] +end + +@model function demo_dot_assume_matrix_dot_observe_matrix( + x=fill(10.0, 2, 1), ::Type{TV}=Array{Float64} +) where {TV} + d = length(x) ÷ 2 + m = TV(undef, d, 2) + m .~ MvNormal(zeros(d), I) + + # Dotted observe for `Matrix`. + x .~ MvNormal(vec(m), 0.25 * I) + + return (; m=m, x=x, logp=getlogp(__varinfo__)) +end +function logprior_true(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, m) + return loglikelihood(Normal(), vec(m)) +end +function loglikelihood_true( + model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, m +) + return loglikelihood(MvNormal(vec(m), 0.25 * I), model.args.x) +end +function Base.keys(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) + return [@varname(m[:, 1]), @varname(m[:, 2])] +end const DEMO_MODELS = ( demo_dot_assume_dot_observe(), @@ -287,6 +314,7 @@ const DEMO_MODELS = ( demo_assume_submodel_observe_index_literal(), demo_dot_assume_observe_submodel(), demo_dot_assume_dot_observe_matrix(), + demo_dot_assume_matrix_dot_observe_matrix(), ) # TODO: Is this really the best/most convenient "default" test method? From ed2fa693d6d988a0cd0f7162000c1ece820cf815 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 26 Jun 2022 14:02:47 +0100 Subject: [PATCH 041/221] updated tests for SimpleVarInfo --- test/simple_varinfo.jl | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 7e9346450..5b7c6ba71 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -85,12 +85,14 @@ isunivariate = !haskey(svi_new, @varname(m[1])) # Realization for `m` should be different wp. 1. - if isunivariate - @test svi_new[@varname(m)] != m - else - @test svi_new[@varname(m[1])] != m[1] - @test svi_new[@varname(m[2])] != m[2] + for vn in keys(model) + # `VarName` functions similarly to `PropertyLens` so + # we just strip this part from `vn` to get a lens we can use + # to extract the corresponding value of `m`. + l = getlens(vn) + @test svi_new[vn] != get(m, l) end + # Logjoint should be non-zero wp. 1. @test getlogp(svi_new) != 0 @@ -103,26 +105,26 @@ end # Update the realizations in `svi_new`. - svi_eval = if isunivariate - DynamicPPL.setindex!!(svi_new, m_eval, @varname(m)) - else - DynamicPPL.setindex!!(svi_new, m_eval, [@varname(m[1]), @varname(m[2])]) + svi_eval = svi_new + for vn in keys(model) + l = getlens(vn) + svi_eval = DynamicPPL.setindex!!(svi_eval, get(m_eval, l), vn) end + # Reset the logp field. svi_eval = DynamicPPL.resetlogp!!(svi_eval) # Compute `logjoint` using the varinfo. logπ = logjoint(model, svi_eval) - # Extract the parameters from `svi_eval`. - m_vi = if isunivariate - svi_eval[@varname(m)] - else - svi_eval[[@varname(m[1]), @varname(m[2])]] + + # Values should not have changed. + for vn in keys(model) + l = getlens(vn) + @test svi_eval[vn] == get(m_eval, l) end - # These should not have changed. - @test m_vi == m_eval + # Compute the true `logjoint` and compare. - logπ_true = DynamicPPL.TestUtils.logjoint_true(model, m_vi) + logπ_true = DynamicPPL.TestUtils.logjoint_true(model, m_eval) @test logπ ≈ logπ_true end end From a82be563fd498907bc0526bda5a4d22b56999002 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 26 Jun 2022 14:05:32 +0100 Subject: [PATCH 042/221] added a no-op reconstruct for UnivariateDistribution --- src/utils.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/utils.jl b/src/utils.jl index b200171a7..ac9222818 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -183,6 +183,7 @@ vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r)) # otherwise we will have error for MatrixDistribution. # Note this is not the case for MultivariateDistribution so I guess this might be lack of # support for some types related to matrices (like PDMat). +reconstruct(d::UnivariateDistribution, val::Real) = val reconstruct(d::Distribution, val::AbstractVector) = reconstruct(size(d), val) reconstruct(::Tuple{}, val::AbstractVector) = val[1] reconstruct(s::NTuple{1}, val::AbstractVector) = copy(val) From 7aacee5e09fdb42de7007811da75974201d4b7d5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 27 Jun 2022 13:20:09 +0100 Subject: [PATCH 043/221] fixed tests for loglikelihoods --- test/loglikelihoods.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index 4d5003f03..0e7f9a3d9 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -1,13 +1,13 @@ @testset "loglikelihoods.jl" begin - for m in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS vi = VarInfo(m) - vns = vi.metadata.m.vns - if length(vns) == 1 && length(vi[vns[1]]) == 1 - # Only have one latent variable. - DynamicPPL.setval!(vi, [1.0], ["m"]) - else - DynamicPPL.setval!(vi, [1.0, 1.0], ["m[1]", "m[2]"]) + for vn in keys(m) + if vi[vn] isa Real + vi = DynamicPPL.setindex!!(vi, 1.0, vn) + else + vi = DynamicPPL.setindex!!(vi, ones(size(vi[vn])), vn) + end end lls = pointwise_loglikelihoods(m, vi) From 96f128fe4d7a281a49f87acf2244f73b2076455e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 27 Jun 2022 13:20:45 +0100 Subject: [PATCH 044/221] fixed dot_tilde_assume for LikelihoodContext --- src/context_implementations.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index a9bf2db92..9263dc441 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -320,13 +320,16 @@ function dot_tilde_assume( dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi) end end + function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi) - return dot_assume(NoDist.(right), left, vn, vi) + nodist = right isa Distribution ? NoDist(right) : NoDist.(right) + return dot_assume(nodist, left, vn, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, vi ) - return dot_assume(rng, sampler, NoDist.(right), vn, left, vi) + nodist = right isa Distribution ? NoDist(right) : NoDist.(right) + return dot_assume(rng, sampler, nodist, vn, left, vi) end # `PriorContext` From 2e88d087eeea267d241667fd495b743603870872 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 27 Jun 2022 13:21:03 +0100 Subject: [PATCH 045/221] removed some now redundant explicit calls to maybe_invlink --- src/context_implementations.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 9263dc441..0ebe4bedd 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -217,9 +217,7 @@ function assume( setorder!(vi, vn, get_num_produce(vi)) else # Otherwise we just extract it. - # r = vi[vn] - r_raw = getindex_raw(vi, vn, dist) - r = maybe_invlink(vi, vn, dist, r_raw) + r = vi[vn, dist] end else r = init(rng, dist, sampler) @@ -481,8 +479,7 @@ function get_and_set_val!( setorder!(vi, vn, get_num_produce(vi)) end else - r_raw = getindex_raw(vi, vns) - r = maybe_invlink(vi, vns, dist, r_raw) + r = vi[vns, dist] end else r = init(rng, dist, spl, n) From 0f9765bda684b27202982cf95d11e8de07304f62 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 27 Jun 2022 13:21:31 +0100 Subject: [PATCH 046/221] added impls of size and length for the wrapper distributions so they work for reconstruct --- src/distribution_wrappers.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index 4045cc089..07dc6f93f 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -13,6 +13,9 @@ end NamedDist(dist::Distribution, name::Symbol) = NamedDist(dist, VarName{name}()) +Base.length(dist::NamedDist) = Base.length(dist.dist) +Base.size(dist::NamedDist) = Base.size(dist.dist) + Distributions.logpdf(dist::NamedDist, x::Real) = Distributions.logpdf(dist.dist, x) function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real}) return Distributions.logpdf(dist.dist, x) @@ -24,12 +27,17 @@ function Distributions.loglikelihood(dist::NamedDist, x::AbstractArray{<:Real}) return Distributions.loglikelihood(dist.dist, x) end +Bijectors.bijector(d::NamedDist) = Bijectors.bijector(d.dist) + struct NoDist{variate,support,Td<:Distribution{variate,support}} <: Distribution{variate,support} dist::Td end NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name) +Base.length(dist::NoDist) = Base.length(dist.dist) +Base.size(dist::NoDist) = Base.size(dist.dist) + Distributions.rand(rng::Random.AbstractRNG, d::NoDist) = rand(rng, d.dist) Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0 Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0 From 116c95c0e695648a24e57b1d9581a9f9b087089a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 28 Jun 2022 17:52:49 +0100 Subject: [PATCH 047/221] bumped version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 305a0a0b3..7adb220b3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.19.2" +version = "0.19.3" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From d797e997399365e56b9d950e9367761ba85aff81 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 28 Jun 2022 17:53:46 +0100 Subject: [PATCH 048/221] removed redunant explict call to maybe_invlink --- src/context_implementations.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 0ebe4bedd..ddc62b639 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -194,9 +194,7 @@ end # fallback without sampler function assume(dist::Distribution, vn::VarName, vi) - # x = vi[vn] - r_raw = getindex_raw(vi, vn, dist) - r = maybe_invlink(vi, vn, dist, r_raw) + r = vi[vn, dist] return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi end From 44b2f66bd170e1e807c08aa7f2225f5b102151bc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 29 Jun 2022 10:54:45 +0100 Subject: [PATCH 049/221] added test model with array on RHS of a .~ statement --- src/test_utils.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/test_utils.jl b/src/test_utils.jl index faa390cea..172df1775 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -302,6 +302,25 @@ function Base.keys(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix return [@varname(m[:, 1]), @varname(m[:, 2])] end +@model function demo_dot_assume_array_dot_observe( + x=[10.0, 10.0], ::Type{TV}=Vector{Float64} +) where {TV} + # `dot_assume` and `observe` + m = TV(undef, length(x)) + m .~ [Normal() for _ in 1:length(x)] + x ~ MvNormal(m, 0.25 * I) + return (; m=m, x=x, logp=getlogp(__varinfo__)) +end +function logprior_true(model::Model{typeof(demo_dot_assume_array_dot_observe)}, m) + return loglikelihood(Normal(), m) +end +function loglikelihood_true(model::Model{typeof(demo_dot_assume_array_dot_observe)}, m) + return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) +end +function Base.keys(model::Model{typeof(demo_dot_assume_array_dot_observe)}) + return [@varname(m[1]), @varname(m[2])] +end + const DEMO_MODELS = ( demo_dot_assume_dot_observe(), demo_assume_index_observe(), @@ -315,6 +334,7 @@ const DEMO_MODELS = ( demo_dot_assume_observe_submodel(), demo_dot_assume_dot_observe_matrix(), demo_dot_assume_matrix_dot_observe_matrix(), + demo_dot_assume_array_dot_observe(), ) # TODO: Is this really the best/most convenient "default" test method? From 81cd88145e4dd1c5b15170d9b13c90e8f73a4865 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 29 Jun 2022 10:55:14 +0100 Subject: [PATCH 050/221] improved some of the default implementations of dot_assume --- src/context_implementations.jl | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index ddc62b639..aead1dde1 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -398,9 +398,7 @@ function dot_assume( # m .~ Normal() # # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - # r = vi[vns] - r_raw = getindex_raw(vi, vns, dist) - r = maybe_invlink(vi, vns, dist, r_raw) + r = vi[vns, dist] lp = sum(zip(vns, eachcol(r))) do (vn, ri) return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) end @@ -422,20 +420,22 @@ function dot_assume( end function dot_assume( - dists::Union{Distribution,AbstractArray{<:Distribution}}, + dist::Distribution, var::AbstractArray, vns::AbstractArray{<:VarName}, vi +) + r = map(vn -> vi[vn, dist], vns) + lp = sum(Bijectors.logpdf_with_trans.(dist, r, map(Base.Fix1(istrans, vi), vns))) + return r, lp, vi +end + +function dot_assume( + dists::AbstractArray{<:Distribution}, var::AbstractArray, vns::AbstractArray{<:VarName}, vi, ) - # NOTE: We cannot work with `var` here because we might have a model of the form - # - # m = Vector{Float64}(undef, n) - # m .~ Normal() - # - # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - r_raw = getindex_raw(vi, vec(vns), dists) - r = reshape(maybe_invlink.(Ref(vi), vns, dists, r_raw), size(vns)) - lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) + @assert length(vns) == length(dists) == length(var) + r = map((vn, dist) -> vi[vn, dist], vns, dists) + lp = sum(Bijectors.logpdf_with_trans.(dists, r, map(Base.Fix1(istrans, vi), vns))) return r, lp, vi end @@ -449,7 +449,7 @@ function dot_assume( ) r = get_and_set_val!(rng, vi, vns, dists, spl) # Make sure `r` is not a matrix for multivariate distributions - lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) + lp = sum(Bijectors.logpdf_with_trans.(dists, r, map(Base.Fix1(istrans, vi), vns))) return r, lp, vi end function dot_assume(rng, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any) From 2e14abd633c544c723d1651425ea45e20132734a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 29 Jun 2022 11:07:58 +0100 Subject: [PATCH 051/221] removed unnecessary code in tests --- test/simple_varinfo.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 5b7c6ba71..092189984 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -79,11 +79,6 @@ # Sample a new varinfo! _, svi_new = DynamicPPL.evaluate!!(model, svi, SamplingContext()) - # If the `m[1]` varname doesn't exist, this is a univariate model. - # TODO: Find a better way of dealing with this that is not dependent - # on knowledge of internals of `model`. - isunivariate = !haskey(svi_new, @varname(m[1])) - # Realization for `m` should be different wp. 1. for vn in keys(model) # `VarName` functions similarly to `PropertyLens` so From 12adc83b793a8329bb091215ae6b6e0c49a5daec Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 29 Jun 2022 11:08:22 +0100 Subject: [PATCH 052/221] improved linking usage in assumes for SimpleVarInfo --- src/simple_varinfo.jl | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 1c6d4043a..c7780b872 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -453,10 +453,9 @@ function assume( ) value = init(rng, dist, sampler) # Transform if we're working in unconstrained space. - ist = istrans(vi, vn) - value_raw = ist ? Bijectors.link(dist, value) : value + value_raw = maybe_link(vi, vn, dist, value) vi = BangBang.push!!(vi, vn, value_raw, dist, sampler) - return value, Bijectors.logpdf_with_trans(dist, value, ist), vi + return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi end function dot_assume( @@ -471,14 +470,19 @@ function dot_assume( value = f.(vns, dists) # Transform if we're working in transformed space. - ist = istrans(vi, first(vns)) - value_raw = ist ? link.(dists, value) : value + value_raw = if dists isa Distribution + @assert length(vns) == length(value) + map((vn, val) -> maybe_link(vi, vn, dists, val), vns, value) + else + @assert length(vns) == length(dists) == length(value) + map((vn, dist, val) -> maybe_link(vi, vn, dist, val), vns, dists, value) + end # Update `vi` vi = BangBang.setindex!!(vi, value_raw, vns) # Compute logp. - lp = sum(Bijectors.logpdf_with_trans.(dists, value, ist)) + lp = sum(Bijectors.logpdf_with_trans.(dists, value, map(Base.Fix1(istrans, vi), vns))) return value, lp, vi end From f7501dfb4f17ddae7618708d539e704c38ad88fc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 30 Jun 2022 09:49:40 +0100 Subject: [PATCH 053/221] added model for testing dynamic constraints --- src/test_utils.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/test_utils.jl b/src/test_utils.jl index 172df1775..a25c25df7 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -49,6 +49,28 @@ function logjoint_true(model::Model, args...) return logprior_true(model, args...) + loglikelihood_true(model, args...) end +""" + demo_dynamic_constraint() + +A model with variables `m` and `x` with `x` having support depending on `m`. +""" +@model function demo_dynamic_constraint() + m ~ Normal() + x ~ truncated(Normal(), m, Inf) + + return (m=m, x=x) +end + +function logprior_true(model::Model{typeof(demo_dynamic_constraint)}, m, x) + return logpdf(Normal(), m) + logpdf(truncated(Normal(), m, Inf)) +end +function loglikelihood_true(model::Model{typeof(demo_dynamic_constraint)}, m, x) + return zero(float(eltype(m))) +end +function Base.keys(model::Model{typeof(demo_dynamic_constraint)}) + return [@varname(m), @varname(x)] +end + # A collection of models for which the mean-of-means for the posterior should # be same. @model function demo_dot_assume_dot_observe( From abcabf49b4f1ef6c468e22a88ce679d3ba8b9b84 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 30 Jun 2022 10:48:18 +0100 Subject: [PATCH 054/221] added logjoint_true_with_logabsdet_jacobian to TestUtils --- src/test_utils.jl | 54 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 50 insertions(+), 4 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index a25c25df7..0b9a5526b 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -6,6 +6,8 @@ using LinearAlgebra using Distributions using Test +using Bijectors: Bijectors + """ logprior_true(model, θ) @@ -29,11 +31,11 @@ See also: [`logjoint_true`](@ref), [`logprior_true`](@ref). function loglikelihood_true end """ - logjoint_true(model, θ) + logjoint_true(model, args...) -Return the `logjoint` of `model` for `θ`. +Return the `logjoint` of `model` for `args...`. -Defaults to `logprior_true(model, θ) + loglikelihood_true(model, θ)`. +Defaults to `logprior_true(model, args...) + loglikelihood_true(model, args..)`. This should generally be implemented by hand for every specific `model` so that the returned value can be used as a ground-truth for testing things like: @@ -49,6 +51,42 @@ function logjoint_true(model::Model, args...) return logprior_true(model, args...) + loglikelihood_true(model, args...) end +""" + logjoint_true_with_logabsdet_jacobian(model::Model, args...) + +Return a tuple `(args_unconstrained, logjoint)` of `model` for `args...`. + +Unlike [`logjoint_true`](@ref), the returned logjoint computation includes the +log-absdet-jacobian adjustment, thus computing logjoint for the unconstrained variables. + +Note that `args` are assumed be in the support of `model`, while `args_unconstrained` +will be unconstrained. + +This should generally not be implemented directly, instead one should implement +[`logprior_true_with_logabsdet_jacobian`](@ref) for a given `model`. + +See also: [`logjoint_true`](@ref), [`logprior_true_with_logabsdet_jacobian`](@ref). +""" +function logjoint_true_with_logabsdet_jacobian(model::Model, args...) + args_unconstrained, lp = logprior_true_with_logabsdet_jacobian(model, args...) + return args_unconstrained, lp + loglikelihood_true(model, args...) +end + +""" + logprior_true_with_logabsdet_jacobian(model::Model, args...) + +Return a tuple `(args_unconstrained, logprior_unconstrained)` of `model` for `args...`. + +Unlike [`logprior_true`](@ref), the returned logprior computation includes the +log-absdet-jacobian adjustment, thus computing logprior for the unconstrained variables. + +Note that `args` are assumed be in the support of `model`, while `args_unconstrained` +will be unconstrained. + +See also: [`logprior_true`](@ref). +""" +function logprior_true_with_logabsdet_jacobian end + """ demo_dynamic_constraint() @@ -62,7 +100,7 @@ A model with variables `m` and `x` with `x` having support depending on `m`. end function logprior_true(model::Model{typeof(demo_dynamic_constraint)}, m, x) - return logpdf(Normal(), m) + logpdf(truncated(Normal(), m, Inf)) + return logpdf(Normal(), m) + logpdf(truncated(Normal(), m, Inf), x) end function loglikelihood_true(model::Model{typeof(demo_dynamic_constraint)}, m, x) return zero(float(eltype(m))) @@ -71,6 +109,14 @@ function Base.keys(model::Model{typeof(demo_dynamic_constraint)}) return [@varname(m), @varname(x)] end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dynamic_constraint)}, m, x +) + b_x = Bijectors.bijector(truncated(Normal(), m, Inf)) + x_unconstrained, Δlogp = Bijectors.with_logabsdet_jacobian(b_x, x) + return (m=m, x=x_unconstrained), logprior_true(model, m, x) - Δlogp +end + # A collection of models for which the mean-of-means for the posterior should # be same. @model function demo_dot_assume_dot_observe( From fdee509a0dd605fc850863d68281aa55734ff079 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 30 Jun 2022 11:00:32 +0100 Subject: [PATCH 055/221] added test for dynamic constraints for SimpleVarInfo --- test/model.jl | 7 +------ test/simple_varinfo.jl | 31 +++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/test/model.jl b/test/model.jl index efe07b362..8eabeeba9 100644 --- a/test/model.jl +++ b/test/model.jl @@ -113,12 +113,7 @@ end end @testset "Dynamic constraints" begin - @model function dynamic_constraints() - m ~ Normal() - return x ~ truncated(Normal(), m, Inf) - end - - model = dynamic_constraints() + model = DynamicPPL.TestUtils.demo_dynamic_constraint() vi = VarInfo(model) spl = SampleFromPrior() link!(vi, spl) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 092189984..9494ae6c1 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -123,4 +123,35 @@ @test logπ ≈ logπ_true end end + + @testset "Dynamic constraints" begin + model = DynamicPPL.TestUtils.demo_dynamic_constraint() + + # Initialize. + svi = DynamicPPL.settrans!!(SimpleVarInfo(), true) + svi = last(DynamicPPL.evaluate!!(model, svi, SamplingContext())) + + # Sample with large variations in unconstrained space. + for i in 1:10 + for vn in keys(svi) + svi = DynamicPPL.setindex!!(svi, 10 * randn(), vn) + end + retval, svi = DynamicPPL.evaluate!!(model, svi, DefaultContext()) + @test retval.m == svi[@varname(m)] # `m` is unconstrained + @test retval.x ≠ svi[@varname(x)] # `x` is constrained depending on `m` + + retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + model, retval.m, retval.x + ) + + # Realizations from model should all be equal to the unconstrained realization. + for vn in keys(model) + @test get(retval_unconstrained, vn) ≈ svi[vn] rtol = 1e-6 + end + + # `getlogp` should be equal to the logjoint with log-absdet-jac correction. + lp = getlogp(svi) + @test lp ≈ lp_true + end + end end From e974c8335bcd359490ff4e2eafffc343fd7cb96a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 30 Jun 2022 11:00:50 +0100 Subject: [PATCH 056/221] fixed keys implementation of SimpleVarInfo --- src/simple_varinfo.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index c7780b872..02565b64a 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -241,6 +241,7 @@ acclogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = getlogp(vi) + logp Return an iterator of keys present in `vi`. """ Base.keys(vi::SimpleVarInfo) = keys(vi.values) +Base.keys(vi::SimpleVarInfo{<:NamedTuple}) = map(k -> VarName{k}(), keys(vi.values)) function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) vi.logp[] = logp @@ -311,8 +312,10 @@ end # HACK: Needed to disambiguiate. Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(Base.Fix1(getindex, vi), vns) -Base.getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.values -Base.getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.values +Base.getindex(vi::SimpleVarInfo, ::Colon) = vi.values +Base.getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi[:] +Base.getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi[:] + # TODO: Should we do better? Base.getindex(vi::SimpleVarInfo, spl::Sampler) = vi.values From 6c6d5f57fda1922cb017999151d4bb90e582340a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 30 Jun 2022 11:03:18 +0100 Subject: [PATCH 057/221] reverted unintended change --- src/simple_varinfo.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 02565b64a..92abc327f 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -312,9 +312,8 @@ end # HACK: Needed to disambiguiate. Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(Base.Fix1(getindex, vi), vns) -Base.getindex(vi::SimpleVarInfo, ::Colon) = vi.values -Base.getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi[:] -Base.getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi[:] +Base.getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.values +Base.getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.values # TODO: Should we do better? Base.getindex(vi::SimpleVarInfo, spl::Sampler) = vi.values From 5d5bc8868b06689383cc1a0a95fa216882160375 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 30 Jun 2022 14:15:18 +0100 Subject: [PATCH 058/221] added example_values and posterior_mean_values methods to models in TestUtils --- src/test_utils.jl | 164 ++++++++++++++++++++++++++++++++++++++++- test/simple_varinfo.jl | 38 ++++------ 2 files changed, 175 insertions(+), 27 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 0b9a5526b..621bbfe30 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -6,6 +6,7 @@ using LinearAlgebra using Distributions using Test +using Random: Random using Bijectors: Bijectors """ @@ -87,6 +88,21 @@ See also: [`logprior_true`](@ref). """ function logprior_true_with_logabsdet_jacobian end +""" + example_values(model::Model) + +Return a `NamedTuple` compatible with `keys(model)` with values in support of `model`. +""" +example_values(model::Model) = example_values(Random.GLOBAL_RNG, model) + +""" + posterior_mean_values(model::Model) + +Return a `NamedTuple` compatible with `keys(model)` where the values represent +the posterior mean under `model`. +""" +function posterior_mean_values end + """ demo_dynamic_constraint() @@ -108,7 +124,12 @@ end function Base.keys(model::Model{typeof(demo_dynamic_constraint)}) return [@varname(m), @varname(x)] end - +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_dynamic_constraint)} +) + m = rand(rng, Normal()) + return (m=m, x=rand(rng, truncated(Normal(), m, Inf))) +end function logprior_true_with_logabsdet_jacobian( model::Model{typeof(demo_dynamic_constraint)}, m, x ) @@ -137,6 +158,16 @@ end function Base.keys(model::Model{typeof(demo_dot_assume_dot_observe)}) return [@varname(m[1]), @varname(m[2])] end +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_dot_observe)} +) + return (m=rand(rng, Normal(), length(model.args.x)),) +end +function posterior_mean_values(model::Model{typeof(demo_dot_assume_dot_observe)}) + vals = example_values(model) + vals.m .= 8 + return vals +end @model function demo_assume_index_observe( x=[10.0, 10.0], ::Type{TV}=Vector{Float64} @@ -159,6 +190,16 @@ end function Base.keys(model::Model{typeof(demo_assume_index_observe)}) return [@varname(m[1]), @varname(m[2])] end +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_assume_index_observe)} +) + return (m=rand(rng, Normal(), length(model.args.x)),) +end +function posterior_mean_values(model::Model{typeof(demo_assume_index_observe)}) + vals = example_values(model) + vals.m .= 8 + return vals +end @model function demo_assume_multivariate_observe(x=[10.0, 10.0]) # Multivariate `assume` and `observe` @@ -176,6 +217,16 @@ end function Base.keys(model::Model{typeof(demo_assume_multivariate_observe)}) return [@varname(m)] end +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_assume_multivariate_observe)} +) + return (m=rand(rng, MvNormal(zero(model.args.x), I)),) +end +function posterior_mean_values(model::Model{typeof(demo_assume_multivariate_observe)}) + vals = example_values(model) + vals.m .= 8 + return vals +end @model function demo_dot_assume_observe_index( x=[10.0, 10.0], ::Type{TV}=Vector{Float64} @@ -198,6 +249,16 @@ end function Base.keys(model::Model{typeof(demo_dot_assume_observe_index)}) return [@varname(m[1]), @varname(m[2])] end +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_observe_index)} +) + return (m=rand(rng, Normal(), length(model.args.x)),) +end +function posterior_mean_values(model::Model{typeof(demo_dot_assume_observe_index)}) + vals = example_values(model) + vals.m .= 8 + return vals +end # Using vector of `length` 1 here so the posterior of `m` is the same # as the others. @@ -217,6 +278,14 @@ end function Base.keys(model::Model{typeof(demo_assume_dot_observe)}) return [@varname(m)] end +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_assume_dot_observe)} +) + return (m=rand(rng, Normal()),) +end +function posterior_mean_values(model::Model{typeof(demo_assume_dot_observe)}) + return (m=8.0,) +end @model function demo_assume_observe_literal() # `assume` and literal `observe` @@ -234,6 +303,16 @@ end function Base.keys(model::Model{typeof(demo_assume_observe_literal)}) return [@varname(m)] end +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_assume_observe_literal)} +) + return (m=rand(rng, MvNormal(zeros(2), I)),) +end +function posterior_mean_values(model::Model{typeof(demo_assume_observe_literal)}) + vals = example_values(model) + vals.m .= 8 + return vals +end @model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and literal `observe` with indexing @@ -254,6 +333,16 @@ end function Base.keys(model::Model{typeof(demo_dot_assume_observe_index_literal)}) return [@varname(m[1]), @varname(m[2])] end +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_observe_index_literal)} +) + return (m=rand(rng, Normal(), 2),) +end +function posterior_mean_values(model::Model{typeof(demo_dot_assume_observe_index_literal)}) + vals = example_values(model) + vals.m .= 8 + return vals +end @model function demo_assume_literal_dot_observe() # `assume` and literal `dot_observe` @@ -271,6 +360,14 @@ end function Base.keys(model::Model{typeof(demo_assume_literal_dot_observe)}) return [@varname(m)] end +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_assume_literal_dot_observe)} +) + return (m=rand(rng, Normal()),) +end +function posterior_mean_values(model::Model{typeof(demo_assume_literal_dot_observe)}) + return (m=8.0,) +end @model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} m = TV(undef, 2) @@ -299,6 +396,19 @@ end function Base.keys(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) return [@varname(m[1]), @varname(m[2])] end +function example_values( + rng::Random.AbstractRNG, + model::Model{typeof(demo_assume_submodel_observe_index_literal)}, +) + return (m=rand(rng, Normal(), 2),) +end +function posterior_mean_values( + model::Model{typeof(demo_assume_submodel_observe_index_literal)} +) + vals = example_values(model) + vals.m .= 8 + return vals +end @model function _likelihood_dot_observe(m, x) return x ~ MvNormal(m, 0.25 * I) @@ -324,6 +434,16 @@ end function Base.keys(model::Model{typeof(demo_dot_assume_observe_submodel)}) return [@varname(m[1]), @varname(m[2])] end +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_observe_submodel)} +) + return (m=rand(rng, Normal(), length(model.args.x)),) +end +function posterior_mean_values(model::Model{typeof(demo_dot_assume_observe_submodel)}) + vals = example_values(model) + vals.m .= 8 + return vals +end @model function demo_dot_assume_dot_observe_matrix( x=fill(10.0, 2, 1), ::Type{TV}=Vector{Float64} @@ -345,6 +465,16 @@ end function Base.keys(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) return [@varname(m[1]), @varname(m[2])] end +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_dot_observe_matrix)} +) + return (m=rand(rng, Normal(), length(model.args.x)),) +end +function posterior_mean_values(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) + vals = example_values(model) + vals.m .= 8 + return vals +end @model function demo_dot_assume_matrix_dot_observe_matrix( x=fill(10.0, 2, 1), ::Type{TV}=Array{Float64} @@ -369,6 +499,19 @@ end function Base.keys(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) return [@varname(m[:, 1]), @varname(m[:, 2])] end +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)} +) + d = length(model.args.x) ÷ 2 + return (m=rand(rng, MvNormal(zeros(d), I), 2),) +end +function posterior_mean_values( + model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)} +) + vals = example_values(model) + vals.m .= 8 + return vals +end @model function demo_dot_assume_array_dot_observe( x=[10.0, 10.0], ::Type{TV}=Vector{Float64} @@ -388,6 +531,16 @@ end function Base.keys(model::Model{typeof(demo_dot_assume_array_dot_observe)}) return [@varname(m[1]), @varname(m[2])] end +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_array_dot_observe)} +) + return (m=rand(rng, Normal(), length(model.args.x)),) +end +function posterior_mean_values(model::Model{typeof(demo_dot_assume_array_dot_observe)}) + vals = example_values(model) + vals.m .= 8 + return vals +end const DEMO_MODELS = ( demo_dot_assume_dot_observe(), @@ -431,7 +584,6 @@ function test_sampler_demo_models( meanfunction, sampler::AbstractMCMC.AbstractSampler, args...; - target=8.0, atol=1e-1, rtol=1e-3, kwargs..., @@ -439,7 +591,11 @@ function test_sampler_demo_models( @testset "$(nameof(typeof(sampler))) on $(nameof(m))" for model in DEMO_MODELS chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) μ = meanfunction(chain) - @test μ ≈ target atol = atol rtol = rtol + target_values = posterior_mean_values(model) + for vn in keys(model) + target = get(target_values, vn) + @test μ ≈ target atol = atol rtol = rtol + end end end @@ -458,7 +614,7 @@ end function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...) # Default for `MCMCChains.Chains`. - return test_sampler_continuous(sampler, args...; kwargs...) do chain + return test_sampler_continuous(sampler, args...; kwargs...) do chain, vn mean(Array(chain)) end end diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 9494ae6c1..2620be405 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -62,18 +62,20 @@ DynamicPPL.TestUtils.DEMO_MODELS # We might need to pre-allocate for the variable `m`, so we need # to see whether this is the case. - m = model().m - svi_nt = if m isa AbstractArray - SimpleVarInfo((m=similar(m),)) - else - SimpleVarInfo() - end + svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.example_values(model)) svi_dict = SimpleVarInfo(VarInfo(model), Dict) - @testset "$(nameof(typeof(svi.values)))" for svi in (svi_nt, svi_dict) + @testset "$(nameof(typeof(DynamicPPL.values_as(svi))))" for svi in ( + svi_nt, + svi_dict, + DynamicPPL.settrans!!(svi_nt, true), + DynamicPPL.settrans!!(svi_dict, true), + ) + Random.seed!(42) + # Random seed is set in each `@testset`, so we need to sample # a new realization for `m` here. - m = model().m + retval = model() ### Sampling ### # Sample a new varinfo! @@ -81,11 +83,7 @@ # Realization for `m` should be different wp. 1. for vn in keys(model) - # `VarName` functions similarly to `PropertyLens` so - # we just strip this part from `vn` to get a lens we can use - # to extract the corresponding value of `m`. - l = getlens(vn) - @test svi_new[vn] != get(m, l) + @test svi_new[vn] != get(retval, vn) end # Logjoint should be non-zero wp. 1. @@ -93,17 +91,12 @@ ### Evaluation ### # Sample some random testing values. - m_eval = if m isa AbstractArray - randn!(similar(m)) - else - randn(eltype(m)) - end + values_eval = DynamicPPL.TestUtils.example_values(model) # Update the realizations in `svi_new`. svi_eval = svi_new for vn in keys(model) - l = getlens(vn) - svi_eval = DynamicPPL.setindex!!(svi_eval, get(m_eval, l), vn) + svi_eval = DynamicPPL.setindex!!(svi_eval, get(values_eval, vn), vn) end # Reset the logp field. @@ -114,12 +107,11 @@ # Values should not have changed. for vn in keys(model) - l = getlens(vn) - @test svi_eval[vn] == get(m_eval, l) + @test svi_eval[vn] == get(values_eval, vn) end # Compute the true `logjoint` and compare. - logπ_true = DynamicPPL.TestUtils.logjoint_true(model, m_eval) + logπ_true = DynamicPPL.TestUtils.logjoint_true(model, values_eval...) @test logπ ≈ logπ_true end end From 0498336481b3e92463ab9d849768b18e76129c53 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 30 Jun 2022 15:20:42 +0100 Subject: [PATCH 059/221] demo models in TestUtils are now a bit more complex, including constrained variables --- src/test_utils.jl | 389 ++++++++++++++++++++++++----------------- test/simple_varinfo.jl | 6 +- 2 files changed, 233 insertions(+), 162 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 621bbfe30..e6f11e0ab 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -141,122 +141,157 @@ end # A collection of models for which the mean-of-means for the posterior should # be same. @model function demo_dot_assume_dot_observe( - x=[10.0, 10.0], ::Type{TV}=Vector{Float64} + x=[1.5, 1.5], ::Type{TV}=Vector{Float64} ) where {TV} # `dot_assume` and `observe` + s = TV(undef, length(x)) m = TV(undef, length(x)) - m .~ Normal() - x ~ MvNormal(m, 0.25 * I) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + s .~ InverseGamma(2, 3) + m .~ Normal.(0, sqrt.(s)) + + x ~ MvNormal(m, Diagonal(s)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe)}, m) - return loglikelihood(Normal(), m) +function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe)}, s, m) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe)}, m) - return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) +function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe)}, s, m) + return loglikelihood(MvNormal(m, Diagonal(s)), model.args.x) end function Base.keys(model::Model{typeof(demo_dot_assume_dot_observe)}) - return [@varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_dot_observe)} ) - return (m=rand(rng, Normal(), length(model.args.x)),) + n = length(model.args.x) + s = rand(rng, InverseGamma(2, 3), n) + m = similar(s) + for i in eachindex(m, s) + m[i] = rand(rng, Normal(0, sqrt(s[i]))) + end + return (s=s, m=m) end function posterior_mean_values(model::Model{typeof(demo_dot_assume_dot_observe)}) vals = example_values(model) - vals.m .= 8 + vals.s .= 2.375 + vals.m .= 0.75 return vals end @model function demo_assume_index_observe( - x=[10.0, 10.0], ::Type{TV}=Vector{Float64} + x=[1.5, 1.5], ::Type{TV}=Vector{Float64} ) where {TV} # `assume` with indexing and `observe` + s = TV(undef, length(x)) + for i in eachindex(s) + s[i] ~ InverseGamma(2, 3) + end m = TV(undef, length(x)) for i in eachindex(m) - m[i] ~ Normal() + m[i] ~ Normal(0, sqrt(s[i])) end - x ~ MvNormal(m, 0.25 * I) + x ~ MvNormal(m, Diagonal(s)) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_index_observe)}, m) - return loglikelihood(Normal(), m) +function logprior_true(model::Model{typeof(demo_assume_index_observe)}, s, m) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_assume_index_observe)}, m) - return logpdf(MvNormal(m, 0.25 * I), model.args.x) +function loglikelihood_true(model::Model{typeof(demo_assume_index_observe)}, s, m) + return logpdf(MvNormal(m, Diagonal(s)), model.args.x) end function Base.keys(model::Model{typeof(demo_assume_index_observe)}) - return [@varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_assume_index_observe)} ) - return (m=rand(rng, Normal(), length(model.args.x)),) + n = length(model.args.x) + s = rand(rng, InverseGamma(2, 3), n) + m = similar(s) + for i in eachindex(m, s) + m[i] = rand(rng, Normal(0, sqrt(s[i]))) + end + return (s=s, m=m) end function posterior_mean_values(model::Model{typeof(demo_assume_index_observe)}) vals = example_values(model) - vals.m .= 8 + vals.s .= 2.375 + vals.m .= 0.75 return vals end @model function demo_assume_multivariate_observe(x=[10.0, 10.0]) # Multivariate `assume` and `observe` - m ~ MvNormal(zero(x), I) - x ~ MvNormal(m, 0.25 * I) + s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) + m ~ MvNormal(zero(x), Diagonal(s)) + x ~ MvNormal(m, Diagonal(s)) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_multivariate_observe)}, m) - return logpdf(MvNormal(zero(model.args.x), I), m) +function logprior_true(model::Model{typeof(demo_assume_multivariate_observe)}, s, m) + s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) + m_dist = MvNormal(zero(model.args.x), Diagonal(s)) + return logpdf(s_dist, s) + logpdf(m_dist, m) end -function loglikelihood_true(model::Model{typeof(demo_assume_multivariate_observe)}, m) - return logpdf(MvNormal(m, 0.25 * I), model.args.x) +function loglikelihood_true(model::Model{typeof(demo_assume_multivariate_observe)}, s, m) + return logpdf(MvNormal(m, Diagonal(s)), model.args.x) end function Base.keys(model::Model{typeof(demo_assume_multivariate_observe)}) - return [@varname(m)] + return [@varname(s), @varname(m)] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_assume_multivariate_observe)} ) - return (m=rand(rng, MvNormal(zero(model.args.x), I)),) + s = rand(rng, product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])) + return (s=s, m=rand(rng, MvNormal(zero(model.args.x), Diagonal(s)))) end function posterior_mean_values(model::Model{typeof(demo_assume_multivariate_observe)}) vals = example_values(model) - vals.m .= 8 + vals.s .= 2.375 + vals.m .= 0.75 return vals end @model function demo_dot_assume_observe_index( - x=[10.0, 10.0], ::Type{TV}=Vector{Float64} + x=[1.5, 1.5], ::Type{TV}=Vector{Float64} ) where {TV} # `dot_assume` and `observe` with indexing + s = TV(undef, length(x)) + s .~ InverseGamma(2, 3) m = TV(undef, length(x)) - m .~ Normal() + m .~ Normal.(0, sqrt.(s)) for i in eachindex(x) - x[i] ~ Normal(m[i], 0.5) + x[i] ~ Normal(m[i], sqrt(s[i])) end - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_observe_index)}, m) - return loglikelihood(Normal(), m) +function logprior_true(model::Model{typeof(demo_dot_assume_observe_index)}, s, m) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index)}, m) - return sum(logpdf.(Normal.(m, 0.5), model.args.x)) +function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index)}, s, m) + return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) end function Base.keys(model::Model{typeof(demo_dot_assume_observe_index)}) - return [@varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_observe_index)} ) - return (m=rand(rng, Normal(), length(model.args.x)),) + n = length(model.args.x) + s = rand(rng, InverseGamma(2, 3), n) + m = similar(s) + for i in eachindex(m, s) + m[i] = rand(rng, Normal(0, sqrt(s[i]))) + end + return (s=s, m=m) end function posterior_mean_values(model::Model{typeof(demo_dot_assume_observe_index)}) vals = example_values(model) - vals.m .= 8 + vals.s .= 2.375 + vals.m .= 0.75 return vals end @@ -264,281 +299,314 @@ end # as the others. @model function demo_assume_dot_observe(x=[10.0]) # `assume` and `dot_observe` - m ~ Normal() - x .~ Normal(m, 0.5) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x .~ Normal(m, sqrt(s)) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_dot_observe)}, m) - return logpdf(Normal(), m) +function logprior_true(model::Model{typeof(demo_assume_dot_observe)}, s, m) + return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) end -function loglikelihood_true(model::Model{typeof(demo_assume_dot_observe)}, m) - return sum(logpdf.(Normal.(m, 0.5), model.args.x)) +function loglikelihood_true(model::Model{typeof(demo_assume_dot_observe)}, s, m) + return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) end function Base.keys(model::Model{typeof(demo_assume_dot_observe)}) - return [@varname(m)] + return [@varname(s), @varname(m)] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_assume_dot_observe)} ) - return (m=rand(rng, Normal()),) + s = rand(rng, InverseGamma(2, 3)) + m = rand(rng, Normal(0, sqrt(s))) + return (s=s, m=m) end function posterior_mean_values(model::Model{typeof(demo_assume_dot_observe)}) - return (m=8.0,) + return (s=2.375, m=0.75) end @model function demo_assume_observe_literal() # `assume` and literal `observe` - m ~ MvNormal(zeros(2), I) - [10.0, 10.0] ~ MvNormal(m, 0.25 * I) + s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) + m ~ MvNormal(zeros(2), Diagonal(s)) + [1.5, 1.5] ~ MvNormal(m, Diagonal(s)) - return (; m=m, x=[10.0, 10.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 1.5], logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, m) - return logpdf(MvNormal(zeros(2), I), m) +function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) + s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) + m_dist = MvNormal(zeros(2), Diagonal(s)) + return logpdf(s_dist, s) + logpdf(m_dist, m) end -function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, m) - return logpdf(MvNormal(m, 0.25 * I), [10.0, 10.0]) +function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) + return logpdf(MvNormal(m, Diagonal(s)), [1.5, 1.5]) end function Base.keys(model::Model{typeof(demo_assume_observe_literal)}) - return [@varname(m)] + return [@varname(s), @varname(m)] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_assume_observe_literal)} ) - return (m=rand(rng, MvNormal(zeros(2), I)),) + s = rand(rng, product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])) + return (s=s, m=rand(rng, MvNormal(zeros(2), Diagonal(s)))) end function posterior_mean_values(model::Model{typeof(demo_assume_observe_literal)}) vals = example_values(model) - vals.m .= 8 + vals.s .= 2.375 + vals.m .= 0.75 return vals end @model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and literal `observe` with indexing + s = TV(undef, 2) m = TV(undef, 2) - m .~ Normal() + s .~ InverseGamma(2, 3) + m .~ Normal.(0, sqrt.(s)) + for i in eachindex(m) - 10.0 ~ Normal(m[i], 0.5) + 1.5 ~ Normal(m[i], sqrt(s[i])) end - return (; m=m, x=fill(10.0, length(m)), logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=fill(1.5, length(m)), logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, m) - return loglikelihood(Normal(), m) +function logprior_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, s, m) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, m) - return sum(logpdf.(Normal.(m, 0.5), fill(10.0, length(m)))) +function loglikelihood_true( + model::Model{typeof(demo_dot_assume_observe_index_literal)}, s, m +) + return sum(logpdf.(Normal.(m, sqrt.(s)), fill(1.5, length(m)))) end function Base.keys(model::Model{typeof(demo_dot_assume_observe_index_literal)}) - return [@varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_observe_index_literal)} ) - return (m=rand(rng, Normal(), 2),) + n = 2 + s = rand(rng, InverseGamma(2, 3), n) + m = similar(s) + for i in eachindex(m, s) + m[i] = rand(rng, Normal(0, sqrt(s[i]))) + end + return (s=s, m=m) end function posterior_mean_values(model::Model{typeof(demo_dot_assume_observe_index_literal)}) vals = example_values(model) - vals.m .= 8 + vals.s .= 2.375 + vals.m .= 0.75 return vals end @model function demo_assume_literal_dot_observe() # `assume` and literal `dot_observe` - m ~ Normal() - [10.0] .~ Normal(m, 0.5) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + [1.5] .~ Normal(m, sqrt(s)) - return (; m=m, x=[10.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5], logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_literal_dot_observe)}, m) - return logpdf(Normal(), m) +function logprior_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m) + return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) end -function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, m) - return logpdf(Normal(m, 0.5), 10.0) +function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m) + return logpdf(Normal(m, sqrt(s)), 1.5) end function Base.keys(model::Model{typeof(demo_assume_literal_dot_observe)}) - return [@varname(m)] + return [@varname(s), @varname(m)] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_assume_literal_dot_observe)} ) - return (m=rand(rng, Normal()),) + s = rand(rng, InverseGamma(2, 3)) + m = rand(rng, Normal(0, sqrt(s))) + return (s=s, m=m) end function posterior_mean_values(model::Model{typeof(demo_assume_literal_dot_observe)}) - return (m=8.0,) + return (s=2.375, m=0.75) end @model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} + s = TV(undef, 2) + s .~ InverseGamma(2, 3) m = TV(undef, 2) - m .~ Normal() + m .~ Normal.(0, sqrt.(s)) - return m + return s, m end @model function demo_assume_submodel_observe_index_literal() # Submodel prior - @submodel m = _prior_dot_assume() - for i in eachindex(m) - 10.0 ~ Normal(m[i], 0.5) + @submodel s, m = _prior_dot_assume() + for i in eachindex(m, s) + 1.5 ~ Normal(m[i], sqrt(s[i])) end - return (; m=m, x=[10.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 1.5], logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_submodel_observe_index_literal)}, m) - return loglikelihood(Normal(), m) +function logprior_true( + model::Model{typeof(demo_assume_submodel_observe_index_literal)}, s, m +) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end function loglikelihood_true( - model::Model{typeof(demo_assume_submodel_observe_index_literal)}, m + model::Model{typeof(demo_assume_submodel_observe_index_literal)}, s, m ) - return sum(logpdf.(Normal.(m, 0.5), 10.0)) + return sum(logpdf.(Normal.(m, sqrt.(s)), 1.5)) end function Base.keys(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) - return [@varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_assume_submodel_observe_index_literal)}, ) - return (m=rand(rng, Normal(), 2),) + n = 2 + s = rand(rng, InverseGamma(2, 3), n) + m = similar(s) + for i in eachindex(m, s) + m[i] = rand(rng, Normal(0, sqrt(s[i]))) + end + return (s=s, m=m) end function posterior_mean_values( model::Model{typeof(demo_assume_submodel_observe_index_literal)} ) vals = example_values(model) - vals.m .= 8 + vals.s .= 2.375 + vals.m .= 0.75 return vals end -@model function _likelihood_dot_observe(m, x) - return x ~ MvNormal(m, 0.25 * I) +@model function _likelihood_mltivariate_observe(s, m, x) + return x ~ MvNormal(m, Diagonal(s)) end @model function demo_dot_assume_observe_submodel( - x=[10.0, 10.0], ::Type{TV}=Vector{Float64} + x=[1.5, 1.5], ::Type{TV}=Vector{Float64} ) where {TV} + s = TV(undef, length(x)) + s .~ InverseGamma(2, 3) m = TV(undef, length(x)) - m .~ Normal() + m .~ Normal.(0, sqrt.(s)) # Submodel likelihood - @submodel _likelihood_dot_observe(m, x) + @submodel _likelihood_mltivariate_observe(s, m, x) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, m) - return loglikelihood(Normal(), m) +function logprior_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, s, m) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, m) - return logpdf(MvNormal(m, 0.25 * I), model.args.x) +function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, s, m) + return logpdf(MvNormal(m, Diagonal(s)), model.args.x) end function Base.keys(model::Model{typeof(demo_dot_assume_observe_submodel)}) - return [@varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_observe_submodel)} ) - return (m=rand(rng, Normal(), length(model.args.x)),) + n = length(model.args.x) + s = rand(rng, InverseGamma(2, 3), n) + m = similar(s) + for i in eachindex(m, s) + m[i] = rand(rng, Normal(0, sqrt(s[i]))) + end + return (s=s, m=m) end function posterior_mean_values(model::Model{typeof(demo_dot_assume_observe_submodel)}) vals = example_values(model) - vals.m .= 8 + vals.s .= 2.375 + vals.m .= 0.75 return vals end @model function demo_dot_assume_dot_observe_matrix( - x=fill(10.0, 2, 1), ::Type{TV}=Vector{Float64} + x=fill(1.5, 2, 1), ::Type{TV}=Vector{Float64} ) where {TV} + s = TV(undef, length(x)) + s .~ InverseGamma(2, 3) m = TV(undef, length(x)) - m .~ Normal() + m .~ Normal.(0, sqrt.(s)) # Dotted observe for `Matrix`. - x .~ MvNormal(m, 0.25 * I) + x .~ MvNormal(m, Diagonal(s)) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, m) - return loglikelihood(Normal(), m) +function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, m) - return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) +function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m) + return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) end function Base.keys(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) - return [@varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_dot_observe_matrix)} ) - return (m=rand(rng, Normal(), length(model.args.x)),) + n = length(model.args.x) + s = rand(rng, InverseGamma(2, 3), n) + m = similar(s) + for i in eachindex(m, s) + m[i] = rand(rng, Normal(0, sqrt(s[i]))) + end + return (s=s, m=m) end function posterior_mean_values(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) vals = example_values(model) - vals.m .= 8 + vals.s .= 2.375 + vals.m .= 0.75 return vals end @model function demo_dot_assume_matrix_dot_observe_matrix( - x=fill(10.0, 2, 1), ::Type{TV}=Array{Float64} + x=fill(1.5, 2, 1), ::Type{TV}=Array{Float64} ) where {TV} d = length(x) ÷ 2 + s = TV(undef, d, 2) + s .~ product_distribution([InverseGamma(2, 3) for _ in 1:d]) m = TV(undef, d, 2) m .~ MvNormal(zeros(d), I) # Dotted observe for `Matrix`. - x .~ MvNormal(vec(m), 0.25 * I) + x .~ MvNormal(vec(m), Diagonal(vec(s))) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, m) - return loglikelihood(Normal(), vec(m)) +function logprior_true( + model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m +) + return loglikelihood(InverseGamma(2, 3), vec(s)) + loglikelihood(Normal(), vec(m)) end function loglikelihood_true( - model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, m + model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m ) - return loglikelihood(MvNormal(vec(m), 0.25 * I), model.args.x) + return loglikelihood(MvNormal(vec(m), Diagonal(vec(s))), model.args.x) end function Base.keys(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) - return [@varname(m[:, 1]), @varname(m[:, 2])] + return [@varname(s[:, 1]), @varname(s[:, 2]), @varname(m[:, 1]), @varname(m[:, 2])] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)} ) d = length(model.args.x) ÷ 2 - return (m=rand(rng, MvNormal(zeros(d), I), 2),) + s = rand(rng, product_distribution([InverseGamma(2, 3) for _ in 1:d]), 2) + m = similar(s) + for i in 1:size(m, 2) + m[:, i] = rand(rng, MvNormal(zeros(d), Diagonal(vec(s[:, i])))) + end + return (s=s, m=m) end function posterior_mean_values( model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)} ) vals = example_values(model) - vals.m .= 8 - return vals -end - -@model function demo_dot_assume_array_dot_observe( - x=[10.0, 10.0], ::Type{TV}=Vector{Float64} -) where {TV} - # `dot_assume` and `observe` - m = TV(undef, length(x)) - m .~ [Normal() for _ in 1:length(x)] - x ~ MvNormal(m, 0.25 * I) - return (; m=m, x=x, logp=getlogp(__varinfo__)) -end -function logprior_true(model::Model{typeof(demo_dot_assume_array_dot_observe)}, m) - return loglikelihood(Normal(), m) -end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_array_dot_observe)}, m) - return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) -end -function Base.keys(model::Model{typeof(demo_dot_assume_array_dot_observe)}) - return [@varname(m[1]), @varname(m[2])] -end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_array_dot_observe)} -) - return (m=rand(rng, Normal(), length(model.args.x)),) -end -function posterior_mean_values(model::Model{typeof(demo_dot_assume_array_dot_observe)}) - vals = example_values(model) - vals.m .= 8 + vals.s .= 2.375 + vals.m .= 0.75 return vals end @@ -555,7 +623,6 @@ const DEMO_MODELS = ( demo_dot_assume_observe_submodel(), demo_dot_assume_dot_observe_matrix(), demo_dot_assume_matrix_dot_observe_matrix(), - demo_dot_assume_array_dot_observe(), ) # TODO: Is this really the best/most convenient "default" test method? @@ -590,6 +657,8 @@ function test_sampler_demo_models( ) @testset "$(nameof(typeof(sampler))) on $(nameof(m))" for model in DEMO_MODELS chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) + # TODO(torfjelde): Move `meanfunction` into loop below, and have it also + # take `vn` as input. μ = meanfunction(chain) target_values = posterior_mean_values(model) for vn in keys(model) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 2620be405..9a2c8a134 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -111,8 +111,10 @@ end # Compute the true `logjoint` and compare. - logπ_true = DynamicPPL.TestUtils.logjoint_true(model, values_eval...) - @test logπ ≈ logπ_true + if !DynamicPPL.istrans(svi) + logπ_true = DynamicPPL.TestUtils.logjoint_true(model, values_eval...) + @test logπ ≈ logπ_true + end end end From f86f264b5d591e2c3b387fec8b3558ba1486964a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 09:04:32 +0100 Subject: [PATCH 060/221] added logprior_true_with_logabsdet_jacobian for demo models --- src/test_utils.jl | 71 ++++++++++++++++++++++++++++++++++++++++-- test/simple_varinfo.jl | 20 ++++++------ 2 files changed, 80 insertions(+), 11 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index e6f11e0ab..6bcc58a31 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -138,8 +138,15 @@ function logprior_true_with_logabsdet_jacobian( return (m=m, x=x_unconstrained), logprior_true(model, m, x) - Δlogp end -# A collection of models for which the mean-of-means for the posterior should -# be same. +# A collection of models for which the posterior should be "similar". +# Some utility methods for these. +function _demo_logprior_true_with_logabsdet_jacobian(model, s, m) + b = Bijectors.bijector(InverseGamma(2, 3)) + s_unconstrained = b.(s) + Δlogp = sum(Base.Fix1(Bijectors.logabsdetjac, b).(s)) + return (s=s_unconstrained, m=m), logprior_true(model, s, m) - Δlogp +end + @model function demo_dot_assume_dot_observe( x=[1.5, 1.5], ::Type{TV}=Vector{Float64} ) where {TV} @@ -158,6 +165,11 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe)}, s, m) return loglikelihood(MvNormal(m, Diagonal(s)), model.args.x) end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_dot_observe)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end function Base.keys(model::Model{typeof(demo_dot_assume_dot_observe)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end @@ -201,6 +213,11 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_index_observe)}, s, m) return logpdf(MvNormal(m, Diagonal(s)), model.args.x) end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_index_observe)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end function Base.keys(model::Model{typeof(demo_assume_index_observe)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end @@ -238,6 +255,11 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_multivariate_observe)}, s, m) return logpdf(MvNormal(m, Diagonal(s)), model.args.x) end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_multivariate_observe)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end function Base.keys(model::Model{typeof(demo_assume_multivariate_observe)}) return [@varname(s), @varname(m)] end @@ -274,6 +296,11 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index)}, s, m) return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_observe_index)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end function Base.keys(model::Model{typeof(demo_dot_assume_observe_index)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end @@ -311,6 +338,11 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_dot_observe)}, s, m) return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_dot_observe)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end function Base.keys(model::Model{typeof(demo_assume_dot_observe)}) return [@varname(s), @varname(m)] end @@ -341,6 +373,11 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) return logpdf(MvNormal(m, Diagonal(s)), [1.5, 1.5]) end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_observe_literal)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end function Base.keys(model::Model{typeof(demo_assume_observe_literal)}) return [@varname(s), @varname(m)] end @@ -378,6 +415,11 @@ function loglikelihood_true( ) return sum(logpdf.(Normal.(m, sqrt.(s)), fill(1.5, length(m)))) end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_observe_index_literal)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end function Base.keys(model::Model{typeof(demo_dot_assume_observe_index_literal)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end @@ -413,6 +455,11 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m) return logpdf(Normal(m, sqrt(s)), 1.5) end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_literal_dot_observe)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end function Base.keys(model::Model{typeof(demo_assume_literal_dot_observe)}) return [@varname(s), @varname(m)] end @@ -455,6 +502,11 @@ function loglikelihood_true( ) return sum(logpdf.(Normal.(m, sqrt.(s)), 1.5)) end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_submodel_observe_index_literal)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end function Base.keys(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end @@ -502,6 +554,11 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, s, m) return logpdf(MvNormal(m, Diagonal(s)), model.args.x) end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_observe_submodel)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end function Base.keys(model::Model{typeof(demo_dot_assume_observe_submodel)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end @@ -542,6 +599,11 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m) return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end function Base.keys(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end @@ -587,6 +649,11 @@ function loglikelihood_true( ) return loglikelihood(MvNormal(vec(m), Diagonal(vec(s))), model.args.x) end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end function Base.keys(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) return [@varname(s[:, 1]), @varname(s[:, 2]), @varname(m[:, 1]), @varname(m[:, 2])] end diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 9a2c8a134..ed5919f5a 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -71,8 +71,6 @@ DynamicPPL.settrans!!(svi_nt, true), DynamicPPL.settrans!!(svi_dict, true), ) - Random.seed!(42) - # Random seed is set in each `@testset`, so we need to sample # a new realization for `m` here. retval = model() @@ -90,8 +88,15 @@ @test getlogp(svi_new) != 0 ### Evaluation ### - # Sample some random testing values. - values_eval = DynamicPPL.TestUtils.example_values(model) + values_eval_constrained = DynamicPPL.TestUtils.example_values(model) + if DynamicPPL.istrans(svi) + values_eval, logπ_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + model, values_eval_constrained... + ) + else + values_eval = values_eval_constrained + logπ_true = DynamicPPL.TestUtils.logjoint_true(model, values_eval...) + end # Update the realizations in `svi_new`. svi_eval = svi_new @@ -110,11 +115,8 @@ @test svi_eval[vn] == get(values_eval, vn) end - # Compute the true `logjoint` and compare. - if !DynamicPPL.istrans(svi) - logπ_true = DynamicPPL.TestUtils.logjoint_true(model, values_eval...) - @test logπ ≈ logπ_true - end + # Compare `logjoint` computations. + @test logπ ≈ logπ_true end end From 0d31137f90ee74fb1803a1f45eb2002c2e832ad7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 09:52:02 +0100 Subject: [PATCH 061/221] fixed mistakes in a couple of models in TestUtils --- src/test_utils.jl | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 6bcc58a31..cdc3da191 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -239,7 +239,7 @@ function posterior_mean_values(model::Model{typeof(demo_assume_index_observe)}) return vals end -@model function demo_assume_multivariate_observe(x=[10.0, 10.0]) +@model function demo_assume_multivariate_observe(x=[1.5, 1.5]) # Multivariate `assume` and `observe` s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) m ~ MvNormal(zero(x), Diagonal(s)) @@ -324,7 +324,7 @@ end # Using vector of `length` 1 here so the posterior of `m` is the same # as the others. -@model function demo_assume_dot_observe(x=[10.0]) +@model function demo_assume_dot_observe(x=[1.5]) # `assume` and `dot_observe` s ~ InverseGamma(2, 3) m ~ Normal(0, sqrt(s)) @@ -628,26 +628,29 @@ end @model function demo_dot_assume_matrix_dot_observe_matrix( x=fill(1.5, 2, 1), ::Type{TV}=Array{Float64} ) where {TV} + n = length(x) d = length(x) ÷ 2 s = TV(undef, d, 2) s .~ product_distribution([InverseGamma(2, 3) for _ in 1:d]) - m = TV(undef, d, 2) - m .~ MvNormal(zeros(d), I) + s_vec = vec(s) + m ~ MvNormal(zeros(n), Diagonal(s_vec)) # Dotted observe for `Matrix`. - x .~ MvNormal(vec(m), Diagonal(vec(s))) + x .~ MvNormal(m, Diagonal(s_vec)) return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end function logprior_true( model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m ) - return loglikelihood(InverseGamma(2, 3), vec(s)) + loglikelihood(Normal(), vec(m)) + n = length(model.args.x) + s_vec = vec(s) + return loglikelihood(InverseGamma(2, 3), s_vec) + logpdf(MvNormal(zeros(n), s_vec), m) end function loglikelihood_true( model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m ) - return loglikelihood(MvNormal(vec(m), Diagonal(vec(s))), model.args.x) + return loglikelihood(MvNormal(m, Diagonal(vec(s))), model.args.x) end function logprior_true_with_logabsdet_jacobian( model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m @@ -655,17 +658,15 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function Base.keys(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) - return [@varname(s[:, 1]), @varname(s[:, 2]), @varname(m[:, 1]), @varname(m[:, 2])] + return [@varname(s[:, 1]), @varname(s[:, 2]), @varname(m[1]), @varname(m[2])] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)} ) - d = length(model.args.x) ÷ 2 + n = length(model.args.x) + d = n ÷ 2 s = rand(rng, product_distribution([InverseGamma(2, 3) for _ in 1:d]), 2) - m = similar(s) - for i in 1:size(m, 2) - m[:, i] = rand(rng, MvNormal(zeros(d), Diagonal(vec(s[:, i])))) - end + m = rand(rng, MvNormal(zeros(n), Diagonal(vec(s)))) return (s=s, m=m) end function posterior_mean_values( From c52630b37761f1bd19e13eac7f5b18f8b6d086b6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 09:52:24 +0100 Subject: [PATCH 062/221] moved varnames method which creates iterator of leaf varnames into TestUtils and starting using this in test_continuous_models --- src/test_utils.jl | 39 ++++++++++++++++++++++++++++++++------- test/contexts.jl | 23 ++--------------------- 2 files changed, 34 insertions(+), 28 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index cdc3da191..1a6aff5f9 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -8,6 +8,28 @@ using Test using Random: Random using Bijectors: Bijectors +using Setfield: Setfield + +""" + varnames(vn::VarName, val) + +Return iterator over all varnames that are represented by `vn` on `val`, +e.g. `varnames(@varname(x), rand(2))` results in an iterator over `[@varname(x[1]), @varname(x[2])]`. +""" +varnames(vn::VarName, val::Real) = [vn] +function varnames(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) + return ( + VarName(vn, DynamicPPL.getlens(vn) ∘ Setfield.IndexLens(Tuple(I))) for + I in CartesianIndices(val) + ) +end +function varnames(vn::VarName, val::AbstractArray) + return Iterators.flatten( + varnames( + VarName(vn, DynamicPPL.getlens(vn) ∘ Setfield.IndexLens(Tuple(I))), val[I] + ) for I in CartesianIndices(val) + ) +end """ logprior_true(model, θ) @@ -723,15 +745,17 @@ function test_sampler_demo_models( rtol=1e-3, kwargs..., ) - @testset "$(nameof(typeof(sampler))) on $(nameof(m))" for model in DEMO_MODELS + @testset "$(typeof(sampler)) on $(nameof(model))" for model in DEMO_MODELS chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) - # TODO(torfjelde): Move `meanfunction` into loop below, and have it also - # take `vn` as input. - μ = meanfunction(chain) target_values = posterior_mean_values(model) for vn in keys(model) - target = get(target_values, vn) - @test μ ≈ target atol = atol rtol = rtol + # We want to compare elementwise which can be achieved by + # extracting the leaves of the `VarName` and the corresponding value. + for vn_leaf in varnames(vn, get(target_values, vn)) + target_value = get(target_values, vn_leaf) + chain_mean_value = meanfunction(chain, vn_leaf) + @test chain_mean_value ≈ target_value atol = atol rtol = rtol + end end end end @@ -752,7 +776,8 @@ end function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...) # Default for `MCMCChains.Chains`. return test_sampler_continuous(sampler, args...; kwargs...) do chain, vn - mean(Array(chain)) + # HACK(torfjelde): This assumes that we can index into `chain` with `Symbol(vn)`. + mean(Array(chain[Symbol(vn)])) end end diff --git a/test/contexts.jl b/test/contexts.jl index 65629afec..24b039852 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -57,25 +57,6 @@ function remove_prefix(vn::VarName) ) end -""" - varnames(vn::VarName, val) - -Return iterator over all varnames that are represented by `vn` on `val`, -e.g. `varnames(@varname(x), rand(2))` results in an iterator over `[@varname(x[1]), @varname(x[2])]`. -""" -varnames(vn::VarName, val::Real) = [vn] -function varnames(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) - return ( - VarName(vn, getlens(vn) ∘ Setfield.IndexLens(Tuple(I))) for - I in CartesianIndices(val) - ) -end -function varnames(vn::VarName, val::AbstractArray) - return Iterators.flatten( - varnames(VarName(vn, getlens(vn) ∘ Setfield.IndexLens(Tuple(I))), val[I]) for - I in CartesianIndices(val) - ) -end @testset "contexts.jl" begin child_contexts = [DefaultContext(), PriorContext(), LikelihoodContext()] @@ -185,7 +166,7 @@ end vn_without_prefix = remove_prefix(vn) # Let's check elementwise. - for vn_child in varnames(vn_without_prefix, val) + for vn_child in DynamicPPL.TestUtils.varnames(vn_without_prefix, val) if get(val, getlens(vn_child)) === missing @test contextual_isassumption(context, vn_child) else @@ -217,7 +198,7 @@ end # `ConditionContext` with the conditioned variable. vn_without_prefix = remove_prefix(vn) - for vn_child in varnames(vn_without_prefix, val) + for vn_child in DynamicPPL.TestUtils.varnames(vn_without_prefix, val) # `vn_child` should be in `context`. @test hasvalue_nested(context, vn_child) # Value should be the same as extracted above. From fff060c74ad8f804835bf0283c163bea3c59c9e7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 09:57:51 +0100 Subject: [PATCH 063/221] updated docstring for test_sampler_demo_models --- src/test_utils.jl | 6 +++--- test/contexts.jl | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 1a6aff5f9..69579073b 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -722,8 +722,9 @@ const DEMO_MODELS = ( Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`. In short, this method iterators through `demo_models`, calls `AbstractMCMC.sample` on the -`model` and `sampler` to produce a `chain`, and then checks `meanfunction(chain)` against `target` -provided in `kwargs...`. +`model` and `sampler` to produce a `chain`, and then checks `meanfunction(chain, vn)` +for every (leaf) varname `vn` against the corresponding value returned by +[`posterior_mean_values`](@ref) for each model. # Arguments - `meanfunction`: A callable which computes the mean of the marginal means from the @@ -732,7 +733,6 @@ provided in `kwargs...`. - `args...`: Arguments forwarded to `sample`. # Keyword arguments -- `target`: Value to compare result of `meanfunction(chain)` to. - `atol=1e-1`: Absolute tolerance used in `@test`. - `rtol=1e-3`: Relative tolerance used in `@test`. - `kwargs...`: Keyword arguments forwarded to `sample`. diff --git a/test/contexts.jl b/test/contexts.jl index 24b039852..ef916b18c 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -57,7 +57,6 @@ function remove_prefix(vn::VarName) ) end - @testset "contexts.jl" begin child_contexts = [DefaultContext(), PriorContext(), LikelihoodContext()] From e21958c5bdc27c7899c22af61532b25421b9ecfc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 10:47:30 +0100 Subject: [PATCH 064/221] renamed varnames to varname_leaves and renamed keys(model) to varnames(model) --- src/test_utils.jl | 66 +++++++++++++++++++++++++++--------------- test/simple_varinfo.jl | 8 ++--- 2 files changed, 47 insertions(+), 27 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 69579073b..78d279067 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -11,21 +11,21 @@ using Bijectors: Bijectors using Setfield: Setfield """ - varnames(vn::VarName, val) + varname_leaves(vn::VarName, val) Return iterator over all varnames that are represented by `vn` on `val`, -e.g. `varnames(@varname(x), rand(2))` results in an iterator over `[@varname(x[1]), @varname(x[2])]`. +e.g. `varname_leaves(@varname(x), rand(2))` results in an iterator over `[@varname(x[1]), @varname(x[2])]`. """ -varnames(vn::VarName, val::Real) = [vn] -function varnames(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) +varname_leaves(vn::VarName, val::Real) = [vn] +function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) return ( VarName(vn, DynamicPPL.getlens(vn) ∘ Setfield.IndexLens(Tuple(I))) for I in CartesianIndices(val) ) end -function varnames(vn::VarName, val::AbstractArray) +function varname_leaves(vn::VarName, val::AbstractArray) return Iterators.flatten( - varnames( + varname_leaves( VarName(vn, DynamicPPL.getlens(vn) ∘ Setfield.IndexLens(Tuple(I))), val[I] ) for I in CartesianIndices(val) ) @@ -110,18 +110,38 @@ See also: [`logprior_true`](@ref). """ function logprior_true_with_logabsdet_jacobian end +""" + varnames(model::Model) + +Return a collection of `VarName` as they are expected to appear in the model. + +Even though it is recommended to implement this by hand for a particular `Model`, +a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. +""" +function varnames(model::Model) + return collect( + keys(last(DynamicPPL.evaluate!!(model, SimpleVarInfo(Dict()), SamplingContext()))) + ) +end + """ example_values(model::Model) -Return a `NamedTuple` compatible with `keys(model)` with values in support of `model`. +Return a `NamedTuple` compatible with `varnames(model)` with values in support of `model`. + +Compatible means that a `varname` from `varnames(model)` can be used to extract the +corresponding value using the call `get(example_values(model), varname)`. """ example_values(model::Model) = example_values(Random.GLOBAL_RNG, model) """ posterior_mean_values(model::Model) -Return a `NamedTuple` compatible with `keys(model)` where the values represent +Return a `NamedTuple` compatible with `varnames(model)` where the values represent the posterior mean under `model`. + +Compatible means that a `varname` from `varnames(model)` can be used to extract the +corresponding value using the call `get(posterior_mean_values(model), varname)`. """ function posterior_mean_values end @@ -143,7 +163,7 @@ end function loglikelihood_true(model::Model{typeof(demo_dynamic_constraint)}, m, x) return zero(float(eltype(m))) end -function Base.keys(model::Model{typeof(demo_dynamic_constraint)}) +function varnames(model::Model{typeof(demo_dynamic_constraint)}) return [@varname(m), @varname(x)] end function example_values( @@ -192,7 +212,7 @@ function logprior_true_with_logabsdet_jacobian( ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function Base.keys(model::Model{typeof(demo_dot_assume_dot_observe)}) +function varnames(model::Model{typeof(demo_dot_assume_dot_observe)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( @@ -240,7 +260,7 @@ function logprior_true_with_logabsdet_jacobian( ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function Base.keys(model::Model{typeof(demo_assume_index_observe)}) +function varnames(model::Model{typeof(demo_assume_index_observe)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( @@ -282,7 +302,7 @@ function logprior_true_with_logabsdet_jacobian( ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function Base.keys(model::Model{typeof(demo_assume_multivariate_observe)}) +function varnames(model::Model{typeof(demo_assume_multivariate_observe)}) return [@varname(s), @varname(m)] end function example_values( @@ -323,7 +343,7 @@ function logprior_true_with_logabsdet_jacobian( ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function Base.keys(model::Model{typeof(demo_dot_assume_observe_index)}) +function varnames(model::Model{typeof(demo_dot_assume_observe_index)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( @@ -365,7 +385,7 @@ function logprior_true_with_logabsdet_jacobian( ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function Base.keys(model::Model{typeof(demo_assume_dot_observe)}) +function varnames(model::Model{typeof(demo_assume_dot_observe)}) return [@varname(s), @varname(m)] end function example_values( @@ -400,7 +420,7 @@ function logprior_true_with_logabsdet_jacobian( ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function Base.keys(model::Model{typeof(demo_assume_observe_literal)}) +function varnames(model::Model{typeof(demo_assume_observe_literal)}) return [@varname(s), @varname(m)] end function example_values( @@ -442,7 +462,7 @@ function logprior_true_with_logabsdet_jacobian( ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function Base.keys(model::Model{typeof(demo_dot_assume_observe_index_literal)}) +function varnames(model::Model{typeof(demo_dot_assume_observe_index_literal)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( @@ -482,7 +502,7 @@ function logprior_true_with_logabsdet_jacobian( ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function Base.keys(model::Model{typeof(demo_assume_literal_dot_observe)}) +function varnames(model::Model{typeof(demo_assume_literal_dot_observe)}) return [@varname(s), @varname(m)] end function example_values( @@ -529,7 +549,7 @@ function logprior_true_with_logabsdet_jacobian( ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function Base.keys(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) +function varnames(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( @@ -581,7 +601,7 @@ function logprior_true_with_logabsdet_jacobian( ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function Base.keys(model::Model{typeof(demo_dot_assume_observe_submodel)}) +function varnames(model::Model{typeof(demo_dot_assume_observe_submodel)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( @@ -626,7 +646,7 @@ function logprior_true_with_logabsdet_jacobian( ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function Base.keys(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) +function varnames(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( @@ -679,7 +699,7 @@ function logprior_true_with_logabsdet_jacobian( ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function Base.keys(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) +function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) return [@varname(s[:, 1]), @varname(s[:, 2]), @varname(m[1]), @varname(m[2])] end function example_values( @@ -748,10 +768,10 @@ function test_sampler_demo_models( @testset "$(typeof(sampler)) on $(nameof(model))" for model in DEMO_MODELS chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) target_values = posterior_mean_values(model) - for vn in keys(model) + for vn in varnames(model) # We want to compare elementwise which can be achieved by # extracting the leaves of the `VarName` and the corresponding value. - for vn_leaf in varnames(vn, get(target_values, vn)) + for vn_leaf in varname_leaves(vn, get(target_values, vn)) target_value = get(target_values, vn_leaf) chain_mean_value = meanfunction(chain, vn_leaf) @test chain_mean_value ≈ target_value atol = atol rtol = rtol diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index ed5919f5a..5e598217a 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -80,7 +80,7 @@ _, svi_new = DynamicPPL.evaluate!!(model, svi, SamplingContext()) # Realization for `m` should be different wp. 1. - for vn in keys(model) + for vn in DynamicPPL.TestUtils.varnames(model) @test svi_new[vn] != get(retval, vn) end @@ -100,7 +100,7 @@ # Update the realizations in `svi_new`. svi_eval = svi_new - for vn in keys(model) + for vn in DynamicPPL.TestUtils.varnames(model) svi_eval = DynamicPPL.setindex!!(svi_eval, get(values_eval, vn), vn) end @@ -111,7 +111,7 @@ logπ = logjoint(model, svi_eval) # Values should not have changed. - for vn in keys(model) + for vn in DynamicPPL.TestUtils.varnames(model) @test svi_eval[vn] == get(values_eval, vn) end @@ -141,7 +141,7 @@ ) # Realizations from model should all be equal to the unconstrained realization. - for vn in keys(model) + for vn in DynamicPPL.TestUtils.varnames(model) @test get(retval_unconstrained, vn) ≈ svi[vn] rtol = 1e-6 end From 9669345f2b89e6ac50cf17323e2143fca6f4d2b7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 10:48:12 +0100 Subject: [PATCH 065/221] added test_sampler_on_models as a generalization of test_sampler_demo_models --- src/test_utils.jl | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 78d279067..75661f24e 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -735,13 +735,12 @@ const DEMO_MODELS = ( demo_dot_assume_matrix_dot_observe_matrix(), ) -# TODO: Is this really the best/most convenient "default" test method? """ - test_sampler_demo_models(meanfunction, sampler, args...; kwargs...) + test_sampler_on_models(meanfunction, models, sampler, args...; kwargs...) -Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`. +Test that `sampler` produces correct marginal posterior means on each model in `models`. -In short, this method iterators through `demo_models`, calls `AbstractMCMC.sample` on the +In short, this method iterates through `models`, calls `AbstractMCMC.sample` on the `model` and `sampler` to produce a `chain`, and then checks `meanfunction(chain, vn)` for every (leaf) varname `vn` against the corresponding value returned by [`posterior_mean_values`](@ref) for each model. @@ -749,6 +748,7 @@ for every (leaf) varname `vn` against the corresponding value returned by # Arguments - `meanfunction`: A callable which computes the mean of the marginal means from the chain resulting from the `sample` call. +- `models`: A collection of instaces of [`DynamicPPL.Model`](@ref) to test on. - `sampler`: The `AbstractMCMC.AbstractSampler` to test. - `args...`: Arguments forwarded to `sample`. @@ -757,15 +757,16 @@ for every (leaf) varname `vn` against the corresponding value returned by - `rtol=1e-3`: Relative tolerance used in `@test`. - `kwargs...`: Keyword arguments forwarded to `sample`. """ -function test_sampler_demo_models( +function test_sampler_on_models( meanfunction, + models, sampler::AbstractMCMC.AbstractSampler, args...; atol=1e-1, rtol=1e-3, kwargs..., ) - @testset "$(typeof(sampler)) on $(nameof(model))" for model in DEMO_MODELS + @testset "$(typeof(sampler)) on $(nameof(model))" for model in models chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) target_values = posterior_mean_values(model) for vn in varnames(model) @@ -780,17 +781,30 @@ function test_sampler_demo_models( end end +""" + test_sampler_on_demo_models(meanfunction, sampler, args...; kwargs...) + +Test `sampler` on every model in [`DEMO_MODELS`](@ref). + +This is just a proxy for `test_sampler_on_models(meanfunction, DEMO_MODELS, sampler, args...; kwargs...)`. +""" +function test_sampler_on_demo_models( + meanfunction, sampler::AbstractMCMC.AbstractSampler, args...; kwargs... +) + return test_sampler_on_models(meanfunction, DEMO_MODELS, sampler, args...; kwargs...) +end + """ test_sampler_continuous([meanfunction, ]sampler, args...; kwargs...) Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`. -As of right now, this is just an alias for [`test_sampler_demo_models`](@ref). +As of right now, this is just an alias for [`test_sampler_on_demo_models`](@ref). """ function test_sampler_continuous( meanfunction, sampler::AbstractMCMC.AbstractSampler, args...; kwargs... ) - return test_sampler_demo_models(meanfunction, sampler, args...; kwargs...) + return test_sampler_on_demo_models(meanfunction, sampler, args...; kwargs...) end function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...) From 7e027356415d6b3c19b4a0e1cd9cfdb249aa2ae2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 10:48:40 +0100 Subject: [PATCH 066/221] updated docs --- docs/src/api.md | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index 133b86e9b..debad2944 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -103,10 +103,15 @@ NamedDist DynamicPPL provides several demo models and helpers for testing samplers in the `DynamicPPL.TestUtils` submodule. ```@docs -DynamicPPL.TestUtils.test_sampler_demo_models +DynamicPPL.TestUtils.test_sampler_on_models +DynamicPPL.TestUtils.test_sampler_on_demo_models DynamicPPL.TestUtils.test_sampler_continuous ``` +```@docs +DynamicPPL.TestUtils.DEMO_MODELS +``` + For every demo model, one can define the true log prior, log likelihood, and log joint probabilities. ```@docs @@ -115,6 +120,21 @@ DynamicPPL.TestUtils.loglikelihood_true DynamicPPL.TestUtils.logjoint_true ``` +And in the case where the model might include constrained variables, it can also be useful to define + +```@docs +DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian +DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian +``` + +Finally, the following methods can also be of use: + +```@docs +DynamicPPL.TestUtils.varnames +DynamicPPL.TestUtils.example_values +DynamicPPL.TestUtils.posterior_mean_values +``` + ## Advanced ### Variable names From a412029736905657c543ac56ea03be0695352acc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 10:48:44 +0100 Subject: [PATCH 067/221] added docs for TestUtils.DEMO_MODELS --- src/test_utils.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/test_utils.jl b/src/test_utils.jl index 75661f24e..3f6ff5cf1 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -720,6 +720,21 @@ function posterior_mean_values( return vals end +""" +A collection of models corresponding to the posterior distribution defined by +the generative process + + s ~ InverseGamma(2, 3) + m ~ Normal(0, √s) + 1.5 ~ Normal(m, √s) + +_or_ a product of such distributions. + +The posterior for both `s` and `m` here is known in closed form. In particular, + + mean(s) == 19 / 8 + mean(m) == 3 / 4 +""" const DEMO_MODELS = ( demo_dot_assume_dot_observe(), demo_assume_index_observe(), From f3818c30ef79e106187e7e25c16d0dd459203331 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 10:51:12 +0100 Subject: [PATCH 068/221] updated some tests --- test/contexts.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/contexts.jl b/test/contexts.jl index ef916b18c..edcf5d0f3 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -165,7 +165,8 @@ end vn_without_prefix = remove_prefix(vn) # Let's check elementwise. - for vn_child in DynamicPPL.TestUtils.varnames(vn_without_prefix, val) + for vn_child in + DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) if get(val, getlens(vn_child)) === missing @test contextual_isassumption(context, vn_child) else @@ -197,7 +198,8 @@ end # `ConditionContext` with the conditioned variable. vn_without_prefix = remove_prefix(vn) - for vn_child in DynamicPPL.TestUtils.varnames(vn_without_prefix, val) + for vn_child in + DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) # `vn_child` should be in `context`. @test hasvalue_nested(context, vn_child) # Value should be the same as extracted above. From 8b799a4f1665a8570d5f52ec262fc40830c0aea7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 10:53:23 +0100 Subject: [PATCH 069/221] fixed docstrings --- src/test_utils.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 3f6ff5cf1..0cecd9d61 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -32,9 +32,9 @@ function varname_leaves(vn::VarName, val::AbstractArray) end """ - logprior_true(model, θ) + logprior_true(model, args...) -Return the `logprior` of `model` for `θ`. +Return the `logprior` of `model` for `args...`. This should generally be implemented by hand for every specific `model`. @@ -43,9 +43,9 @@ See also: [`logjoint_true`](@ref), [`loglikelihood_true`](@ref). function logprior_true end """ - loglikelihood_true(model, θ) + loglikelihood_true(model, args...) -Return the `loglikelihood` of `model` for `θ`. +Return the `loglikelihood` of `model` for `args...`. This should generally be implemented by hand for every specific `model`. From 93cb298ee1ffd921e3841c7e62a2d9cc093a78ef Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 10:53:55 +0100 Subject: [PATCH 070/221] fixed docstrings --- src/test_utils.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 0cecd9d61..b015ce8fa 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -34,7 +34,7 @@ end """ logprior_true(model, args...) -Return the `logprior` of `model` for `args...`. +Return the `logprior` of `model` for `args`. This should generally be implemented by hand for every specific `model`. @@ -45,7 +45,7 @@ function logprior_true end """ loglikelihood_true(model, args...) -Return the `loglikelihood` of `model` for `args...`. +Return the `loglikelihood` of `model` for `args`. This should generally be implemented by hand for every specific `model`. @@ -56,7 +56,7 @@ function loglikelihood_true end """ logjoint_true(model, args...) -Return the `logjoint` of `model` for `args...`. +Return the `logjoint` of `model` for `args`. Defaults to `logprior_true(model, args...) + loglikelihood_true(model, args..)`. @@ -77,7 +77,7 @@ end """ logjoint_true_with_logabsdet_jacobian(model::Model, args...) -Return a tuple `(args_unconstrained, logjoint)` of `model` for `args...`. +Return a tuple `(args_unconstrained, logjoint)` of `model` for `args`. Unlike [`logjoint_true`](@ref), the returned logjoint computation includes the log-absdet-jacobian adjustment, thus computing logjoint for the unconstrained variables. From ba5852b85eed5a8a13ffd3c0ef76346927670d63 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 10:54:55 +0100 Subject: [PATCH 071/221] imprvoed docstring --- src/test_utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index b015ce8fa..fd42e86ab 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -129,7 +129,7 @@ end Return a `NamedTuple` compatible with `varnames(model)` with values in support of `model`. -Compatible means that a `varname` from `varnames(model)` can be used to extract the +\"Compatible\" means that a `varname` from `varnames(model)` can be used to extract the corresponding value using the call `get(example_values(model), varname)`. """ example_values(model::Model) = example_values(Random.GLOBAL_RNG, model) @@ -140,7 +140,7 @@ example_values(model::Model) = example_values(Random.GLOBAL_RNG, model) Return a `NamedTuple` compatible with `varnames(model)` where the values represent the posterior mean under `model`. -Compatible means that a `varname` from `varnames(model)` can be used to extract the +\"Compatible\" means that a `varname` from `varnames(model)` can be used to extract the corresponding value using the call `get(posterior_mean_values(model), varname)`. """ function posterior_mean_values end From 328f7134b6ce2be0f332b6ca38f7173b49f8f9b5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 10:57:03 +0100 Subject: [PATCH 072/221] improved docstrings --- src/test_utils.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index fd42e86ab..1eca56350 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -129,8 +129,8 @@ end Return a `NamedTuple` compatible with `varnames(model)` with values in support of `model`. -\"Compatible\" means that a `varname` from `varnames(model)` can be used to extract the -corresponding value using the call `get(example_values(model), varname)`. +"Compatible" means that a `varname` from `varnames(model)` can be used to extract the +corresponding value using `get`, e.g. `get(example_values(model), varname)`. """ example_values(model::Model) = example_values(Random.GLOBAL_RNG, model) @@ -140,8 +140,8 @@ example_values(model::Model) = example_values(Random.GLOBAL_RNG, model) Return a `NamedTuple` compatible with `varnames(model)` where the values represent the posterior mean under `model`. -\"Compatible\" means that a `varname` from `varnames(model)` can be used to extract the -corresponding value using the call `get(posterior_mean_values(model), varname)`. +"Compatible" means that a `varname` from `varnames(model)` can be used to extract the +corresponding value using `get`, e.g. `get(posterior_mean_values(model), varname)`. """ function posterior_mean_values end From 801bd4caf110367b6f92f81e7fb226cf64e354b2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 11:22:50 +0100 Subject: [PATCH 073/221] renamed Base.keys(model) to varnames(model) in TestUtils --- src/test_utils.jl | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 0b9a5526b..f9130bd61 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -105,7 +105,7 @@ end function loglikelihood_true(model::Model{typeof(demo_dynamic_constraint)}, m, x) return zero(float(eltype(m))) end -function Base.keys(model::Model{typeof(demo_dynamic_constraint)}) +function varnames(model::Model{typeof(demo_dynamic_constraint)}) return [@varname(m), @varname(x)] end @@ -134,7 +134,7 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe)}, m) return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) end -function Base.keys(model::Model{typeof(demo_dot_assume_dot_observe)}) +function varnames(model::Model{typeof(demo_dot_assume_dot_observe)}) return [@varname(m[1]), @varname(m[2])] end @@ -156,7 +156,7 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_index_observe)}, m) return logpdf(MvNormal(m, 0.25 * I), model.args.x) end -function Base.keys(model::Model{typeof(demo_assume_index_observe)}) +function varnames(model::Model{typeof(demo_assume_index_observe)}) return [@varname(m[1]), @varname(m[2])] end @@ -173,7 +173,7 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_multivariate_observe)}, m) return logpdf(MvNormal(m, 0.25 * I), model.args.x) end -function Base.keys(model::Model{typeof(demo_assume_multivariate_observe)}) +function varnames(model::Model{typeof(demo_assume_multivariate_observe)}) return [@varname(m)] end @@ -195,7 +195,7 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index)}, m) return sum(logpdf.(Normal.(m, 0.5), model.args.x)) end -function Base.keys(model::Model{typeof(demo_dot_assume_observe_index)}) +function varnames(model::Model{typeof(demo_dot_assume_observe_index)}) return [@varname(m[1]), @varname(m[2])] end @@ -214,7 +214,7 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_dot_observe)}, m) return sum(logpdf.(Normal.(m, 0.5), model.args.x)) end -function Base.keys(model::Model{typeof(demo_assume_dot_observe)}) +function varnames(model::Model{typeof(demo_assume_dot_observe)}) return [@varname(m)] end @@ -231,7 +231,7 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, m) return logpdf(MvNormal(m, 0.25 * I), [10.0, 10.0]) end -function Base.keys(model::Model{typeof(demo_assume_observe_literal)}) +function varnames(model::Model{typeof(demo_assume_observe_literal)}) return [@varname(m)] end @@ -251,7 +251,7 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, m) return sum(logpdf.(Normal.(m, 0.5), fill(10.0, length(m)))) end -function Base.keys(model::Model{typeof(demo_dot_assume_observe_index_literal)}) +function varnames(model::Model{typeof(demo_dot_assume_observe_index_literal)}) return [@varname(m[1]), @varname(m[2])] end @@ -268,7 +268,7 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, m) return logpdf(Normal(m, 0.5), 10.0) end -function Base.keys(model::Model{typeof(demo_assume_literal_dot_observe)}) +function varnames(model::Model{typeof(demo_assume_literal_dot_observe)}) return [@varname(m)] end @@ -296,7 +296,7 @@ function loglikelihood_true( ) return sum(logpdf.(Normal.(m, 0.5), 10.0)) end -function Base.keys(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) +function varnames(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) return [@varname(m[1]), @varname(m[2])] end @@ -321,7 +321,7 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, m) return logpdf(MvNormal(m, 0.25 * I), model.args.x) end -function Base.keys(model::Model{typeof(demo_dot_assume_observe_submodel)}) +function varnames(model::Model{typeof(demo_dot_assume_observe_submodel)}) return [@varname(m[1]), @varname(m[2])] end @@ -342,7 +342,7 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, m) return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) end -function Base.keys(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) +function varnames(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) return [@varname(m[1]), @varname(m[2])] end @@ -366,7 +366,7 @@ function loglikelihood_true( ) return loglikelihood(MvNormal(vec(m), 0.25 * I), model.args.x) end -function Base.keys(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) +function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) return [@varname(m[:, 1]), @varname(m[:, 2])] end @@ -385,7 +385,7 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_array_dot_observe)}, m) return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) end -function Base.keys(model::Model{typeof(demo_dot_assume_array_dot_observe)}) +function varnames(model::Model{typeof(demo_dot_assume_array_dot_observe)}) return [@varname(m[1]), @varname(m[2])] end From 46f6f4c835e22976ab6333f8edfadb831cc8ff1b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 11:24:32 +0100 Subject: [PATCH 074/221] added default implementation and docstring for TestUtils.varnames --- src/test_utils.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/test_utils.jl b/src/test_utils.jl index f9130bd61..ea509e2de 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -87,6 +87,20 @@ See also: [`logprior_true`](@ref). """ function logprior_true_with_logabsdet_jacobian end +""" + varnames(model::Model) + +Return a collection of `VarName` as they are expected to appear in the model. + +Even though it is recommended to implement this by hand for a particular `Model`, +a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. +""" +function varnames(model::Model) + return collect( + keys(last(DynamicPPL.evaluate!!(model, SimpleVarInfo(Dict()), SamplingContext()))) + ) +end + """ demo_dynamic_constraint() From bcb767b34666a75db34b3676e214738f1015afb5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 11:27:14 +0100 Subject: [PATCH 075/221] replace handwritten by DocStringExtensions --- src/simple_varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 92abc327f..e82b35486 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -4,7 +4,7 @@ struct Constrained <: AbstractConstraint end struct Unconstrained <: AbstractConstraint end """ - SimpleVarInfo{NT,T,C} <: AbstractVarInfo + $(TYPEDEF) A simple wrapper of the parameters with a `logp` field for accumulation of the logdensity. From c5be1c2f8e2856f2fe509c6be6850bab82e1ab07 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 11:32:53 +0100 Subject: [PATCH 076/221] Apply suggestions from @devmotion Co-authored-by: David Widmann --- src/context_implementations.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index aead1dde1..855717eea 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -423,7 +423,7 @@ function dot_assume( dist::Distribution, var::AbstractArray, vns::AbstractArray{<:VarName}, vi ) r = map(vn -> vi[vn, dist], vns) - lp = sum(Bijectors.logpdf_with_trans.(dist, r, map(Base.Fix1(istrans, vi), vns))) + lp = sum(Bijectors.logpdf_with_trans.(dist, r, istrans.((vi,), vns))) return r, lp, vi end @@ -435,7 +435,7 @@ function dot_assume( ) @assert length(vns) == length(dists) == length(var) r = map((vn, dist) -> vi[vn, dist], vns, dists) - lp = sum(Bijectors.logpdf_with_trans.(dists, r, map(Base.Fix1(istrans, vi), vns))) + lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi, ), vns))) return r, lp, vi end @@ -449,7 +449,7 @@ function dot_assume( ) r = get_and_set_val!(rng, vi, vns, dists, spl) # Make sure `r` is not a matrix for multivariate distributions - lp = sum(Bijectors.logpdf_with_trans.(dists, r, map(Base.Fix1(istrans, vi), vns))) + lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))) return r, lp, vi end function dot_assume(rng, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any) From f266929153ba5e4b2e74ba41f90180e01eae7710 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 11:37:00 +0100 Subject: [PATCH 077/221] Update src/context_implementations.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 855717eea..f81809664 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -435,7 +435,7 @@ function dot_assume( ) @assert length(vns) == length(dists) == length(var) r = map((vn, dist) -> vi[vn, dist], vns, dists) - lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi, ), vns))) + lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))) return r, lp, vi end From c2dbbafbdd2f47bd9b1131bff52204f842e980fa Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 11:41:20 +0100 Subject: [PATCH 078/221] removed some asserts and use broadcast instead of map --- src/context_implementations.jl | 1 - src/simple_varinfo.jl | 6 ++---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index f81809664..5a8136a13 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -433,7 +433,6 @@ function dot_assume( vns::AbstractArray{<:VarName}, vi, ) - @assert length(vns) == length(dists) == length(var) r = map((vn, dist) -> vi[vn, dist], vns, dists) lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))) return r, lp, vi diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index e82b35486..70a6a4770 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -473,11 +473,9 @@ function dot_assume( # Transform if we're working in transformed space. value_raw = if dists isa Distribution - @assert length(vns) == length(value) - map((vn, val) -> maybe_link(vi, vn, dists, val), vns, value) + maybe_link.((vi,), vns, (dists, ), value) else - @assert length(vns) == length(dists) == length(value) - map((vn, dist, val) -> maybe_link(vi, vn, dist, val), vns, dists, value) + maybe_link.((vi,), vns, dists, value) end # Update `vi` From 1abb46c92e8dd4d9682d7116927e836f0a93bef7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 11:43:12 +0100 Subject: [PATCH 079/221] replace map with broadcasting to ensure consistent behavior --- src/context_implementations.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 5a8136a13..24ee74439 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -422,8 +422,8 @@ end function dot_assume( dist::Distribution, var::AbstractArray, vns::AbstractArray{<:VarName}, vi ) - r = map(vn -> vi[vn, dist], vns) - lp = sum(Bijectors.logpdf_with_trans.(dist, r, istrans.((vi,), vns))) + r = getindex.((vi,), vns, (dist,)) + lp = sum(Bijectors.logpdf_with_trans.((dist,), r, istrans.((vi,), vns))) return r, lp, vi end @@ -433,7 +433,7 @@ function dot_assume( vns::AbstractArray{<:VarName}, vi, ) - r = map((vn, dist) -> vi[vn, dist], vns, dists) + r = getindex.((vi,), vns, dists) lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))) return r, lp, vi end From 1086c6c8fd8fd4ffa65dcc5266e3eec1bffc1508 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 11:45:40 +0100 Subject: [PATCH 080/221] Update src/simple_varinfo.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/simple_varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 70a6a4770..4ca25b2f2 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -473,7 +473,7 @@ function dot_assume( # Transform if we're working in transformed space. value_raw = if dists isa Distribution - maybe_link.((vi,), vns, (dists, ), value) + maybe_link.((vi,), vns, (dists,), value) else maybe_link.((vi,), vns, dists, value) end From f2fb4a5c53c8fbbeafbb9d8395810530a1982bf7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 11:47:07 +0100 Subject: [PATCH 081/221] added a method nodist to allow broadcasting NoDist constructor --- src/context_implementations.jl | 4 ++-- src/distribution_wrappers.jl | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 24ee74439..a255eb02f 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -318,13 +318,13 @@ function dot_tilde_assume( end function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi) - nodist = right isa Distribution ? NoDist(right) : NoDist.(right) + nodist = nodist.(right) return dot_assume(nodist, left, vn, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, vi ) - nodist = right isa Distribution ? NoDist(right) : NoDist.(right) + nodist = nodist.(right) return dot_assume(rng, sampler, nodist, vn, left, vi) end diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index 07dc6f93f..fb4ecab64 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -35,6 +35,9 @@ struct NoDist{variate,support,Td<:Distribution{variate,support}} <: end NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name) +# HACK(torfjelde): Useful to have a constructor we can use in broadcasting. +nodist(dist::Distribution) = NoDist(dist) + Base.length(dist::NoDist) = Base.length(dist.dist) Base.size(dist::NoDist) = Base.size(dist.dist) From 490d24eea599b8bc113d8be5a8af92e1be4d0edc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 11:52:50 +0100 Subject: [PATCH 082/221] updated some tests --- test/simple_varinfo.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 9494ae6c1..9847d6959 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -80,7 +80,7 @@ _, svi_new = DynamicPPL.evaluate!!(model, svi, SamplingContext()) # Realization for `m` should be different wp. 1. - for vn in keys(model) + for vn in DynamicPPL.TestUtils.varnames(model) # `VarName` functions similarly to `PropertyLens` so # we just strip this part from `vn` to get a lens we can use # to extract the corresponding value of `m`. @@ -101,7 +101,7 @@ # Update the realizations in `svi_new`. svi_eval = svi_new - for vn in keys(model) + for vn in DynamicPPL.TestUtils.varnames(model) l = getlens(vn) svi_eval = DynamicPPL.setindex!!(svi_eval, get(m_eval, l), vn) end @@ -113,7 +113,7 @@ logπ = logjoint(model, svi_eval) # Values should not have changed. - for vn in keys(model) + for vn in DynamicPPL.TestUtils.varnames(model) l = getlens(vn) @test svi_eval[vn] == get(m_eval, l) end @@ -145,7 +145,7 @@ ) # Realizations from model should all be equal to the unconstrained realization. - for vn in keys(model) + for vn in DynamicPPL.TestUtils.varnames(model) @test get(retval_unconstrained, vn) ≈ svi[vn] rtol = 1e-6 end From 6350ccdc1be1958cc97e71847a2e44bf04b513c4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 11:52:58 +0100 Subject: [PATCH 083/221] renamed AbstractConstraint to AbstractTransformation and its subtypes --- src/simple_varinfo.jl | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 4ca25b2f2..20f86222a 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -1,7 +1,7 @@ -abstract type AbstractConstraint end +abstract type AbstractTransformation end -struct Constrained <: AbstractConstraint end -struct Unconstrained <: AbstractConstraint end +struct NoTransformation <: AbstractTransformation end +struct DefaultTransformation <: AbstractTransformation end """ $(TYPEDEF) @@ -86,7 +86,7 @@ ERROR: KeyError: key x[1:2] not found [...] ``` -You can also sample in _unconstrained_ space: +You can also sample in _transformed_ space: ```jldoctest simplevarinfo-general julia> @model demo_constrained() = x ~ Exponential() @@ -121,19 +121,19 @@ julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true ``` -Evaluation in unconstrained space of course also works: +Evaluation in transformed space of course also works: ```jldoctest simplevarinfo-general julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true) -Unconstrained SimpleVarInfo((x = -1.0,), 0.0) +Transformed SimpleVarInfo((x = -1.0,), 0.0) julia> # (✓) Positive probability mass on negative numbers! getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) -1.3678794411714423 -julia> # While if we forget to make indicate that it's unconstrained/transformed: +julia> # While if we forget to make indicate that it's transformed: vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false) -Constrained SimpleVarInfo((x = -1.0,), 0.0) +SimpleVarInfo((x = -1.0,), 0.0) julia> # (✓) No probability mass on negative numbers! getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) @@ -188,16 +188,16 @@ ERROR: type NamedTuple has no field b [...] ``` """ -struct SimpleVarInfo{NT,T,C<:AbstractConstraint} <: AbstractVarInfo +struct SimpleVarInfo{NT,T,C<:AbstractTransformation} <: AbstractVarInfo "underlying representation of the realization represented" values::NT "holds the accumulated log-probability" logp::T - "represents whether it assumes variables to be constrained or unconstrained" - constraint::C + "represents whether it assumes variables to be transformed" + transformation::C end -SimpleVarInfo(values, logp) = SimpleVarInfo(values, logp, Constrained()) +SimpleVarInfo(values, logp) = SimpleVarInfo(values, logp, NoTransformation()) function SimpleVarInfo{T}(θ) where {T<:Real} return SimpleVarInfo(θ, zero(T)) @@ -254,15 +254,15 @@ function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) end function Base.show( - io::IO, ::MIME"text/plain", svi::SimpleVarInfo{<:Any,<:Any,<:Constrained} + io::IO, ::MIME"text/plain", svi::SimpleVarInfo{<:Any,<:Any,<:NoTransformation} ) - return print(io, "Constrained SimpleVarInfo(", svi.values, ", ", svi.logp, ")") + return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ")") end function Base.show( - io::IO, ::MIME"text/plain", svi::SimpleVarInfo{<:Any,<:Any,<:Unconstrained} + io::IO, ::MIME"text/plain", svi::SimpleVarInfo{<:Any,<:Any,<:DefaultTransformation} ) - return print(io, "Unconstrained SimpleVarInfo(", svi.values, ", ", svi.logp, ")") + return print(io, "Transformed SimpleVarInfo(", svi.values, ", ", svi.logp, ")") end # `NamedTuple` @@ -516,14 +516,14 @@ increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing # NOTE: We don't implement `settrans!!(vi, trans, vn)`. function settrans!!(vi::SimpleVarInfo, trans) - return SimpleVarInfo(vi.values, vi.logp, trans ? Unconstrained() : Constrained()) + return SimpleVarInfo(vi.values, vi.logp, trans ? DefaultTransformation() : NoTransformation()) end function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) return Setfield.@set vi.varinfo = settrans!!(vi, trans) end -istrans(vi::SimpleVarInfo{<:Any,<:Any,<:Constrained}) = false -istrans(vi::SimpleVarInfo{<:Any,<:Any,<:Unconstrained}) = true +istrans(vi::SimpleVarInfo{<:Any,<:Any,<:NoTransformation}) = false +istrans(vi::SimpleVarInfo{<:Any,<:Any,<:DefaultTransformation}) = true istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi) istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) From 951e4c36ba67f518671744282d59f5dc7d044b68 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 11:58:49 +0100 Subject: [PATCH 084/221] updated tests --- test/loglikelihoods.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index 0e7f9a3d9..eaf1e00bd 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -2,7 +2,7 @@ @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS vi = VarInfo(m) - for vn in keys(m) + for vn in DynamicPPL.TestUtils.varnames(m) if vi[vn] isa Real vi = DynamicPPL.setindex!!(vi, 1.0, vn) else From dcd92c926f28cff980fc5a91623037269d8af36c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 13:51:36 +0100 Subject: [PATCH 085/221] fixed nodist usage --- src/context_implementations.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index a255eb02f..2bef931a4 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -318,14 +318,12 @@ function dot_tilde_assume( end function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi) - nodist = nodist.(right) - return dot_assume(nodist, left, vn, vi) + return dot_assume(nodist.(right), left, vn, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, vi ) - nodist = nodist.(right) - return dot_assume(rng, sampler, nodist, vn, left, vi) + return dot_assume(rng, sampler, nodist.(right), vn, left, vi) end # `PriorContext` From 2922ffa352a8a928b8898b0bcf788bdbb3fab82a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 14:09:24 +0100 Subject: [PATCH 086/221] fixed implementation of nodist --- src/context_implementations.jl | 4 ++-- src/distribution_wrappers.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 2bef931a4..0068546d5 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -318,12 +318,12 @@ function dot_tilde_assume( end function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi) - return dot_assume(nodist.(right), left, vn, vi) + return dot_assume(nodist(right), left, vn, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, vi ) - return dot_assume(rng, sampler, nodist.(right), vn, left, vi) + return dot_assume(rng, sampler, nodist(right), vn, left, vi) end # `PriorContext` diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index fb4ecab64..9761123c0 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -35,8 +35,8 @@ struct NoDist{variate,support,Td<:Distribution{variate,support}} <: end NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name) -# HACK(torfjelde): Useful to have a constructor we can use in broadcasting. nodist(dist::Distribution) = NoDist(dist) +nodist(dists::AbstractArray) = nodist.(dist) Base.length(dist::NoDist) = Base.length(dist.dist) Base.size(dist::NoDist) = Base.size(dist.dist) From 5266a4ba6fcc233887903a519f076e24c86ff35c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 14:16:04 +0100 Subject: [PATCH 087/221] fixed typo --- src/distribution_wrappers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index 9761123c0..65479f035 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -36,7 +36,7 @@ end NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name) nodist(dist::Distribution) = NoDist(dist) -nodist(dists::AbstractArray) = nodist.(dist) +nodist(dists::AbstractArray) = nodist.(dists) Base.length(dist::NoDist) = Base.length(dist.dist) Base.size(dist::NoDist) = Base.size(dist.dist) From 3c38710e138a9bf1d9f35e0500496df4b4f4c581 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 14:16:09 +0100 Subject: [PATCH 088/221] formatting --- src/simple_varinfo.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 20f86222a..b6f01fd08 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -516,7 +516,9 @@ increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing # NOTE: We don't implement `settrans!!(vi, trans, vn)`. function settrans!!(vi::SimpleVarInfo, trans) - return SimpleVarInfo(vi.values, vi.logp, trans ? DefaultTransformation() : NoTransformation()) + return SimpleVarInfo( + vi.values, vi.logp, trans ? DefaultTransformation() : NoTransformation() + ) end function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) return Setfield.@set vi.varinfo = settrans!!(vi, trans) From ba92f3f1b47ba414ea32736c7901d9e1bb368684 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 14:23:24 +0100 Subject: [PATCH 089/221] bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7adb220b3..bfa13d956 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.19.3" +version = "0.19.4" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From 70c864c1b910e6008d1e02f31f36f5cbd7d7fe6d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 14:37:58 +0100 Subject: [PATCH 090/221] fixed ThreadsafeVarInfo --- src/threadsafe.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index b42cf82a5..7c7dd13ac 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -62,11 +62,24 @@ islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl getindex(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex(vi.varinfo, spl) getindex(vi::ThreadSafeVarInfo, spl::SampleFromPrior) = getindex(vi.varinfo, spl) getindex(vi::ThreadSafeVarInfo, spl::SampleFromUniform) = getindex(vi.varinfo, spl) + getindex(vi::ThreadSafeVarInfo, vn::VarName) = getindex(vi.varinfo, vn) +function getindex(vi::ThreadSafeVarInfo, vn::VarName, dist::Distribution) + return getindex(vi.varinfo, vn, dist) +end getindex(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) = getindex(vi.varinfo, vns) +function getindex(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}, dist::Distribution) + return getindex(vi.varinfo, vns, dist) +end getindex_raw(vi::ThreadSafeVarInfo, vn::VarName) = getindex_raw(vi.varinfo, vn) +function getindex_raw(vi::ThreadSafeVarInfo, vn::VarName, dist::Distribution) + return getindex_raw(vi.varinfo, vn, dist) +end getindex_raw(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) = getindex_raw(vi.varinfo, vns) +function getindex_raw(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}, dist::Distribution) + return getindex_raw(vi.varinfo, vns, dist) +end function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::AbstractSampler) return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) From 58436994f423fff5077020766fa708e693f42566 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 15:22:21 +0100 Subject: [PATCH 091/221] fixed tests of pointwise_loglikelihoods --- src/test_utils.jl | 2 +- test/loglikelihoods.jl | 22 ++++++++-------------- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 1eca56350..7ce3cf3b8 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -700,7 +700,7 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) - return [@varname(s[:, 1]), @varname(s[:, 2]), @varname(m[1]), @varname(m[2])] + return [@varname(s[:, 1]), @varname(s[:, 2]), @varname(m)] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)} diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index eaf1e00bd..bd04a76a5 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -1,15 +1,14 @@ @testset "loglikelihoods.jl" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - vi = VarInfo(m) + example_values = DynamicPPL.TestUtils.example_values(m) + # Instantiate a `VarInfo` with the example values. + vi = VarInfo(m) for vn in DynamicPPL.TestUtils.varnames(m) - if vi[vn] isa Real - vi = DynamicPPL.setindex!!(vi, 1.0, vn) - else - vi = DynamicPPL.setindex!!(vi, ones(size(vi[vn])), vn) - end + vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) end + # Compute the pointwise loglikelihoods. lls = pointwise_loglikelihoods(m, vi) if isempty(lls) @@ -17,14 +16,9 @@ continue end - loglikelihood = if length(keys(lls)) == 1 && length(m.args.x) == 1 - # Only have one observation, so we need to double it - # for comparison with other models. - 2 * sum(lls[first(keys(lls))]) - else - sum(sum, values(lls)) - end + loglikelihood = sum(sum, values(lls)) + loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(m, example_values...) - @test loglikelihood ≈ -324.45158270528947 + @test loglikelihood ≈ loglikelihood_true end end From 66f41a936a3d3172fee9d0320d4eaeb6e78f5c84 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 17:08:15 +0100 Subject: [PATCH 092/221] Apply suggestions from code review Co-authored-by: David Widmann --- src/context_implementations.jl | 14 +++++++------- src/simple_varinfo.jl | 14 +++++++------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 0068546d5..742082346 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -291,7 +291,7 @@ function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left, var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!!.(Ref(vi), false, _vns) + settrans!!.((vi,), false, _vns) dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi) else dot_tilde_assume(LikelihoodContext(), right, left, vn, vi) @@ -310,7 +310,7 @@ function dot_tilde_assume( var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!!.(Ref(vi), false, _vns) + settrans!!.((vi,), false, _vns) dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi) else dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi) @@ -332,7 +332,7 @@ function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!!.(Ref(vi), false, _vns) + settrans!!.((vi,), false, _vns) dot_tilde_assume(PriorContext(), _right, _left, _vns, vi) else dot_tilde_assume(PriorContext(), right, left, vn, vi) @@ -351,7 +351,7 @@ function dot_tilde_assume( var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!!.(Ref(vi), false, _vns) + settrans!!.((vi,), false, _vns) dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi) else dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi) @@ -525,11 +525,11 @@ function get_and_set_val!( # 2. Define an anonymous function which returns `nothing`, which # we then broadcast. This will allocate a vector of `nothing` though. if istrans(vi) - push!!.(Ref(vi), vns, link.(Ref(vi), vns, dists, r), dists, Ref(spl)) + push!!.((vi,), vns, link.((vi,), vns, dists, r), dists, (spl,)) # `push!!` sets the trans-flag to `false` by default. - settrans!!.(Ref(vi), true, vns) + settrans!!.((vi,), true, vns) else - push!!.(Ref(vi), vns, r, dists, Ref(spl)) + push!!.((vi,), vns, r, dists, (spl,)) end end return r diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index b6f01fd08..74a0bac20 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -267,13 +267,13 @@ end # `NamedTuple` function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution) - return maybe_invlink(vi, vn, dist, Base.getindex(vi, vn)) + return maybe_invlink(vi, vn, dist, getindex(vi, vn)) end function Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution) - vals_linked = map(vns) do vn - maybe_invlink(vi, vn, dist, Base.getindex(vi, vn)) + vals_linked = mapreduce(vcat, vns) do vn + getindex(vi, vn, dist) end - return reconstruct(dist, reduce(vcat, vals_linked), length(vns)) + return reconstruct(dist, vals_linked, length(vns)) end Base.getindex(vi::SimpleVarInfo, vn::VarName) = get(vi.values, vn) @@ -326,9 +326,9 @@ function getindex_raw(vi::SimpleVarInfo, vn::VarName, dist::Distribution) end getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}) = vi[vns] function getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution) - vals = getindex_raw(vi, vns) # `reconstruct` expects a flattened `Vector` regardless of the type of `dist`, so we `vcat` everything. - return reconstruct(dist, reduce(vcat, vals), length(vns)) + vals = mapreduce(Base.Fix1(getindex_raw, vi), vcat, vns) + return reconstruct(dist, vals, length(vns)) end Base.haskey(vi::SimpleVarInfo, vn::VarName) = _haskey(vi.values, vn) @@ -482,7 +482,7 @@ function dot_assume( vi = BangBang.setindex!!(vi, value_raw, vns) # Compute logp. - lp = sum(Bijectors.logpdf_with_trans.(dists, value, map(Base.Fix1(istrans, vi), vns))) + lp = sum(Bijectors.logpdf_with_trans.(dists, value, istrans.((vi,), vns))) return value, lp, vi end From eb2d6b5ae2b4350f5df88c527ff95abaf62103f8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 17:37:32 +0100 Subject: [PATCH 093/221] allow type-stable settrans!! for SimpleVarInfo --- src/simple_varinfo.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 74a0bac20..dd546fb3f 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -516,16 +516,16 @@ increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing # NOTE: We don't implement `settrans!!(vi, trans, vn)`. function settrans!!(vi::SimpleVarInfo, trans) - return SimpleVarInfo( - vi.values, vi.logp, trans ? DefaultTransformation() : NoTransformation() - ) + return settrans!!(vi, trans ? DefaultTransformation() : NoTransformation()) +end +function settrans!!(vi::SimpleVarInfo, transformation::AbstractTransformation) + return Setfield.@set vi.transformation = transformation end function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) - return Setfield.@set vi.varinfo = settrans!!(vi, trans) + return Setfield.@set vi.varinfo = settrans!!(vi.varinfo, trans) end -istrans(vi::SimpleVarInfo{<:Any,<:Any,<:NoTransformation}) = false -istrans(vi::SimpleVarInfo{<:Any,<:Any,<:DefaultTransformation}) = true +istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi) istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) From e8cdb91e7096e1570ee447fcb98e089f220d136f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 17:45:01 +0100 Subject: [PATCH 094/221] use maybe_invlink in getindex for VarInfo --- src/varinfo.jl | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 582582b77..b7048b64b 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -941,11 +941,7 @@ getindex(vi::AbstractVarInfo, vn::VarName) = getindex(vi, vn, getdist(vi, vn)) function getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution) @assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" val = getindex_raw(vi, vn, dist) - return if istrans(vi, vn) - Bijectors.invlink(dist, val) - else - val - end + return maybe_invlink(vi, vn, dist, val) end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) return getindex(vi, vns, getdist(vi, first(vns))) @@ -953,11 +949,7 @@ end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" val = getindex_raw(vi, vns, dist) - return if istrans(vi, vns[1]) - Bijectors.invlink(dist, val) - else - val - end + return maybe_invlink.((vi,), vns, (dist,), val) end getindex_raw(vi::AbstractVarInfo, vn::VarName) = getindex_raw(vi, vn, getdist(vi, vn)) From 359d384eae0847e8271ab60493fdb0684137295b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 17:45:11 +0100 Subject: [PATCH 095/221] added comment to warn about buggy behavior --- src/varinfo.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/varinfo.jl b/src/varinfo.jl index b7048b64b..938fa6285 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -944,6 +944,8 @@ function getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution) return maybe_invlink(vi, vn, dist, val) end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) + # NOTE(torfjelde): Using `getdist(vi, first(vns))` won't be correct in cases + # such as `x .~ [Normal(), Exponential()]`. return getindex(vi, vns, getdist(vi, first(vns))) end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) From ab0a99bf8e5b86cce12a407b376b6ea197d98a22 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 17:46:51 +0100 Subject: [PATCH 096/221] Update src/context_implementations.jl Co-authored-by: David Widmann --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 742082346..4454701b8 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -514,7 +514,7 @@ function get_and_set_val!( else # r = reshape(vi[vec(vns)], size(vns)) r_raw = getindex_raw(vi, vec(vns)) - r = maybe_invlink.(Ref(vi), vns, dists, reshape(r_raw, size(vns))) + r = maybe_invlink.((vi,), vns, dists, reshape(r_raw, size(vns))) end else f = (vn, dist) -> init(rng, dist, spl) From dd109134bda4574ba1a9f4ec2cda8a44f40c2947 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 17:50:16 +0100 Subject: [PATCH 097/221] just fix potential bug in getindex for VarInfo --- src/varinfo.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 938fa6285..58f737785 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -944,9 +944,7 @@ function getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution) return maybe_invlink(vi, vn, dist, val) end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) - # NOTE(torfjelde): Using `getdist(vi, first(vns))` won't be correct in cases - # such as `x .~ [Normal(), Exponential()]`. - return getindex(vi, vns, getdist(vi, first(vns))) + return getindex.((vi,), vns, getdist.((vi,), vns)) end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" From 18d28ccc8647720551d6d15cf8e742474d5deeeb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 17:51:48 +0100 Subject: [PATCH 098/221] revert previous change because it likely introduces bugs --- src/varinfo.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 58f737785..938fa6285 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -944,7 +944,9 @@ function getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution) return maybe_invlink(vi, vn, dist, val) end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) - return getindex.((vi,), vns, getdist.((vi,), vns)) + # NOTE(torfjelde): Using `getdist(vi, first(vns))` won't be correct in cases + # such as `x .~ [Normal(), Exponential()]`. + return getindex(vi, vns, getdist(vi, first(vns))) end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" From 32b7aab6744c52af7ea4b01d9fa7916e38442a5c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 17:53:51 +0100 Subject: [PATCH 099/221] elaborate in comment regarding potential bug --- src/varinfo.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/varinfo.jl b/src/varinfo.jl index 938fa6285..4748643f3 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -946,6 +946,9 @@ end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) # NOTE(torfjelde): Using `getdist(vi, first(vns))` won't be correct in cases # such as `x .~ [Normal(), Exponential()]`. + # BUT we also can't fix this here because this will lead to "incorrect" + # behavior if `vns` arose from something like `x .~ MvNormal(zeros(2), I)`, + # where by "incorrect" we mean there exists pieces of code expecting this behavior. return getindex(vi, vns, getdist(vi, first(vns))) end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) From f782fe25b1b992e496e823a707688803f9c6eb99 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 17:55:31 +0100 Subject: [PATCH 100/221] added error message to dot_assume --- src/simple_varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index dd546fb3f..e7a389ee8 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -494,7 +494,7 @@ function dot_assume( var::AbstractMatrix, vi::SimpleOrThreadSafeSimple, ) - @assert length(dist) == size(var, 1) + @assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))" # r = get_and_set_val!(rng, vi, vns, dist, spl) n = length(vns) From 7d3493dc53c9f27adf93f498df5dc376f74c364d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 17:56:06 +0100 Subject: [PATCH 101/221] added error message to dot_assume again --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 4454701b8..6271b5d8c 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -389,7 +389,7 @@ function dot_assume( vns::AbstractVector{<:VarName}, vi::AbstractVarInfo, ) - @assert length(dist) == size(var, 1) + @assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))" # NOTE: We cannot work with `var` here because we might have a model of the form # # m = Vector{Float64}(undef, n) From 912d7f847fd4642837a55d1cabe4428e79be1146 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 2 Jul 2022 12:03:45 +0100 Subject: [PATCH 102/221] Apply suggestions from code review Co-authored-by: David Widmann --- src/test_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 7ce3cf3b8..ad28585d7 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -185,7 +185,7 @@ end function _demo_logprior_true_with_logabsdet_jacobian(model, s, m) b = Bijectors.bijector(InverseGamma(2, 3)) s_unconstrained = b.(s) - Δlogp = sum(Base.Fix1(Bijectors.logabsdetjac, b).(s)) + Δlogp = sum(Base.Fix1(Bijectors.logabsdetjac, b), s) return (s=s_unconstrained, m=m), logprior_true(model, s, m) - Δlogp end From a276e4a662795b9726e91823a52c8f5b20188965 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 2 Jul 2022 12:05:56 +0100 Subject: [PATCH 103/221] renamed posterior_mean_values to posterior_mean --- docs/src/api.md | 2 +- src/test_utils.jl | 34 +++++++++++++++++----------------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index debad2944..ab7e7fe60 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -132,7 +132,7 @@ Finally, the following methods can also be of use: ```@docs DynamicPPL.TestUtils.varnames DynamicPPL.TestUtils.example_values -DynamicPPL.TestUtils.posterior_mean_values +DynamicPPL.TestUtils.posterior_mean ``` ## Advanced diff --git a/src/test_utils.jl b/src/test_utils.jl index ad28585d7..bbc974677 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -135,15 +135,15 @@ corresponding value using `get`, e.g. `get(example_values(model), varname)`. example_values(model::Model) = example_values(Random.GLOBAL_RNG, model) """ - posterior_mean_values(model::Model) + posterior_mean(model::Model) Return a `NamedTuple` compatible with `varnames(model)` where the values represent the posterior mean under `model`. "Compatible" means that a `varname` from `varnames(model)` can be used to extract the -corresponding value using `get`, e.g. `get(posterior_mean_values(model), varname)`. +corresponding value using `get`, e.g. `get(posterior_mean(model), varname)`. """ -function posterior_mean_values end +function posterior_mean end """ demo_dynamic_constraint() @@ -226,7 +226,7 @@ function example_values( end return (s=s, m=m) end -function posterior_mean_values(model::Model{typeof(demo_dot_assume_dot_observe)}) +function posterior_mean(model::Model{typeof(demo_dot_assume_dot_observe)}) vals = example_values(model) vals.s .= 2.375 vals.m .= 0.75 @@ -274,7 +274,7 @@ function example_values( end return (s=s, m=m) end -function posterior_mean_values(model::Model{typeof(demo_assume_index_observe)}) +function posterior_mean(model::Model{typeof(demo_assume_index_observe)}) vals = example_values(model) vals.s .= 2.375 vals.m .= 0.75 @@ -311,7 +311,7 @@ function example_values( s = rand(rng, product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])) return (s=s, m=rand(rng, MvNormal(zero(model.args.x), Diagonal(s)))) end -function posterior_mean_values(model::Model{typeof(demo_assume_multivariate_observe)}) +function posterior_mean(model::Model{typeof(demo_assume_multivariate_observe)}) vals = example_values(model) vals.s .= 2.375 vals.m .= 0.75 @@ -357,7 +357,7 @@ function example_values( end return (s=s, m=m) end -function posterior_mean_values(model::Model{typeof(demo_dot_assume_observe_index)}) +function posterior_mean(model::Model{typeof(demo_dot_assume_observe_index)}) vals = example_values(model) vals.s .= 2.375 vals.m .= 0.75 @@ -395,7 +395,7 @@ function example_values( m = rand(rng, Normal(0, sqrt(s))) return (s=s, m=m) end -function posterior_mean_values(model::Model{typeof(demo_assume_dot_observe)}) +function posterior_mean(model::Model{typeof(demo_assume_dot_observe)}) return (s=2.375, m=0.75) end @@ -429,7 +429,7 @@ function example_values( s = rand(rng, product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])) return (s=s, m=rand(rng, MvNormal(zeros(2), Diagonal(s)))) end -function posterior_mean_values(model::Model{typeof(demo_assume_observe_literal)}) +function posterior_mean(model::Model{typeof(demo_assume_observe_literal)}) vals = example_values(model) vals.s .= 2.375 vals.m .= 0.75 @@ -476,7 +476,7 @@ function example_values( end return (s=s, m=m) end -function posterior_mean_values(model::Model{typeof(demo_dot_assume_observe_index_literal)}) +function posterior_mean(model::Model{typeof(demo_dot_assume_observe_index_literal)}) vals = example_values(model) vals.s .= 2.375 vals.m .= 0.75 @@ -512,7 +512,7 @@ function example_values( m = rand(rng, Normal(0, sqrt(s))) return (s=s, m=m) end -function posterior_mean_values(model::Model{typeof(demo_assume_literal_dot_observe)}) +function posterior_mean(model::Model{typeof(demo_assume_literal_dot_observe)}) return (s=2.375, m=0.75) end @@ -564,7 +564,7 @@ function example_values( end return (s=s, m=m) end -function posterior_mean_values( +function posterior_mean( model::Model{typeof(demo_assume_submodel_observe_index_literal)} ) vals = example_values(model) @@ -615,7 +615,7 @@ function example_values( end return (s=s, m=m) end -function posterior_mean_values(model::Model{typeof(demo_dot_assume_observe_submodel)}) +function posterior_mean(model::Model{typeof(demo_dot_assume_observe_submodel)}) vals = example_values(model) vals.s .= 2.375 vals.m .= 0.75 @@ -660,7 +660,7 @@ function example_values( end return (s=s, m=m) end -function posterior_mean_values(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) +function posterior_mean(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) vals = example_values(model) vals.s .= 2.375 vals.m .= 0.75 @@ -711,7 +711,7 @@ function example_values( m = rand(rng, MvNormal(zeros(n), Diagonal(vec(s)))) return (s=s, m=m) end -function posterior_mean_values( +function posterior_mean( model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)} ) vals = example_values(model) @@ -758,7 +758,7 @@ Test that `sampler` produces correct marginal posterior means on each model in ` In short, this method iterates through `models`, calls `AbstractMCMC.sample` on the `model` and `sampler` to produce a `chain`, and then checks `meanfunction(chain, vn)` for every (leaf) varname `vn` against the corresponding value returned by -[`posterior_mean_values`](@ref) for each model. +[`posterior_mean`](@ref) for each model. # Arguments - `meanfunction`: A callable which computes the mean of the marginal means from the @@ -783,7 +783,7 @@ function test_sampler_on_models( ) @testset "$(typeof(sampler)) on $(nameof(model))" for model in models chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) - target_values = posterior_mean_values(model) + target_values = posterior_mean(model) for vn in varnames(model) # We want to compare elementwise which can be achieved by # extracting the leaves of the `VarName` and the corresponding value. From 626eea212d8e500c0af51fb63a4b216c51a2a436 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 2 Jul 2022 12:55:10 +0100 Subject: [PATCH 104/221] made demo models a bit more complex, now including different observations --- src/test_utils.jl | 175 +++++++++++++++++++++++++++++++--------------- 1 file changed, 119 insertions(+), 56 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index bbc974677..f88414a7d 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -190,7 +190,7 @@ function _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end @model function demo_dot_assume_dot_observe( - x=[1.5, 1.5], ::Type{TV}=Vector{Float64} + x=[1.5, 2.0], ::Type{TV}=Vector{Float64} ) where {TV} # `dot_assume` and `observe` s = TV(undef, length(x)) @@ -228,13 +228,18 @@ function example_values( end function posterior_mean(model::Model{typeof(demo_dot_assume_dot_observe)}) vals = example_values(model) - vals.s .= 2.375 - vals.m .= 0.75 + + vals.s[1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s[2] = 8 / 3 + vals.m[2] = 1 + return vals end @model function demo_assume_index_observe( - x=[1.5, 1.5], ::Type{TV}=Vector{Float64} + x=[1.5, 2.0], ::Type{TV}=Vector{Float64} ) where {TV} # `assume` with indexing and `observe` s = TV(undef, length(x)) @@ -276,12 +281,17 @@ function example_values( end function posterior_mean(model::Model{typeof(demo_assume_index_observe)}) vals = example_values(model) - vals.s .= 2.375 - vals.m .= 0.75 + + vals.s[1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s[2] = 8 / 3 + vals.m[2] = 1 + return vals end -@model function demo_assume_multivariate_observe(x=[1.5, 1.5]) +@model function demo_assume_multivariate_observe(x=[1.5, 2.0]) # Multivariate `assume` and `observe` s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) m ~ MvNormal(zero(x), Diagonal(s)) @@ -313,13 +323,18 @@ function example_values( end function posterior_mean(model::Model{typeof(demo_assume_multivariate_observe)}) vals = example_values(model) - vals.s .= 2.375 - vals.m .= 0.75 + + vals.s[1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s[2] = 8 / 3 + vals.m[2] = 1 + return vals end @model function demo_dot_assume_observe_index( - x=[1.5, 1.5], ::Type{TV}=Vector{Float64} + x=[1.5, 2.0], ::Type{TV}=Vector{Float64} ) where {TV} # `dot_assume` and `observe` with indexing s = TV(undef, length(x)) @@ -359,14 +374,19 @@ function example_values( end function posterior_mean(model::Model{typeof(demo_dot_assume_observe_index)}) vals = example_values(model) - vals.s .= 2.375 - vals.m .= 0.75 + + vals.s[1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s[2] = 8 / 3 + vals.m[2] = 1 + return vals end # Using vector of `length` 1 here so the posterior of `m` is the same # as the others. -@model function demo_assume_dot_observe(x=[1.5]) +@model function demo_assume_dot_observe(x=[1.5, 2.0]) # `assume` and `dot_observe` s ~ InverseGamma(2, 3) m ~ Normal(0, sqrt(s)) @@ -396,16 +416,16 @@ function example_values( return (s=s, m=m) end function posterior_mean(model::Model{typeof(demo_assume_dot_observe)}) - return (s=2.375, m=0.75) + return (s=49 / 24, m=7 / 6) end @model function demo_assume_observe_literal() # `assume` and literal `observe` s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) m ~ MvNormal(zeros(2), Diagonal(s)) - [1.5, 1.5] ~ MvNormal(m, Diagonal(s)) + [1.5, 2.0] ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=[1.5, 1.5], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) end function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) @@ -413,7 +433,7 @@ function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) return logpdf(s_dist, s) + logpdf(m_dist, m) end function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) - return logpdf(MvNormal(m, Diagonal(s)), [1.5, 1.5]) + return logpdf(MvNormal(m, Diagonal(s)), [1.5, 2.0]) end function logprior_true_with_logabsdet_jacobian( model::Model{typeof(demo_assume_observe_literal)}, s, m @@ -431,8 +451,13 @@ function example_values( end function posterior_mean(model::Model{typeof(demo_assume_observe_literal)}) vals = example_values(model) - vals.s .= 2.375 - vals.m .= 0.75 + + vals.s[1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s[2] = 8 / 3 + vals.m[2] = 1 + return vals end @@ -443,11 +468,10 @@ end s .~ InverseGamma(2, 3) m .~ Normal.(0, sqrt.(s)) - for i in eachindex(m) - 1.5 ~ Normal(m[i], sqrt(s[i])) - end + 1.5 ~ Normal(m[1], sqrt(s[1])) + 2.0 ~ Normal(m[2], sqrt(s[2])) - return (; s=s, m=m, x=fill(1.5, length(m)), logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) end function logprior_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -455,7 +479,7 @@ end function loglikelihood_true( model::Model{typeof(demo_dot_assume_observe_index_literal)}, s, m ) - return sum(logpdf.(Normal.(m, sqrt.(s)), fill(1.5, length(m)))) + return sum(logpdf.(Normal.(m, sqrt.(s)), [1.5, 2.0])) end function logprior_true_with_logabsdet_jacobian( model::Model{typeof(demo_dot_assume_observe_index_literal)}, s, m @@ -478,8 +502,13 @@ function example_values( end function posterior_mean(model::Model{typeof(demo_dot_assume_observe_index_literal)}) vals = example_values(model) - vals.s .= 2.375 - vals.m .= 0.75 + + vals.s[1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s[2] = 8 / 3 + vals.m[2] = 1 + return vals end @@ -487,15 +516,15 @@ end # `assume` and literal `dot_observe` s ~ InverseGamma(2, 3) m ~ Normal(0, sqrt(s)) - [1.5] .~ Normal(m, sqrt(s)) + [1.5, 2.0] .~ Normal(m, sqrt(s)) - return (; s=s, m=m, x=[1.5], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) end function logprior_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m) return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) end function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m) - return logpdf(Normal(m, sqrt(s)), 1.5) + return logpdf(Normal(m, sqrt(s)), [1.5, 2.0]) end function logprior_true_with_logabsdet_jacobian( model::Model{typeof(demo_assume_literal_dot_observe)}, s, m @@ -513,7 +542,7 @@ function example_values( return (s=s, m=m) end function posterior_mean(model::Model{typeof(demo_assume_literal_dot_observe)}) - return (s=2.375, m=0.75) + return (s=49 / 24, m=7 / 6) end @model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} @@ -528,11 +557,10 @@ end @model function demo_assume_submodel_observe_index_literal() # Submodel prior @submodel s, m = _prior_dot_assume() - for i in eachindex(m, s) - 1.5 ~ Normal(m[i], sqrt(s[i])) - end + 1.5 ~ Normal(m[1], sqrt(s[1])) + 2.0 ~ Normal(m[2], sqrt(s[2])) - return (; s=s, m=m, x=[1.5, 1.5], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) end function logprior_true( model::Model{typeof(demo_assume_submodel_observe_index_literal)}, s, m @@ -542,7 +570,7 @@ end function loglikelihood_true( model::Model{typeof(demo_assume_submodel_observe_index_literal)}, s, m ) - return sum(logpdf.(Normal.(m, sqrt.(s)), 1.5)) + return sum(logpdf.(Normal.(m, sqrt.(s)), [1.5, 2.0])) end function logprior_true_with_logabsdet_jacobian( model::Model{typeof(demo_assume_submodel_observe_index_literal)}, s, m @@ -564,12 +592,15 @@ function example_values( end return (s=s, m=m) end -function posterior_mean( - model::Model{typeof(demo_assume_submodel_observe_index_literal)} -) +function posterior_mean(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) vals = example_values(model) - vals.s .= 2.375 - vals.m .= 0.75 + + vals.s[1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s[2] = 8 / 3 + vals.m[2] = 1 + return vals end @@ -578,7 +609,7 @@ end end @model function demo_dot_assume_observe_submodel( - x=[1.5, 1.5], ::Type{TV}=Vector{Float64} + x=[1.5, 2.0], ::Type{TV}=Vector{Float64} ) where {TV} s = TV(undef, length(x)) s .~ InverseGamma(2, 3) @@ -617,13 +648,18 @@ function example_values( end function posterior_mean(model::Model{typeof(demo_dot_assume_observe_submodel)}) vals = example_values(model) - vals.s .= 2.375 - vals.m .= 0.75 + + vals.s[1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s[2] = 8 / 3 + vals.m[2] = 1 + return vals end @model function demo_dot_assume_dot_observe_matrix( - x=fill(1.5, 2, 1), ::Type{TV}=Vector{Float64} + x=transpose([1.5 2.0;]), ::Type{TV}=Vector{Float64} ) where {TV} s = TV(undef, length(x)) s .~ InverseGamma(2, 3) @@ -662,13 +698,18 @@ function example_values( end function posterior_mean(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) vals = example_values(model) - vals.s .= 2.375 - vals.m .= 0.75 + + vals.s[1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s[2] = 8 / 3 + vals.m[2] = 1 + return vals end @model function demo_dot_assume_matrix_dot_observe_matrix( - x=fill(1.5, 2, 1), ::Type{TV}=Array{Float64} + x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64} ) where {TV} n = length(x) d = length(x) ÷ 2 @@ -711,12 +752,15 @@ function example_values( m = rand(rng, MvNormal(zeros(n), Diagonal(vec(s)))) return (s=s, m=m) end -function posterior_mean( - model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)} -) +function posterior_mean(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) vals = example_values(model) - vals.s .= 2.375 - vals.m .= 0.75 + + vals.s[1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s[2] = 8 / 3 + vals.m[2] = 1 + return vals end @@ -727,13 +771,32 @@ the generative process s ~ InverseGamma(2, 3) m ~ Normal(0, √s) 1.5 ~ Normal(m, √s) + 2.0 ~ Normal(m, √s) + +or by + + s[1] ~ InverseGamma(2, 3) + s[2] ~ InverseGamma(2, 3) + m[1] ~ Normal(0, √s) + m[2] ~ Normal(0, √s) + 1.5 ~ Normal(m[1], √s[1]) + 2.0 ~ Normal(m[2], √s[2]) + +These are examples of a Normal-InverseGamma conjugate prior with Normal likelihood, +for which the posterior is known in closed form. + +In particular, for the univariate model (the former one): + + mean(s) == 49 / 24 + mean(m) == 7 / 6 -_or_ a product of such distributions. +And for the multivariate one (the latter one): -The posterior for both `s` and `m` here is known in closed form. In particular, + mean(s[1]) == 19 / 8 + mean(m[1]) == 3 / 4 + mean(s[2]) == 8 / 3 + mean(m[2]) == 1 - mean(s) == 19 / 8 - mean(m) == 3 / 4 """ const DEMO_MODELS = ( demo_dot_assume_dot_observe(), From 15589247ed2d26ff3144b9a4ddb7102f6e5fc047 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 2 Jul 2022 18:11:34 +0100 Subject: [PATCH 105/221] Update docs/src/api.md Co-authored-by: David Widmann --- docs/src/api.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index ab7e7fe60..9f4dae3f5 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -120,7 +120,7 @@ DynamicPPL.TestUtils.loglikelihood_true DynamicPPL.TestUtils.logjoint_true ``` -And in the case where the model might include constrained variables, it can also be useful to define +And in the case where the model includes constrained variables, it can also be useful to define ```@docs DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian From a62c881ef3de794f47e8084d9b0b2112a9b5066b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 2 Jul 2022 19:43:20 +0100 Subject: [PATCH 106/221] reduce number of method definitions by defining some useful type unions in TestUtils --- src/test_utils.jl | 145 +++++++++++++--------------------------------- 1 file changed, 39 insertions(+), 106 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index f88414a7d..ebab44d91 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -226,17 +226,6 @@ function example_values( end return (s=s, m=m) end -function posterior_mean(model::Model{typeof(demo_dot_assume_dot_observe)}) - vals = example_values(model) - - vals.s[1] = 19 / 8 - vals.m[1] = 3 / 4 - - vals.s[2] = 8 / 3 - vals.m[2] = 1 - - return vals -end @model function demo_assume_index_observe( x=[1.5, 2.0], ::Type{TV}=Vector{Float64} @@ -279,17 +268,6 @@ function example_values( end return (s=s, m=m) end -function posterior_mean(model::Model{typeof(demo_assume_index_observe)}) - vals = example_values(model) - - vals.s[1] = 19 / 8 - vals.m[1] = 3 / 4 - - vals.s[2] = 8 / 3 - vals.m[2] = 1 - - return vals -end @model function demo_assume_multivariate_observe(x=[1.5, 2.0]) # Multivariate `assume` and `observe` @@ -321,17 +299,6 @@ function example_values( s = rand(rng, product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])) return (s=s, m=rand(rng, MvNormal(zero(model.args.x), Diagonal(s)))) end -function posterior_mean(model::Model{typeof(demo_assume_multivariate_observe)}) - vals = example_values(model) - - vals.s[1] = 19 / 8 - vals.m[1] = 3 / 4 - - vals.s[2] = 8 / 3 - vals.m[2] = 1 - - return vals -end @model function demo_dot_assume_observe_index( x=[1.5, 2.0], ::Type{TV}=Vector{Float64} @@ -372,17 +339,6 @@ function example_values( end return (s=s, m=m) end -function posterior_mean(model::Model{typeof(demo_dot_assume_observe_index)}) - vals = example_values(model) - - vals.s[1] = 19 / 8 - vals.m[1] = 3 / 4 - - vals.s[2] = 8 / 3 - vals.m[2] = 1 - - return vals -end # Using vector of `length` 1 here so the posterior of `m` is the same # as the others. @@ -415,9 +371,6 @@ function example_values( m = rand(rng, Normal(0, sqrt(s))) return (s=s, m=m) end -function posterior_mean(model::Model{typeof(demo_assume_dot_observe)}) - return (s=49 / 24, m=7 / 6) -end @model function demo_assume_observe_literal() # `assume` and literal `observe` @@ -449,17 +402,6 @@ function example_values( s = rand(rng, product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])) return (s=s, m=rand(rng, MvNormal(zeros(2), Diagonal(s)))) end -function posterior_mean(model::Model{typeof(demo_assume_observe_literal)}) - vals = example_values(model) - - vals.s[1] = 19 / 8 - vals.m[1] = 3 / 4 - - vals.s[2] = 8 / 3 - vals.m[2] = 1 - - return vals -end @model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and literal `observe` with indexing @@ -500,17 +442,6 @@ function example_values( end return (s=s, m=m) end -function posterior_mean(model::Model{typeof(demo_dot_assume_observe_index_literal)}) - vals = example_values(model) - - vals.s[1] = 19 / 8 - vals.m[1] = 3 / 4 - - vals.s[2] = 8 / 3 - vals.m[2] = 1 - - return vals -end @model function demo_assume_literal_dot_observe() # `assume` and literal `dot_observe` @@ -541,9 +472,6 @@ function example_values( m = rand(rng, Normal(0, sqrt(s))) return (s=s, m=m) end -function posterior_mean(model::Model{typeof(demo_assume_literal_dot_observe)}) - return (s=49 / 24, m=7 / 6) -end @model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} s = TV(undef, 2) @@ -592,17 +520,6 @@ function example_values( end return (s=s, m=m) end -function posterior_mean(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) - vals = example_values(model) - - vals.s[1] = 19 / 8 - vals.m[1] = 3 / 4 - - vals.s[2] = 8 / 3 - vals.m[2] = 1 - - return vals -end @model function _likelihood_mltivariate_observe(s, m, x) return x ~ MvNormal(m, Diagonal(s)) @@ -646,17 +563,6 @@ function example_values( end return (s=s, m=m) end -function posterior_mean(model::Model{typeof(demo_dot_assume_observe_submodel)}) - vals = example_values(model) - - vals.s[1] = 19 / 8 - vals.m[1] = 3 / 4 - - vals.s[2] = 8 / 3 - vals.m[2] = 1 - - return vals -end @model function demo_dot_assume_dot_observe_matrix( x=transpose([1.5 2.0;]), ::Type{TV}=Vector{Float64} @@ -696,17 +602,6 @@ function example_values( end return (s=s, m=m) end -function posterior_mean(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) - vals = example_values(model) - - vals.s[1] = 19 / 8 - vals.m[1] = 3 / 4 - - vals.s[2] = 8 / 3 - vals.m[2] = 1 - - return vals -end @model function demo_dot_assume_matrix_dot_observe_matrix( x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64} @@ -752,7 +647,45 @@ function example_values( m = rand(rng, MvNormal(zeros(n), Diagonal(vec(s)))) return (s=s, m=m) end -function posterior_mean(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) + +const DemoModels = Union{ + Model{typeof(demo_dot_assume_dot_observe)}, + Model{typeof(demo_assume_index_observe)}, + Model{typeof(demo_assume_multivariate_observe)}, + Model{typeof(demo_dot_assume_observe_index)}, + Model{typeof(demo_assume_dot_observe)}, + Model{typeof(demo_assume_literal_dot_observe)}, + Model{typeof(demo_assume_observe_literal)}, + Model{typeof(demo_dot_assume_observe_index_literal)}, + Model{typeof(demo_assume_submodel_observe_index_literal)}, + Model{typeof(demo_dot_assume_observe_submodel)}, + Model{typeof(demo_dot_assume_dot_observe_matrix)}, + Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, +} +_observations(model::DemoModels) = [1.5, 2.0] + +const UnivariateAssumeDemoModels = Union{ + Model{typeof(demo_assume_dot_observe)}, + Model{typeof(demo_assume_literal_dot_observe)}, +} +function posterior_mean(model::UnivariateAssumeDemoModels) + return (s=49 / 24, m=7 / 6) +end + +const MultivariateAssumeDemoModels = Union{ + Model{typeof(demo_dot_assume_dot_observe)}, + Model{typeof(demo_assume_index_observe)}, + Model{typeof(demo_assume_multivariate_observe)}, + Model{typeof(demo_dot_assume_observe_index)}, + Model{typeof(demo_assume_observe_literal)}, + Model{typeof(demo_dot_assume_observe_index_literal)}, + Model{typeof(demo_assume_submodel_observe_index_literal)}, + Model{typeof(demo_dot_assume_observe_submodel)}, + Model{typeof(demo_dot_assume_dot_observe_matrix)}, + Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, +} +function posterior_mean(model::MultivariateAssumeDemoModels) + # Get some containers to fill. vals = example_values(model) vals.s[1] = 19 / 8 From 5cc195aea99e8a0df65abc9de9bf92494e0ed4f9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 2 Jul 2022 22:42:27 +0100 Subject: [PATCH 107/221] removed unnecessary method --- src/test_utils.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index ebab44d91..8515fe4e6 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -662,7 +662,6 @@ const DemoModels = Union{ Model{typeof(demo_dot_assume_dot_observe_matrix)}, Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, } -_observations(model::DemoModels) = [1.5, 2.0] const UnivariateAssumeDemoModels = Union{ Model{typeof(demo_assume_dot_observe)}, From 702f2ff942d109e170eeb68355dddfa7dc6e5043 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 00:45:21 +0100 Subject: [PATCH 108/221] fixed a couple of loglikelihood_true definitions --- src/test_utils.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 8515fe4e6..7ece1e56e 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -455,7 +455,7 @@ function logprior_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) end function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m) - return logpdf(Normal(m, sqrt(s)), [1.5, 2.0]) + return loglikelihood(Normal(m, sqrt(s)), [1.5, 2.0]) end function logprior_true_with_logabsdet_jacobian( model::Model{typeof(demo_assume_literal_dot_observe)}, s, m @@ -623,7 +623,8 @@ function logprior_true( ) n = length(model.args.x) s_vec = vec(s) - return loglikelihood(InverseGamma(2, 3), s_vec) + logpdf(MvNormal(zeros(n), s_vec), m) + return loglikelihood(InverseGamma(2, 3), s_vec) + + logpdf(MvNormal(zeros(n), Diagonal(s_vec)), m) end function loglikelihood_true( model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m From d8f497019113c4d35793a1f394b5a1b89ae3e125 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 00:45:38 +0100 Subject: [PATCH 109/221] style --- src/test_utils.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 7ece1e56e..496a40e50 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -665,8 +665,7 @@ const DemoModels = Union{ } const UnivariateAssumeDemoModels = Union{ - Model{typeof(demo_assume_dot_observe)}, - Model{typeof(demo_assume_literal_dot_observe)}, + Model{typeof(demo_assume_dot_observe)},Model{typeof(demo_assume_literal_dot_observe)} } function posterior_mean(model::UnivariateAssumeDemoModels) return (s=49 / 24, m=7 / 6) From 56f30bc45b1a326f3edc715beb1e62647c648b73 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 00:46:08 +0100 Subject: [PATCH 110/221] added tests for logprior and loglikelihood computation for SimpleVarInfo --- test/simple_varinfo.jl | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 5e598217a..955c3676b 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -90,14 +90,27 @@ ### Evaluation ### values_eval_constrained = DynamicPPL.TestUtils.example_values(model) if DynamicPPL.istrans(svi) + _, logpri_true = DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian( + model, values_eval_constrained... + ) values_eval, logπ_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( model, values_eval_constrained... ) else + logpri_true = DynamicPPL.TestUtils.logprior_true( + model, values_eval_constrained... + ) + logπ_true = DynamicPPL.TestUtils.logjoint_true( + model, values_eval_constrained... + ) values_eval = values_eval_constrained - logπ_true = DynamicPPL.TestUtils.logjoint_true(model, values_eval...) end + # No logabsdet-jacobian correction needed for the likelihood. + loglik_true = DynamicPPL.TestUtils.loglikelihood_true( + model, values_eval_constrained... + ) + # Update the realizations in `svi_new`. svi_eval = svi_new for vn in DynamicPPL.TestUtils.varnames(model) @@ -109,13 +122,19 @@ # Compute `logjoint` using the varinfo. logπ = logjoint(model, svi_eval) + logpri = logprior(model, svi_eval) + loglik = loglikelihood(model, svi_eval) + + retval_svi, _ = DynamicPPL.evaluate!!(model, svi, LikelihoodContext()) # Values should not have changed. for vn in DynamicPPL.TestUtils.varnames(model) @test svi_eval[vn] == get(values_eval, vn) end - # Compare `logjoint` computations. + # Compare log-probability computations. + @test logpri ≈ logpri_true + @test loglik ≈ loglik_true @test logπ ≈ logπ_true end end From 2eaef02d9e47fa3cd9de9b34689ab51652bf7dd4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 00:46:52 +0100 Subject: [PATCH 111/221] fixed implementation of logpdf_with_trans for NoDist --- src/distribution_wrappers.jl | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index 65479f035..d8968a68e 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -51,9 +51,21 @@ Distributions.logpdf(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0 Distributions.minimum(d::NoDist) = minimum(d.dist) Distributions.maximum(d::NoDist) = maximum(d.dist) -Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real) = 0 -Bijectors.logpdf_with_trans(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0 -function Bijectors.logpdf_with_trans(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}) +Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real, ::Bool) = 0 +function Bijectors.logpdf_with_trans( + d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}, ::Bool +) + return 0 +end +function Bijectors.logpdf_with_trans( + d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}, ::Bool +) return zeros(Int, size(x, 2)) end -Bijectors.logpdf_with_trans(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0 +function Bijectors.logpdf_with_trans( + d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}, ::Bool +) + return 0 +end + +Bijectors.bijector(d::NoDist) = Bijectors.bijector(d.dist) From f0f981b77c329b792bfc9d2f3786f454cc3fec23 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 00:56:21 +0100 Subject: [PATCH 112/221] added _protect_dists method to help with broadcasting of NoDist --- src/context_implementations.jl | 4 ++-- src/distribution_wrappers.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 6271b5d8c..eaf77f07f 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -318,12 +318,12 @@ function dot_tilde_assume( end function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi) - return dot_assume(nodist(right), left, vn, vi) + return dot_assume(NoDist.(_protect_dists(right)), left, vn, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, vi ) - return dot_assume(rng, sampler, nodist(right), vn, left, vi) + return dot_assume(rng, sampler, NoDist.(_protect_dists(right)), vn, left, vi) end # `PriorContext` diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index 65479f035..ee58b107a 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -35,8 +35,8 @@ struct NoDist{variate,support,Td<:Distribution{variate,support}} <: end NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name) -nodist(dist::Distribution) = NoDist(dist) -nodist(dists::AbstractArray) = nodist.(dists) +_protect_dists(x) = x +_protect_dists(x::Distribution) = tuple(x) Base.length(dist::NoDist) = Base.length(dist.dist) Base.size(dist::NoDist) = Base.size(dist.dist) From 1e0b946f0d240ca45db22e618118a3cbece5f080 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 00:58:59 +0100 Subject: [PATCH 113/221] simplified show for SimpleVarInfo --- src/simple_varinfo.jl | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index e7a389ee8..dc6eeae3e 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -254,15 +254,13 @@ function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) end function Base.show( - io::IO, ::MIME"text/plain", svi::SimpleVarInfo{<:Any,<:Any,<:NoTransformation} + io::IO, ::MIME"text/plain", svi::SimpleVarInfo ) - return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ")") -end + if !(svi.transform isa NoTransformation) + print(io, "Transformed ") + end -function Base.show( - io::IO, ::MIME"text/plain", svi::SimpleVarInfo{<:Any,<:Any,<:DefaultTransformation} -) - return print(io, "Transformed SimpleVarInfo(", svi.values, ", ", svi.logp, ")") + return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ")") end # `NamedTuple` From faa0e4210f2e5c060841ca67d469d76caf28e2b4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 00:59:22 +0100 Subject: [PATCH 114/221] styling --- src/simple_varinfo.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index dc6eeae3e..4fb0208c9 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -253,9 +253,7 @@ function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) return vi end -function Base.show( - io::IO, ::MIME"text/plain", svi::SimpleVarInfo -) +function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo) if !(svi.transform isa NoTransformation) print(io, "Transformed ") end From 78f22e177bcc071b42618eaf8696cc37daf5b2bb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 01:01:24 +0100 Subject: [PATCH 115/221] removed unused variable --- test/simple_varinfo.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 955c3676b..7163d3106 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -125,8 +125,6 @@ logpri = logprior(model, svi_eval) loglik = loglikelihood(model, svi_eval) - retval_svi, _ = DynamicPPL.evaluate!!(model, svi, LikelihoodContext()) - # Values should not have changed. for vn in DynamicPPL.TestUtils.varnames(model) @test svi_eval[vn] == get(values_eval, vn) From 025a4d47e79ca4bddad6d11547407971190e4926 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 01:03:39 +0100 Subject: [PATCH 116/221] added test for transformed values for the logprior_true and loglikelihood_true methods --- test/simple_varinfo.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 7163d3106..175f264d4 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -90,12 +90,15 @@ ### Evaluation ### values_eval_constrained = DynamicPPL.TestUtils.example_values(model) if DynamicPPL.istrans(svi) - _, logpri_true = DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian( + _values_prior, logpri_true = DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian( model, values_eval_constrained... ) values_eval, logπ_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( model, values_eval_constrained... ) + # Make sure that these two computation paths provide the same + # transformed values. + @test values_eval == _values_prior else logpri_true = DynamicPPL.TestUtils.logprior_true( model, values_eval_constrained... From 9e7f493be1315b5faaef93f84b1adde23105357d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 01:06:16 +0100 Subject: [PATCH 117/221] fixed bug in show for SimpleVarInfo --- src/simple_varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 4fb0208c9..5b9edefdf 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -254,7 +254,7 @@ function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) end function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo) - if !(svi.transform isa NoTransformation) + if !(svi.transformation isa NoTransformation) print(io, "Transformed ") end From 0a9383b1c3609ac5f310e798e1d7b4119cb6b74a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 01:14:45 +0100 Subject: [PATCH 118/221] Revert "added _protect_dists method to help with broadcasting of NoDist" This reverts commit f0f981b77c329b792bfc9d2f3786f454cc3fec23. --- src/context_implementations.jl | 4 ++-- src/distribution_wrappers.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index eaf77f07f..6271b5d8c 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -318,12 +318,12 @@ function dot_tilde_assume( end function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi) - return dot_assume(NoDist.(_protect_dists(right)), left, vn, vi) + return dot_assume(nodist(right), left, vn, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, vi ) - return dot_assume(rng, sampler, NoDist.(_protect_dists(right)), vn, left, vi) + return dot_assume(rng, sampler, nodist(right), vn, left, vi) end # `PriorContext` diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index ee58b107a..65479f035 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -35,8 +35,8 @@ struct NoDist{variate,support,Td<:Distribution{variate,support}} <: end NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name) -_protect_dists(x) = x -_protect_dists(x::Distribution) = tuple(x) +nodist(dist::Distribution) = NoDist(dist) +nodist(dists::AbstractArray) = nodist.(dists) Base.length(dist::NoDist) = Base.length(dist.dist) Base.size(dist::NoDist) = Base.size(dist.dist) From f5c60aec1eb7ccad4aa7ec14b83f6c350b246140 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 01:17:21 +0100 Subject: [PATCH 119/221] renamed test_sampler_on_models to test_sampler --- src/test_utils.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 496a40e50..4e728884c 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -746,7 +746,7 @@ const DEMO_MODELS = ( ) """ - test_sampler_on_models(meanfunction, models, sampler, args...; kwargs...) + test_sampler(meanfunction, models, sampler, args...; kwargs...) Test that `sampler` produces correct marginal posterior means on each model in `models`. @@ -767,7 +767,7 @@ for every (leaf) varname `vn` against the corresponding value returned by - `rtol=1e-3`: Relative tolerance used in `@test`. - `kwargs...`: Keyword arguments forwarded to `sample`. """ -function test_sampler_on_models( +function test_sampler( meanfunction, models, sampler::AbstractMCMC.AbstractSampler, @@ -796,12 +796,12 @@ end Test `sampler` on every model in [`DEMO_MODELS`](@ref). -This is just a proxy for `test_sampler_on_models(meanfunction, DEMO_MODELS, sampler, args...; kwargs...)`. +This is just a proxy for `test_sampler(meanfunction, DEMO_MODELS, sampler, args...; kwargs...)`. """ function test_sampler_on_demo_models( meanfunction, sampler::AbstractMCMC.AbstractSampler, args...; kwargs... ) - return test_sampler_on_models(meanfunction, DEMO_MODELS, sampler, args...; kwargs...) + return test_sampler(meanfunction, DEMO_MODELS, sampler, args...; kwargs...) end """ From d8b0a75a7d56baa75589327db40e2456c5be1dd4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 01:33:16 +0100 Subject: [PATCH 120/221] fixed getindex with vector of varnames for AbstractVarInfo --- src/varinfo.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 4748643f3..22728ba9a 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -944,7 +944,7 @@ function getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution) return maybe_invlink(vi, vn, dist, val) end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) - # NOTE(torfjelde): Using `getdist(vi, first(vns))` won't be correct in cases + # FIXME(torfjelde): Using `getdist(vi, first(vns))` won't be correct in cases # such as `x .~ [Normal(), Exponential()]`. # BUT we also can't fix this here because this will lead to "incorrect" # behavior if `vns` arose from something like `x .~ MvNormal(zeros(2), I)`, @@ -953,8 +953,10 @@ function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" - val = getindex_raw(vi, vns, dist) - return maybe_invlink.((vi,), vns, (dist,), val) + vals_linked = mapreduce(vcat, vns) do vn + getindex(vi, vn, dist) + end + return reconstruct(dist, vals_linked, length(vns)) end getindex_raw(vi::AbstractVarInfo, vn::VarName) = getindex_raw(vi, vn, getdist(vi, vn)) From 25f05de007337c0b56799272a96948024909f7c6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 01:37:44 +0100 Subject: [PATCH 121/221] updated docs --- docs/src/api.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index 9f4dae3f5..9aa481cc4 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -103,7 +103,7 @@ NamedDist DynamicPPL provides several demo models and helpers for testing samplers in the `DynamicPPL.TestUtils` submodule. ```@docs -DynamicPPL.TestUtils.test_sampler_on_models +DynamicPPL.TestUtils.test_sampler DynamicPPL.TestUtils.test_sampler_on_demo_models DynamicPPL.TestUtils.test_sampler_continuous ``` From e05fa291e0b3dd835e887df1c44c95720a2995be Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 12:13:24 +0100 Subject: [PATCH 122/221] share implementation of example_values --- src/test_utils.jl | 139 +++++++--------------------------------------- 1 file changed, 20 insertions(+), 119 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 4e728884c..7b3758f40 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -166,12 +166,6 @@ end function varnames(model::Model{typeof(demo_dynamic_constraint)}) return [@varname(m), @varname(x)] end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_dynamic_constraint)} -) - m = rand(rng, Normal()) - return (m=m, x=rand(rng, truncated(Normal(), m, Inf))) -end function logprior_true_with_logabsdet_jacobian( model::Model{typeof(demo_dynamic_constraint)}, m, x ) @@ -215,17 +209,6 @@ end function varnames(model::Model{typeof(demo_dot_assume_dot_observe)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_dot_observe)} -) - n = length(model.args.x) - s = rand(rng, InverseGamma(2, 3), n) - m = similar(s) - for i in eachindex(m, s) - m[i] = rand(rng, Normal(0, sqrt(s[i]))) - end - return (s=s, m=m) -end @model function demo_assume_index_observe( x=[1.5, 2.0], ::Type{TV}=Vector{Float64} @@ -257,17 +240,6 @@ end function varnames(model::Model{typeof(demo_assume_index_observe)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_assume_index_observe)} -) - n = length(model.args.x) - s = rand(rng, InverseGamma(2, 3), n) - m = similar(s) - for i in eachindex(m, s) - m[i] = rand(rng, Normal(0, sqrt(s[i]))) - end - return (s=s, m=m) -end @model function demo_assume_multivariate_observe(x=[1.5, 2.0]) # Multivariate `assume` and `observe` @@ -293,12 +265,6 @@ end function varnames(model::Model{typeof(demo_assume_multivariate_observe)}) return [@varname(s), @varname(m)] end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_assume_multivariate_observe)} -) - s = rand(rng, product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])) - return (s=s, m=rand(rng, MvNormal(zero(model.args.x), Diagonal(s)))) -end @model function demo_dot_assume_observe_index( x=[1.5, 2.0], ::Type{TV}=Vector{Float64} @@ -328,17 +294,6 @@ end function varnames(model::Model{typeof(demo_dot_assume_observe_index)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_observe_index)} -) - n = length(model.args.x) - s = rand(rng, InverseGamma(2, 3), n) - m = similar(s) - for i in eachindex(m, s) - m[i] = rand(rng, Normal(0, sqrt(s[i]))) - end - return (s=s, m=m) -end # Using vector of `length` 1 here so the posterior of `m` is the same # as the others. @@ -364,13 +319,6 @@ end function varnames(model::Model{typeof(demo_assume_dot_observe)}) return [@varname(s), @varname(m)] end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_assume_dot_observe)} -) - s = rand(rng, InverseGamma(2, 3)) - m = rand(rng, Normal(0, sqrt(s))) - return (s=s, m=m) -end @model function demo_assume_observe_literal() # `assume` and literal `observe` @@ -396,12 +344,6 @@ end function varnames(model::Model{typeof(demo_assume_observe_literal)}) return [@varname(s), @varname(m)] end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_assume_observe_literal)} -) - s = rand(rng, product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])) - return (s=s, m=rand(rng, MvNormal(zeros(2), Diagonal(s)))) -end @model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and literal `observe` with indexing @@ -431,17 +373,6 @@ end function varnames(model::Model{typeof(demo_dot_assume_observe_index_literal)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_observe_index_literal)} -) - n = 2 - s = rand(rng, InverseGamma(2, 3), n) - m = similar(s) - for i in eachindex(m, s) - m[i] = rand(rng, Normal(0, sqrt(s[i]))) - end - return (s=s, m=m) -end @model function demo_assume_literal_dot_observe() # `assume` and literal `dot_observe` @@ -465,13 +396,6 @@ end function varnames(model::Model{typeof(demo_assume_literal_dot_observe)}) return [@varname(s), @varname(m)] end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_assume_literal_dot_observe)} -) - s = rand(rng, InverseGamma(2, 3)) - m = rand(rng, Normal(0, sqrt(s))) - return (s=s, m=m) -end @model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} s = TV(undef, 2) @@ -508,18 +432,6 @@ end function varnames(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end -function example_values( - rng::Random.AbstractRNG, - model::Model{typeof(demo_assume_submodel_observe_index_literal)}, -) - n = 2 - s = rand(rng, InverseGamma(2, 3), n) - m = similar(s) - for i in eachindex(m, s) - m[i] = rand(rng, Normal(0, sqrt(s[i]))) - end - return (s=s, m=m) -end @model function _likelihood_mltivariate_observe(s, m, x) return x ~ MvNormal(m, Diagonal(s)) @@ -552,17 +464,6 @@ end function varnames(model::Model{typeof(demo_dot_assume_observe_submodel)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_observe_submodel)} -) - n = length(model.args.x) - s = rand(rng, InverseGamma(2, 3), n) - m = similar(s) - for i in eachindex(m, s) - m[i] = rand(rng, Normal(0, sqrt(s[i]))) - end - return (s=s, m=m) -end @model function demo_dot_assume_dot_observe_matrix( x=transpose([1.5 2.0;]), ::Type{TV}=Vector{Float64} @@ -591,17 +492,6 @@ end function varnames(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_dot_observe_matrix)} -) - n = length(model.args.x) - s = rand(rng, InverseGamma(2, 3), n) - m = similar(s) - for i in eachindex(m, s) - m[i] = rand(rng, Normal(0, sqrt(s[i]))) - end - return (s=s, m=m) -end @model function demo_dot_assume_matrix_dot_observe_matrix( x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64} @@ -639,15 +529,6 @@ end function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) return [@varname(s[:, 1]), @varname(s[:, 2]), @varname(m)] end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)} -) - n = length(model.args.x) - d = n ÷ 2 - s = rand(rng, product_distribution([InverseGamma(2, 3) for _ in 1:d]), 2) - m = rand(rng, MvNormal(zeros(n), Diagonal(vec(s)))) - return (s=s, m=m) -end const DemoModels = Union{ Model{typeof(demo_dot_assume_dot_observe)}, @@ -670,6 +551,12 @@ const UnivariateAssumeDemoModels = Union{ function posterior_mean(model::UnivariateAssumeDemoModels) return (s=49 / 24, m=7 / 6) end +function example_values(rng::Random.AbstractRNG, model::UnivariateAssumeDemoModels) + s = rand(rng, InverseGamma(2, 3)) + m = rand(rng, Normal(0, sqrt(s))) + + return (s=s, m=m) +end const MultivariateAssumeDemoModels = Union{ Model{typeof(demo_dot_assume_dot_observe)}, @@ -695,6 +582,20 @@ function posterior_mean(model::MultivariateAssumeDemoModels) return vals end +function example_values( + rng::Random.AbstractRNG, model::MultivariateAssumeDemoModels +) + # Get template values from `model`. + retval = model(rng) + vals = (s = retval.s, m = retval.m) + # Fill containers with realizations from prior. + for i in LinearIndices(vals.s) + vals.s[i] = rand(rng, InverseGamma(2, 3)) + vals.m[i] = rand(rng, Normal(0, sqrt(vals.s[i]))) + end + + return vals +end """ A collection of models corresponding to the posterior distribution defined by From 431664dbe0e2aa21e0c6bdb7dd3984659187f01b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 4 Jul 2022 11:45:32 +0100 Subject: [PATCH 123/221] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/test_utils.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 7b3758f40..f68e8cafd 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -582,12 +582,10 @@ function posterior_mean(model::MultivariateAssumeDemoModels) return vals end -function example_values( - rng::Random.AbstractRNG, model::MultivariateAssumeDemoModels -) +function example_values(rng::Random.AbstractRNG, model::MultivariateAssumeDemoModels) # Get template values from `model`. retval = model(rng) - vals = (s = retval.s, m = retval.m) + vals = (s=retval.s, m=retval.m) # Fill containers with realizations from prior. for i in LinearIndices(vals.s) vals.s[i] = rand(rng, InverseGamma(2, 3)) From 363ebae43a39a89957277be1375d5a3ece515520 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 6 Jul 2022 16:39:15 +0100 Subject: [PATCH 124/221] added unflatten and values_as for Vector --- src/simple_varinfo.jl | 8 ++++++++ src/utils.jl | 34 ++++++++++++++++++++++++++++++++++ src/varinfo.jl | 4 ++++ 3 files changed, 46 insertions(+) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 5b9edefdf..41dab2c73 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -227,6 +227,12 @@ function SimpleVarInfo{T}( return SimpleVarInfo(values, convert(T, getlogp(vi))) end +SimpleVarInfo(svi::SimpleVarInfo, spl, x::AbstractVector) = unflatten(svi, x) + +function unflatten(svi::SimpleVarInfo, x::AbstractVector) + return Setfield.@set svi.values = unflatten(svi.values, x) +end + function BangBang.empty!!(vi::SimpleVarInfo) Setfield.@set resetlogp!!(vi).values = empty!!(vi.values) end @@ -536,6 +542,8 @@ values_as(vi::SimpleVarInfo) = vi.values values_as(vi::SimpleVarInfo, ::Type{Dict}) = Dict(pairs(vi.values)) values_as(vi::SimpleVarInfo, ::Type{NamedTuple}) = NamedTuple(pairs(vi.values)) values_as(vi::SimpleVarInfo{<:NamedTuple}, ::Type{NamedTuple}) = vi.values +values_as(vi::SimpleVarInfo, ::Type{Vector}) = mapreduce(v -> vec([v;]), vcat, values(vi.values)) + """ logjoint(model::Model, θ) diff --git a/src/utils.jl b/src/utils.jl index ac9222818..3e8433886 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -426,3 +426,37 @@ function BangBang.possible( return BangBang.implements(setindex!, C) && promote_type(eltype(C), eltype(T)) <: eltype(C) end + + +# HACK(torfjelde): This makes it so it works on iterators, etc. by default. +# TODO(torfjelde): Do better. +""" + unflatten(original, x::AbstractVector) + +Return instance of `original` constructed from `x`. +""" +unflatten(original, x::AbstractVector) = map(zip(original, x)) do (original_val, x_val) + unflatten(original_val, x_val) +end +unflatten(::Real, x::Real) = x +unflatten(::Real, x::AbstractVector) = only(x) +unflatten(::AbstractVector{<:Real}, x::AbstractVector) = x +unflatten(original::AbstractArray{<:Real}, x::AbstractVector) = reshape(x, size(original)) + +function unflatten(original::Tuple, x::AbstractVector) + lengths = map(length, original) + end_indices = cumsum(lengths) + return ntuple(length(original)) do i + v = original[i] + l = lengths[i] + end_idx = end_indices[i] + start_idx = end_idx - l + 1 + return unflatten(v, @view(x[start_idx:end_idx])) + end +end +function unflatten(original::NamedTuple{names}, x::AbstractVector) where {names} + return NamedTuple{names}(unflatten(values(original), x)) +end +function unflatten(original::Dict, x::AbstractVector) + return Dict(zip(keys(original), unflatten(values(original), x))) +end diff --git a/src/varinfo.jl b/src/varinfo.jl index 22728ba9a..6dbc46699 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1557,9 +1557,11 @@ values_as(vi::VarInfo) = vi.metadata """ values_as(vi::AbstractVarInfo, ::Type{NamedTuple}) values_as(vi::AbstractVarInfo, ::Type{Dict}) + values_as(vi::AbstractVarInfo, ::Type{Vector}) Return values in `vi` as the specified type. """ +values_as(vi::VarInfo, ::Type{Vector}) = vi[SampleFromPrior()] function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) iter = values_from_metadata(vi.metadata) return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) @@ -1582,3 +1584,5 @@ function values_from_metadata(md::Metadata) vn in md.vns ) end + +unflatten(vi::VarInfo, x::AbstractVector) = VarInfo(vi, SampleFromPrior(), x) From 66424f879fac9e6342c74d3d79c7bada2caf1176 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 6 Jul 2022 16:39:31 +0100 Subject: [PATCH 125/221] added getindex for AbstractVarInfo with Colon --- src/simple_varinfo.jl | 3 +++ src/varinfo.jl | 5 +++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 41dab2c73..a0ae3233f 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -233,6 +233,9 @@ function unflatten(svi::SimpleVarInfo, x::AbstractVector) return Setfield.@set svi.values = unflatten(svi.values, x) end +Base.getindex(svi::SimpleVarInfo, ::Colon) = values_as(svi, Vector) +Base.getindex(svi::SimpleVarInfo, ::AbstractSampler) = svi[:] + function BangBang.empty!!(vi::SimpleVarInfo) Setfield.@set resetlogp!!(vi).values = empty!!(vi.values) end diff --git a/src/varinfo.jl b/src/varinfo.jl index 6dbc46699..595bf18a5 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -977,8 +977,9 @@ Return the current value(s) of the random variables sampled by `spl` in `vi`. The value(s) may or may not be transformed to Euclidean space. """ -getindex(vi::AbstractVarInfo, spl::SampleFromPrior) = copy(getall(vi)) -getindex(vi::AbstractVarInfo, spl::SampleFromUniform) = copy(getall(vi)) +getindex(vi::AbstractVarInfo, ::Colon) = copy(getall(vi)) +getindex(vi::AbstractVarInfo, ::SampleFromPrior) = vi[:] +getindex(vi::AbstractVarInfo, ::SampleFromUniform) = vi[:] getindex(vi::UntypedVarInfo, spl::Sampler) = copy(getval(vi, _getranges(vi, spl))) function getindex(vi::TypedVarInfo, spl::Sampler) # Gets the ranges as a NamedTuple From 3bc27f811afb81aed0861f524db25a10faf70f60 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 7 Jul 2022 07:11:00 +0100 Subject: [PATCH 126/221] added unflatten to VarInfo --- src/varinfo.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 595bf18a5..d0febc6a0 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -129,6 +129,9 @@ function VarInfo( end VarInfo(model::Model, args...) = VarInfo(Random.GLOBAL_RNG, model, args...) +unflatten(vi::VarInfo, x::AbstractVector) = unflatten(vi, x, SampleFromPrior()) +unflatten(vi::VarInfo, x::AbstractVector, spl) = VarInfo(vi, spl, x) + # without AbstractSampler function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) return VarInfo(rng, model, SampleFromPrior(), context) @@ -1585,5 +1588,3 @@ function values_from_metadata(md::Metadata) vn in md.vns ) end - -unflatten(vi::VarInfo, x::AbstractVector) = VarInfo(vi, SampleFromPrior(), x) From ca5b080dd503c410b0fd1eb9e3b111755e7d698a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 7 Jul 2022 07:54:03 +0100 Subject: [PATCH 127/221] added make_default_varinfo allowing specification of how to initialize AbstractVarInfo used by a Sampler --- src/sampler.jl | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/sampler.jl b/src/sampler.jl index 550c27642..d3ff14b5d 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -67,6 +67,20 @@ function AbstractMCMC.step( return vi, nothing end +function make_default_varinfo( + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler +) + return make_default_varinfo(rng, model, sampler, DefaultContext()) +end +function make_default_varinfo( + rng::Random.AbstractRNG, + model::Model, + sampler::AbstractSampler, + context::AbstractContext, +) + return VarInfo(rng, model, sampler, context) +end + # initial step: general interface for resuming and function AbstractMCMC.step( rng::Random.AbstractRNG, @@ -83,7 +97,7 @@ function AbstractMCMC.step( # Sample initial values. _spl = initialsampler(spl) - vi = VarInfo(rng, model, _spl) + vi = make_default_varinfo(rng, model, _spl) # Update the parameters if provided. if init_params !== nothing From cb05fc908d1b76dbfc4c73e52d45bc276a555fcc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 10 Jul 2022 16:38:23 +0100 Subject: [PATCH 128/221] added unflatten also taking sampler for SimpleVarInfo --- src/simple_varinfo.jl | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index a0ae3233f..e32bdb1ca 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -229,13 +229,11 @@ end SimpleVarInfo(svi::SimpleVarInfo, spl, x::AbstractVector) = unflatten(svi, x) +unflatten(svi::SimpleVarInfo, spl, x::AbstractVector) = unflatten(svi, x) function unflatten(svi::SimpleVarInfo, x::AbstractVector) return Setfield.@set svi.values = unflatten(svi.values, x) end -Base.getindex(svi::SimpleVarInfo, ::Colon) = values_as(svi, Vector) -Base.getindex(svi::SimpleVarInfo, ::AbstractSampler) = svi[:] - function BangBang.empty!!(vi::SimpleVarInfo) Setfield.@set resetlogp!!(vi).values = empty!!(vi.values) end @@ -317,11 +315,8 @@ end # HACK: Needed to disambiguiate. Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(Base.Fix1(getindex, vi), vns) -Base.getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.values -Base.getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.values - -# TODO: Should we do better? -Base.getindex(vi::SimpleVarInfo, spl::Sampler) = vi.values +Base.getindex(svi::SimpleVarInfo, ::Colon) = values_as(svi, Vector) +Base.getindex(svi::SimpleVarInfo, ::AbstractSampler) = svi[:] # Since we don't perform any transformations in `getindex` for `SimpleVarInfo` # we simply call `getindex` in `getindex_raw`. From 45445cf6b95468696625bc2c42aecd9569e3b898 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 10 Jul 2022 16:38:42 +0100 Subject: [PATCH 129/221] added tonamedtuple impl for SimpleVarInfo --- src/simple_varinfo.jl | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index e32bdb1ca..22e98c8dc 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -513,6 +513,44 @@ end # HACK: Allows us to re-use the implementation of `dot_tilde`, etc. for literals. increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing +setgid!(vi::SimpleOrThreadSafeSimple, gid::Selector, vn::VarName) = nothing + +function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:NamedTuple{names}}) where {names} + nt_vals = map(keys(vi)) do vn + val = vi[vn] + vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val)) + vals = map(Base.Fix1(getindex, vi), vns) + (vals, map(string, vns)) + end + + return NamedTuple{names}(nt_vals) +end + +function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:Dict}) + syms_to_result = Dict{Symbol,Tuple{Vector{Real},Vector{String}}}() + for vn in keys(vi) + val = vi[vn] + vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val)) + vals = map(Base.Fix1(getindex, vi), vns) + + # Determine the corresponding symbol. + sym = only(unique(map(getsym, vns))) + + # Initialize entry if not yet initialized. + if !haskey(syms_to_result, sym) + syms_to_result[sym] = (Real[], String[]) + end + + # Combine with old result. + old_result = syms_to_result[sym] + syms_to_result[sym] = ( + vcat(old_result[1], vals), + vcat(old_result[2], map(string, vns)) + ) + end + + return NamedTuple(pairs(syms_to_result)) +end # NOTE: We don't implement `settrans!!(vi, trans, vn)`. function settrans!!(vi::SimpleVarInfo, trans) From ea8f844b1b1d76393698bce9c9fd6a3cd614377c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 10 Jul 2022 16:39:06 +0100 Subject: [PATCH 130/221] fixed implementation of unflatten for arrays --- src/utils.jl | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 3e8433886..ea10cacff 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -435,11 +435,18 @@ end Return instance of `original` constructed from `x`. """ -unflatten(original, x::AbstractVector) = map(zip(original, x)) do (original_val, x_val) - unflatten(original_val, x_val) +function unflatten(original, x::AbstractVector) + lengths = map(length, original) + end_indices = cumsum(lengths) + return map(zip(original, lengths, end_indices)) do (v, l, end_idx) + start_idx = end_idx - l + 1 + return unflatten(v, @view(x[start_idx:end_idx])) + end end + unflatten(::Real, x::Real) = x unflatten(::Real, x::AbstractVector) = only(x) +unflatten(::AbstractVector{<:Real}, x::Real) = vcat(x) unflatten(::AbstractVector{<:Real}, x::AbstractVector) = x unflatten(original::AbstractArray{<:Real}, x::AbstractVector) = reshape(x, size(original)) @@ -458,5 +465,5 @@ function unflatten(original::NamedTuple{names}, x::AbstractVector) where {names} return NamedTuple{names}(unflatten(values(original), x)) end function unflatten(original::Dict, x::AbstractVector) - return Dict(zip(keys(original), unflatten(values(original), x))) + return Dict(zip(keys(original), unflatten(collect(values(original)), x))) end From 52274ba5634e58046b93311cae9defe2081b0dcb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 10 Jul 2022 16:39:17 +0100 Subject: [PATCH 131/221] added default impl of unflatten taking sampler --- src/varinfo.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index d0febc6a0..2ca2cf82f 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -129,8 +129,8 @@ function VarInfo( end VarInfo(model::Model, args...) = VarInfo(Random.GLOBAL_RNG, model, args...) -unflatten(vi::VarInfo, x::AbstractVector) = unflatten(vi, x, SampleFromPrior()) -unflatten(vi::VarInfo, x::AbstractVector, spl) = VarInfo(vi, spl, x) +unflatten(vi::VarInfo, x::AbstractVector) = unflatten(vi, SampleFromPrior(), x) +unflatten(vi::VarInfo, spl, x::AbstractVector) = VarInfo(vi, spl, x) # without AbstractSampler function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) From 7da0ee979da155e893c9d7a0e7d7a11e2d6b995c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 12 Jul 2022 12:42:26 +0100 Subject: [PATCH 132/221] improved tonamedtuple for SimpleVarInfo with Dict --- src/simple_varinfo.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 22e98c8dc..d8f4e2725 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -529,6 +529,7 @@ end function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:Dict}) syms_to_result = Dict{Symbol,Tuple{Vector{Real},Vector{String}}}() for vn in keys(vi) + # Extract the leaf varnames and values. val = vi[vn] vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val)) vals = map(Base.Fix1(getindex, vi), vns) @@ -542,13 +543,14 @@ function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:Dict}) end # Combine with old result. - old_result = syms_to_result[sym] + old_vals, old_string_vns = syms_to_result[sym] syms_to_result[sym] = ( - vcat(old_result[1], vals), - vcat(old_result[2], map(string, vns)) + vcat(old_vals, vals), + vcat(old_string_vns, map(string, vns)) ) end + # Construct `NamedTuple`. return NamedTuple(pairs(syms_to_result)) end From b3499a37d4316fc0fe3819aaee580e3cf3ff1158 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 12 Jul 2022 12:48:54 +0100 Subject: [PATCH 133/221] added marginal_mean_of_samples according to suggestions --- src/test_utils.jl | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index f68e8cafd..53068efbf 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -645,18 +645,26 @@ const DEMO_MODELS = ( ) """ - test_sampler(meanfunction, models, sampler, args...; kwargs...) + marginal_mean_of_samples(chain, varname) + +Return the mean of variable represented by `varname` in `chain`. +""" +marginal_mean_of_samples(chain, varname) = mean(Array(chain[Symbol(varname)])) + +""" + test_sampler(models, sampler, args...; kwargs...) Test that `sampler` produces correct marginal posterior means on each model in `models`. In short, this method iterates through `models`, calls `AbstractMCMC.sample` on the -`model` and `sampler` to produce a `chain`, and then checks `meanfunction(chain, vn)` +`model` and `sampler` to produce a `chain`, and then checks [`marginal_mean_of_samples(chain, vn)`](@ref) for every (leaf) varname `vn` against the corresponding value returned by [`posterior_mean`](@ref) for each model. +To change how comparison is done for a particular `chain` type, one can overload +[`marginal_mean_of_samples(chain, vn)`](@ref) for the corresponding type. + # Arguments -- `meanfunction`: A callable which computes the mean of the marginal means from the - chain resulting from the `sample` call. - `models`: A collection of instaces of [`DynamicPPL.Model`](@ref) to test on. - `sampler`: The `AbstractMCMC.AbstractSampler` to test. - `args...`: Arguments forwarded to `sample`. @@ -667,7 +675,6 @@ for every (leaf) varname `vn` against the corresponding value returned by - `kwargs...`: Keyword arguments forwarded to `sample`. """ function test_sampler( - meanfunction, models, sampler::AbstractMCMC.AbstractSampler, args...; @@ -683,7 +690,7 @@ function test_sampler( # extracting the leaves of the `VarName` and the corresponding value. for vn_leaf in varname_leaves(vn, get(target_values, vn)) target_value = get(target_values, vn_leaf) - chain_mean_value = meanfunction(chain, vn_leaf) + chain_mean_value = marginal_mean_of_samples(chain, vn_leaf) @test chain_mean_value ≈ target_value atol = atol rtol = rtol end end @@ -698,30 +705,22 @@ Test `sampler` on every model in [`DEMO_MODELS`](@ref). This is just a proxy for `test_sampler(meanfunction, DEMO_MODELS, sampler, args...; kwargs...)`. """ function test_sampler_on_demo_models( - meanfunction, sampler::AbstractMCMC.AbstractSampler, args...; kwargs... + sampler::AbstractMCMC.AbstractSampler, args...; kwargs... ) - return test_sampler(meanfunction, DEMO_MODELS, sampler, args...; kwargs...) + return test_sampler(DEMO_MODELS, sampler, args...; kwargs...) end """ - test_sampler_continuous([meanfunction, ]sampler, args...; kwargs...) + test_sampler_continuous(sampler, args...; kwargs...) Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`. As of right now, this is just an alias for [`test_sampler_on_demo_models`](@ref). """ function test_sampler_continuous( - meanfunction, sampler::AbstractMCMC.AbstractSampler, args...; kwargs... + sampler::AbstractMCMC.AbstractSampler, args...; kwargs... ) - return test_sampler_on_demo_models(meanfunction, sampler, args...; kwargs...) -end - -function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...) - # Default for `MCMCChains.Chains`. - return test_sampler_continuous(sampler, args...; kwargs...) do chain, vn - # HACK(torfjelde): This assumes that we can index into `chain` with `Symbol(vn)`. - mean(Array(chain[Symbol(vn)])) - end + return test_sampler_on_demo_models(sampler, args...; kwargs...) end end From 2bd5dcddcef5bd2aa0ab79197cd72f1026fc9383 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 13 Jul 2022 10:17:05 +0100 Subject: [PATCH 134/221] removed example_values in favour of rand with NamedTuple --- docs/src/api.md | 1 - src/test_utils.jl | 37 +++++++++++++++---------------------- test/loglikelihoods.jl | 2 +- test/simple_varinfo.jl | 4 ++-- 4 files changed, 18 insertions(+), 26 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 9aa481cc4..c7133a5f9 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -131,7 +131,6 @@ Finally, the following methods can also be of use: ```@docs DynamicPPL.TestUtils.varnames -DynamicPPL.TestUtils.example_values DynamicPPL.TestUtils.posterior_mean ``` diff --git a/src/test_utils.jl b/src/test_utils.jl index 53068efbf..508eb275f 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -124,16 +124,6 @@ function varnames(model::Model) ) end -""" - example_values(model::Model) - -Return a `NamedTuple` compatible with `varnames(model)` with values in support of `model`. - -"Compatible" means that a `varname` from `varnames(model)` can be used to extract the -corresponding value using `get`, e.g. `get(example_values(model), varname)`. -""" -example_values(model::Model) = example_values(Random.GLOBAL_RNG, model) - """ posterior_mean(model::Model) @@ -545,13 +535,21 @@ const DemoModels = Union{ Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, } +# We require demo models to have explict impleentations of `rand` since we want +# these to be considered as ground truth. +function Random.rand(rng::Random.AbstractRNG, ::Type{NamedTuple}, model::DemoModels) + return error("demo models requires explicit implementation of rand") +end + const UnivariateAssumeDemoModels = Union{ Model{typeof(demo_assume_dot_observe)},Model{typeof(demo_assume_literal_dot_observe)} } function posterior_mean(model::UnivariateAssumeDemoModels) return (s=49 / 24, m=7 / 6) end -function example_values(rng::Random.AbstractRNG, model::UnivariateAssumeDemoModels) +function Random.rand( + rng::Random.AbstractRNG, ::Type{NamedTuple}, model::UnivariateAssumeDemoModels +) s = rand(rng, InverseGamma(2, 3)) m = rand(rng, Normal(0, sqrt(s))) @@ -572,7 +570,7 @@ const MultivariateAssumeDemoModels = Union{ } function posterior_mean(model::MultivariateAssumeDemoModels) # Get some containers to fill. - vals = example_values(model) + vals = Random.rand(model) vals.s[1] = 19 / 8 vals.m[1] = 3 / 4 @@ -582,7 +580,9 @@ function posterior_mean(model::MultivariateAssumeDemoModels) return vals end -function example_values(rng::Random.AbstractRNG, model::MultivariateAssumeDemoModels) +function Random.rand( + rng::Random.AbstractRNG, ::Type{NamedTuple}, model::MultivariateAssumeDemoModels +) # Get template values from `model`. retval = model(rng) vals = (s=retval.s, m=retval.m) @@ -675,12 +675,7 @@ To change how comparison is done for a particular `chain` type, one can overload - `kwargs...`: Keyword arguments forwarded to `sample`. """ function test_sampler( - models, - sampler::AbstractMCMC.AbstractSampler, - args...; - atol=1e-1, - rtol=1e-3, - kwargs..., + models, sampler::AbstractMCMC.AbstractSampler, args...; atol=1e-1, rtol=1e-3, kwargs... ) @testset "$(typeof(sampler)) on $(nameof(model))" for model in models chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) @@ -717,9 +712,7 @@ Test that `sampler` produces the correct marginal posterior means on all models As of right now, this is just an alias for [`test_sampler_on_demo_models`](@ref). """ -function test_sampler_continuous( - sampler::AbstractMCMC.AbstractSampler, args...; kwargs... -) +function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...) return test_sampler_on_demo_models(sampler, args...; kwargs...) end diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index bd04a76a5..b390997af 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -1,6 +1,6 @@ @testset "loglikelihoods.jl" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - example_values = DynamicPPL.TestUtils.example_values(m) + example_values = rand(NamedTuple, m) # Instantiate a `VarInfo` with the example values. vi = VarInfo(m) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 175f264d4..6a8c545ca 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -62,7 +62,7 @@ DynamicPPL.TestUtils.DEMO_MODELS # We might need to pre-allocate for the variable `m`, so we need # to see whether this is the case. - svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.example_values(model)) + svi_nt = SimpleVarInfo(rand(NamedTuple, model)) svi_dict = SimpleVarInfo(VarInfo(model), Dict) @testset "$(nameof(typeof(DynamicPPL.values_as(svi))))" for svi in ( @@ -88,7 +88,7 @@ @test getlogp(svi_new) != 0 ### Evaluation ### - values_eval_constrained = DynamicPPL.TestUtils.example_values(model) + values_eval_constrained = rand(NamedTuple, model) if DynamicPPL.istrans(svi) _values_prior, logpri_true = DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian( model, values_eval_constrained... From 61a594cb6c7283b8079982c32d3ef9ee4b22c063 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 13 Jul 2022 10:39:22 +0100 Subject: [PATCH 135/221] updated docs --- docs/src/api.md | 1 + src/test_utils.jl | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index c7133a5f9..809e6c49e 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -106,6 +106,7 @@ DynamicPPL provides several demo models and helpers for testing samplers in the DynamicPPL.TestUtils.test_sampler DynamicPPL.TestUtils.test_sampler_on_demo_models DynamicPPL.TestUtils.test_sampler_continuous +DynamicPPL.TestUtils.marginal_mean_of_samples ``` ```@docs diff --git a/src/test_utils.jl b/src/test_utils.jl index 508eb275f..ef314fa91 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -657,12 +657,12 @@ marginal_mean_of_samples(chain, varname) = mean(Array(chain[Symbol(varname)])) Test that `sampler` produces correct marginal posterior means on each model in `models`. In short, this method iterates through `models`, calls `AbstractMCMC.sample` on the -`model` and `sampler` to produce a `chain`, and then checks [`marginal_mean_of_samples(chain, vn)`](@ref) +`model` and `sampler` to produce a `chain`, and then checks `marginal_mean_of_samples(chain, vn)` for every (leaf) varname `vn` against the corresponding value returned by [`posterior_mean`](@ref) for each model. To change how comparison is done for a particular `chain` type, one can overload -[`marginal_mean_of_samples(chain, vn)`](@ref) for the corresponding type. +[`marginal_mean_of_samples`](@ref) for the corresponding type. # Arguments - `models`: A collection of instaces of [`DynamicPPL.Model`](@ref) to test on. From 6c941bd5d704b9301beacf7c95bcc52cb1bfd627 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 18 Jul 2022 10:30:50 +0100 Subject: [PATCH 136/221] fixed method ambiguity error --- src/simple_varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index d8f4e2725..4c994581e 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -316,7 +316,7 @@ end Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(Base.Fix1(getindex, vi), vns) Base.getindex(svi::SimpleVarInfo, ::Colon) = values_as(svi, Vector) -Base.getindex(svi::SimpleVarInfo, ::AbstractSampler) = svi[:] +Base.getindex(svi::SimpleVarInfo, ::Sampler) = svi[:] # Since we don't perform any transformations in `getindex` for `SimpleVarInfo` # we simply call `getindex` in `getindex_raw`. From 5e92e568eae57c25454ba81ace526ac9eb9f81f6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 18 Jul 2022 10:31:20 +0100 Subject: [PATCH 137/221] added islinked for SimpleVarInfo --- src/simple_varinfo.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 4c994581e..6de918273 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -569,6 +569,8 @@ istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi) istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) +islinked(vi::SimpleVarInfo, ::Union{Sampler,SampleFromPrior}) = istrans(vi) + """ values_as(varinfo[, Type]) From aabc45af03b6b4b18a818ef419fa74cafbf16de4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 18 Jul 2022 10:31:33 +0100 Subject: [PATCH 138/221] formatting --- src/simple_varinfo.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 6de918273..6d5df2f9f 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -582,8 +582,9 @@ values_as(vi::SimpleVarInfo) = vi.values values_as(vi::SimpleVarInfo, ::Type{Dict}) = Dict(pairs(vi.values)) values_as(vi::SimpleVarInfo, ::Type{NamedTuple}) = NamedTuple(pairs(vi.values)) values_as(vi::SimpleVarInfo{<:NamedTuple}, ::Type{NamedTuple}) = vi.values -values_as(vi::SimpleVarInfo, ::Type{Vector}) = mapreduce(v -> vec([v;]), vcat, values(vi.values)) - +function values_as(vi::SimpleVarInfo, ::Type{Vector}) + return mapreduce(v -> vec([v;]), vcat, values(vi.values)) +end """ logjoint(model::Model, θ) From 939540c1e534a0297e3d2d929ad41b0aa5d79db9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 18 Jul 2022 10:32:58 +0100 Subject: [PATCH 139/221] added link!! and invlink!! as BangBang alternatives to link! and invlink! --- src/DynamicPPL.jl | 3 ++ src/bijectors.jl | 101 ++++++++++++++++++++++++++++++++++++++++++ src/simple_varinfo.jl | 7 +-- src/varinfo.jl | 26 +++++++++++ 4 files changed, 132 insertions(+), 5 deletions(-) create mode 100644 src/bijectors.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 86d4e0def..1027ae3ad 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -57,7 +57,9 @@ export AbstractVarInfo, setorder!, istrans, link!, + link!!, invlink!, + invlink!!, tonamedtuple, # VarName (reexport from AbstractPPL) VarName, @@ -150,5 +152,6 @@ include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") include("test_utils.jl") +include("bijectors.jl") end # module diff --git a/src/bijectors.jl b/src/bijectors.jl new file mode 100644 index 000000000..a510fa162 --- /dev/null +++ b/src/bijectors.jl @@ -0,0 +1,101 @@ +using Bijectors + +function Bijectors.Stacked( + model::DynamicPPL.Model, + ::Val{sym2ranges}=Val(false); + varinfo::DynamicPPL.VarInfo=DynamicPPL.VarInfo(model), +) where {sym2ranges} + dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...) + + num_ranges = sum([ + length(varinfo.metadata[sym].ranges) for sym in keys(varinfo.metadata) + ]) + ranges = Vector{UnitRange{Int}}(undef, num_ranges) + idx = 0 + range_idx = 1 + + # ranges might be discontinuous => values are vectors of ranges rather than just ranges + sym_lookup = Dict{Symbol,Vector{UnitRange{Int}}}() + for sym in keys(varinfo.metadata) + sym_lookup[sym] = Vector{UnitRange{Int}}() + for r in varinfo.metadata[sym].ranges + ranges[range_idx] = idx .+ r + push!(sym_lookup[sym], ranges[range_idx]) + range_idx += 1 + end + + idx += varinfo.metadata[sym].ranges[end][end] + end + + b = Bijectors.Stacked(map(Bijectors.bijector, dists), ranges) + return sym2ranges ? (b, Dict(zip(keys(sym_lookup), values(sym_lookup)))) : b +end + +link!!(vi::AbstractVarInfo, model::Model) = link!!(vi, SampleFromPrior(), model) +function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) + return link!!(t, vi, SampleFromPrior(), model) +end +function link!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) + # Use `default_transformation` to decide which transformation to use if none is specified. + return link!!(default_transformation(model, vi), vi, spl, model) +end +function link!!( + t::DefaultTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model +) + # TODO: Implement this properly, e.g. using a context or something. + return link!!(BijectorTransformation(Bijectors.Stacked(model)), vi, spl, model) +end +function link!!(t::DefaultTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) + # TODO: Implement this properly, e.g. using a context or something. + DynamicPPL.link!(vi, spl) + return vi +end +function link!!( + t::BijectorTransformation{<:Bijectors.Stacked}, + vi::AbstractVarInfo, + spl::AbstractSampler, + model::Model, +) + b = t.bijector + x = vi[spl] + y, logjac = with_logabsdet_jacobian(b, x) + # TODO: Do we need this? + lp_new = getlogp(vi) - logjac + vi_new = setlogp!!(unflatten(vi, spl, y), lp_new) + return settrans!!(vi_new, t) +end + +invlink!!(vi::AbstractVarInfo, model::Model) = invlink!!(vi, SampleFromPrior(), model) +function invlink!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) + return invlink!!(t, vi, SampleFromPrior(), model) +end +function invlink!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) + # Here we extract the `transformation` from `vi` rather than using the default one. + return invlink!!(transformation(vi), vi, spl, model) +end +function invlink!!( + ::DefaultTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model +) + # TODO: Implement this properly, e.g. using a context or something. + return invlink!!(BijectorTransformation(Bijectors.Stacked(model)), vi, spl, model) +end +function invlink!!(::DefaultTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) + # TODO: Implement this properly, e.g. using a context or something. + DynamicPPL.invlink!(vi, spl) + return vi +end +function invlink!!( + t::BijectorTransformation{<:Bijectors.Stacked}, + vi::AbstractVarInfo, + spl::AbstractSampler, + model::Model, +) + b = t.bijector + ib = inv(b) + y = vi[spl] + x, logjac = with_logabsdet_jacobian(ib, y) + # TODO: Do we need this? + lp_new = getlogp(vi) - logjac + vi_new = setlogp!!(unflatten(vi, spl, x), lp_new) + return settrans!!(vi_new, NoTransformation()) +end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 6d5df2f9f..4587767e4 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -1,8 +1,3 @@ -abstract type AbstractTransformation end - -struct NoTransformation <: AbstractTransformation end -struct DefaultTransformation <: AbstractTransformation end - """ $(TYPEDEF) @@ -197,6 +192,8 @@ struct SimpleVarInfo{NT,T,C<:AbstractTransformation} <: AbstractVarInfo transformation::C end +transformation(vi::SimpleVarInfo) = vi.transformation + SimpleVarInfo(values, logp) = SimpleVarInfo(values, logp, NoTransformation()) function SimpleVarInfo{T}(θ) where {T<:Real} diff --git a/src/varinfo.jl b/src/varinfo.jl index 2ca2cf82f..f33299fb7 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -2,6 +2,30 @@ #### Types for typed and untyped VarInfo #### +abstract type AbstractTransformation end + +struct NoTransformation <: AbstractTransformation end +struct DefaultTransformation <: AbstractTransformation end + +struct BijectorTransformation{B<:Bijectors.AbstractBijector} <: AbstractTransformation + bijector::B +end + +""" + default_transformation(model::Model[, vi::AbstractVarInfo]) + +Return the `AbstractTransformation` currently related to `model` and, potentially, `vi`. +""" +default_transformation(model::Model, ::AbstractVarInfo) = default_transformation(model) +default_transformation(::Model) = DefaultTransformation() + +""" + transformation(vi::AbstractVarInfo) + +Return the `AbstractTransformation` related to `vi`. +""" +function transformation end + #################### # VarInfo metadata # #################### @@ -104,6 +128,8 @@ end const UntypedVarInfo = VarInfo{<:Metadata} const TypedVarInfo = VarInfo{<:NamedTuple} +transformation(vi::VarInfo) = DefaultTransformation() + function VarInfo(old_vi::UntypedVarInfo, spl, x::AbstractVector) new_vi = deepcopy(old_vi) new_vi[spl] = x From aecf97fad2d8994bad02fa73dfd7d9490155ebd8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 18 Jul 2022 10:33:15 +0100 Subject: [PATCH 140/221] added specialized implementation for NamedBijector and SimpleVarInfo when used together --- src/simple_varinfo.jl | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 4587767e4..ebce2a1eb 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -541,10 +541,7 @@ function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:Dict}) # Combine with old result. old_vals, old_string_vns = syms_to_result[sym] - syms_to_result[sym] = ( - vcat(old_vals, vals), - vcat(old_string_vns, map(string, vns)) - ) + syms_to_result[sym] = (vcat(old_vals, vals), vcat(old_string_vns, map(string, vns))) end # Construct `NamedTuple`. @@ -678,3 +675,35 @@ julia> # Truth. ``` """ Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarInfo(θ)) + +# Allow usage of `NamedBijector` too. +function link!!( + t::BijectorTransformation{<:Bijectors.NamedBijector}, + vi::SimpleVarInfo{<:NamedTuple}, + spl::AbstractSampler, + model::Model, +) + # TODO: Make sure that `spl` is respected. + b = t.bijector + x = vi.values + y, logjac = with_logabsdet_jacobian(b, x) + lp_new = getlogp(vi) - logjac + vi_new = setlogp!!(Setfield.@set(vi.values = y), lp_new) + return settrans!!(vi_new, t) +end + +function invlink!!( + t::BijectorTransformation{<:Bijectors.NamedBijector}, + vi::SimpleVarInfo{<:NamedTuple}, + spl::AbstractSampler, + model::Model, +) + # TODO: Make sure that `spl` is respected. + b = t.bijector + ib = inverse(b) + y = vi.values + x, logjac = with_logabsdet_jacobian(ib, y) + lp_new = getlogp(vi) - logjac + vi_new = setlogp!!(Setfield.@set(vi.values = x), lp_new) + return settrans!!(vi_new, NoTransformation()) +end From 9dcefdb25823158e8f62b84b62d9e70db7b57180 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 19 Jul 2022 13:10:01 +0100 Subject: [PATCH 141/221] use inverse instead of inv --- src/bijectors.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bijectors.jl b/src/bijectors.jl index a510fa162..de91eef32 100644 --- a/src/bijectors.jl +++ b/src/bijectors.jl @@ -91,7 +91,7 @@ function invlink!!( model::Model, ) b = t.bijector - ib = inv(b) + ib = inverse(b) y = vi[spl] x, logjac = with_logabsdet_jacobian(ib, y) # TODO: Do we need this? From fd0796b2130ada58b9725c908fb603d05460dd72 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 19 Jul 2022 13:10:50 +0100 Subject: [PATCH 142/221] preserve DefaultTransformation --- src/bijectors.jl | 11 ++++++++--- test/distribution_wrappers.jl | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/bijectors.jl b/src/bijectors.jl index de91eef32..24b244fe3 100644 --- a/src/bijectors.jl +++ b/src/bijectors.jl @@ -43,7 +43,12 @@ function link!!( t::DefaultTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model ) # TODO: Implement this properly, e.g. using a context or something. - return link!!(BijectorTransformation(Bijectors.Stacked(model)), vi, spl, model) + # Fall back to `Bijectors.Stacked` but then we act like we're using + # the `DefaultTransformation` by setting the transformation accordingly. + return settrans!!( + link!!(BijectorTransformation(Bijectors.Stacked(model)), vi, spl, model), + t + ) end function link!!(t::DefaultTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) # TODO: Implement this properly, e.g. using a context or something. @@ -51,7 +56,7 @@ function link!!(t::DefaultTransformation, vi::VarInfo, spl::AbstractSampler, mod return vi end function link!!( - t::BijectorTransformation{<:Bijectors.Stacked}, + t::BijectorTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model, @@ -85,7 +90,7 @@ function invlink!!(::DefaultTransformation, vi::VarInfo, spl::AbstractSampler, m return vi end function invlink!!( - t::BijectorTransformation{<:Bijectors.Stacked}, + t::BijectorTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model, diff --git a/test/distribution_wrappers.jl b/test/distribution_wrappers.jl index 350ce6014..0e7eb8bed 100644 --- a/test/distribution_wrappers.jl +++ b/test/distribution_wrappers.jl @@ -9,5 +9,5 @@ @test minimum(nd) == -Inf @test maximum(nd) == Inf @test logpdf(nd, 15.0) == 0 - @test Bijectors.logpdf_with_trans(nd, 0) == 0 + @test Bijectors.logpdf_with_trans(nd, 0, true) == 0 end From 94e5d48189c0234a92484b8659ed76636b0614f6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 23 Jul 2022 15:29:37 +0100 Subject: [PATCH 143/221] removed duplicated defs --- src/simple_varinfo.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 034f8f58e..ebce2a1eb 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -1,8 +1,3 @@ -abstract type AbstractTransformation end - -struct NoTransformation <: AbstractTransformation end -struct DefaultTransformation <: AbstractTransformation end - """ $(TYPEDEF) From 15fdf193520288bed9a49c005ead512f9b7b81e8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 23 Jul 2022 16:14:31 +0100 Subject: [PATCH 144/221] style --- src/bijectors.jl | 13 +++---------- src/utils.jl | 1 - 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/src/bijectors.jl b/src/bijectors.jl index 24b244fe3..aa02bb56e 100644 --- a/src/bijectors.jl +++ b/src/bijectors.jl @@ -46,8 +46,7 @@ function link!!( # Fall back to `Bijectors.Stacked` but then we act like we're using # the `DefaultTransformation` by setting the transformation accordingly. return settrans!!( - link!!(BijectorTransformation(Bijectors.Stacked(model)), vi, spl, model), - t + link!!(BijectorTransformation(Bijectors.Stacked(model)), vi, spl, model), t ) end function link!!(t::DefaultTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) @@ -56,10 +55,7 @@ function link!!(t::DefaultTransformation, vi::VarInfo, spl::AbstractSampler, mod return vi end function link!!( - t::BijectorTransformation, - vi::AbstractVarInfo, - spl::AbstractSampler, - model::Model, + t::BijectorTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model ) b = t.bijector x = vi[spl] @@ -90,10 +86,7 @@ function invlink!!(::DefaultTransformation, vi::VarInfo, spl::AbstractSampler, m return vi end function invlink!!( - t::BijectorTransformation, - vi::AbstractVarInfo, - spl::AbstractSampler, - model::Model, + t::BijectorTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model ) b = t.bijector ib = inverse(b) diff --git a/src/utils.jl b/src/utils.jl index ea10cacff..13de147ab 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -427,7 +427,6 @@ function BangBang.possible( promote_type(eltype(C), eltype(T)) <: eltype(C) end - # HACK(torfjelde): This makes it so it works on iterators, etc. by default. # TODO(torfjelde): Do better. """ From 48dfb9ccf1a8e13b31a3b11d5991dc8747a04f33 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 23 Jul 2022 17:19:50 +0100 Subject: [PATCH 145/221] fixed empty!! and added isempty for SimpleVarInfo --- src/simple_varinfo.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index ebce2a1eb..860f9b509 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -232,8 +232,9 @@ function unflatten(svi::SimpleVarInfo, x::AbstractVector) end function BangBang.empty!!(vi::SimpleVarInfo) - Setfield.@set resetlogp!!(vi).values = empty!!(vi.values) + return resetlogp!!(Setfield.@set vi.values = empty!!(vi.values)) end +Base.isempty(vi::SimpleVarInfo) = isempty(vi.values) getlogp(vi::SimpleVarInfo) = vi.logp setlogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = logp From 70ba82d24f0c5981c077de9d4988542a83aa07d1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 23 Jul 2022 17:20:04 +0100 Subject: [PATCH 146/221] added setindex!! for sampler with SimpleVarInfo --- src/simple_varinfo.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 860f9b509..a7101445d 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -367,6 +367,10 @@ function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName) return Setfield.@set vi.values = set!!(vi.values, vn, val) end +function BangBang.setindex!!(vi::SimpleVarInfo, val, spl::AbstractSampler) + return unflatten(vi, spl, val) +end + # TODO: Specialize to handle certain cases, e.g. a collection of `VarName` with # same symbol and same type of, say, `IndexLens`, for improved `.~` performance. function BangBang.setindex!!(vi::SimpleVarInfo, vals, vns::AbstractVector{<:VarName}) From d3bff268b9a8bfd4e0bc42b3721d14e2dfa92238 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 23 Jul 2022 17:20:15 +0100 Subject: [PATCH 147/221] made values_as compatible with empty SimpleVarInfo --- src/simple_varinfo.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index a7101445d..59cd16071 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -581,7 +581,8 @@ values_as(vi::SimpleVarInfo) = vi.values values_as(vi::SimpleVarInfo, ::Type{Dict}) = Dict(pairs(vi.values)) values_as(vi::SimpleVarInfo, ::Type{NamedTuple}) = NamedTuple(pairs(vi.values)) values_as(vi::SimpleVarInfo{<:NamedTuple}, ::Type{NamedTuple}) = vi.values -function values_as(vi::SimpleVarInfo, ::Type{Vector}) +function values_as(vi::SimpleVarInfo{<:Any,T}, ::Type{Vector}) where {T} + length(vi.values) == 0 && return T[] return mapreduce(v -> vec([v;]), vcat, values(vi.values)) end From b14e9cf3c7c3ab5fae29d4154feac4044fbf9dcf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 23 Jul 2022 17:20:32 +0100 Subject: [PATCH 148/221] added tests for base functionality for SimpleVarInfo too --- test/varinfo.jl | 57 +++++++++++++++++++++++++------------------------ 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/test/varinfo.jl b/test/varinfo.jl index 32c90bf47..e162e8180 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -45,20 +45,35 @@ @test hash(vn2) == hash(vn1) @test inspace(vn1, (:x,)) - function test_base!(vi) - empty!!(vi) + # Tests for `inspace` + space = (:x, :y, @varname(z[1]), @varname(M[1:10, :])) + + @test inspace(@varname(x), space) + @test inspace(@varname(y), space) + @test inspace(@varname(x[1]), space) + @test inspace(@varname(z[1][1]), space) + @test inspace(@varname(z[1][:]), space) + @test inspace(@varname(z[1][2:3:10]), space) + @test inspace(@varname(M[[2, 3], 1]), space) + @test inspace(@varname(M[:, 1:4]), space) + @test inspace(@varname(M[1, [2, 4, 6]]), space) + @test !inspace(@varname(z[2]), space) + @test !inspace(@varname(z), space) + + function test_base!!(vi_original) + vi = empty!!(vi_original) @test getlogp(vi) == 0 - @test get_num_produce(vi) == 0 + @test isempty(vi[:]) vn = @varname x dist = Normal(0, 1) r = rand(dist) - gid = Selector() + gid = DynamicPPL.Selector() @test isempty(vi) @test ~haskey(vi, vn) @test !(vn in keys(vi)) - push!!(vi, vn, r, dist, gid) + vi = push!!(vi, vn, r, dist, gid) @test ~isempty(vi) @test haskey(vi, vn) @test vn in keys(vi) @@ -68,37 +83,23 @@ @test vi[vn] == r @test vi[SampleFromPrior()][1] == r - vi[vn] = [2 * r] + vi = DynamicPPL.setindex!!(vi, 2 * r, vn) @test vi[vn] == 2 * r @test vi[SampleFromPrior()][1] == 2 * r - vi[SampleFromPrior()] = [3 * r] + vi = DynamicPPL.setindex!!(vi, [3 * r], SampleFromPrior()) @test vi[vn] == 3 * r @test vi[SampleFromPrior()][1] == 3 * r - empty!!(vi) + vi = empty!!(vi) @test isempty(vi) - push!!(vi, vn, r, dist, gid) - - function test_inspace() - space = (:x, :y, @varname(z[1]), @varname(M[1:10, :])) - - @test inspace(@varname(x), space) - @test inspace(@varname(y), space) - @test inspace(@varname(x[1]), space) - @test inspace(@varname(z[1][1]), space) - @test inspace(@varname(z[1][:]), space) - @test inspace(@varname(z[1][2:3:10]), space) - @test inspace(@varname(M[[2, 3], 1]), space) - @test inspace(@varname(M[:, 1:4]), space) - @test inspace(@varname(M[1, [2, 4, 6]]), space) - @test !inspace(@varname(z[2]), space) - @test !inspace(@varname(z), space) - end - return test_inspace() + return push!!(vi, vn, r, dist, gid) end + vi = VarInfo() - test_base!(vi) - test_base!(empty!!(TypedVarInfo(vi))) + test_base!!(vi) + test_base!!(TypedVarInfo(vi)) + test_base!!(SimpleVarInfo()) + test_base!!(SimpleVarInfo(Dict())) end @testset "flags" begin # Test flag setting: From 9f106face56eb9dcebfd78c1b9ff63803075ceaf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 23 Jul 2022 17:47:24 +0100 Subject: [PATCH 149/221] renamed bijectors.jl to transforming.jl --- src/DynamicPPL.jl | 2 +- src/{bijectors.jl => transforming.jl} | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) rename src/{bijectors.jl => transforming.jl} (98%) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 1027ae3ad..c573e7f32 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -152,6 +152,6 @@ include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") include("test_utils.jl") -include("bijectors.jl") +include("transforming.jl") end # module diff --git a/src/bijectors.jl b/src/transforming.jl similarity index 98% rename from src/bijectors.jl rename to src/transforming.jl index aa02bb56e..51067ac6b 100644 --- a/src/bijectors.jl +++ b/src/transforming.jl @@ -52,6 +52,7 @@ end function link!!(t::DefaultTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) # TODO: Implement this properly, e.g. using a context or something. DynamicPPL.link!(vi, spl) + # TODO: Add `logabsdet_jacobian` correction to `logp`! return vi end function link!!( @@ -60,7 +61,7 @@ function link!!( b = t.bijector x = vi[spl] y, logjac = with_logabsdet_jacobian(b, x) - # TODO: Do we need this? + lp_new = getlogp(vi) - logjac vi_new = setlogp!!(unflatten(vi, spl, y), lp_new) return settrans!!(vi_new, t) From bf34356e7f7f92dd891aa8058c2bb6f3aa23928d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 23 Jul 2022 17:47:36 +0100 Subject: [PATCH 150/221] fixed update of logp after initialize_parameters!! --- src/sampler.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index d3ff14b5d..0891e7136 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -108,11 +108,7 @@ function AbstractMCMC.step( # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 # and https://github.com/TuringLang/Turing.jl/issues/1563 # to avoid that existing variables are resampled - if _spl isa SampleFromUniform - model(rng, vi, SampleFromPrior()) - else - model(rng, vi, _spl) - end + vi = last(evaluate!!(model, vi, DefaultContext())) end return initialstep(rng, model, spl, vi; init_params=init_params, kwargs...) From 3e5f763a6979d5f59a1d9883c0dc8a8631a303ad Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 23 Jul 2022 17:48:19 +0100 Subject: [PATCH 151/221] remove now-redundant todo --- src/sampler.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sampler.jl b/src/sampler.jl index 0891e7136..50df65158 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -104,7 +104,6 @@ function AbstractMCMC.step( vi = initialize_parameters!!(vi, init_params, spl) # Update joint log probability. - # TODO: fix properly by using sampler and evaluation contexts # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 # and https://github.com/TuringLang/Turing.jl/issues/1563 # to avoid that existing variables are resampled From e649f3763ed5d71f0cbbc9aab85e2ba10d38f0c8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 23 Jul 2022 17:53:59 +0100 Subject: [PATCH 152/221] improved the initial step --- src/sampler.jl | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index 50df65158..d2b9cce11 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -78,7 +78,8 @@ function make_default_varinfo( sampler::AbstractSampler, context::AbstractContext, ) - return VarInfo(rng, model, sampler, context) + init_sampler = initialsampler(sampler) + return VarInfo(rng, model, init_sampler, context) end # initial step: general interface for resuming and @@ -96,8 +97,7 @@ function AbstractMCMC.step( end # Sample initial values. - _spl = initialsampler(spl) - vi = make_default_varinfo(rng, model, _spl) + vi = make_default_varinfo(rng, model, spl) # Update the parameters if provided. if init_params !== nothing @@ -141,8 +141,7 @@ function initialize_parameters!!(vi::AbstractVarInfo, init_params, spl::Sampler) # Get all values. linked = islinked(vi, spl) if linked - # TODO: Make work with immutable `vi`. - invlink!(vi, spl) + vi = invlink!!(vi, spl) end theta = vi[spl] length(theta) == length(init_theta) || @@ -159,8 +158,7 @@ function initialize_parameters!!(vi::AbstractVarInfo, init_params, spl::Sampler) # Update in `vi`. vi = setindex!!(vi, theta, spl) if linked - # TODO: Make work with immutable `vi`. - link!(vi, spl) + vi = link!!(vi, spl) end return vi From 594127035cdb5b6cdf20092abc94fa1247f2cc58 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 23 Jul 2022 18:24:35 +0100 Subject: [PATCH 153/221] fixed bug with initialize_parameters!! introduced in previous commit --- src/sampler.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index d2b9cce11..4acfed55f 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -101,7 +101,7 @@ function AbstractMCMC.step( # Update the parameters if provided. if init_params !== nothing - vi = initialize_parameters!!(vi, init_params, spl) + vi = initialize_parameters!!(vi, init_params, spl, model) # Update joint log probability. # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 @@ -130,7 +130,7 @@ By default, it returns an instance of [`SampleFromPrior`](@ref). """ initialsampler(spl::Sampler) = SampleFromPrior() -function initialize_parameters!!(vi::AbstractVarInfo, init_params, spl::Sampler) +function initialize_parameters!!(vi::AbstractVarInfo, init_params, spl::Sampler, model::Model) @debug "Using passed-in initial variable values" init_params # Flatten parameters. @@ -141,7 +141,7 @@ function initialize_parameters!!(vi::AbstractVarInfo, init_params, spl::Sampler) # Get all values. linked = islinked(vi, spl) if linked - vi = invlink!!(vi, spl) + vi = invlink!!(vi, spl, model) end theta = vi[spl] length(theta) == length(init_theta) || @@ -158,7 +158,7 @@ function initialize_parameters!!(vi::AbstractVarInfo, init_params, spl::Sampler) # Update in `vi`. vi = setindex!!(vi, theta, spl) if linked - vi = link!!(vi, spl) + vi = link!!(vi, spl, model) end return vi From f30b8759024fe4d08db7d0b26b049eef5f3bfc9f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 23 Jul 2022 18:27:18 +0100 Subject: [PATCH 154/221] Update src/sampler.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/sampler.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sampler.jl b/src/sampler.jl index 4acfed55f..24befae69 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -130,7 +130,9 @@ By default, it returns an instance of [`SampleFromPrior`](@ref). """ initialsampler(spl::Sampler) = SampleFromPrior() -function initialize_parameters!!(vi::AbstractVarInfo, init_params, spl::Sampler, model::Model) +function initialize_parameters!!( + vi::AbstractVarInfo, init_params, spl::Sampler, model::Model +) @debug "Using passed-in initial variable values" init_params # Flatten parameters. From cb3e1f4d4814267c170c11bf32542d338a3d23c0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 23 Jul 2022 18:29:00 +0100 Subject: [PATCH 155/221] add some comments on tonamedtuple --- src/simple_varinfo.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 59cd16071..a91237e2f 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -517,6 +517,8 @@ end increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing setgid!(vi::SimpleOrThreadSafeSimple, gid::Selector, vn::VarName) = nothing +# We need these to be compatible with how chains are constructed from `AbstractVarInfo` in Turing.jl. +# TODO: Move away from using these `tonamedtuple` methods. function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:NamedTuple{names}}) where {names} nt_vals = map(keys(vi)) do vn val = vi[vn] From f79fab4a3909d55edbbd933001451b68168c0f7d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 24 Jul 2022 10:08:47 +0100 Subject: [PATCH 156/221] Apply suggestions from code review Co-authored-by: David Widmann --- src/sampler.jl | 2 +- src/simple_varinfo.jl | 2 +- src/transforming.jl | 6 ++---- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index 24befae69..52f7a303a 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -97,7 +97,7 @@ function AbstractMCMC.step( end # Sample initial values. - vi = make_default_varinfo(rng, model, spl) + vi = default_varinfo(rng, model, spl) # Update the parameters if provided. if init_params !== nothing diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index a91237e2f..9bee5526a 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -584,7 +584,7 @@ values_as(vi::SimpleVarInfo, ::Type{Dict}) = Dict(pairs(vi.values)) values_as(vi::SimpleVarInfo, ::Type{NamedTuple}) = NamedTuple(pairs(vi.values)) values_as(vi::SimpleVarInfo{<:NamedTuple}, ::Type{NamedTuple}) = vi.values function values_as(vi::SimpleVarInfo{<:Any,T}, ::Type{Vector}) where {T} - length(vi.values) == 0 && return T[] + isempty(vi.values) && return T[] return mapreduce(v -> vec([v;]), vcat, values(vi.values)) end diff --git a/src/transforming.jl b/src/transforming.jl index 51067ac6b..f53a9435c 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -1,9 +1,7 @@ -using Bijectors - function Bijectors.Stacked( - model::DynamicPPL.Model, + model::Model, ::Val{sym2ranges}=Val(false); - varinfo::DynamicPPL.VarInfo=DynamicPPL.VarInfo(model), + varinfo::VarInfo=VarInfo(model), ) where {sym2ranges} dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...) From 58c25509ed71cf8179abb2d7092af42bc23acb0f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 24 Jul 2022 15:12:07 +0100 Subject: [PATCH 157/221] Update src/transforming.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/transforming.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transforming.jl b/src/transforming.jl index f53a9435c..9a34ae1d8 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -1,7 +1,5 @@ function Bijectors.Stacked( - model::Model, - ::Val{sym2ranges}=Val(false); - varinfo::VarInfo=VarInfo(model), + model::Model, ::Val{sym2ranges}=Val(false); varinfo::VarInfo=VarInfo(model) ) where {sym2ranges} dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...) From c316e702243dd3ce33ee279bc1334becc761fbc9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 24 Jul 2022 22:06:19 +0100 Subject: [PATCH 158/221] renamed make_default_varinfo to default_varinfo --- src/sampler.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index 52f7a303a..c4d8304c3 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -67,12 +67,12 @@ function AbstractMCMC.step( return vi, nothing end -function make_default_varinfo( +function default_varinfo( rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler ) - return make_default_varinfo(rng, model, sampler, DefaultContext()) + return default_varinfo(rng, model, sampler, DefaultContext()) end -function make_default_varinfo( +function default_varinfo( rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler, From 998fcf44b6bcbb5c3a486e91dfeef98d72f32a95 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 24 Jul 2022 22:06:32 +0100 Subject: [PATCH 159/221] simplified impls of getindex --- src/simple_varinfo.jl | 1 - src/varinfo.jl | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 9bee5526a..f689225dc 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -314,7 +314,6 @@ end Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(Base.Fix1(getindex, vi), vns) Base.getindex(svi::SimpleVarInfo, ::Colon) = values_as(svi, Vector) -Base.getindex(svi::SimpleVarInfo, ::Sampler) = svi[:] # Since we don't perform any transformations in `getindex` for `SimpleVarInfo` # we simply call `getindex` in `getindex_raw`. diff --git a/src/varinfo.jl b/src/varinfo.jl index f33299fb7..a249192d9 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1007,8 +1007,7 @@ Return the current value(s) of the random variables sampled by `spl` in `vi`. The value(s) may or may not be transformed to Euclidean space. """ getindex(vi::AbstractVarInfo, ::Colon) = copy(getall(vi)) -getindex(vi::AbstractVarInfo, ::SampleFromPrior) = vi[:] -getindex(vi::AbstractVarInfo, ::SampleFromUniform) = vi[:] +getindex(vi::AbstractVarInfo, ::AbstractSampler) = vi[:] getindex(vi::UntypedVarInfo, spl::Sampler) = copy(getval(vi, _getranges(vi, spl))) function getindex(vi::TypedVarInfo, spl::Sampler) # Gets the ranges as a NamedTuple From 0913a243eaf6f360e65c6eb44826771aab50a08a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 24 Jul 2022 22:08:32 +0100 Subject: [PATCH 160/221] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/sampler.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index c4d8304c3..3a4daf0b1 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -67,9 +67,7 @@ function AbstractMCMC.step( return vi, nothing end -function default_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler -) +function default_varinfo(rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler) return default_varinfo(rng, model, sampler, DefaultContext()) end function default_varinfo( From 9af2638fe6987241cbc3e62f1e066b2b186492da Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 24 Jul 2022 23:20:18 +0100 Subject: [PATCH 161/221] made impls of default getindex for VarInfo a bit more sensible --- src/varinfo.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index a249192d9..1602dcdcc 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1006,7 +1006,7 @@ Return the current value(s) of the random variables sampled by `spl` in `vi`. The value(s) may or may not be transformed to Euclidean space. """ -getindex(vi::AbstractVarInfo, ::Colon) = copy(getall(vi)) +getindex(vi::AbstractVarInfo, ::Colon) = values_as(vi, Vector) getindex(vi::AbstractVarInfo, ::AbstractSampler) = vi[:] getindex(vi::UntypedVarInfo, spl::Sampler) = copy(getval(vi, _getranges(vi, spl))) function getindex(vi::TypedVarInfo, spl::Sampler) @@ -1590,7 +1590,7 @@ values_as(vi::VarInfo) = vi.metadata Return values in `vi` as the specified type. """ -values_as(vi::VarInfo, ::Type{Vector}) = vi[SampleFromPrior()] +values_as(vi::VarInfo, ::Type{Vector}) = copy(getall(vi)) function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) iter = values_from_metadata(vi.metadata) return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) From 482ade72793f739128752175bfcc3e380f9680af Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 24 Jul 2022 23:20:41 +0100 Subject: [PATCH 162/221] removed unnecessary namespace specification --- src/model.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index f7dc4b113..3ff27e680 100644 --- a/src/model.jl +++ b/src/model.jl @@ -529,7 +529,7 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} SamplingContext(rng, SampleFromPrior(), DefaultContext()), ), ) - return DynamicPPL.values_as(x, T) + return values_as(x, T) end # Default RNG and type From b79bf28c72754204a6a014f4c072380544e3b297 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 24 Jul 2022 23:20:57 +0100 Subject: [PATCH 163/221] use isempty(vi) instead of checking its values --- src/simple_varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index f689225dc..b77815fb0 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -583,7 +583,7 @@ values_as(vi::SimpleVarInfo, ::Type{Dict}) = Dict(pairs(vi.values)) values_as(vi::SimpleVarInfo, ::Type{NamedTuple}) = NamedTuple(pairs(vi.values)) values_as(vi::SimpleVarInfo{<:NamedTuple}, ::Type{NamedTuple}) = vi.values function values_as(vi::SimpleVarInfo{<:Any,T}, ::Type{Vector}) where {T} - isempty(vi.values) && return T[] + isempty(vi) && return T[] return mapreduce(v -> vec([v;]), vcat, values(vi.values)) end From 15087c5f9c28b519b527116e1d138e5f3c9258d1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 24 Jul 2022 23:21:12 +0100 Subject: [PATCH 164/221] fix values_as for certain combinations --- src/simple_varinfo.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index b77815fb0..99db746fa 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -579,8 +579,10 @@ Return the values/realizations in `varinfo` as `Type`, if implemented. If no `Type` is provided, return values as stored in `varinfo`. """ values_as(vi::SimpleVarInfo) = vi.values -values_as(vi::SimpleVarInfo, ::Type{Dict}) = Dict(pairs(vi.values)) -values_as(vi::SimpleVarInfo, ::Type{NamedTuple}) = NamedTuple(pairs(vi.values)) +values_as(vi::SimpleVarInfo, ::Type{Dict}) = Dict(zip(keys(vi), values(vi.values))) +function values_as(vi::SimpleVarInfo{<:Dict}, ::Type{NamedTuple}) + return NamedTuple((Symbol(k), v) for (k, v) in vi.values) +end values_as(vi::SimpleVarInfo{<:NamedTuple}, ::Type{NamedTuple}) = vi.values function values_as(vi::SimpleVarInfo{<:Any,T}, ::Type{Vector}) where {T} isempty(vi) && return T[] From 88dbdca6b5c65f8fa38ec14351b1755b76e5bf22 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 25 Jul 2022 10:55:58 +0100 Subject: [PATCH 165/221] added deprecation warnings for link! and invlink! --- src/varinfo.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/varinfo.jl b/src/varinfo.jl index 1602dcdcc..b2934cf5b 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -796,6 +796,10 @@ of their distributions to the Euclidean space and set their corresponding `"tran flag values to `true`. """ function link!(vi::UntypedVarInfo, spl::Sampler) + Base.depwarn( + "`link!(varinfo, sampler)` is deprecated, use `link!!(varinfo, sampler, model)` instead.", + :link!, + ) # TODO: Change to a lazy iterator over `vns` vns = _getvns(vi, spl) if ~istrans(vi, vns[1]) @@ -815,6 +819,10 @@ function link!(vi::UntypedVarInfo, spl::Sampler) end end function link!(vi::TypedVarInfo, spl::AbstractSampler) + Base.depwarn( + "`link!(varinfo, sampler)` is deprecated, use `link!!(varinfo, sampler, model)` instead.", + :link!, + ) return link!(vi, spl, Val(getspace(spl))) end function link!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) @@ -865,6 +873,10 @@ Euclidean space back to the support of their distributions and sets their corres `"trans"` flag values to `false`. """ function invlink!(vi::UntypedVarInfo, spl::AbstractSampler) + Base.depwarn( + "`invlink!(varinfo, sampler)` is deprecated, use `invlink!!(varinfo, sampler, model)` instead.", + :invlink!, + ) vns = _getvns(vi, spl) if istrans(vi, vns[1]) for vn in vns @@ -882,6 +894,10 @@ function invlink!(vi::UntypedVarInfo, spl::AbstractSampler) end end function invlink!(vi::TypedVarInfo, spl::AbstractSampler) + Base.depwarn( + "`invlink!(varinfo, sampler)` is deprecated, use `invlink!!(varinfo, sampler, model)` instead.", + :invlink!, + ) return invlink!(vi, spl, Val(getspace(spl))) end function invlink!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) From 9b3c40fc6c2159ff10ad880d2456fbd1f457c8f5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 27 Jul 2022 01:13:43 +0100 Subject: [PATCH 166/221] add logabsdet-jacobian term in link! and invlink! --- src/varinfo.jl | 48 ++++++++++++++++++++---------------------------- 1 file changed, 20 insertions(+), 28 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index b2934cf5b..f3664a1c4 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -807,11 +807,11 @@ function link!(vi::UntypedVarInfo, spl::Sampler) @debug "X -> ℝ for $(vn)..." dist = getdist(vi, vn) # TODO: Use inplace versions to avoid allocations - setval!( - vi, - vectorize(dist, Bijectors.link(dist, reconstruct(dist, getval(vi, vn)))), - vn, - ) + b = bijector(dist) + x = reconstruct(dist, getval(vi, vn)) + y, logjac = with_logabsdet_jacobian(b, x) + setval!(vi, vectorize(dist, y), vn) + acclogp!!(vi, -logjac) settrans!!(vi, true, vn) end else @@ -844,14 +844,11 @@ end for vn in f_vns @debug "X -> R for $(vn)..." dist = getdist(vi, vn) - setval!( - vi, - vectorize( - dist, - Bijectors.link(dist, reconstruct(dist, getval(vi, vn))), - ), - vn, - ) + x = reconstruct(dist, getval(vi, vn)) + b = bijector(dist) + y, logjac = with_logabsdet_jacobian(b, x) + setval!(vi, vectorize(dist, y), vn) + acclogp!!(vi, -logjac) settrans!!(vi, true, vn) end else @@ -882,11 +879,11 @@ function invlink!(vi::UntypedVarInfo, spl::AbstractSampler) for vn in vns @debug "ℝ -> X for $(vn)..." dist = getdist(vi, vn) - setval!( - vi, - vectorize(dist, Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn)))), - vn, - ) + y = reconstruct(dist, getval(vi, vn)) + b = bijector(dist) + x, logjac = with_logabsdet_jacobian(b, y) + setval!(vi, vectorize(dist, x), vn) + acclogp!!(vi, -logjac) settrans!!(vi, false, vn) end else @@ -919,16 +916,11 @@ end for vn in f_vns @debug "ℝ -> X for $(vn)..." dist = getdist(vi, vn) - setval!( - vi, - vectorize( - dist, - Bijectors.invlink( - dist, reconstruct(dist, getval(vi, vn)) - ), - ), - vn, - ) + y = reconstruct(dist, getval(vi, vn)) + b = inv(bijector(dist)) + x, logjac = with_logabsdet_jacobian(b, y) + setval!(vi, vectorize(dist, x), vn) + acclogp!!(vi, -logjac) settrans!!(vi, false, vn) end else From 57d321ce324c0b053defc318dac815ad925e4e51 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 27 Jul 2022 01:15:19 +0100 Subject: [PATCH 167/221] use context to implement link!! and invlink!! --- src/transforming.jl | 157 ++++++++++++++++++++++++++++++-------------- 1 file changed, 106 insertions(+), 51 deletions(-) diff --git a/src/transforming.jl b/src/transforming.jl index 9a34ae1d8..233c31cfc 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -1,30 +1,83 @@ -function Bijectors.Stacked( - model::Model, ::Val{sym2ranges}=Val(false); varinfo::VarInfo=VarInfo(model) -) where {sym2ranges} - dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...) - - num_ranges = sum([ - length(varinfo.metadata[sym].ranges) for sym in keys(varinfo.metadata) - ]) - ranges = Vector{UnitRange{Int}}(undef, num_ranges) - idx = 0 - range_idx = 1 - - # ranges might be discontinuous => values are vectors of ranges rather than just ranges - sym_lookup = Dict{Symbol,Vector{UnitRange{Int}}}() - for sym in keys(varinfo.metadata) - sym_lookup[sym] = Vector{UnitRange{Int}}() - for r in varinfo.metadata[sym].ranges - ranges[range_idx] = idx .+ r - push!(sym_lookup[sym], ranges[range_idx]) - range_idx += 1 - end - - idx += varinfo.metadata[sym].ranges[end][end] +struct LazyTransformationContext{isinverse} <: AbstractContext end +NodeTrait(::LazyTransformationContext) = IsLeaf() + +function tilde_assume( + ::LazyTransformationContext{isinverse}, right, vn, vi +) where {isinverse} + r = vi[vn, right] + lp = Bijectors.logpdf_with_trans(right, r, !isinverse) + + if istrans(vi, vn) + @assert isinverse "Trying to link already transformed variables" + else + @assert !isinverse "Trying to invlink non-transformed variables" + end + + # Only transform if `!isinverse` since `vi[vn, right]` + # already performs the inverse transformation if it's transformed. + r_transformed = isinverse ? r : bijector(right)(r) + return r, lp, setindex!!(vi, r_transformed, vn) +end + +function dot_tilde_assume( + ::LazyTransformationContext{isinverse}, + dist::Distribution, + var::AbstractArray, + vns::AbstractArray{<:VarName}, + vi, +) where {isinverse} + r = getindex.((vi,), vns, (dist,)) + b = bijector(dist) + + is_trans_uniques = unique(istrans.((vi,), vns)) + @assert length(is_trans_uniques) == 1 "LazyTransformationContext only supports transforming all variables" + is_trans = first(is_trans_uniques) + if is_trans + @assert isinverse "Trying to link already transformed variables" + else + @assert !isinverse "Trying to invlink non-transformed variables" + end + + # Only transform if `!isinverse` since `vi[vn, right]` + # already performs the inverse transformation if it's transformed. + r_transformed = isinverse ? r : b.(r) + lp = sum(Bijectors.logpdf_with_trans.((dist,), r, (!isinverse,))) + return r, lp, setindex!!(vi, r_transformed, vns) +end + +function dot_tilde_assume( + ::LazyTransformationContext{isinverse}, + dist::MultivariateDistribution, + var::AbstractMatrix, + vns::AbstractVector{<:VarName}, + vi::AbstractVarInfo, +) where {isinverse} + @assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))" + r = vi[vns, dist] + + # Compute `logpdf` with logabsdet-jacobian correction. + lp = sum(zip(vns, eachcol(r))) do (vn, ri) + return Bijectors.logpdf_with_trans(dist, ri, !isinverse) + end + + # Transform _all_ values. + is_trans_uniques = unique(istrans.((vi,), vns)) + @assert length(is_trans_uniques) == 1 "LazyTransformationContext only supports transforming all variables" + is_trans = first(is_trans_uniques) + if is_trans + @assert isinverse "Trying to link already transformed variables" + else + @assert !isinverse "Trying to invlink non-transformed variables" end - b = Bijectors.Stacked(map(Bijectors.bijector, dists), ranges) - return sym2ranges ? (b, Dict(zip(keys(sym_lookup), values(sym_lookup)))) : b + b = bijector(dist) + for (vn, ri) in zip(vns, eachcol(r)) + # Only transform if `!isinverse` since `vi[vn, right]` + # already performs the inverse transformation if it's transformed. + vi = DynamicPPL.setindex!!(vi, isinverse ? ri : b(ri), vn) + end + + return r, lp, vi end link!!(vi::AbstractVarInfo, model::Model) = link!!(vi, SampleFromPrior(), model) @@ -38,30 +91,12 @@ end function link!!( t::DefaultTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model ) - # TODO: Implement this properly, e.g. using a context or something. - # Fall back to `Bijectors.Stacked` but then we act like we're using - # the `DefaultTransformation` by setting the transformation accordingly. - return settrans!!( - link!!(BijectorTransformation(Bijectors.Stacked(model)), vi, spl, model), t - ) + return settrans!!(last(evaluate!!(model, vi, LazyTransformationContext{false}())), t) end function link!!(t::DefaultTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) - # TODO: Implement this properly, e.g. using a context or something. - DynamicPPL.link!(vi, spl) - # TODO: Add `logabsdet_jacobian` correction to `logp`! + link!(vi, spl) return vi end -function link!!( - t::BijectorTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model -) - b = t.bijector - x = vi[spl] - y, logjac = with_logabsdet_jacobian(b, x) - - lp_new = getlogp(vi) - logjac - vi_new = setlogp!!(unflatten(vi, spl, y), lp_new) - return settrans!!(vi_new, t) -end invlink!!(vi::AbstractVarInfo, model::Model) = invlink!!(vi, SampleFromPrior(), model) function invlink!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) @@ -74,22 +109,42 @@ end function invlink!!( ::DefaultTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model ) - # TODO: Implement this properly, e.g. using a context or something. - return invlink!!(BijectorTransformation(Bijectors.Stacked(model)), vi, spl, model) + return settrans!!( + last(evaluate!!(model, vi, LazyTransformationContext{true}())), NoTransformation() + ) end function invlink!!(::DefaultTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) - # TODO: Implement this properly, e.g. using a context or something. - DynamicPPL.invlink!(vi, spl) + invlink!(vi, spl) return vi end + +# BijectorTransformation +function link!!( + t::BijectorTransformation{<:Bijectors.Bijector{1}}, + vi::AbstractVarInfo, + spl::AbstractSampler, + model::Model, +) + b = t.bijector + x = vi[spl] + y, logjac = with_logabsdet_jacobian(b, x) + + lp_new = getlogp(vi) - logjac + vi_new = setlogp!!(unflatten(vi, spl, y), lp_new) + return settrans!!(vi_new, t) +end + function invlink!!( - t::BijectorTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model + t::BijectorTransformation{<:Bijectors.Bijector{1}}, + vi::AbstractVarInfo, + spl::AbstractSampler, + model::Model, ) b = t.bijector ib = inverse(b) y = vi[spl] x, logjac = with_logabsdet_jacobian(ib, y) - # TODO: Do we need this? + lp_new = getlogp(vi) - logjac vi_new = setlogp!!(unflatten(vi, spl, x), lp_new) return settrans!!(vi_new, NoTransformation()) From 4409149477673a080160f67ed2ad311f6ff92192 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 27 Jul 2022 01:19:40 +0100 Subject: [PATCH 168/221] added tests for link!! and invlink!! --- test/simple_varinfo.jl | 48 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 6a8c545ca..d6c0b6d3c 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -58,6 +58,54 @@ end end + @testset "link!! & invlink!! on $(nameof(model))" for model in + DynamicPPL.TestUtils.DEMO_MODELS + values_constrained = rand(NamedTuple, model) + @testset "$(typeof(vi))" for vi in ( + SimpleVarInfo(Dict()), SimpleVarInfo(values_constrained), VarInfo(model) + ) + for vn in DynamicPPL.TestUtils.varnames(model) + vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) + end + vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) + lp_orig = getlogp(vi) + + # `link!!` + vi_linked = link!!(deepcopy(vi), model) + lp_linked = getlogp(vi_linked) + values_unconstrained, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + model, values_constrained... + ) + # Should result in the correct logjoint. + @test lp_linked ≈ lp_linked_true + # Should be approx. the same as the "lazy" transformation. + @test logjoint(model, vi_linked) ≈ lp_linked + + # TODO: Should not `VarInfo` also error here? The current implementation + # only warns and acts as a no-op. + if vi isa SimpleVarInfo + @test_throws AssertionError link!!(vi_linked, model) + end + + # `invlink!!` + vi_invlinked = invlink!!(deepcopy(vi_linked), model) + lp_invlinked = getlogp(vi_invlinked) + lp_invlinked_true = DynamicPPL.TestUtils.logjoint_true( + model, values_constrained... + ) + # Should result in the correct logjoint. + @test lp_invlinked ≈ lp_invlinked_true + # Should be approx. the same as the "lazy" transformation. + @test logjoint(model, vi_invlinked) ≈ lp_invlinked + + # Should result in same values. + @test all( + DynamicPPL.getindex_raw(vi_invlinked, vn) ≈ get(values_constrained, vn) for + vn in DynamicPPL.TestUtils.varnames(model) + ) + end + end + @testset "SimpleVarInfo on $(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS # We might need to pre-allocate for the variable `m`, so we need From c34f257352aa8403e72f011fc0dd11693ae66d64 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 27 Jul 2022 01:19:51 +0100 Subject: [PATCH 169/221] added a note comment --- src/varinfo.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/varinfo.jl b/src/varinfo.jl index f3664a1c4..c85277297 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -128,6 +128,9 @@ end const UntypedVarInfo = VarInfo{<:Metadata} const TypedVarInfo = VarInfo{<:NamedTuple} +# NOTE: This is kind of weird, but it effectively preserves the "old" +# behavior where we're allowed to call `link!` on the same `VarInfo` +# multiple times. transformation(vi::VarInfo) = DefaultTransformation() function VarInfo(old_vi::UntypedVarInfo, spl, x::AbstractVector) From 73765e7a67a9787c2c0445b5d5e78a51c12f3061 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 27 Jul 2022 14:43:53 +0100 Subject: [PATCH 170/221] renamed DefaultTransformation to LazyTransformation and BijectorTransformation to StaticTransformation --- src/simple_varinfo.jl | 6 +++--- src/transforming.jl | 19 +++++++++---------- src/varinfo.jl | 10 +++++----- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 99db746fa..b8671ef65 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -556,7 +556,7 @@ end # NOTE: We don't implement `settrans!!(vi, trans, vn)`. function settrans!!(vi::SimpleVarInfo, trans) - return settrans!!(vi, trans ? DefaultTransformation() : NoTransformation()) + return settrans!!(vi, trans ? LazyTransformation() : NoTransformation()) end function settrans!!(vi::SimpleVarInfo, transformation::AbstractTransformation) return Setfield.@set vi.transformation = transformation @@ -687,7 +687,7 @@ Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarIn # Allow usage of `NamedBijector` too. function link!!( - t::BijectorTransformation{<:Bijectors.NamedBijector}, + t::StaticTransformation{<:Bijectors.NamedBijector}, vi::SimpleVarInfo{<:NamedTuple}, spl::AbstractSampler, model::Model, @@ -702,7 +702,7 @@ function link!!( end function invlink!!( - t::BijectorTransformation{<:Bijectors.NamedBijector}, + t::StaticTransformation{<:Bijectors.NamedBijector}, vi::SimpleVarInfo{<:NamedTuple}, spl::AbstractSampler, model::Model, diff --git a/src/transforming.jl b/src/transforming.jl index 233c31cfc..22cd22f36 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -89,11 +89,11 @@ function link!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) return link!!(default_transformation(model, vi), vi, spl, model) end function link!!( - t::DefaultTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model + t::LazyTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model ) return settrans!!(last(evaluate!!(model, vi, LazyTransformationContext{false}())), t) end -function link!!(t::DefaultTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) +function link!!(t::LazyTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) link!(vi, spl) return vi end @@ -107,25 +107,25 @@ function invlink!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) return invlink!!(transformation(vi), vi, spl, model) end function invlink!!( - ::DefaultTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model + ::LazyTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model ) return settrans!!( last(evaluate!!(model, vi, LazyTransformationContext{true}())), NoTransformation() ) end -function invlink!!(::DefaultTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) +function invlink!!(::LazyTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) invlink!(vi, spl) return vi end -# BijectorTransformation +# Vector-based ones. function link!!( - t::BijectorTransformation{<:Bijectors.Bijector{1}}, + t::StaticTransformation{<:Bijectors.Bijector{1}}, vi::AbstractVarInfo, spl::AbstractSampler, model::Model, ) - b = t.bijector + b = inverse(t.bijector) x = vi[spl] y, logjac = with_logabsdet_jacobian(b, x) @@ -135,15 +135,14 @@ function link!!( end function invlink!!( - t::BijectorTransformation{<:Bijectors.Bijector{1}}, + t::StaticTransformation{<:Bijectors.Bijector{1}}, vi::AbstractVarInfo, spl::AbstractSampler, model::Model, ) b = t.bijector - ib = inverse(b) y = vi[spl] - x, logjac = with_logabsdet_jacobian(ib, y) + x, logjac = with_logabsdet_jacobian(b, y) lp_new = getlogp(vi) - logjac vi_new = setlogp!!(unflatten(vi, spl, x), lp_new) diff --git a/src/varinfo.jl b/src/varinfo.jl index c85277297..ba48bd58a 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -5,10 +5,10 @@ abstract type AbstractTransformation end struct NoTransformation <: AbstractTransformation end -struct DefaultTransformation <: AbstractTransformation end +struct LazyTransformation <: AbstractTransformation end -struct BijectorTransformation{B<:Bijectors.AbstractBijector} <: AbstractTransformation - bijector::B +struct StaticTransformation{F} <: AbstractTransformation + bijector::F end """ @@ -17,7 +17,7 @@ end Return the `AbstractTransformation` currently related to `model` and, potentially, `vi`. """ default_transformation(model::Model, ::AbstractVarInfo) = default_transformation(model) -default_transformation(::Model) = DefaultTransformation() +default_transformation(::Model) = LazyTransformation() """ transformation(vi::AbstractVarInfo) @@ -131,7 +131,7 @@ const TypedVarInfo = VarInfo{<:NamedTuple} # NOTE: This is kind of weird, but it effectively preserves the "old" # behavior where we're allowed to call `link!` on the same `VarInfo` # multiple times. -transformation(vi::VarInfo) = DefaultTransformation() +transformation(vi::VarInfo) = LazyTransformation() function VarInfo(old_vi::UntypedVarInfo, spl, x::AbstractVector) new_vi = deepcopy(old_vi) From 0c0c393258a0a420c76b7e6edbe90e22d087c31d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 27 Jul 2022 14:44:22 +0100 Subject: [PATCH 171/221] added maybe_invlink_before_eval!! allowing invlinking once --- src/model.jl | 7 +++- src/transforming.jl | 79 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index 3ff27e680..4c567dc52 100644 --- a/src/model.jl +++ b/src/model.jl @@ -490,7 +490,12 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf context_new = setleafcontext( context, setleafcontext(model.context, leafcontext(context)) ) - model.f(model, varinfo, context_new, $(unwrap_args...)) + model.f( + model, + maybe_invlink_before_eval!!(varinfo, context_new, model), + context_new, + $(unwrap_args...), + ) end end diff --git a/src/transforming.jl b/src/transforming.jl index 22cd22f36..dcabc4ccd 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -98,6 +98,85 @@ function link!!(t::LazyTransformation, vi::VarInfo, spl::AbstractSampler, model: return vi end +""" + maybe_invlink_before_eval!!([t::Transformation,] vi, context, model) + +Return a possibly invlinked version of `vi`. + +This will be called prior to `model` evaluation, allowing one to perform a single +`invlink!!` _before_ evaluation rather lazyily evaluate the transforms on as-we-need +basis as is done with [`LazyTransformation` ](@ref). + +# Examples +```julia-repl +julia> using DynamicPPL, Distributions, Bijectors + +julia> @model demo() = x ~ Normal() +demo (generic function with 2 methods) + +julia> # By subtyping `Bijector{1}`, we inherit the `(inv)link!!` defined for + # bijectors which acts on 1-dimensional arrays, i.e. vectors. + struct MyBijector <: Bijectors.Bijector{1} end + +julia> # Define some dummy `inverse` which will be used in the `link!!` call. + Bijectors.inverse(f::MyBijector) = identity + +julia> # We need to define `with_logabsdet_jacobian` for `MyBijector` + # (`identity` already has `with_logabsdet_jacobian` defined) + function Bijectors.with_logabsdet_jacobian(::MyBijector, x) + # Just using a large number of the logabsdet-jacobian term + # for demonstration purposes. + return (x, 1000) + end + +julia> # Change the `default_transformation` for our model to be a + # `StaticTransformation` using `MyBijector`. + function DynamicPPL.default_transformation(::Model{typeof(demo)}) + return DynamicPPL.StaticTransformation(MyBijector()) + end + +julia> model = demo(); + +julia> vi = SimpleVarInfo(x=1.0) +SimpleVarInfo((x = 1.0,), 0.0) + +julia> # Uses the `inverse` of `MyBijector`, which we have defined as `identity` + vi_linked = link!!(vi, model) +Transformed SimpleVarInfo((x = 1.0,), 0.0) + +julia> # Now performs a single `invlink!!` before model evaluation. + logjoint(model, vi_linked) +-1001.4189385332047 +``` +""" +function maybe_invlink_before_eval!!( + vi::AbstractVarInfo, context::AbstractContext, model::Model +) + return maybe_invlink_before_eval!!(transformation(vi), vi, context, model) +end +function maybe_invlink_before_eval!!( + t::AbstractTransformation, + vi::AbstractVarInfo, + context::AbstractContext, + model::Model, +) + # Default behavior is to _not_ transform. + return vi +end +function maybe_invlink_before_eval!!( + t::StaticTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model +) + return invlink!!(t, vi, _default_sampler(context), model) +end + +function _default_sampler(context::AbstractContext) + return _default_sampler(NodeTrait(_default_sampler, context), context) +end +_default_sampler(::IsLeaf, context::AbstractContext) = SampleFromPrior() +function _default_sampler(::IsParent, context::AbstractContext) + return _default_sampler(childcontext(context)) +end + invlink!!(vi::AbstractVarInfo, model::Model) = invlink!!(vi, SampleFromPrior(), model) function invlink!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) return invlink!!(t, vi, SampleFromPrior(), model) From 5e5175519e6c124b486eadbbfe815defd21b6e17 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 27 Jul 2022 14:45:20 +0100 Subject: [PATCH 172/221] formatting --- src/transforming.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/transforming.jl b/src/transforming.jl index dcabc4ccd..625daa240 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -155,10 +155,7 @@ function maybe_invlink_before_eval!!( return maybe_invlink_before_eval!!(transformation(vi), vi, context, model) end function maybe_invlink_before_eval!!( - t::AbstractTransformation, - vi::AbstractVarInfo, - context::AbstractContext, - model::Model, + t::AbstractTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model ) # Default behavior is to _not_ transform. return vi From 3dbc7a926c4af7e72583e10581a29e9bc42d3e33 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 28 Jul 2022 17:24:57 +0100 Subject: [PATCH 173/221] use OrderedDict instead of Dict for SimpleVarInfo as it preserves the order of insertion --- Project.toml | 1 + src/DynamicPPL.jl | 2 ++ src/simple_varinfo.jl | 23 ++++++++++++++--------- src/utils.jl | 2 +- 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 89aa66be2..557872ae2 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c573e7f32..c789013b8 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -4,6 +4,7 @@ using AbstractMCMC: AbstractSampler, AbstractChains using AbstractPPL using Bijectors using Distributions +using OrderedCollections: OrderedDict using AbstractMCMC: AbstractMCMC using BangBang: BangBang, push!!, empty!!, setindex!! @@ -75,6 +76,7 @@ export AbstractVarInfo, Sample, init, vectorize, + OrderedDict, # Model Model, getmissings, diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index b8671ef65..130425455 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -4,7 +4,7 @@ A simple wrapper of the parameters with a `logp` field for accumulation of the logdensity. -Currently only implemented for `NT<:NamedTuple` and `NT<:Dict`. +Currently only implemented for `NT<:NamedTuple` and `NT<:AbstractDict`. # Fields $(FIELDS) @@ -64,8 +64,8 @@ julia> # (×) If we don't provide the container... ERROR: type NamedTuple has no field x [...] -julia> # If one does not know the varnames, we can use a `Dict` instead. - _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo{Float64}(Dict()), ctx); +julia> # If one does not know the varnames, we can use a `OrderedDict` instead. + _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo{Float64}(OrderedDict()), ctx); julia> # (✓) Sort of fast, but only possible at runtime. vi[@varname(x[1])] @@ -81,6 +81,11 @@ ERROR: KeyError: key x[1:2] not found [...] ``` +_Technically_, it's possible to use any implementation of `AbstractDict` in place of +`OrderedDict`, but `OrderedDict` ensures that certain operations, e.g. linearization/flattening +of the values in the varinfo, are consistent between evaluations. Hence `OrderedDict` is +the preferred implementation of `AbstractDict` to use here. + You can also sample in _transformed_ space: ```jldoctest simplevarinfo-general @@ -104,8 +109,8 @@ julia> xs = [last(DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo() julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true -julia> # And with `Dict` of course! - _, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true), ctx); +julia> # And with `OrderedDict` of course! + _, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict()), true), ctx); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ 0.6225185067787314 @@ -160,9 +165,9 @@ ERROR: type NamedTuple has no field b [...] ``` -Using `Dict` as underlying storage. +Using `OrderedDict` as underlying storage. ```jldoctest -julia> svi_dict = SimpleVarInfo(Dict(@varname(m) => (a = [1.0], ))); +julia> svi_dict = SimpleVarInfo(OrderedDict(@varname(m) => (a = [1.0], ))); julia> svi_dict[@varname(m)] (a = [1.0],) @@ -279,7 +284,7 @@ end Base.getindex(vi::SimpleVarInfo, vn::VarName) = get(vi.values, vn) -# `Dict` +# `AbstractDict` function Base.getindex(vi::SimpleVarInfo{<:AbstractDict}, vn::VarName) if haskey(vi.values, vn) return vi.values[vn] @@ -421,7 +426,7 @@ function BangBang.push!!( return Setfield.@set vi.values = set!!(vi.values, vn, value) end -# `Dict` +# `AbstractDict` function BangBang.push!!( vi::SimpleVarInfo{<:AbstractDict}, vn::VarName, diff --git a/src/utils.jl b/src/utils.jl index 13de147ab..54fa66433 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -463,6 +463,6 @@ end function unflatten(original::NamedTuple{names}, x::AbstractVector) where {names} return NamedTuple{names}(unflatten(values(original), x)) end -function unflatten(original::Dict, x::AbstractVector) +function unflatten(original::AbstractDict, x::AbstractVector) return Dict(zip(keys(original), unflatten(collect(values(original)), x))) end From 809de9a8b935eb76b2d860015766604b93dd8a5a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 29 Jul 2022 10:47:17 +0100 Subject: [PATCH 174/221] added compat entry for OrderedCollections --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 557872ae2..62d997620 100644 --- a/Project.toml +++ b/Project.toml @@ -29,6 +29,7 @@ ConstructionBase = "1" Distributions = "0.23.8, 0.24, 0.25" DocStringExtensions = "0.8" MacroTools = "0.5.6" +OrderedCollections = "1" Setfield = "0.7.1, 0.8" ZygoteRules = "0.2" julia = "1.6" From 656175f3cb00940e151dd4fe2bc2b94880bb7dbd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 29 Jul 2022 10:47:17 +0100 Subject: [PATCH 175/221] added compat entry for OrderedCollections --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 162448316..47990b043 100644 --- a/Project.toml +++ b/Project.toml @@ -28,6 +28,7 @@ ConstructionBase = "1" Distributions = "0.23.8, 0.24, 0.25" DocStringExtensions = "0.8, 0.9" MacroTools = "0.5.6" +OrderedCollections = "1" Setfield = "0.7.1, 0.8" ZygoteRules = "0.2" julia = "1.6" From a2259785958d50877edfd61d83cb4545894fddbe Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 28 Jul 2022 17:24:57 +0100 Subject: [PATCH 176/221] use OrderedDict instead of Dict for SimpleVarInfo as it preserves the order of insertion --- Project.toml | 1 + src/DynamicPPL.jl | 2 ++ src/simple_varinfo.jl | 23 ++++++++++++++--------- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 47990b043..8695ef3f3 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 86d4e0def..4cd41c09b 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -4,6 +4,7 @@ using AbstractMCMC: AbstractSampler, AbstractChains using AbstractPPL using Bijectors using Distributions +using OrderedCollections: OrderedDict using AbstractMCMC: AbstractMCMC using BangBang: BangBang, push!!, empty!!, setindex!! @@ -73,6 +74,7 @@ export AbstractVarInfo, Sample, init, vectorize, + OrderedDict, # Model Model, getmissings, diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 5b9edefdf..9e7d98e10 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -9,7 +9,7 @@ struct DefaultTransformation <: AbstractTransformation end A simple wrapper of the parameters with a `logp` field for accumulation of the logdensity. -Currently only implemented for `NT<:NamedTuple` and `NT<:Dict`. +Currently only implemented for `NT<:NamedTuple` and `NT<:AbstractDict`. # Fields $(FIELDS) @@ -69,8 +69,8 @@ julia> # (×) If we don't provide the container... ERROR: type NamedTuple has no field x [...] -julia> # If one does not know the varnames, we can use a `Dict` instead. - _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo{Float64}(Dict()), ctx); +julia> # If one does not know the varnames, we can use a `OrderedDict` instead. + _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo{Float64}(OrderedDict()), ctx); julia> # (✓) Sort of fast, but only possible at runtime. vi[@varname(x[1])] @@ -86,6 +86,11 @@ ERROR: KeyError: key x[1:2] not found [...] ``` +_Technically_, it's possible to use any implementation of `AbstractDict` in place of +`OrderedDict`, but `OrderedDict` ensures that certain operations, e.g. linearization/flattening +of the values in the varinfo, are consistent between evaluations. Hence `OrderedDict` is +the preferred implementation of `AbstractDict` to use here. + You can also sample in _transformed_ space: ```jldoctest simplevarinfo-general @@ -109,8 +114,8 @@ julia> xs = [last(DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo() julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true -julia> # And with `Dict` of course! - _, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true), ctx); +julia> # And with `OrderedDict` of course! + _, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict()), true), ctx); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ 0.6225185067787314 @@ -165,9 +170,9 @@ ERROR: type NamedTuple has no field b [...] ``` -Using `Dict` as underlying storage. +Using `OrderedDict` as underlying storage. ```jldoctest -julia> svi_dict = SimpleVarInfo(Dict(@varname(m) => (a = [1.0], ))); +julia> svi_dict = SimpleVarInfo(OrderedDict(@varname(m) => (a = [1.0], ))); julia> svi_dict[@varname(m)] (a = [1.0],) @@ -274,7 +279,7 @@ end Base.getindex(vi::SimpleVarInfo, vn::VarName) = get(vi.values, vn) -# `Dict` +# `AbstractDict` function Base.getindex(vi::SimpleVarInfo{<:AbstractDict}, vn::VarName) if haskey(vi.values, vn) return vi.values[vn] @@ -416,7 +421,7 @@ function BangBang.push!!( return Setfield.@set vi.values = set!!(vi.values, vn, value) end -# `Dict` +# `AbstractDict` function BangBang.push!!( vi::SimpleVarInfo{<:AbstractDict}, vn::VarName, From fce67ee0ce4db50f917b983b315026bcd5c95e53 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 30 Jul 2022 12:45:20 +0100 Subject: [PATCH 177/221] improvements to values_as --- src/simple_varinfo.jl | 17 +++---- src/varinfo.jl | 102 +++++++++++++++++++++++++++++++++++++----- 2 files changed, 99 insertions(+), 20 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 9e7d98e10..9199d52a3 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -530,17 +530,14 @@ istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi) istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) -""" - values_as(varinfo[, Type]) - -Return the values/realizations in `varinfo` as `Type`, if implemented. - -If no `Type` is provided, return values as stored in `varinfo`. -""" values_as(vi::SimpleVarInfo) = vi.values -values_as(vi::SimpleVarInfo, ::Type{Dict}) = Dict(pairs(vi.values)) -values_as(vi::SimpleVarInfo, ::Type{NamedTuple}) = NamedTuple(pairs(vi.values)) -values_as(vi::SimpleVarInfo{<:NamedTuple}, ::Type{NamedTuple}) = vi.values +values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values +function values_as(vi::SimpleVarInfo, ::Type{D}) where {D<:AbstractDict} + return Setfield.constructorof(D)(zip(keys(vi), values(vi.values))) +end +function values_as(vi::SimpleVarInfo{<:AbstractDict}, ::Type{NamedTuple}) + return NamedTuple((Symbol(k), v) for (k, v) in vi.values) +end """ logjoint(model::Model, θ) diff --git a/src/varinfo.jl b/src/varinfo.jl index 22728ba9a..606647330 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1550,30 +1550,112 @@ function _setval_and_resample_kernel!(vi::VarInfo, vn::VarName, values, keys) end """ - values_as(vi::AbstractVarInfo) -""" -values_as(vi::VarInfo) = vi.metadata + values_as(varinfo[, Type]) -""" - values_as(vi::AbstractVarInfo, ::Type{NamedTuple}) - values_as(vi::AbstractVarInfo, ::Type{Dict}) +Return the values/realizations in `varinfo` as `Type`, if implemented. + +If no `Type` is provided, return values as stored in `varinfo`. + +# Examples + +`SimpleVarInfo` with `NamedTuple`: + +```jldoctest +julia> data = (x = 1.0, m = [2.0]); + +julia> values_as(SimpleVarInfo(data)) +(x = 1.0, m = [2.0]) + +julia> values_as(SimpleVarInfo(data), NamedTuple) +(x = 1.0, m = [2.0]) + +julia> values_as(SimpleVarInfo(data), OrderedDict) +OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Any} with 2 entries: + x => 1.0 + m => [2.0] +``` + +`SimpleVarInfo` with `OrderedDict`: + +```jldoctest +julia> data = OrderedDict{Any,Any}(@varname(x) => 1.0, @varname(m) => [2.0]); + +julia> values_as(SimpleVarInfo(data)) +OrderedDict{Any, Any} with 2 entries: + x => 1.0 + m => [2.0] + +julia> values_as(SimpleVarInfo(data), NamedTuple) +(x = 1.0, m = [2.0]) + +julia> values_as(SimpleVarInfo(data), OrderedDict) +OrderedDict{Any, Any} with 2 entries: + x => 1.0 + m => [2.0] +``` + +`TypedVarInfo`: + +```jldoctest +julia> # Just use an example model to construct the `VarInfo` because we're lazy. + vi = VarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); + +julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; + +julia> # For the sake of brevity, let's just check the type. + md = values_as(vi); md.s isa DynamicPPL.Metadata +true + +julia> values_as(vi, NamedTuple) +(s = 1.0, m = 2.0) + +julia> values_as(vi, OrderedDict) +OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entries: + s => 1.0 + m => 2.0 +``` -Return values in `vi` as the specified type. +`UntypedVarInfo`: + +```jldoctest +julia> # Just use an example model to construct the `VarInfo` because we're lazy. + vi = VarInfo(); DynamicPPL.TestUtils.demo_assume_dot_observe()(vi); + +julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; + +julia> # For the sake of brevity, let's just check the type. + values_as(vi) isa DynamicPPL.Metadata +true + +julia> values_as(vi, NamedTuple) +(s = 1.0, m = 2.0) + +julia> values_as(vi, OrderedDict) +OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entries: + s => 1.0 + m => 2.0 +``` """ +values_as(vi::VarInfo) = vi.metadata function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) iter = values_from_metadata(vi.metadata) return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) end -values_as(vi::UntypedVarInfo, ::Type{Dict}) = Dict(values_from_metadata(vi.metadata)) +function values_as(vi::UntypedVarInfo, ::Type{D}) where {D<:AbstractDict} + # TODO: Should we just use `ConstructionBase.constructorof` here instead? + return Setfield.constructorof(D)(values_from_metadata(vi.metadata)) +end function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{NamedTuple}) where {names} iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names) return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) end -function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{Dict}) where {names} +function values_as( + vi::VarInfo{<:NamedTuple{names}}, ::Type{D} +) where {names,D<:AbstractDict} iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names) - return Dict(iter) + return Setfield.constructorof(D)(iter) end function values_from_metadata(md::Metadata) From b165b35732b406d8e0d92c52bcabf679a603720e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 30 Jul 2022 12:45:35 +0100 Subject: [PATCH 178/221] export values_as --- src/DynamicPPL.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 4cd41c09b..0b632f120 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -60,6 +60,7 @@ export AbstractVarInfo, link!, invlink!, tonamedtuple, + values_as, # VarName (reexport from AbstractPPL) VarName, inspace, From af9c520bb971322f665067ff9077e736bd5f17a8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 30 Jul 2022 12:45:58 +0100 Subject: [PATCH 179/221] added values_as to docs --- docs/src/api.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/src/api.md b/docs/src/api.md index 809e6c49e..f7e5443fc 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -167,6 +167,10 @@ push!! empty!! ``` +```@docs +values_as +``` + #### `SimpleVarInfo` ```@docs From 47c30e3afd1680d59f75ea11dd282e8522e1caee Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 30 Jul 2022 12:46:23 +0100 Subject: [PATCH 180/221] added proper testing for values_as --- src/varname.jl | 3 +++ test/varinfo.jl | 70 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/src/varname.jl b/src/varname.jl index 343bb0da8..6f42981f0 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -1,3 +1,6 @@ +# FIXME: This fix should be in `AbstractPPL`. +AbstractPPL.subsumes(::Setfield.IdentityLens, ::Setfield.IdentityLens) = true + """ subsumes_string(u::String, v::String[, u_indexing]) diff --git a/test/varinfo.jl b/test/varinfo.jl index 32c90bf47..e1d9bb160 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,3 +1,23 @@ +# TODO: Should all this go somewhere else? Seems useful for more tests. +short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = short_varinfo_name(vi.varinfo) +short_varinfo_name(::TypedVarInfo) = "TypedVarInfo" +short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" +short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" +short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" + +function update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns) + for vn in vns + vi = DynamicPPL.setindex!!(vi, get(vals, vn), vn) + end + return vi +end + +function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns) + for vn in vns + @test vi[vn] == get(vals, vn) + end +end + @testset "varinfo.jl" begin @testset "TypedVarInfo" begin @model gdemo(x, y) = begin @@ -314,4 +334,54 @@ x = Bijectors.invlink(dist, DynamicPPL.getindex_raw(vi, vn)) @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) end + + @testset "values_as" begin + @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS + example_values = rand(NamedTuple, model) + vns = DynamicPPL.TestUtils.varnames(model) + + vi_untyped = VarInfo() + model(vi_untyped) + vi_typed = TypedVarInfo(vi_untyped) + svi_typed = SimpleVarInfo(example_values) + svi_untyped = SimpleVarInfo(OrderedDict()) + + varinfos = map((vi_untyped, vi_typed, svi_typed, svi_untyped)) do vi + # Set them all to the same values. + update_values!!(vi, example_values, vns) + end + + @testset "$(short_varinfo_name(vi))" for vi in varinfos + # Just making sure. + test_values(vi, example_values, vns) + + @testset "NamedTuple" begin + vals = values_as(vi, NamedTuple) + for vn in vns + if haskey(vals, Symbol(vn)) + # Assumed to be of form `(var"m[1]" = 1.0, ...)`. + @test getindex(vals, Symbol(vn)) == getindex(vi, vn) + else + # Assumed to be of form `(m = [1.0, ...], ...)`. + @test get(vals, vn) == getindex(vi, vn) + end + end + end + + @testset "OrderedDict" begin + vals = values_as(vi, OrderedDict) + # All varnames in `vns` should be subsumed by one of `keys(vals)`. + @test all(vns) do vn + any(DynamicPPL.subsumes(vn_left, vn) for vn_left in keys(vals)) + end + # Iterate over `keys(vals)` because we might have scenarios such as + # `vals = OrderedDict(@varname(m) => [1.0])` but `@varname(m[1])` is + # the varname present in `vns`, not `@varname(m)`. + for vn in keys(vals) + @test getindex(vals, vn) == getindex(vi, vn) + end + end + end + end + end end From 40477a41a35c1d3d78927e29109c2e5df53ea738 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 30 Jul 2022 12:51:47 +0100 Subject: [PATCH 181/221] bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8695ef3f3..381076b82 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.20.0" +version = "0.20.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From a972f8ec887e8788931f1e6f695b33f947972bf1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 30 Jul 2022 13:19:29 +0100 Subject: [PATCH 182/221] Apply suggestions from code review --- src/varinfo.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 606647330..2d677fb9c 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1642,7 +1642,6 @@ function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) end function values_as(vi::UntypedVarInfo, ::Type{D}) where {D<:AbstractDict} - # TODO: Should we just use `ConstructionBase.constructorof` here instead? return Setfield.constructorof(D)(values_from_metadata(vi.metadata)) end From ab2a8b5634e788974329f2089b4eee47be2d4a11 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 18 Aug 2022 14:07:12 +0100 Subject: [PATCH 183/221] use ConstructionBase explicitly --- src/DynamicPPL.jl | 1 + src/varinfo.jl | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 0b632f120..6871e2f93 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -10,6 +10,7 @@ using AbstractMCMC: AbstractMCMC using BangBang: BangBang, push!!, empty!!, setindex!! using ChainRulesCore: ChainRulesCore using MacroTools: MacroTools +using ConstructionBase: ConstructionBase using Setfield: Setfield using ZygoteRules: ZygoteRules diff --git a/src/varinfo.jl b/src/varinfo.jl index 2d677fb9c..00b99162f 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1642,7 +1642,7 @@ function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) end function values_as(vi::UntypedVarInfo, ::Type{D}) where {D<:AbstractDict} - return Setfield.constructorof(D)(values_from_metadata(vi.metadata)) + return ConstructionBase.constructorof(D)(values_from_metadata(vi.metadata)) end function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{NamedTuple}) where {names} @@ -1654,7 +1654,7 @@ function values_as( vi::VarInfo{<:NamedTuple{names}}, ::Type{D} ) where {names,D<:AbstractDict} iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names) - return Setfield.constructorof(D)(iter) + return ConstructionBase.constructorof(D)(iter) end function values_from_metadata(md::Metadata) From ee7fcd6c8ced0b4d6ee9b142fe3377e8ac843555 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 18 Aug 2022 14:12:51 +0100 Subject: [PATCH 184/221] use OrderedDict in rand instead of NamedTuple as it supports arbitrary models --- src/model.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index f7dc4b113..d768f746e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -525,7 +525,7 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} x = last( evaluate!!( model, - SimpleVarInfo{Float64}(), + SimpleVarInfo{Float64}(OrderedDict()), SamplingContext(rng, SampleFromPrior(), DefaultContext()), ), ) From 0082505eef4e16b2103b023cc2a160ad0acba911 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 9 Sep 2022 12:48:28 +0100 Subject: [PATCH 185/221] properly deprecate link! and invlink! --- src/transforming.jl | 6 ++++-- src/varinfo.jl | 44 +++++++++++++++++++++++++++++++------------- 2 files changed, 35 insertions(+), 15 deletions(-) diff --git a/src/transforming.jl b/src/transforming.jl index 625daa240..68cf6f8d7 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -94,7 +94,8 @@ function link!!( return settrans!!(last(evaluate!!(model, vi, LazyTransformationContext{false}())), t) end function link!!(t::LazyTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) - link!(vi, spl) + # Call `_link!` instead of `link!` to avoid deprecation warning. + _link!(vi, spl) return vi end @@ -190,7 +191,8 @@ function invlink!!( ) end function invlink!!(::LazyTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) - invlink!(vi, spl) + # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. + _invlink!(vi, spl) return vi end diff --git a/src/varinfo.jl b/src/varinfo.jl index 15503f471..2070d4210 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -798,11 +798,21 @@ Transform the values of the random variables sampled by `spl` in `vi` from the s of their distributions to the Euclidean space and set their corresponding `"trans"` flag values to `true`. """ -function link!(vi::UntypedVarInfo, spl::Sampler) +function link!(vi::VarInfo, spl::AbstractSampler) Base.depwarn( "`link!(varinfo, sampler)` is deprecated, use `link!!(varinfo, sampler, model)` instead.", :link!, ) + return _link!(vi, spl) +end +function link!(vi::VarInfo, spl::AbstractSampler, spaceval::Val) + Base.depwarn( + "`link!(varinfo, sampler, spaceval)` is deprecated, use `link!!(varinfo, sampler, model)` instead.", + :link!, + ) + return _link!(vi, spl, spaceval) +end +function _link!(vi::UntypedVarInfo, spl::Sampler) # TODO: Change to a lazy iterator over `vns` vns = _getvns(vi, spl) if ~istrans(vi, vns[1]) @@ -821,14 +831,14 @@ function link!(vi::UntypedVarInfo, spl::Sampler) @warn("[DynamicPPL] attempt to link a linked vi") end end -function link!(vi::TypedVarInfo, spl::AbstractSampler) +function _link!(vi::TypedVarInfo, spl::AbstractSampler) Base.depwarn( "`link!(varinfo, sampler)` is deprecated, use `link!!(varinfo, sampler, model)` instead.", :link!, ) - return link!(vi, spl, Val(getspace(spl))) + return _link!(vi, spl, Val(getspace(spl))) end -function link!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) +function _link!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) vns = _getvns(vi, spl) return _link!(vi.metadata, vi, vns, spaceval) end @@ -872,18 +882,30 @@ Transform the values of the random variables sampled by `spl` in `vi` from the Euclidean space back to the support of their distributions and sets their corresponding `"trans"` flag values to `false`. """ -function invlink!(vi::UntypedVarInfo, spl::AbstractSampler) +function invlink!(vi::VarInfo, spl::AbstractSampler) Base.depwarn( "`invlink!(varinfo, sampler)` is deprecated, use `invlink!!(varinfo, sampler, model)` instead.", :invlink!, ) + return _invlink!(vi, spl) +end + +function invlink!(vi::VarInfo, spl::AbstractSampler, spaceval::Val) + Base.depwarn( + "`invlink!(varinfo, sampler, spaceval)` is deprecated, use `invlink!!(varinfo, sampler, model)` instead.", + :invlink!, + ) + return _invlink!(vi, spl, spaceval) +end + +function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler) vns = _getvns(vi, spl) if istrans(vi, vns[1]) for vn in vns @debug "ℝ -> X for $(vn)..." dist = getdist(vi, vn) y = reconstruct(dist, getval(vi, vn)) - b = bijector(dist) + b = inverse(bijector(dist)) x, logjac = with_logabsdet_jacobian(b, y) setval!(vi, vectorize(dist, x), vn) acclogp!!(vi, -logjac) @@ -893,14 +915,10 @@ function invlink!(vi::UntypedVarInfo, spl::AbstractSampler) @warn("[DynamicPPL] attempt to invlink an invlinked vi") end end -function invlink!(vi::TypedVarInfo, spl::AbstractSampler) - Base.depwarn( - "`invlink!(varinfo, sampler)` is deprecated, use `invlink!!(varinfo, sampler, model)` instead.", - :invlink!, - ) - return invlink!(vi, spl, Val(getspace(spl))) +function _invlink!(vi::TypedVarInfo, spl::AbstractSampler) + return _invlink!(vi, spl, Val(getspace(spl))) end -function invlink!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) +function _invlink!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) vns = _getvns(vi, spl) return _invlink!(vi.metadata, vi, vns, spaceval) end From 597dfdac5e3b37be941e2aeb416d31fe4699777b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 9 Sep 2022 12:48:56 +0100 Subject: [PATCH 186/221] added transformation impls for ThreadSafeVarInfo --- src/threadsafe.jl | 2 ++ src/transforming.jl | 8 ++++++++ 2 files changed, 10 insertions(+) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 4dc37e4ea..9060f5d5c 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -17,6 +17,8 @@ const ThreadSafeVarInfoWithRef{V<:AbstractVarInfo} = ThreadSafeVarInfo{ V,<:AbstractArray{<:Ref} } +transformation(vi::ThreadSafeVarInfo) = transformation(vi.varinfo) + # Instead of updating the log probability of the underlying variables we # just update the array of log probabilities. function acclogp!!(vi::ThreadSafeVarInfo, logp) diff --git a/src/transforming.jl b/src/transforming.jl index 68cf6f8d7..d7e18553c 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -88,6 +88,11 @@ function link!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) # Use `default_transformation` to decide which transformation to use if none is specified. return link!!(default_transformation(model, vi), vi, spl, model) end +function link!!( + t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model +) + return link!!(t, vi.varinfo, spl, model) +end function link!!( t::LazyTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model ) @@ -183,6 +188,9 @@ function invlink!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) # Here we extract the `transformation` from `vi` rather than using the default one. return invlink!!(transformation(vi), vi, spl, model) end +function invlink!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model) + return invlink!!(t, vi.varinfo, spl, model) +end function invlink!!( ::LazyTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model ) From be7ae6ce798249e85e8b129e193813bcd706a251 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 9 Sep 2022 12:49:18 +0100 Subject: [PATCH 187/221] added missing impl of values_as for VarInfo and Vector --- src/varinfo.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/varinfo.jl b/src/varinfo.jl index 2070d4210..835df4120 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1695,6 +1695,7 @@ OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entri ``` """ values_as(vi::VarInfo) = vi.metadata +values_as(vi::VarInfo, ::Type{Vector}) = copy(getall(vi)) function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) iter = values_from_metadata(vi.metadata) return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) From 8dfc7c049a17bc6de02e3c210028108ef9aeb15f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 9 Sep 2022 12:49:36 +0100 Subject: [PATCH 188/221] use inverse instead of deprecated inv --- src/varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 835df4120..c2ef3a022 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -938,7 +938,7 @@ end @debug "ℝ -> X for $(vn)..." dist = getdist(vi, vn) y = reconstruct(dist, getval(vi, vn)) - b = inv(bijector(dist)) + b = inverse(bijector(dist)) x, logjac = with_logabsdet_jacobian(b, y) setval!(vi, vectorize(dist, x), vn) acclogp!!(vi, -logjac) From e475c87b6406cab5dea74eb1736c1b2c90212f71 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 9 Sep 2022 12:52:06 +0100 Subject: [PATCH 189/221] Update src/transforming.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/transforming.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transforming.jl b/src/transforming.jl index d7e18553c..d1bcaff8e 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -188,7 +188,9 @@ function invlink!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) # Here we extract the `transformation` from `vi` rather than using the default one. return invlink!!(transformation(vi), vi, spl, model) end -function invlink!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model) +function invlink!!( + t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model +) return invlink!!(t, vi.varinfo, spl, model) end function invlink!!( From 1e6b0a9468673b3713dd95c6b4a432c234124ee3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 9 Sep 2022 14:00:24 +0100 Subject: [PATCH 190/221] renamed nested_haskey and defined common method called getvalue and hasvalue --- src/contexts.jl | 19 +++-- src/simple_varinfo.jl | 2 +- src/utils.jl | 167 ++++++++++++++++++++++++++++++------------ 3 files changed, 130 insertions(+), 58 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index bd8acf278..2c59cf68c 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -298,34 +298,33 @@ childcontext(context::ConditionContext) = context.context setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) """ - hasvalue(context, vn) + hasvalue(context::AbstractContext, vn::VarName) Return `true` if `vn` is found in `context`. """ -hasvalue(context, vn) = false -hasvalue(context::ConditionContext, vn::VarName) = nested_haskey(context.values, vn) +hasvalue(context::AbstractContext, vn::VarName) = false +hasvalue(context::ConditionContext, vn::VarName) = hasvalue(context.values, vn) function hasvalue(context::ConditionContext, vns::AbstractArray{<:VarName}) - return all(Base.Fix1(nested_haskey, context.values), vns) + return all(Base.Fix1(hasvalue, context.values), vns) end """ - getvalue(context, vn) + getvalue(context::AbstractContext, vn::VarName) Return value of `vn` in `context`. """ -function getvalue(context::AbstractContext, vn) +function getvalue(context::AbstractContext, vn::VarName) return error("context $(context) does not contain value for $vn") end -getvalue(context::NamedConditionContext, vn) = get(context.values, vn) -getvalue(context::ConditionContext, vn) = nested_getindex(context.values, vn) +getvalue(context::ConditionContext, vn::VarName) = getvalue(context.values, vn) """ hasvalue_nested(context, vn) Return `true` if `vn` is found in `context` or any of its descendants. -This is contrast to [`hasvalue`](@ref) which only checks for `vn` in `context`, -not recursively checking if `vn` is in any of its descendants. +This is contrast to [`hasvalue(::AbstractContext, ::VarName)`](@ref) which only checks +for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. """ function hasvalue_nested(context::AbstractContext, vn) return hasvalue_nested(NodeTrait(hasvalue_nested, context), context, vn) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index a8c7996b0..f7b741836 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -331,7 +331,7 @@ function getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribut return reconstruct(dist, vals, length(vns)) end -Base.haskey(vi::SimpleVarInfo, vn::VarName) = nested_haskey(vi.values, vn) +Base.haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn) function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName) # For `NamedTuple` we treat the symbol in `vn` as the _property_ to set. diff --git a/src/utils.jl b/src/utils.jl index de45978b1..1ce1330a7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -467,113 +467,154 @@ function unflatten(original::AbstractDict, x::AbstractVector) return Dict(zip(keys(original), unflatten(collect(values(original)), x))) end +# TODO: Move `getvalue` and `hasvalue` to AbstractPPL.jl. """ - nested_getindex(values::AbstractDict, vn::VarName) + getvalue(vals, vn::VarName) -Return value corresponding to `vn` in `values` by also looking -in the the actual values of the dict. +Return the value(s) in `vals` represented by `vn`. + +Note that this method is different from `getindex`. See examples below. # Examples +For `NamedTuple`: + ```jldoctest -julia> DynamicPPL.nested_getindex(Dict(@varname(x) => [1.0]), @varname(x)) # same as `getindex` +julia> vals = (x = [1.0],); + +julia> DynamicPPL.getvalue(vals, @varname(x)) # same as `getindex` 1-element Vector{Float64}: 1.0 -julia> DynamicPPL.nested_getindex(Dict(@varname(x) => [1.0]), @varname(x[1])) # different from `getindex` +julia> DynamicPPL.getvalue(vals, @varname(x[1])) # different from `getindex` 1.0 -julia> DynamicPPL.nested_getindex(Dict(@varname(x) => [1.0]), @varname(x[2])) +julia> DynamicPPL.getvalue(vals, @varname(x[2])) ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] [...] ``` -""" -function nested_getindex(values::AbstractDict, vn::VarName) - maybeval = get(values, vn, nothing) - if maybeval !== nothing - return maybeval - end - # Split the lens into the key / `parent` and the extraction lens / `child`. - parent, child, issuccess = splitlens(getlens(vn)) do lens - l = lens === nothing ? Setfield.IdentityLens() : lens - haskey(values, VarName(vn, l)) - end - # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. - keylens = parent === nothing ? Setfield.IdentityLens() : parent +For `AbstractDict`: - # If we found a valid split, then we can extract the value. - if !issuccess - # At this point we just throw an error since the key could not be found. - throw(KeyError(vn)) - end +```jldoctest +julia> vals = Dict(@varname(x) => [1.0]); - # TODO: Should we also check that we `canview` the extracted `value` - # rather than just let it fail upon `get` call? - value = values[VarName(vn, keylens)] - return get(value, child) -end +julia> DynamicPPL.getvalue(vals, @varname(x)) # same as `getindex` +1-element Vector{Float64}: + 1.0 + +julia> DynamicPPL.getvalue(vals, @varname(x[1])) # different from `getindex` +1.0 + +julia> DynamicPPL.getvalue(vals, @varname(x[2])) +ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] +[...] +``` + +In the `AbstractDict` case we can also have keys such as `v[1]`: + +```jldoctest +julia> vals = Dict(@varname(x[1]) => [1.0,]); + +julia> DynamicPPL.getvalue(vals, @varname(x[1])) # same as `getindex` +1-element Vector{Float64}: + 1.0 + +julia> DynamicPPL.getvalue(vals, @varname(x[1][1])) # different from `getindex` +1.0 + +julia> DynamicPPL.getvalue(vals, @varname(x[1][2])) +ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] +[...] + +julia> DynamicPPL.getvalue(vals, @varname(x[2][1])) +ERROR: KeyError: key x[2][1] not found +[...] +``` +""" +getvalue(vals::NamedTuple, vn::VarName) = get(vals, vn) +getvalue(vals::AbstractDict, vn::VarName) = nested_getindex(vals, vn) """ - nested_haskey(x, vn::VarName) + hasvalue(vals, vn::VarName) -Determine whether `x` has a mapping for a given `vn`. +Determine whether `vals` has a mapping for a given `vn`, as compatible with [`getvalue`](@ref). # Examples With `x` as a `NamedTuple`: + ```jldoctest -julia> DynamicPPL.nested_haskey((x = 1.0, ), @varname(x)) +julia> DynamicPPL.hasvalue((x = 1.0, ), @varname(x)) true -julia> DynamicPPL.nested_haskey((x = 1.0, ), @varname(x[1])) +julia> DynamicPPL.hasvalue((x = 1.0, ), @varname(x[1])) false -julia> DynamicPPL.nested_haskey((x = [1.0],), @varname(x)) +julia> DynamicPPL.hasvalue((x = [1.0],), @varname(x)) true -julia> DynamicPPL.nested_haskey((x = [1.0],), @varname(x[1])) +julia> DynamicPPL.hasvalue((x = [1.0],), @varname(x[1])) true -julia> DynamicPPL.nested_haskey((x = [1.0],), @varname(x[2])) +julia> DynamicPPL.hasvalue((x = [1.0],), @varname(x[2])) false ``` With `x` as a `AbstractDict`: + ```jldoctest -julia> DynamicPPL.nested_haskey(Dict(@varname(x) => 1.0, ), @varname(x)) +julia> DynamicPPL.hasvalue(Dict(@varname(x) => 1.0, ), @varname(x)) true -julia> DynamicPPL.nested_haskey(Dict(@varname(x) => 1.0, ), @varname(x[1])) +julia> DynamicPPL.hasvalue(Dict(@varname(x) => 1.0, ), @varname(x[1])) false -julia> DynamicPPL.nested_haskey(Dict(@varname(x) => [1.0]), @varname(x)) +julia> DynamicPPL.hasvalue(Dict(@varname(x) => [1.0]), @varname(x)) true -julia> DynamicPPL.nested_haskey(Dict(@varname(x) => [1.0]), @varname(x[1])) +julia> DynamicPPL.hasvalue(Dict(@varname(x) => [1.0]), @varname(x[1])) true -julia> DynamicPPL.nested_haskey(Dict(@varname(x) => [1.0]), @varname(x[2])) +julia> DynamicPPL.hasvalue(Dict(@varname(x) => [1.0]), @varname(x[2])) +false +``` + +In the `AbstractDict` case we can also have keys such as `v[1]`: + +```jldoctest +julia> vals = Dict(@varname(x[1]) => [1.0,]); + +julia> DynamicPPL.hasvalue(vals, @varname(x[1])) # same as `haskey` +true + +julia> DynamicPPL.hasvalue(vals, @varname(x[1][1])) # different from `haskey` +true + +julia> DynamicPPL.hasvalue(vals, @varname(x[1][2])) +false + +julia> DynamicPPL.hasvalue(vals, @varname(x[2][1])) false ``` """ -function nested_haskey(nt::NamedTuple, vn::VarName{sym}) where {sym} +function hasvalue(vals::NamedTuple, vn::VarName{sym}) where {sym} # LHS: Ensure that `nt` indeed has the property we want. # RHS: Ensure that the lens can view into `nt`. - return haskey(nt, sym) && canview(getlens(vn), getproperty(nt, sym)) + return haskey(vals, sym) && canview(getlens(vn), getproperty(vals, sym)) end # For `dictlike` we need to check wether `vn` is "immediately" present, or # if some ancestor of `vn` is present in `dictlike`. -function nested_haskey(dict::AbstractDict, vn::VarName) +function hasvalue(vals::AbstractDict, vn::VarName) # First we check if `vn` is present as is. - haskey(dict, vn) && return true + haskey(vals, vn) && return true # If `vn` is not present, we check any parent-varnames by attempting # to split the lens into the key / `parent` and the extraction lens / `child`. # If `issuccess` is `true`, we found such a split, and hence `vn` is present. parent, child, issuccess = splitlens(getlens(vn)) do lens l = lens === nothing ? Setfield.IdentityLens() : lens - haskey(dict, VarName(vn, l)) + haskey(vals, VarName(vn, l)) end # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. keylens = parent === nothing ? Setfield.IdentityLens() : parent @@ -582,11 +623,43 @@ function nested_haskey(dict::AbstractDict, vn::VarName) issuccess || return false # At this point we just need to check that we `canview` the value. - value = dict[VarName(vn, keylens)] + value = vals[VarName(vn, keylens)] return canview(child, value) end +""" + nested_getindex(values::AbstractDict, vn::VarName) + +Return value corresponding to `vn` in `values` by also looking +in the the actual values of the dict. +""" +function nested_getindex(values::AbstractDict, vn::VarName) + maybeval = get(values, vn, nothing) + if maybeval !== nothing + return maybeval + end + + # Split the lens into the key / `parent` and the extraction lens / `child`. + parent, child, issuccess = splitlens(getlens(vn)) do lens + l = lens === nothing ? Setfield.IdentityLens() : lens + haskey(values, VarName(vn, l)) + end + # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. + keylens = parent === nothing ? Setfield.IdentityLens() : parent + + # If we found a valid split, then we can extract the value. + if !issuccess + # At this point we just throw an error since the key could not be found. + throw(KeyError(vn)) + end + + # TODO: Should we also check that we `canview` the extracted `value` + # rather than just let it fail upon `get` call? + value = values[VarName(vn, keylens)] + return get(value, child) +end + """ float_type_with_fallback(x) From e23763b1450ddc27feece43dfcd6259963318f70 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 9 Sep 2022 19:32:32 +0100 Subject: [PATCH 191/221] minor version bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 46c2cb600..3ce653baa 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.20.2" +version = "0.21.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From 43b034d8ce1474dc5d4f764913ccf7e29d6d2fb9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 6 Oct 2022 23:16:26 +0100 Subject: [PATCH 192/221] removed unnecessary and confusing constructor --- src/simple_varinfo.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index f7b741836..3cadc0da8 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -248,8 +248,6 @@ function SimpleVarInfo{T}( return SimpleVarInfo(values, convert(T, getlogp(vi))) end -SimpleVarInfo(svi::SimpleVarInfo, spl, x::AbstractVector) = unflatten(svi, x) - unflatten(svi::SimpleVarInfo, spl, x::AbstractVector) = unflatten(svi, x) function unflatten(svi::SimpleVarInfo, x::AbstractVector) return Setfield.@set svi.values = unflatten(svi.values, x) From 5c7df848c57cc0efcd7ab10f3f278f193fd35b9a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 19 Oct 2022 13:33:18 +0100 Subject: [PATCH 193/221] added TODO comment to deprecate --- src/varinfo.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/varinfo.jl b/src/varinfo.jl index c2ef3a022..4c93bdf03 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -159,6 +159,8 @@ end VarInfo(model::Model, args...) = VarInfo(Random.GLOBAL_RNG, model, args...) unflatten(vi::VarInfo, x::AbstractVector) = unflatten(vi, SampleFromPrior(), x) + +# TODO: deprecate. unflatten(vi::VarInfo, spl, x::AbstractVector) = VarInfo(vi, spl, x) # without AbstractSampler From 5c7163dca0e4accbdcd4541fbccce33ee7dfff65 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 24 Oct 2022 16:44:28 +0100 Subject: [PATCH 194/221] renamed LazyTransformation to DynamicTransformation --- src/simple_varinfo.jl | 2 +- src/transforming.jl | 29 +++++++++++++++-------------- src/varinfo.jl | 6 +++--- 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 3cadc0da8..9b592a706 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -526,7 +526,7 @@ end # NOTE: We don't implement `settrans!!(vi, trans, vn)`. function settrans!!(vi::SimpleVarInfo, trans) - return settrans!!(vi, trans ? LazyTransformation() : NoTransformation()) + return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation()) end function settrans!!(vi::SimpleVarInfo, transformation::AbstractTransformation) return Setfield.@set vi.transformation = transformation diff --git a/src/transforming.jl b/src/transforming.jl index d1bcaff8e..150a64f25 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -1,8 +1,8 @@ -struct LazyTransformationContext{isinverse} <: AbstractContext end -NodeTrait(::LazyTransformationContext) = IsLeaf() +struct DynamicTransformationContext{isinverse} <: AbstractContext end +NodeTrait(::DynamicTransformationContext) = IsLeaf() function tilde_assume( - ::LazyTransformationContext{isinverse}, right, vn, vi + ::DynamicTransformationContext{isinverse}, right, vn, vi ) where {isinverse} r = vi[vn, right] lp = Bijectors.logpdf_with_trans(right, r, !isinverse) @@ -20,7 +20,7 @@ function tilde_assume( end function dot_tilde_assume( - ::LazyTransformationContext{isinverse}, + ::DynamicTransformationContext{isinverse}, dist::Distribution, var::AbstractArray, vns::AbstractArray{<:VarName}, @@ -30,7 +30,7 @@ function dot_tilde_assume( b = bijector(dist) is_trans_uniques = unique(istrans.((vi,), vns)) - @assert length(is_trans_uniques) == 1 "LazyTransformationContext only supports transforming all variables" + @assert length(is_trans_uniques) == 1 "DynamicTransformationContext only supports transforming all variables" is_trans = first(is_trans_uniques) if is_trans @assert isinverse "Trying to link already transformed variables" @@ -46,7 +46,7 @@ function dot_tilde_assume( end function dot_tilde_assume( - ::LazyTransformationContext{isinverse}, + ::DynamicTransformationContext{isinverse}, dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, @@ -62,7 +62,7 @@ function dot_tilde_assume( # Transform _all_ values. is_trans_uniques = unique(istrans.((vi,), vns)) - @assert length(is_trans_uniques) == 1 "LazyTransformationContext only supports transforming all variables" + @assert length(is_trans_uniques) == 1 "DynamicTransformationContext only supports transforming all variables" is_trans = first(is_trans_uniques) if is_trans @assert isinverse "Trying to link already transformed variables" @@ -94,11 +94,11 @@ function link!!( return link!!(t, vi.varinfo, spl, model) end function link!!( - t::LazyTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model + t::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model ) - return settrans!!(last(evaluate!!(model, vi, LazyTransformationContext{false}())), t) + return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end -function link!!(t::LazyTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) +function link!!(t::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) # Call `_link!` instead of `link!` to avoid deprecation warning. _link!(vi, spl) return vi @@ -111,7 +111,7 @@ Return a possibly invlinked version of `vi`. This will be called prior to `model` evaluation, allowing one to perform a single `invlink!!` _before_ evaluation rather lazyily evaluate the transforms on as-we-need -basis as is done with [`LazyTransformation` ](@ref). +basis as is done with [`DynamicTransformation` ](@ref). # Examples ```julia-repl @@ -194,13 +194,14 @@ function invlink!!( return invlink!!(t, vi.varinfo, spl, model) end function invlink!!( - ::LazyTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model + ::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model ) return settrans!!( - last(evaluate!!(model, vi, LazyTransformationContext{true}())), NoTransformation() + last(evaluate!!(model, vi, DynamicTransformationContext{true}())), + NoTransformation(), ) end -function invlink!!(::LazyTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) +function invlink!!(::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. _invlink!(vi, spl) return vi diff --git a/src/varinfo.jl b/src/varinfo.jl index 4c93bdf03..1eca1bc81 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -5,7 +5,7 @@ abstract type AbstractTransformation end struct NoTransformation <: AbstractTransformation end -struct LazyTransformation <: AbstractTransformation end +struct DynamicTransformation <: AbstractTransformation end struct StaticTransformation{F} <: AbstractTransformation bijector::F @@ -17,7 +17,7 @@ end Return the `AbstractTransformation` currently related to `model` and, potentially, `vi`. """ default_transformation(model::Model, ::AbstractVarInfo) = default_transformation(model) -default_transformation(::Model) = LazyTransformation() +default_transformation(::Model) = DynamicTransformation() """ transformation(vi::AbstractVarInfo) @@ -131,7 +131,7 @@ const TypedVarInfo = VarInfo{<:NamedTuple} # NOTE: This is kind of weird, but it effectively preserves the "old" # behavior where we're allowed to call `link!` on the same `VarInfo` # multiple times. -transformation(vi::VarInfo) = LazyTransformation() +transformation(vi::VarInfo) = DynamicTransformation() function VarInfo(old_vi::UntypedVarInfo, spl, x::AbstractVector) new_vi = deepcopy(old_vi) From 5f21dd78b63ab58f5bd1d53367ace6c869c8c32e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 24 Oct 2022 17:18:24 +0100 Subject: [PATCH 195/221] updated docs and docstrings --- docs/make.jl | 2 +- docs/src/api.md | 1 + src/simple_varinfo.jl | 7 ------- src/varinfo.jl | 41 +++++++++++++++++++++++++++++++++++++---- 4 files changed, 39 insertions(+), 12 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 000b8dbae..681d70ccd 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -3,7 +3,7 @@ using DynamicPPL using DynamicPPL: AbstractPPL # Doctest setup -DocMeta.setdocmeta!(DynamicPPL, :DocTestSetup, :(using DynamicPPL); recursive=true) +DocMeta.setdocmeta!(DynamicPPL, :DocTestSetup, :(using DynamicPPL, Distributions); recursive=true) makedocs(; sitename="DynamicPPL", diff --git a/docs/src/api.md b/docs/src/api.md index f7e5443fc..fa38f881a 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -163,6 +163,7 @@ resetlogp!! ```@docs getindex +DynamicPPL.getindex_raw push!! empty!! ``` diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 9b592a706..442aeec96 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -541,13 +541,6 @@ istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinf islinked(vi::SimpleVarInfo, ::Union{Sampler,SampleFromPrior}) = istrans(vi) -""" - values_as(varinfo[, Type]) - -Return the values/realizations in `varinfo` as `Type`, if implemented. - -If no `Type` is provided, return values as stored in `varinfo`. -""" values_as(vi::SimpleVarInfo) = vi.values values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values function values_as(vi::SimpleVarInfo{<:Any,T}, ::Type{Vector}) where {T} diff --git a/src/varinfo.jl b/src/varinfo.jl index 1eca1bc81..4c8cadf14 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -988,14 +988,13 @@ end # The default getindex & setindex!() for get & set values # NOTE: vi[vn] will always transform the variable to its original space and Julia type """ - getindex(vi::VarInfo, vn::VarName) - getindex(vi::VarInfo, vns::Vector{<:VarName}) + getindex(vi::AbstractVarInfo, vn::VarName[, dist::Distribution]) + getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}[, dist::Distribution]) Return the current value(s) of `vn` (`vns`) in `vi` in the support of its (their) distribution(s). -If the value(s) is (are) transformed to the Euclidean space, it is -(they are) transformed back. +If `dist` is specified, the value(s) will be reshaped accordingly. """ getindex(vi::AbstractVarInfo, vn::VarName) = getindex(vi, vn, getdist(vi, vn)) function getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution) @@ -1019,6 +1018,20 @@ function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distributio return reconstruct(dist, vals_linked, length(vns)) end +""" + getindex_raw(vi::AbstractVarInfo, vn::VarName[, dist::Distribution]) + getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}[, dist::Distribution]) + +Return the current value(s) of `vn` (`vns`) in `vi`. + +If `dist` is specified, the value(s) will be reshaped accordingly. + +The difference between `getindex(vi, vn, dist)` and `getindex_raw` is that +`getindex` will also transform the value(s) to the support of the distribution(s). +This is _not_ the case for `getindex_raw`. + +See also: [`getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution)`](@ref) +""" getindex_raw(vi::AbstractVarInfo, vn::VarName) = getindex_raw(vi, vn, getdist(vi, vn)) function getindex_raw(vi::AbstractVarInfo, vn::VarName, dist::Distribution) return reconstruct(dist, getval(vi, vn)) @@ -1633,6 +1646,11 @@ julia> values_as(SimpleVarInfo(data), OrderedDict) OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Any} with 2 entries: x => 1.0 m => [2.0] + +julia> values_as(SimpleVarInfo(data), Vector) +2-element Vector{Float64}: + 1.0 + 2.0 ``` `SimpleVarInfo` with `OrderedDict`: @@ -1652,6 +1670,11 @@ julia> values_as(SimpleVarInfo(data), OrderedDict) OrderedDict{Any, Any} with 2 entries: x => 1.0 m => [2.0] + +julia> values_as(SimpleVarInfo(data), Vector) +2-element Vector{Float64}: + 1.0 + 2.0 ``` `TypedVarInfo`: @@ -1673,6 +1696,11 @@ julia> values_as(vi, OrderedDict) OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entries: s => 1.0 m => 2.0 + +julia> values_as(vi, Vector) +2-element Vector{Float64}: + 1.0 + 2.0 ``` `UntypedVarInfo`: @@ -1694,6 +1722,11 @@ julia> values_as(vi, OrderedDict) OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entries: s => 1.0 m => 2.0 + +julia> values_as(vi, Vector) +2-element Vector{Real}: + 1.0 + 2.0 ``` """ values_as(vi::VarInfo) = vi.metadata From e04467685e671d782f48ceed9eebcc6b310dc7aa Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Tue, 25 Oct 2022 13:51:23 +0100 Subject: [PATCH 196/221] Update docs/make.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/make.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/make.jl b/docs/make.jl index 681d70ccd..6b88c18cd 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -3,7 +3,9 @@ using DynamicPPL using DynamicPPL: AbstractPPL # Doctest setup -DocMeta.setdocmeta!(DynamicPPL, :DocTestSetup, :(using DynamicPPL, Distributions); recursive=true) +DocMeta.setdocmeta!( + DynamicPPL, :DocTestSetup, :(using DynamicPPL, Distributions); recursive=true +) makedocs(; sitename="DynamicPPL", From effec2b80829219c294d1cca8ed27894f75557b8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 25 Oct 2022 14:47:38 +0100 Subject: [PATCH 197/221] increased tolerance in one of the tests --- test/serialization.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/serialization.jl b/test/serialization.jl index 7ea81e410..a2d9abb36 100644 --- a/test/serialization.jl +++ b/test/serialization.jl @@ -11,7 +11,7 @@ samples_m = last.(samples) @test mean(samples_s) ≈ 3 atol = 0.2 - @test mean(samples_m) ≈ 0 atol = 0.1 + @test mean(samples_m) ≈ 0 atol = 0.15 end @testset "pmap" begin # Add worker processes. From 5fabd07ca35bd85283965eaa2bb61c2ff84c7ce4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 25 Oct 2022 15:03:44 +0100 Subject: [PATCH 198/221] increase tolerance of tests --- test/sampler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/sampler.jl b/test/sampler.jl index 959ec3ccd..ba1a8a600 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -15,7 +15,7 @@ @test length(chains) == N # Expected value of ``X`` where ``X ~ N(2, ...)`` is 2. - @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.1 + @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.15 # Expected value of ``X`` where ``X ~ IG(2, 3)`` is 3. @test mean(vi[@varname(s)] for vi in chains) ≈ 3 atol = 0.2 From 2e0fe49e97bbdc93fbe150e608b516979a26ead7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 25 Oct 2022 15:38:15 +0100 Subject: [PATCH 199/221] updated docs for varinfos, in particular the shared interface --- docs/src/api.md | 23 +++++++++++++++++++++-- src/transforming.jl | 21 +++++++++++++++++++++ src/varinfo.jl | 33 +++++++++++++++------------------ 3 files changed, 57 insertions(+), 20 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index fa38f881a..3c0830e3e 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -154,6 +154,8 @@ AbstractVarInfo ### Common API +#### Accumulation of log-probabilities + ```@docs getlogp setlogp!! @@ -161,6 +163,8 @@ acclogp!! resetlogp!! ``` +#### Variables and their realizations + ```@docs getindex DynamicPPL.getindex_raw @@ -172,6 +176,23 @@ empty!! values_as ``` +#### Transformations + +```@docs +DynamicPPL.istrans +DynamicPPL.settrans!! +DynamicPPL.transformation +DynamicPPL.link!! +DynamicPPL.invlink!! +DynamicPPL.default_transformation +``` + +#### Utils + +```@docs +DynamicPPL.tonamedtuple +``` + #### `SimpleVarInfo` ```@docs @@ -190,10 +211,8 @@ TypedVarInfo One main characteristic of [`VarInfo`](@ref) is that samples are stored in a linearized form. ```@docs -tonamedtuple link! invlink! -istrans ``` ```@docs diff --git a/src/transforming.jl b/src/transforming.jl index 150a64f25..50b3792cb 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -80,6 +80,16 @@ function dot_tilde_assume( return r, lp, vi end +""" + link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) + link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) + +Transforms the variables in `vi` to their linked space, using the transformatio `t`. + +If `t` is not provided, `default_transformation(model, vi)` will be used. + +See also: [`default_transformation`](@ref), [`invlink!!`](@ref). +""" link!!(vi::AbstractVarInfo, model::Model) = link!!(vi, SampleFromPrior(), model) function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) return link!!(t, vi, SampleFromPrior(), model) @@ -180,6 +190,17 @@ function _default_sampler(::IsParent, context::AbstractContext) return _default_sampler(childcontext(context)) end +""" + invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) + invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) + +Transform the variables in `vi` to their constrained space, using the (inverse of) +transformation `t`. + +If `t` is not provided, `default_transformation(model, vi)` will be used. + +See also: [`default_transformation`](@ref), [`link!!`](@ref). +""" invlink!!(vi::AbstractVarInfo, model::Model) = invlink!!(vi, SampleFromPrior(), model) function invlink!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) return invlink!!(t, vi, SampleFromPrior(), model) diff --git a/src/varinfo.jl b/src/varinfo.jl index 4c8cadf14..169f60505 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -392,11 +392,13 @@ Return the set of sampler selectors associated with `vn` in `vi`. getgid(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).gids[getidx(vi, vn)] """ - settrans!!(vi::VarInfo, trans::Bool, vn::VarName) + settrans!!(vi::VarInfo, trans::Bool[, vn::VarName]) -Set the `trans` flag value of `vn` in `vi`. +Return `vi` with `istrans(vi, vn)` evaluating to `true`. + +If `vn` is not specified, then `istrans(vi)` evaluates to `true` for all variables. """ -function settrans!!(vi::AbstractVarInfo, trans::Bool, vn::VarName) +function settrans!!(vi::VarInfo, trans::Bool, vn::VarName) if trans set_flag!(vi, vn, "trans") else @@ -406,11 +408,6 @@ function settrans!!(vi::AbstractVarInfo, trans::Bool, vn::VarName) return vi end -""" - settrans!!(vi::AbstractVarInfo, trans) - -Return new instance of `vi` but with `istrans(vi, trans)` now evaluating to `true`. -""" function settrans!!(vi::VarInfo, trans::Bool) for vn in keys(vi) settrans!!(vi, trans, vn) @@ -692,19 +689,19 @@ function setgid!(vi::VarInfo, gid::Selector, vn::VarName) end """ - istrans(vi::AbstractVarInfo) + istrans(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}]) Return `true` if `vi` is working in unconstrained space, and `false` if `vi` is assuming realizations to be in support of the corresponding distributions. -""" -istrans(vi::AbstractVarInfo) = false # `VarInfo` works in constrained space by default. -""" - istrans(vi::VarInfo, vn::VarName) +If `vns` is provided, then only check if this/these varname(s) are transformed. -Return true if `vn`'s values in `vi` are transformed to Euclidean space, and false if -they are in the support of `vn`'s distribution. +!!! warning + Not all implementations of `AbstractVarInfo` support transforming only a subset of + the variables. """ +istrans(vi::AbstractVarInfo) = false # `VarInfo` works in constrained space by default. +# TODO: Should this be restricted to `vi::VarInfo`? Requires explicit impl. for `ThreadSafeVarInfo`. istrans(vi::AbstractVarInfo, vn::VarName) = is_flagged(vi, vn, "trans") function istrans(vi::AbstractVarInfo, vns::AbstractVector{<:VarName}) return all(Base.Fix1(istrans, vi), vns) @@ -719,7 +716,7 @@ Return the log of the joint probability of the observed data and parameters samp getlogp(vi::AbstractVarInfo) = vi.logp[] """ - setlogp!!(vi::VarInfo, logp) + setlogp!!(vi::AbstractVarInfo, logp) Set the log of the joint probability of the observed data and parameters sampled in `vi` to `logp`, mutating if it makes sense. @@ -735,7 +732,7 @@ end Add `logp` to the value of the log of the joint probability of the observed data and parameters sampled in `vi`, mutating if it makes sense. """ -function acclogp!!(vi::VarInfo, logp) +function acclogp!!(vi::AbstractVarInfo, logp) vi.logp[] += logp return vi end @@ -1118,7 +1115,7 @@ end end """ - tonamedtuple(vi::VarInfo) + tonamedtuple(vi::AbstractVarInfo) Convert a `vi` into a `NamedTuple` where each variable symbol maps to the values and indexing string of the variable. From 965fcf5e229dbe585b5a1a126ef005dea6dacac8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Oct 2022 17:27:31 +0100 Subject: [PATCH 200/221] big refactoring of the varinfo related implementations and docs --- docs/src/api.md | 10 + src/DynamicPPL.jl | 12 +- src/abstract_varinfo.jl | 536 ++++++++++++++++++++++++++++++++++++++++ src/simple_varinfo.jl | 6 +- src/threadsafe.jl | 44 ++-- src/transforming.jl | 164 ------------ src/varinfo.jl | 380 +++++----------------------- 7 files changed, 640 insertions(+), 512 deletions(-) create mode 100644 src/abstract_varinfo.jl diff --git a/docs/src/api.md b/docs/src/api.md index 3c0830e3e..87c604ffb 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -166,10 +166,12 @@ resetlogp!! #### Variables and their realizations ```@docs +keys getindex DynamicPPL.getindex_raw push!! empty!! +isempty ``` ```@docs @@ -178,6 +180,13 @@ values_as #### Transformations +```@docs +DynamicPPL.AbstractTransformation +DynamicPPL.NoTransformation +DynamicPPL.DynamicTransformation +DynamicPPL.StaticTransformation +``` + ```@docs DynamicPPL.istrans DynamicPPL.settrans!! @@ -185,6 +194,7 @@ DynamicPPL.transformation DynamicPPL.link!! DynamicPPL.invlink!! DynamicPPL.default_transformation +DynamicPPL.maybe_invlink_before_eval!! ``` #### Utils diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 36b8ca4ce..443bdd797 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -127,7 +127,6 @@ export loglikelihood # Used here and overloaded in Turing function getspace end -# Necessary forward declarations """ AbstractVarInfo @@ -135,10 +134,16 @@ Abstract supertype for data structures that capture random variables when execut probabilistic model and accumulate log densities such as the log likelihood or the log joint probability of the model. -See also: [`VarInfo`](@ref) +See also: [`VarInfo`](@ref), [`SimpleVarInfo`](@ref). """ abstract type AbstractVarInfo <: AbstractModelTrace end +const LEGACY_WARNING = """ +!!! warning + This method is considered legacy, and is likely to be deprecated in the future. +""" + +# Necessary forward declarations include("utils.jl") include("selector.jl") include("model.jl") @@ -146,8 +151,9 @@ include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") -include("varinfo.jl") +include("abstract_varinfo.jl") include("threadsafe.jl") +include("varinfo.jl") include("simple_varinfo.jl") include("context_implementations.jl") include("compiler.jl") diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl new file mode 100644 index 000000000..7503123b2 --- /dev/null +++ b/src/abstract_varinfo.jl @@ -0,0 +1,536 @@ +# Transformation related. +""" + $(TYPEDEF) + +Represents a transformation to be used in `link!!` and `invlink!!`, amongst others. + +A concrete implementation of this should implement the following methods: +- [`link!!`](@ref): transforms the [`AbstractVarInfo`](@ref) to the unconstrained space. +- [`invlink!!`](@ref): transforms the [`AbstractVarInfo`](@ref) to the constrained space. + +And potentially: +- [`maybe_invlink_before_eval!!`](@ref): hook to decide whether to transform _before_ + evaluating the model. + +See also: [`link!!`](@ref), [`invlink!!`](@ref), [`maybe_invlink_before_eval!!`](@ref). +""" +abstract type AbstractTransformation end + +""" + $(TYPEDEF) + +Transformation which applies the identity function. +""" +struct NoTransformation <: AbstractTransformation end + +""" + $(TYPEDEF) + +Transformation which transforms the variables on a per-need-basis +in the execution of a given `Model`. + +This is in constrast to `StaticTransformation` which transforms all variables +_before_ the execution of a given `Model`. + +See also: [`StaticTransformation`](@ref). +""" +struct DynamicTransformation <: AbstractTransformation end + +""" + $(TYPEDEF) + +Transformation which transforms all variables _before_ the execution of a given `Model`. + +This is done through the `maybe_invlink_before_eval!!` method. + +See also: [`DynamicTransformation`](@ref), [`maybe_invlink_before_eval!!`](@ref). + +# Fields +$(TYPEDFIELDS) +""" +struct StaticTransformation{F} <: AbstractTransformation + "The function, assumed to implement the `Bijectors` interface, to be applied to the variables" + bijector::F +end + +""" + default_transformation(model::Model[, vi::AbstractVarInfo]) + +Return the `AbstractTransformation` currently related to `model` and, potentially, `vi`. +""" +default_transformation(model::Model, ::AbstractVarInfo) = default_transformation(model) +default_transformation(::Model) = DynamicTransformation() + +""" + transformation(vi::AbstractVarInfo) + +Return the `AbstractTransformation` related to `vi`. +""" +function transformation end + +# Accumulation of log-probabilities. +""" + getlogp(vi::AbstractVarInfo) + +Return the log of the joint probability of the observed data and parameters sampled in +`vi`. +""" +function getlogp end + +""" + setlogp!!(vi::AbstractVarInfo, logp) + +Set the log of the joint probability of the observed data and parameters sampled in +`vi` to `logp`, mutating if it makes sense. +""" +function setlogp!! end + +""" + acclogp!!(vi::AbstractVarInfo, logp) + +Add `logp` to the value of the log of the joint probability of the observed data and +parameters sampled in `vi`, mutating if it makes sense. +""" +function acclogp!! end + +""" + resetlogp!!(vi::AbstractVarInfo) + +Reset the value of the log of the joint probability of the observed data and parameters +sampled in `vi` to 0, mutating if it makes sense. +""" +resetlogp!!(vi::AbstractVarInfo) = setlogp!!(vi, zero(getlogp(vi))) + +# Variables and their realizations. +@doc """ + keys(vi::AbstractVarInfo) + +Return an iterator over all `vns` in `vi`. +""" Base.keys + +@doc """ + getindex(vi::AbstractVarInfo, vn::VarName[, dist::Distribution]) + getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}[, dist::Distribution]) + +Return the current value(s) of `vn` (`vns`) in `vi` in the support of its (their) +distribution(s). + +If `dist` is specified, the value(s) will be reshaped accordingly. + +See also: [`getindex_raw(vi::AbstractVarInfo, vn::VarName, dist::Distribution)`](@ref) +""" Base.getindex + +""" + getindex(vi::AbstractVarInfo, ::Colon) + getindex(vi::AbstractVarInfo, ::AbstractSampler) + +Return the current value(s) of `vn` (`vns`) in `vi` in the support of its (their) +distribution(s) as a flattened `Vector`. + +The default implementation is to call [`values_as`](@ref) with `Vector` as the type-argument. + +See also: [`getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution)`](@ref) +""" +Base.getindex(vi::AbstractVarInfo, ::Colon) = values_as(vi, Vector) +Base.getindex(vi::AbstractVarInfo, ::AbstractSampler) = vi[:] + +""" + getindex_raw(vi::AbstractVarInfo, vn::VarName[, dist::Distribution]) + getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}[, dist::Distribution]) + +Return the current value(s) of `vn` (`vns`) in `vi`. + +If `dist` is specified, the value(s) will be reshaped accordingly. + +See also: [`getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution)`](@ref) + +!!! note + The difference between `getindex(vi, vn, dist)` and `getindex_raw` is that + `getindex` will also transform the value(s) to the support of the distribution(s). + This is _not_ the case for `getindex_raw`. + +""" +function getindex_raw end + +""" + push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) + +Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to +the `VarInfo` `vi`, mutating if it makes sense. +""" +function BangBang.push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) + return BangBang.push!!(vi, vn, r, dist, Set{Selector}([])) +end + +""" + push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler) + +Push a new random variable `vn` with a sampled value `r` sampled with a sampler `spl` +from a distribution `dist` to `VarInfo` `vi`, if it makes sense. + +The sampler is passed here to invalidate its cache where defined. + +$(LEGACY_WARNING) +""" +function BangBang.push!!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler +) + return BangBang.push!!(vi, vn, r, dist, spl.selector) +end +function BangBang.push!!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler +) + return BangBang.push!!(vi, vn, r, dist) +end + +""" + push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) + +Push a new random variable `vn` with a sampled value `r` sampled with a sampler of +selector `gid` from a distribution `dist` to `VarInfo` `vi`. + +$(LEGACY_WARNING) +""" +function BangBang.push!!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector +) + return BangBang.push!!(vi, vn, r, dist, Set([gid])) +end + + +@doc """ + empty!!(vi::AbstractVarInfo) + +Empty the fields of `vi.metadata` and reset `vi.logp[]` and `vi.num_produce[]` to +zeros. + +This is useful when using a sampling algorithm that assumes an empty `vi`, e.g. `SMC`. +""" BangBang.empty!! + +@doc """ + isempty(vi::AbstractVarInfo) + +Return true if `vi` is empty and false otherwise. +""" Base.isempty + +""" + values_as(varinfo[, Type]) + +Return the values/realizations in `varinfo` as `Type`, if implemented. + +If no `Type` is provided, return values as stored in `varinfo`. + +# Examples + +`SimpleVarInfo` with `NamedTuple`: + +```jldoctest +julia> data = (x = 1.0, m = [2.0]); + +julia> values_as(SimpleVarInfo(data)) +(x = 1.0, m = [2.0]) + +julia> values_as(SimpleVarInfo(data), NamedTuple) +(x = 1.0, m = [2.0]) + +julia> values_as(SimpleVarInfo(data), OrderedDict) +OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Any} with 2 entries: + x => 1.0 + m => [2.0] + +julia> values_as(SimpleVarInfo(data), Vector) +2-element Vector{Float64}: + 1.0 + 2.0 +``` + +`SimpleVarInfo` with `OrderedDict`: + +```jldoctest +julia> data = OrderedDict{Any,Any}(@varname(x) => 1.0, @varname(m) => [2.0]); + +julia> values_as(SimpleVarInfo(data)) +OrderedDict{Any, Any} with 2 entries: + x => 1.0 + m => [2.0] + +julia> values_as(SimpleVarInfo(data), NamedTuple) +(x = 1.0, m = [2.0]) + +julia> values_as(SimpleVarInfo(data), OrderedDict) +OrderedDict{Any, Any} with 2 entries: + x => 1.0 + m => [2.0] + +julia> values_as(SimpleVarInfo(data), Vector) +2-element Vector{Float64}: + 1.0 + 2.0 +``` + +`TypedVarInfo`: + +```jldoctest +julia> # Just use an example model to construct the `VarInfo` because we're lazy. + vi = VarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); + +julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; + +julia> # For the sake of brevity, let's just check the type. + md = values_as(vi); md.s isa DynamicPPL.Metadata +true + +julia> values_as(vi, NamedTuple) +(s = 1.0, m = 2.0) + +julia> values_as(vi, OrderedDict) +OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entries: + s => 1.0 + m => 2.0 + +julia> values_as(vi, Vector) +2-element Vector{Float64}: + 1.0 + 2.0 +``` + +`UntypedVarInfo`: + +```jldoctest +julia> # Just use an example model to construct the `VarInfo` because we're lazy. + vi = VarInfo(); DynamicPPL.TestUtils.demo_assume_dot_observe()(vi); + +julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; + +julia> # For the sake of brevity, let's just check the type. + values_as(vi) isa DynamicPPL.Metadata +true + +julia> values_as(vi, NamedTuple) +(s = 1.0, m = 2.0) + +julia> values_as(vi, OrderedDict) +OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entries: + s => 1.0 + m => 2.0 + +julia> values_as(vi, Vector) +2-element Vector{Real}: + 1.0 + 2.0 +``` +""" +function values_as end + +""" + eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior} + +Determine the default `eltype` of the values returned by `vi[spl]`. + +!!! warning + This should generally not be called explicitly, as it's only used in + [`matchingvalue`](@ref) to determine the default type to use in place of + type-parameters passed to the model. + + This method is considered legacy, and is likely to be deprecated in the future. +""" +function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior}) + return eltype(Core.Compiler.return_type(getindex, Tuple{typeof(vi),typeof(spl)})) +end + +# Transformations +""" + istrans(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}]) + +Return `true` if `vi` is working in unconstrained space, and `false` +if `vi` is assuming realizations to be in support of the corresponding distributions. + +If `vns` is provided, then only check if this/these varname(s) are transformed. + +!!! warning + Not all implementations of `AbstractVarInfo` support transforming only a subset of + the variables. +""" +istrans(vi::AbstractVarInfo) = istrans(vi, collect(keys(vi))) +function istrans(vi::AbstractVarInfo, vns::AbstractVector{<:VarName}) + return all(Base.Fix1(istrans, vi), vns) +end + +""" + settrans!!(vi::AbstractVarInfo, trans::Bool[, vn::VarName]) + +Return `vi` with `istrans(vi, vn)` evaluating to `true`. + +If `vn` is not specified, then `istrans(vi)` evaluates to `true` for all variables. +""" +function settrans!! end + +""" + link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) + link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) + +Transforms the variables in `vi` to their linked space, using the transformation `t`. + +If `t` is not provided, `default_transformation(model, vi)` will be used. + +See also: [`default_transformation`](@ref), [`invlink!!`](@ref). +""" +link!!(vi::AbstractVarInfo, model::Model) = link!!(vi, SampleFromPrior(), model) +function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) + return link!!(t, vi, SampleFromPrior(), model) +end +function link!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) + # Use `default_transformation` to decide which transformation to use if none is specified. + return link!!(default_transformation(model, vi), vi, spl, model) +end + +""" + invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) + invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) + +Transform the variables in `vi` to their constrained space, using the (inverse of) +transformation `t`. + +If `t` is not provided, `default_transformation(model, vi)` will be used. + +See also: [`default_transformation`](@ref), [`link!!`](@ref). +""" +invlink!!(vi::AbstractVarInfo, model::Model) = invlink!!(vi, SampleFromPrior(), model) +function invlink!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) + return invlink!!(t, vi, SampleFromPrior(), model) +end +function invlink!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) + # Here we extract the `transformation` from `vi` rather than using the default one. + return invlink!!(transformation(vi), vi, spl, model) +end + +# Vector-based ones. +function link!!( + t::StaticTransformation{<:Bijectors.Bijector{1}}, + vi::AbstractVarInfo, + spl::AbstractSampler, + model::Model, +) + b = inverse(t.bijector) + x = vi[spl] + y, logjac = with_logabsdet_jacobian(b, x) + + lp_new = getlogp(vi) - logjac + vi_new = setlogp!!(unflatten(vi, spl, y), lp_new) + return settrans!!(vi_new, t) +end + +function invlink!!( + t::StaticTransformation{<:Bijectors.Bijector{1}}, + vi::AbstractVarInfo, + spl::AbstractSampler, + model::Model, +) + b = t.bijector + y = vi[spl] + x, logjac = with_logabsdet_jacobian(b, y) + + lp_new = getlogp(vi) - logjac + vi_new = setlogp!!(unflatten(vi, spl, x), lp_new) + return settrans!!(vi_new, NoTransformation()) +end + +""" + maybe_invlink_before_eval!!([t::Transformation,] vi, context, model) + +Return a possibly invlinked version of `vi`. + +This will be called prior to `model` evaluation, allowing one to perform a single +`invlink!!` _before_ evaluation rather than lazyily evaluating the transforms on as-we-need +basis as is done with [`DynamicTransformation`](@ref). + +See also: [`StaticTransformation`](@ref), [`DynamicTransformation`](@ref). + +# Examples +```julia-repl +julia> using DynamicPPL, Distributions, Bijectors + +julia> @model demo() = x ~ Normal() +demo (generic function with 2 methods) + +julia> # By subtyping `Bijector{1}`, we inherit the `(inv)link!!` defined for + # bijectors which acts on 1-dimensional arrays, i.e. vectors. + struct MyBijector <: Bijectors.Bijector{1} end + +julia> # Define some dummy `inverse` which will be used in the `link!!` call. + Bijectors.inverse(f::MyBijector) = identity + +julia> # We need to define `with_logabsdet_jacobian` for `MyBijector` + # (`identity` already has `with_logabsdet_jacobian` defined) + function Bijectors.with_logabsdet_jacobian(::MyBijector, x) + # Just using a large number of the logabsdet-jacobian term + # for demonstration purposes. + return (x, 1000) + end + +julia> # Change the `default_transformation` for our model to be a + # `StaticTransformation` using `MyBijector`. + function DynamicPPL.default_transformation(::Model{typeof(demo)}) + return DynamicPPL.StaticTransformation(MyBijector()) + end + +julia> model = demo(); + +julia> vi = SimpleVarInfo(x=1.0) +SimpleVarInfo((x = 1.0,), 0.0) + +julia> # Uses the `inverse` of `MyBijector`, which we have defined as `identity` + vi_linked = link!!(vi, model) +Transformed SimpleVarInfo((x = 1.0,), 0.0) + +julia> # Now performs a single `invlink!!` before model evaluation. + logjoint(model, vi_linked) +-1001.4189385332047 +``` +""" +function maybe_invlink_before_eval!!( + vi::AbstractVarInfo, context::AbstractContext, model::Model +) + return maybe_invlink_before_eval!!(transformation(vi), vi, context, model) +end +function maybe_invlink_before_eval!!( + t::AbstractTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model +) + # Default behavior is to _not_ transform. + return vi +end +function maybe_invlink_before_eval!!( + t::StaticTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model +) + return invlink!!(t, vi, _default_sampler(context), model) +end + +function _default_sampler(context::AbstractContext) + return _default_sampler(NodeTrait(_default_sampler, context), context) +end +_default_sampler(::IsLeaf, context::AbstractContext) = SampleFromPrior() +function _default_sampler(::IsParent, context::AbstractContext) + return _default_sampler(childcontext(context)) +end + +# Utilities +""" + tonamedtuple(vi::AbstractVarInfo) + +Convert a `vi` into a `NamedTuple` where each variable symbol maps to the values and +indexing string of the variable. + +For example, a model that had a vector of vector-valued +variables `x` would return + +```julia +(x = ([1.5, 2.0], [3.0, 1.0], ["x[1]", "x[2]"]), ) +``` +""" +function tonamedtuple end + + +# Legacy code that is currently overloaded for the sake of simplicity. +# TODO: Remove when possible. +increment_num_produce!(::AbstractVarInfo) = nothing +setgid!(vi::AbstractVarInfo, gid::Selector, vn::VarName) = nothing diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 442aeec96..a0766374d 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -131,7 +131,7 @@ julia> # (✓) Positive probability mass on negative numbers! getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) -1.3678794411714423 -julia> # While if we forget to make indicate that it's transformed: +julia> # While if we forget to indicate that it's transformed: vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false) SimpleVarInfo((x = -1.0,), 0.0) @@ -482,10 +482,6 @@ function dot_assume( return value, lp, vi end -# HACK: Allows us to re-use the implementation of `dot_tilde`, etc. for literals. -increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing -setgid!(vi::SimpleOrThreadSafeSimple, gid::Selector, vn::VarName) = nothing - # We need these to be compatible with how chains are constructed from `AbstractVarInfo` in Turing.jl. # TODO: Move away from using these `tonamedtuple` methods. function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:NamedTuple{names}}) where {names} diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 9060f5d5c..4ea67b31b 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -75,27 +75,32 @@ link!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = link!(vi.varinfo, spl) invlink!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = invlink!(vi.varinfo, spl) islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl) -getindex(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex(vi.varinfo, spl) -getindex(vi::ThreadSafeVarInfo, spl::SampleFromPrior) = getindex(vi.varinfo, spl) -getindex(vi::ThreadSafeVarInfo, spl::SampleFromUniform) = getindex(vi.varinfo, spl) +function link!!( + t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model +) + return link!!(t, vi.varinfo, spl, model) +end + +function invlink!!( + t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model +) + return invlink!!(t, vi.varinfo, spl, model) +end +# `getindex` +getindex(vi::ThreadSafeVarInfo, ::Colon) = getindex(vi.varinfo, Colon()) getindex(vi::ThreadSafeVarInfo, vn::VarName) = getindex(vi.varinfo, vn) function getindex(vi::ThreadSafeVarInfo, vn::VarName, dist::Distribution) return getindex(vi.varinfo, vn, dist) end -getindex(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) = getindex(vi.varinfo, vns) -function getindex(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}, dist::Distribution) - return getindex(vi.varinfo, vns, dist) -end +getindex(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex(vi.varinfo, spl) +getindex_raw(vi::ThreadSafeVarInfo, ::Colon) = getindex_raw(vi.varinfo, Colon()) getindex_raw(vi::ThreadSafeVarInfo, vn::VarName) = getindex_raw(vi.varinfo, vn) function getindex_raw(vi::ThreadSafeVarInfo, vn::VarName, dist::Distribution) return getindex_raw(vi.varinfo, vn, dist) end -getindex_raw(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) = getindex_raw(vi.varinfo, vns) -function getindex_raw(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}, dist::Distribution) - return getindex_raw(vi.varinfo, vns, dist) -end +getindex_raw(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex_raw(vi.varinfo, spl) function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::AbstractSampler) return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) @@ -123,11 +128,7 @@ function BangBang.empty!!(vi::ThreadSafeVarInfo) return resetlogp!!(Setfield.@set!(vi.varinfo = empty!!(vi.varinfo))) end -function BangBang.push!!( - vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} -) - return Setfield.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist, gidset) -end +values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T) function unset_flag!(vi::ThreadSafeVarInfo, vn::VarName, flag::String) return unset_flag!(vi.varinfo, vn, flag) @@ -137,3 +138,14 @@ function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String) end tonamedtuple(vi::ThreadSafeVarInfo) = tonamedtuple(vi.varinfo) + +# Transformations. +function settrans!!(vi::ThreadSafeVarInfo, trans::Bool, vn::VarName) + return Setfield.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn) +end +function settrans!!(vi::ThreadSafeVarInfo, spl::AbstractSampler, dist::Distribution) + return Setfield.@set vi.varinfo = settrans!!(vi.varinfo, spl, dist) +end + +istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn) +istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns) diff --git a/src/transforming.jl b/src/transforming.jl index 50b3792cb..f4b50b057 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -80,140 +80,12 @@ function dot_tilde_assume( return r, lp, vi end -""" - link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) - -Transforms the variables in `vi` to their linked space, using the transformatio `t`. - -If `t` is not provided, `default_transformation(model, vi)` will be used. - -See also: [`default_transformation`](@ref), [`invlink!!`](@ref). -""" -link!!(vi::AbstractVarInfo, model::Model) = link!!(vi, SampleFromPrior(), model) -function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - return link!!(t, vi, SampleFromPrior(), model) -end -function link!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) - # Use `default_transformation` to decide which transformation to use if none is specified. - return link!!(default_transformation(model, vi), vi, spl, model) -end -function link!!( - t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model -) - return link!!(t, vi.varinfo, spl, model) -end function link!!( t::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model ) return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end -function link!!(t::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) - # Call `_link!` instead of `link!` to avoid deprecation warning. - _link!(vi, spl) - return vi -end - -""" - maybe_invlink_before_eval!!([t::Transformation,] vi, context, model) - -Return a possibly invlinked version of `vi`. - -This will be called prior to `model` evaluation, allowing one to perform a single -`invlink!!` _before_ evaluation rather lazyily evaluate the transforms on as-we-need -basis as is done with [`DynamicTransformation` ](@ref). - -# Examples -```julia-repl -julia> using DynamicPPL, Distributions, Bijectors - -julia> @model demo() = x ~ Normal() -demo (generic function with 2 methods) - -julia> # By subtyping `Bijector{1}`, we inherit the `(inv)link!!` defined for - # bijectors which acts on 1-dimensional arrays, i.e. vectors. - struct MyBijector <: Bijectors.Bijector{1} end - -julia> # Define some dummy `inverse` which will be used in the `link!!` call. - Bijectors.inverse(f::MyBijector) = identity - -julia> # We need to define `with_logabsdet_jacobian` for `MyBijector` - # (`identity` already has `with_logabsdet_jacobian` defined) - function Bijectors.with_logabsdet_jacobian(::MyBijector, x) - # Just using a large number of the logabsdet-jacobian term - # for demonstration purposes. - return (x, 1000) - end - -julia> # Change the `default_transformation` for our model to be a - # `StaticTransformation` using `MyBijector`. - function DynamicPPL.default_transformation(::Model{typeof(demo)}) - return DynamicPPL.StaticTransformation(MyBijector()) - end - -julia> model = demo(); - -julia> vi = SimpleVarInfo(x=1.0) -SimpleVarInfo((x = 1.0,), 0.0) - -julia> # Uses the `inverse` of `MyBijector`, which we have defined as `identity` - vi_linked = link!!(vi, model) -Transformed SimpleVarInfo((x = 1.0,), 0.0) - -julia> # Now performs a single `invlink!!` before model evaluation. - logjoint(model, vi_linked) --1001.4189385332047 -``` -""" -function maybe_invlink_before_eval!!( - vi::AbstractVarInfo, context::AbstractContext, model::Model -) - return maybe_invlink_before_eval!!(transformation(vi), vi, context, model) -end -function maybe_invlink_before_eval!!( - t::AbstractTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model -) - # Default behavior is to _not_ transform. - return vi -end -function maybe_invlink_before_eval!!( - t::StaticTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model -) - return invlink!!(t, vi, _default_sampler(context), model) -end - -function _default_sampler(context::AbstractContext) - return _default_sampler(NodeTrait(_default_sampler, context), context) -end -_default_sampler(::IsLeaf, context::AbstractContext) = SampleFromPrior() -function _default_sampler(::IsParent, context::AbstractContext) - return _default_sampler(childcontext(context)) -end - -""" - invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) - -Transform the variables in `vi` to their constrained space, using the (inverse of) -transformation `t`. -If `t` is not provided, `default_transformation(model, vi)` will be used. - -See also: [`default_transformation`](@ref), [`link!!`](@ref). -""" -invlink!!(vi::AbstractVarInfo, model::Model) = invlink!!(vi, SampleFromPrior(), model) -function invlink!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - return invlink!!(t, vi, SampleFromPrior(), model) -end -function invlink!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) - # Here we extract the `transformation` from `vi` rather than using the default one. - return invlink!!(transformation(vi), vi, spl, model) -end -function invlink!!( - t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model -) - return invlink!!(t, vi.varinfo, spl, model) -end function invlink!!( ::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model ) @@ -222,39 +94,3 @@ function invlink!!( NoTransformation(), ) end -function invlink!!(::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) - # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. - _invlink!(vi, spl) - return vi -end - -# Vector-based ones. -function link!!( - t::StaticTransformation{<:Bijectors.Bijector{1}}, - vi::AbstractVarInfo, - spl::AbstractSampler, - model::Model, -) - b = inverse(t.bijector) - x = vi[spl] - y, logjac = with_logabsdet_jacobian(b, x) - - lp_new = getlogp(vi) - logjac - vi_new = setlogp!!(unflatten(vi, spl, y), lp_new) - return settrans!!(vi_new, t) -end - -function invlink!!( - t::StaticTransformation{<:Bijectors.Bijector{1}}, - vi::AbstractVarInfo, - spl::AbstractSampler, - model::Model, -) - b = t.bijector - y = vi[spl] - x, logjac = with_logabsdet_jacobian(b, y) - - lp_new = getlogp(vi) - logjac - vi_new = setlogp!!(unflatten(vi, spl, x), lp_new) - return settrans!!(vi_new, NoTransformation()) -end diff --git a/src/varinfo.jl b/src/varinfo.jl index 169f60505..a02a0d589 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -2,30 +2,6 @@ #### Types for typed and untyped VarInfo #### -abstract type AbstractTransformation end - -struct NoTransformation <: AbstractTransformation end -struct DynamicTransformation <: AbstractTransformation end - -struct StaticTransformation{F} <: AbstractTransformation - bijector::F -end - -""" - default_transformation(model::Model[, vi::AbstractVarInfo]) - -Return the `AbstractTransformation` currently related to `model` and, potentially, `vi`. -""" -default_transformation(model::Model, ::AbstractVarInfo) = default_transformation(model) -default_transformation(::Model) = DynamicTransformation() - -""" - transformation(vi::AbstractVarInfo) - -Return the `AbstractTransformation` related to `vi`. -""" -function transformation end - #################### # VarInfo metadata # #################### @@ -127,6 +103,7 @@ struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo end const UntypedVarInfo = VarInfo{<:Metadata} const TypedVarInfo = VarInfo{<:NamedTuple} +const MaybeThreadSafeVarInfo{Tmeta} = Union{VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}}} # NOTE: This is kind of weird, but it effectively preserves the "old" # behavior where we're allowed to call `link!` on the same `VarInfo` @@ -302,11 +279,11 @@ Return the index range of `vn` in the metadata of `vi`. getrange(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).ranges[getidx(vi, vn)] """ - getranges(vi::AbstractVarInfo, vns::Vector{<:VarName}) + getranges(vi::VarInfo, vns::Vector{<:VarName}) Return the indices of `vns` in the metadata of `vi` corresponding to `vn`. """ -function getranges(vi::AbstractVarInfo, vns::Vector{<:VarName}) +function getranges(vi::VarInfo, vns::Vector{<:VarName}) return mapreduce(vn -> getrange(vi, vn), vcat, vns; init=Int[]) end @@ -342,7 +319,7 @@ Return the value(s) of `vns`. The values may or may not be transformed to Euclidean space. """ -function getval(vi::AbstractVarInfo, vns::Vector{<:VarName}) +function getval(vi::VarInfo, vns::Vector{<:VarName}) return mapreduce(vn -> getval(vi, vn), vcat, vns) end @@ -391,13 +368,6 @@ Return the set of sampler selectors associated with `vn` in `vi`. """ getgid(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).gids[getidx(vi, vn)] -""" - settrans!!(vi::VarInfo, trans::Bool[, vn::VarName]) - -Return `vi` with `istrans(vi, vn)` evaluating to `true`. - -If `vn` is not specified, then `istrans(vi)` evaluates to `true` for all variables. -""" function settrans!!(vi::VarInfo, trans::Bool, vn::VarName) if trans set_flag!(vi, vn, "trans") @@ -444,7 +414,7 @@ end end # Get all indices of variables belonging to a given sampler -@inline function _getidcs(vi::AbstractVarInfo, spl::Sampler) +@inline function _getidcs(vi::VarInfo, spl::Sampler) # NOTE: 0b00 is the sanity flag for # |\____ getidcs (mask = 0b10) # \_____ getranges (mask = 0b01) @@ -494,8 +464,8 @@ end end # Get all vns of variables belonging to spl -_getvns(vi::AbstractVarInfo, spl::Sampler) = _getvns(vi, spl.selector, Val(getspace(spl))) -function _getvns(vi::AbstractVarInfo, spl::Union{SampleFromPrior,SampleFromUniform}) +_getvns(vi::VarInfo, spl::Sampler) = _getvns(vi, spl.selector, Val(getspace(spl))) +function _getvns(vi::VarInfo, spl::Union{SampleFromPrior,SampleFromUniform}) return _getvns(vi, Selector(), Val(())) end function _getvns(vi::UntypedVarInfo, s::Selector, space) @@ -515,7 +485,7 @@ end end # Get the index (in vals) ranges of all the vns of variables belonging to spl -@inline function _getranges(vi::AbstractVarInfo, spl::Sampler) +@inline function _getranges(vi::VarInfo, spl::Sampler) ## Uncomment the spl.info stuff when it is concretely typed, not Dict{Symbol, Any} #if ~haskey(spl.info, :cache_updated) spl.info[:cache_updated] = CACHERESET end #if haskey(spl.info, :ranges) && (spl.info[:cache_updated] & CACHERANGES) > 0 @@ -528,7 +498,7 @@ end #end end # Get the index (in vals) ranges of all the vns of variables belonging to selector `s` in `space` -@inline function _getranges(vi::AbstractVarInfo, s::Selector, space) +@inline function _getranges(vi::VarInfo, s::Selector, space) return _getranges(vi, _getidcs(vi, s, space)) end @inline function _getranges(vi::UntypedVarInfo, idcs::Vector{Int}) @@ -637,14 +607,6 @@ function TypedVarInfo(vi::UntypedVarInfo) end TypedVarInfo(vi::TypedVarInfo) = vi -""" - empty!!(vi::VarInfo) - -Empty the fields of `vi.metadata` and reset `vi.logp[]` and `vi.num_produce[]` to -zeros. - -This is useful when using a sampling algorithm that assumes an empty `vi`, e.g. `SMC`. -""" function BangBang.empty!!(vi::VarInfo) _empty!(vi.metadata) resetlogp!!(vi) @@ -661,11 +623,6 @@ end end # Functions defined only for UntypedVarInfo -""" - keys(vi::AbstractVarInfo) - -Return an iterator over all `vns` in `vi`. -""" Base.keys(vi::UntypedVarInfo) = keys(vi.metadata.idcs) @generated function Base.keys(vi::TypedVarInfo{<:NamedTuple{names}}) where {names} @@ -688,63 +645,20 @@ function setgid!(vi::VarInfo, gid::Selector, vn::VarName) return push!(getmetadata(vi, vn).gids[getidx(vi, vn)], gid) end -""" - istrans(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}]) - -Return `true` if `vi` is working in unconstrained space, and `false` -if `vi` is assuming realizations to be in support of the corresponding distributions. - -If `vns` is provided, then only check if this/these varname(s) are transformed. - -!!! warning - Not all implementations of `AbstractVarInfo` support transforming only a subset of - the variables. -""" -istrans(vi::AbstractVarInfo) = false # `VarInfo` works in constrained space by default. -# TODO: Should this be restricted to `vi::VarInfo`? Requires explicit impl. for `ThreadSafeVarInfo`. -istrans(vi::AbstractVarInfo, vn::VarName) = is_flagged(vi, vn, "trans") -function istrans(vi::AbstractVarInfo, vns::AbstractVector{<:VarName}) - return all(Base.Fix1(istrans, vi), vns) -end - -""" - getlogp(vi::VarInfo) +istrans(vi::VarInfo, vn::VarName) = is_flagged(vi, vn, "trans") -Return the log of the joint probability of the observed data and parameters sampled in -`vi`. -""" -getlogp(vi::AbstractVarInfo) = vi.logp[] +getlogp(vi::VarInfo) = vi.logp[] -""" - setlogp!!(vi::AbstractVarInfo, logp) - -Set the log of the joint probability of the observed data and parameters sampled in -`vi` to `logp`, mutating if it makes sense. -""" function setlogp!!(vi::VarInfo, logp) vi.logp[] = logp return vi end -""" - acclogp!!(vi::VarInfo, logp) - -Add `logp` to the value of the log of the joint probability of the observed data and -parameters sampled in `vi`, mutating if it makes sense. -""" -function acclogp!!(vi::AbstractVarInfo, logp) +function acclogp!!(vi::VarInfo, logp) vi.logp[] += logp return vi end -""" - resetlogp!!(vi::AbstractVarInfo) - -Reset the value of the log of the joint probability of the observed data and parameters -sampled in `vi` to 0, mutating if it makes sense. -""" -resetlogp!!(vi::AbstractVarInfo) = setlogp!!(vi, zero(getlogp(vi))) - """ get_num_produce(vi::VarInfo) @@ -767,18 +681,13 @@ Add 1 to `num_produce` in `vi`. increment_num_produce!(vi::VarInfo) = vi.num_produce[] += 1 """ - reset_num_produce!(vi::AbstractVarInfo) + reset_num_produce!(vi::VarInfo) Reset the value of `num_produce` the log of the joint probability of the observed data and parameters sampled in `vi` to 0. """ -reset_num_produce!(vi::AbstractVarInfo) = set_num_produce!(vi, 0) - -""" - isempty(vi::VarInfo) +reset_num_produce!(vi::VarInfo) = set_num_produce!(vi, 0) -Return true if `vi` is empty and false otherwise. -""" isempty(vi::UntypedVarInfo) = isempty(vi.metadata.idcs) isempty(vi::TypedVarInfo) = _isempty(vi.metadata) @generated function _isempty(metadata::NamedTuple{names}) where {names} @@ -790,6 +699,12 @@ isempty(vi::TypedVarInfo) = _isempty(vi.metadata) end # X -> R for all variables associated with given sampler +function link!!(t::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) + # Call `_link!` instead of `link!` to avoid deprecation warning. + _link!(vi, spl) + return vi +end + """ link!(vi::VarInfo, spl::Sampler) @@ -874,6 +789,12 @@ end end # R -> X for all variables associated with given sampler +function invlink!!(::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) + # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. + _invlink!(vi, spl) + return vi +end + """ invlink!(vi::VarInfo, spl::AbstractSampler) @@ -984,22 +905,13 @@ end # The default getindex & setindex!() for get & set values # NOTE: vi[vn] will always transform the variable to its original space and Julia type -""" - getindex(vi::AbstractVarInfo, vn::VarName[, dist::Distribution]) - getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}[, dist::Distribution]) - -Return the current value(s) of `vn` (`vns`) in `vi` in the support of its (their) -distribution(s). - -If `dist` is specified, the value(s) will be reshaped accordingly. -""" -getindex(vi::AbstractVarInfo, vn::VarName) = getindex(vi, vn, getdist(vi, vn)) -function getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution) +getindex(vi::VarInfo, vn::VarName) = getindex(vi, vn, getdist(vi, vn)) +function getindex(vi::VarInfo, vn::VarName, dist::Distribution) @assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" val = getindex_raw(vi, vn, dist) return maybe_invlink(vi, vn, dist, val) end -function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) +function getindex(vi::VarInfo, vns::Vector{<:VarName}) # FIXME(torfjelde): Using `getdist(vi, first(vns))` won't be correct in cases # such as `x .~ [Normal(), Exponential()]`. # BUT we also can't fix this here because this will lead to "incorrect" @@ -1007,7 +919,7 @@ function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) # where by "incorrect" we mean there exists pieces of code expecting this behavior. return getindex(vi, vns, getdist(vi, first(vns))) end -function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) +function getindex(vi::VarInfo, vns::Vector{<:VarName}, dist::Distribution) @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" vals_linked = mapreduce(vcat, vns) do vn getindex(vi, vn, dist) @@ -1015,28 +927,14 @@ function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distributio return reconstruct(dist, vals_linked, length(vns)) end -""" - getindex_raw(vi::AbstractVarInfo, vn::VarName[, dist::Distribution]) - getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}[, dist::Distribution]) - -Return the current value(s) of `vn` (`vns`) in `vi`. - -If `dist` is specified, the value(s) will be reshaped accordingly. - -The difference between `getindex(vi, vn, dist)` and `getindex_raw` is that -`getindex` will also transform the value(s) to the support of the distribution(s). -This is _not_ the case for `getindex_raw`. - -See also: [`getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution)`](@ref) -""" -getindex_raw(vi::AbstractVarInfo, vn::VarName) = getindex_raw(vi, vn, getdist(vi, vn)) -function getindex_raw(vi::AbstractVarInfo, vn::VarName, dist::Distribution) +getindex_raw(vi::VarInfo, vn::VarName) = getindex_raw(vi, vn, getdist(vi, vn)) +function getindex_raw(vi::VarInfo, vn::VarName, dist::Distribution) return reconstruct(dist, getval(vi, vn)) end -function getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}) +function getindex_raw(vi::VarInfo, vns::Vector{<:VarName}) return getindex_raw(vi, vns, getdist(vi, first(vns))) end -function getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) +function getindex_raw(vi::VarInfo, vns::Vector{<:VarName}, dist::Distribution) return reconstruct(dist, getval(vi, vns), length(vns)) end @@ -1047,8 +945,6 @@ Return the current value(s) of the random variables sampled by `spl` in `vi`. The value(s) may or may not be transformed to Euclidean space. """ -getindex(vi::AbstractVarInfo, ::Colon) = values_as(vi, Vector) -getindex(vi::AbstractVarInfo, ::AbstractSampler) = vi[:] getindex(vi::UntypedVarInfo, spl::Sampler) = copy(getval(vi, _getranges(vi, spl))) function getindex(vi::TypedVarInfo, spl::Sampler) # Gets the ranges as a NamedTuple @@ -1072,8 +968,8 @@ Set the current value(s) of the random variable `vn` in `vi` to `val`. The value(s) may or may not be transformed to Euclidean space. """ -setindex!(vi::AbstractVarInfo, val, vn::VarName) = (setval!(vi, val, vn); return vi) -function BangBang.setindex!!(vi::AbstractVarInfo, val, vn::VarName) +setindex!(vi::VarInfo, val, vn::VarName) = (setval!(vi, val, vn); return vi) +function BangBang.setindex!!(vi::VarInfo, val, vn::VarName) return (setindex!(vi, val, vn); return vi) end @@ -1084,7 +980,7 @@ Set the current value(s) of the random variables sampled by `spl` in `vi` to `va The value(s) may or may not be transformed to Euclidean space. """ -setindex!(vi::AbstractVarInfo, val, spl::SampleFromPrior) = setall!(vi, val) +setindex!(vi::VarInfo, val, spl::SampleFromPrior) = setall!(vi, val) setindex!(vi::UntypedVarInfo, val, spl::Sampler) = setval!(vi, val, _getranges(vi, spl)) function setindex!(vi::TypedVarInfo, val, spl::Sampler) # Gets a `NamedTuple` mapping each symbol to the indices in the symbol's `vals` field sampled from the sampler `spl` @@ -1093,7 +989,7 @@ function setindex!(vi::TypedVarInfo, val, spl::Sampler) return nothing end -function BangBang.setindex!!(vi::AbstractVarInfo, val, spl::AbstractSampler) +function BangBang.setindex!!(vi::VarInfo, val, spl::AbstractSampler) setindex!(vi, val, spl) return vi end @@ -1114,19 +1010,6 @@ end return expr end -""" - tonamedtuple(vi::AbstractVarInfo) - -Convert a `vi` into a `NamedTuple` where each variable symbol maps to the values and -indexing string of the variable. - -For example, a model that had a vector of vector-valued -variables `x` would return - -```julia -(x = ([1.5, 2.0], [3.0, 1.0], ["x[1]", "x[2]"]), ) -``` -""" function tonamedtuple(vi::VarInfo) return tonamedtuple(vi.metadata, vi) end @@ -1149,10 +1032,6 @@ end return map(vn -> vi[vn], f_vns) end -function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior}) - return eltype(Core.Compiler.return_type(getindex, Tuple{typeof(vi),typeof(spl)})) -end - """ haskey(vi::VarInfo, vn::VarName) @@ -1213,46 +1092,6 @@ function Base.show(io::IO, vi::UntypedVarInfo) return print(io, ")") end -""" - push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) - -Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to -the `VarInfo` `vi`, mutating if it makes sense. -""" -function BangBang.push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) - return BangBang.push!!(vi, vn, r, dist, Set{Selector}([])) -end - -""" - push!!(vi::VarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler) - -Push a new random variable `vn` with a sampled value `r` sampled with a sampler `spl` -from a distribution `dist` to `VarInfo` `vi`, if it makes sense. - -The sampler is passed here to invalidate its cache where defined. -""" -function BangBang.push!!( - vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler -) - return BangBang.push!!(vi, vn, r, dist, spl.selector) -end -function BangBang.push!!( - vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler -) - return BangBang.push!!(vi, vn, r, dist) -end - -""" - push!!(vi::VarInfo, vn::VarName, r, dist::Distribution, gid::Selector) - -Push a new random variable `vn` with a sampled value `r` sampled with a sampler of -selector `gid` from a distribution `dist` to `VarInfo` `vi`. -""" -function BangBang.push!!( - vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector -) - return BangBang.push!!(vi, vn, r, dist, Set([gid])) -end function BangBang.push!!( vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) @@ -1378,7 +1217,7 @@ end Set `vn`'s `gid` to `Set([spl.selector])`, if `vn` does not have a sampler selector linked and `vn`'s symbol is in the space of `spl`. """ -function updategid!(vi::AbstractVarInfo, vn::VarName, spl::Sampler) +function updategid!(vi::VarInfo, vn::VarName, spl::Sampler) if inspace(vn, getspace(spl)) setgid!(vi, spl.selector, vn) end @@ -1386,11 +1225,11 @@ end # TODO: Maybe rename or something? """ - _apply!(kernel!, vi::AbstractVarInfo, values, keys) + _apply!(kernel!, vi::VarInfo, values, keys) Calls `kernel!(vi, vn, values, keys)` for every `vn` in `vi`. """ -function _apply!(kernel!, vi::AbstractVarInfo, values, keys) +function _apply!(kernel!, vi::MaybeThreadSafeVarInfo, values, keys) keys_strings = map(string, collectmaybe(keys)) num_indices_seen = 0 @@ -1448,7 +1287,7 @@ end end end -function _find_missing_keys(vi::AbstractVarInfo, keys) +function _find_missing_keys(vi::MaybeThreadSafeVarInfo, keys) string_vns = map(string, collectmaybe(Base.keys(vi))) # If `key` isn't subsumed by any element of `string_vns`, it is not present in `vi`. missing_keys = filter(keys) do key @@ -1459,9 +1298,9 @@ function _find_missing_keys(vi::AbstractVarInfo, keys) end """ - setval!(vi::AbstractVarInfo, x) - setval!(vi::AbstractVarInfo, values, keys) - setval!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int) + setval!(vi::VarInfo, x) + setval!(vi::VarInfo, values, keys) + setval!(vi::VarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int) Set the values in `vi` to the provided values and leave those which are not present in `x` or `chains` unchanged. @@ -1515,15 +1354,15 @@ julia> var_info[@varname(x[1])] # [✓] unchanged -0.22312984965118443 ``` """ -setval!(vi::AbstractVarInfo, x) = setval!(vi, values(x), keys(x)) -setval!(vi::AbstractVarInfo, values, keys) = _apply!(_setval_kernel!, vi, values, keys) +setval!(vi::VarInfo, x) = setval!(vi, values(x), keys(x)) +setval!(vi::VarInfo, values, keys) = _apply!(_setval_kernel!, vi, values, keys) function setval!( - vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int + vi::VarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int ) return setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) end -function _setval_kernel!(vi::VarInfo, vn::VarName, values, keys) +function _setval_kernel!(vi::MaybeThreadSafeVarInfo, vn::VarName, values, keys) indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) if !isempty(indices) val = reduce(vcat, values[indices]) @@ -1535,9 +1374,9 @@ function _setval_kernel!(vi::VarInfo, vn::VarName, values, keys) end """ - setval_and_resample!(vi::AbstractVarInfo, x) - setval_and_resample!(vi::AbstractVarInfo, values, keys) - setval_and_resample!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx, chain_idx) + setval_and_resample!(vi::VarInfo, x) + setval_and_resample!(vi::VarInfo, values, keys) + setval_and_resample!(vi::VarInfo, chains::AbstractChains, sample_idx, chain_idx) Set the values in `vi` to the provided values and those which are not present in `x` or `chains` to *be* resampled. @@ -1592,19 +1431,19 @@ julia> var_info[@varname(x[1])] # [✓] changed ## See also - [`setval!`](@ref) """ -function setval_and_resample!(vi::AbstractVarInfo, x) +function setval_and_resample!(vi::MaybeThreadSafeVarInfo, x) return setval_and_resample!(vi, values(x), keys(x)) end -function setval_and_resample!(vi::AbstractVarInfo, values, keys) +function setval_and_resample!(vi::MaybeThreadSafeVarInfo, values, keys) return _apply!(_setval_and_resample_kernel!, vi, values, keys) end function setval_and_resample!( - vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int + vi::MaybeThreadSafeVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int ) return setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) end -function _setval_and_resample_kernel!(vi::VarInfo, vn::VarName, values, keys) +function _setval_and_resample_kernel!(vi::MaybeThreadSafeVarInfo, vn::VarName, values, keys) indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) if !isempty(indices) val = reduce(vcat, values[indices]) @@ -1619,113 +1458,6 @@ function _setval_and_resample_kernel!(vi::VarInfo, vn::VarName, values, keys) return indices end -""" - values_as(varinfo[, Type]) - -Return the values/realizations in `varinfo` as `Type`, if implemented. - -If no `Type` is provided, return values as stored in `varinfo`. - -# Examples - -`SimpleVarInfo` with `NamedTuple`: - -```jldoctest -julia> data = (x = 1.0, m = [2.0]); - -julia> values_as(SimpleVarInfo(data)) -(x = 1.0, m = [2.0]) - -julia> values_as(SimpleVarInfo(data), NamedTuple) -(x = 1.0, m = [2.0]) - -julia> values_as(SimpleVarInfo(data), OrderedDict) -OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Any} with 2 entries: - x => 1.0 - m => [2.0] - -julia> values_as(SimpleVarInfo(data), Vector) -2-element Vector{Float64}: - 1.0 - 2.0 -``` - -`SimpleVarInfo` with `OrderedDict`: - -```jldoctest -julia> data = OrderedDict{Any,Any}(@varname(x) => 1.0, @varname(m) => [2.0]); - -julia> values_as(SimpleVarInfo(data)) -OrderedDict{Any, Any} with 2 entries: - x => 1.0 - m => [2.0] - -julia> values_as(SimpleVarInfo(data), NamedTuple) -(x = 1.0, m = [2.0]) - -julia> values_as(SimpleVarInfo(data), OrderedDict) -OrderedDict{Any, Any} with 2 entries: - x => 1.0 - m => [2.0] - -julia> values_as(SimpleVarInfo(data), Vector) -2-element Vector{Float64}: - 1.0 - 2.0 -``` - -`TypedVarInfo`: - -```jldoctest -julia> # Just use an example model to construct the `VarInfo` because we're lazy. - vi = VarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); - -julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; - -julia> # For the sake of brevity, let's just check the type. - md = values_as(vi); md.s isa DynamicPPL.Metadata -true - -julia> values_as(vi, NamedTuple) -(s = 1.0, m = 2.0) - -julia> values_as(vi, OrderedDict) -OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entries: - s => 1.0 - m => 2.0 - -julia> values_as(vi, Vector) -2-element Vector{Float64}: - 1.0 - 2.0 -``` - -`UntypedVarInfo`: - -```jldoctest -julia> # Just use an example model to construct the `VarInfo` because we're lazy. - vi = VarInfo(); DynamicPPL.TestUtils.demo_assume_dot_observe()(vi); - -julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; - -julia> # For the sake of brevity, let's just check the type. - values_as(vi) isa DynamicPPL.Metadata -true - -julia> values_as(vi, NamedTuple) -(s = 1.0, m = 2.0) - -julia> values_as(vi, OrderedDict) -OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entries: - s => 1.0 - m => 2.0 - -julia> values_as(vi, Vector) -2-element Vector{Real}: - 1.0 - 2.0 -``` -""" values_as(vi::VarInfo) = vi.metadata values_as(vi::VarInfo, ::Type{Vector}) = copy(getall(vi)) function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) From 200a886ac1e7029f998702d618c5d5e2cd56e52c Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Thu, 27 Oct 2022 20:33:04 +0100 Subject: [PATCH 201/221] Update src/abstract_varinfo.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/abstract_varinfo.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 7503123b2..59bca291d 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -197,7 +197,6 @@ function BangBang.push!!( return BangBang.push!!(vi, vn, r, dist, Set([gid])) end - @doc """ empty!!(vi::AbstractVarInfo) From 5a0296ece86f2b46ae42509c934fe2804acbc3f3 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Thu, 27 Oct 2022 20:36:36 +0100 Subject: [PATCH 202/221] Update src/abstract_varinfo.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/abstract_varinfo.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 59bca291d..b0d3551ff 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -528,7 +528,6 @@ variables `x` would return """ function tonamedtuple end - # Legacy code that is currently overloaded for the sake of simplicity. # TODO: Remove when possible. increment_num_produce!(::AbstractVarInfo) = nothing From a2e332ec133f57e1093454b9b28a568743dc439d Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Thu, 27 Oct 2022 20:37:26 +0100 Subject: [PATCH 203/221] Update src/varinfo.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/varinfo.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index a02a0d589..a036f2670 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -103,7 +103,9 @@ struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo end const UntypedVarInfo = VarInfo{<:Metadata} const TypedVarInfo = VarInfo{<:NamedTuple} -const MaybeThreadSafeVarInfo{Tmeta} = Union{VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}}} +const MaybeThreadSafeVarInfo{Tmeta} = Union{ + VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} +} # NOTE: This is kind of weird, but it effectively preserves the "old" # behavior where we're allowed to call `link!` on the same `VarInfo` From d95063530081dc52074c07c30f000f0896f4b6ec Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Sat, 29 Oct 2022 11:26:51 +0100 Subject: [PATCH 204/221] Update varinfo.jl --- src/varinfo.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index a036f2670..49c78690d 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1358,9 +1358,7 @@ julia> var_info[@varname(x[1])] # [✓] unchanged """ setval!(vi::VarInfo, x) = setval!(vi, values(x), keys(x)) setval!(vi::VarInfo, values, keys) = _apply!(_setval_kernel!, vi, values, keys) -function setval!( - vi::VarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int -) +function setval!(vi::VarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int) return setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) end From e0907c16f2a031f21cad6736d19afd55a466f096 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 31 Oct 2022 20:23:24 +0000 Subject: [PATCH 205/221] fixed bugs with linking --- src/abstract_varinfo.jl | 4 +--- src/context_implementations.jl | 2 ++ src/utils.jl | 16 ++++++++++------ src/varinfo.jl | 14 ++++++++------ 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 7503123b2..651113ec0 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -197,7 +197,6 @@ function BangBang.push!!( return BangBang.push!!(vi, vn, r, dist, Set([gid])) end - @doc """ empty!!(vi::AbstractVarInfo) @@ -353,7 +352,7 @@ If `vns` is provided, then only check if this/these varname(s) are transformed. """ istrans(vi::AbstractVarInfo) = istrans(vi, collect(keys(vi))) function istrans(vi::AbstractVarInfo, vns::AbstractVector{<:VarName}) - return all(Base.Fix1(istrans, vi), vns) + return !isempty(vns) && all(Base.Fix1(istrans, vi), vns) end """ @@ -529,7 +528,6 @@ variables `x` would return """ function tonamedtuple end - # Legacy code that is currently overloaded for the sake of simplicity. # TODO: Remove when possible. increment_num_produce!(::AbstractVarInfo) = nothing diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 6271b5d8c..3e350d8c6 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -526,6 +526,8 @@ function get_and_set_val!( # we then broadcast. This will allocate a vector of `nothing` though. if istrans(vi) push!!.((vi,), vns, link.((vi,), vns, dists, r), dists, (spl,)) + # NOTE: Need to add the correction. + acclogp!!(vi, sum(logabsdetjac.(bijector.(dists), r))) # `push!!` sets the trans-flag to `false` by default. settrans!!.((vi,), true, vns) else diff --git a/src/utils.jl b/src/utils.jl index 1ce1330a7..b9056873a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -236,22 +236,26 @@ istransformable(::Transformable) = true # Single-sample initialisations # ################################# -inittrans(rng, dist::UnivariateDistribution) = invlink(dist, randrealuni(rng)) +inittrans(rng, dist::UnivariateDistribution) = Bijectors.invlink(dist, randrealuni(rng)) function inittrans(rng, dist::MultivariateDistribution) - return invlink(dist, randrealuni(rng, size(dist)[1])) + return Bijectors.invlink(dist, randrealuni(rng, size(dist)[1])) +end +function inittrans(rng, dist::MatrixDistribution) + return Bijectors.invlink(dist, randrealuni(rng, size(dist)...)) end -inittrans(rng, dist::MatrixDistribution) = invlink(dist, randrealuni(rng, size(dist)...)) ################################ # Multi-sample initialisations # ################################ -inittrans(rng, dist::UnivariateDistribution, n::Int) = invlink(dist, randrealuni(rng, n)) +function inittrans(rng, dist::UnivariateDistribution, n::Int) + return Bijectors.invlink(dist, randrealuni(rng, n)) +end function inittrans(rng, dist::MultivariateDistribution, n::Int) - return invlink(dist, randrealuni(rng, size(dist)[1], n)) + return Bijectors.invlink(dist, randrealuni(rng, size(dist)[1], n)) end function inittrans(rng, dist::MatrixDistribution, n::Int) - return invlink(dist, [randrealuni(rng, size(dist)...) for _ in 1:n]) + return Bijectors.invlink(dist, [randrealuni(rng, size(dist)...) for _ in 1:n]) end ####################### diff --git a/src/varinfo.jl b/src/varinfo.jl index a02a0d589..f662884bf 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -103,7 +103,9 @@ struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo end const UntypedVarInfo = VarInfo{<:Metadata} const TypedVarInfo = VarInfo{<:NamedTuple} -const MaybeThreadSafeVarInfo{Tmeta} = Union{VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}}} +const MaybeThreadSafeVarInfo{Tmeta} = Union{ + VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} +} # NOTE: This is kind of weird, but it effectively preserves the "old" # behavior where we're allowed to call `link!` on the same `VarInfo` @@ -874,8 +876,10 @@ end return expr end -maybe_link(vi, vn, dist, val) = istrans(vi, vn) ? Bijectors.link(dist, val) : val -maybe_invlink(vi, vn, dist, val) = istrans(vi, vn) ? Bijectors.invlink(dist, val) : val +link(vi, vn, dist, val) = Bijectors.link(dist, val) +invlink(vi, vn, dist, val) = Bijectors.invlink(dist, val) +maybe_link(vi, vn, dist, val) = istrans(vi, vn) ? link(vi, vn, dist, val) : val +maybe_invlink(vi, vn, dist, val) = istrans(vi, vn) ? invlink(vi, vn, dist, val) : val """ islinked(vi::VarInfo, spl::Union{Sampler, SampleFromPrior}) @@ -1356,9 +1360,7 @@ julia> var_info[@varname(x[1])] # [✓] unchanged """ setval!(vi::VarInfo, x) = setval!(vi, values(x), keys(x)) setval!(vi::VarInfo, values, keys) = _apply!(_setval_kernel!, vi, values, keys) -function setval!( - vi::VarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int -) +function setval!(vi::VarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int) return setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) end From e5d8984bdaf341954a5b64fb177ef218efdb350a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 31 Oct 2022 21:50:39 +0000 Subject: [PATCH 206/221] fixed threadsafevarinfo issues --- src/context_implementations.jl | 4 ++-- src/threadsafe.jl | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 3e350d8c6..6b538240e 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -204,14 +204,14 @@ function assume( sampler::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, - vi::AbstractVarInfo, + vi::MaybeThreadSafeVarInfo, ) if haskey(vi, vn) # Always overwrite the parameters with new ones for `SampleFromUniform`. if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") unset_flag!(vi, vn, "del") r = init(rng, dist, sampler) - vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r)) + BangBang.setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r)), vn) setorder!(vi, vn, get_num_produce(vi)) else # Otherwise we just extract it. diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 4ea67b31b..32a6df4af 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -55,6 +55,12 @@ function setlogp!!(vi::ThreadSafeVarInfoWithRef, logp) return ThreadSafeVarInfo(setlogp!!(vi.varinfo, logp), vi.logps) end +function BangBang.push!!( + vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} +) + return Setfield.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist, gidset) +end + get_num_produce(vi::ThreadSafeVarInfo) = get_num_produce(vi.varinfo) increment_num_produce!(vi::ThreadSafeVarInfo) = increment_num_produce!(vi.varinfo) reset_num_produce!(vi::ThreadSafeVarInfo) = reset_num_produce!(vi.varinfo) @@ -90,16 +96,28 @@ end # `getindex` getindex(vi::ThreadSafeVarInfo, ::Colon) = getindex(vi.varinfo, Colon()) getindex(vi::ThreadSafeVarInfo, vn::VarName) = getindex(vi.varinfo, vn) +getindex(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = getindex(vi.varinfo, vns) function getindex(vi::ThreadSafeVarInfo, vn::VarName, dist::Distribution) return getindex(vi.varinfo, vn, dist) end +function getindex(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}, dist::Distribution) + return getindex(vi.varinfo, vns, dist) +end getindex(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex(vi.varinfo, spl) getindex_raw(vi::ThreadSafeVarInfo, ::Colon) = getindex_raw(vi.varinfo, Colon()) getindex_raw(vi::ThreadSafeVarInfo, vn::VarName) = getindex_raw(vi.varinfo, vn) +function getindex_raw(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) + return getindex_raw(vi.varinfo, vns) +end function getindex_raw(vi::ThreadSafeVarInfo, vn::VarName, dist::Distribution) return getindex_raw(vi.varinfo, vn, dist) end +function getindex_raw( + vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}, dist::Distribution +) + return getindex_raw(vi.varinfo, vns, dist) +end getindex_raw(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex_raw(vi.varinfo, spl) function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::AbstractSampler) From 63b36386062af7f3012d53cd6fb1dcbd9692fc74 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Nov 2022 12:28:54 +0000 Subject: [PATCH 207/221] added tests for StaticBijector --- src/test_utils.jl | 33 +++++++++++++++++++++++++++ test/simple_varinfo.jl | 52 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) diff --git a/src/test_utils.jl b/src/test_utils.jl index bcc649675..9c9034ee5 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -644,6 +644,39 @@ const DEMO_MODELS = ( demo_dot_assume_matrix_dot_observe_matrix(), ) +# Model to test `StaticTransformation` with. +""" + demo_static_transformation() + +Simple model for which [`default_transformation`](@ref) returns a [`StaticTransformation`](@ref). +""" +@model function demo_static_transformation() + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + 1.5 ~ Normal(m, sqrt(s)) + 2.0 ~ Normal(m, sqrt(s)) + + return (; s, m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) +end + +function DynamicPPL.default_transformation(::Model{typeof(demo_static_transformation)}) + b = Bijectors.stack(Bijectors.Exp{0}(), Bijectors.Identity{0}()) + return DynamicPPL.StaticTransformation(b) +end + +posterior_mean(::Model{typeof(demo_static_transformation)}) = (s=49 / 24, m=7 / 6) +function logprior_true(::Model{typeof(demo_static_transformation)}, s, m) + return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) +end +function loglikelihood_true(::Model{typeof(demo_static_transformation)}, s, m) + return logpdf(Normal(m, sqrt(s)), 1.5) + logpdf(Normal(m, sqrt(s)), 2.0) +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_static_transformation)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end + """ marginal_mean_of_samples(chain, varname) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index d6c0b6d3c..a1bbfd503 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -218,4 +218,56 @@ @test lp ≈ lp_true end end + + @testset "Static transformation" begin + model = DynamicPPL.TestUtils.demo_static_transformation() + + varinfos = setup_varinfos( + model, rand(NamedTuple, model), [@varname(s), @varname(m)] + ) + @testset "$(short_varinfo_name(vi))" for vi in varinfos + # Initialize varinfo and link. + vi_linked = DynamicPPL.link!!(vi, model) + + # Make sure `maybe_invlink_before_eval!!` results in `invlink!!`. + @test !DynamicPPL.istrans( + DynamicPPL.maybe_invlink_before_eval!!( + deepcopy(vi), SamplingContext(), model + ), + ) + + # Resulting varinfo should no longer be transformed. + vi_result = last(DynamicPPL.evaluate!!(model, deepcopy(vi), SamplingContext())) + @test !DynamicPPL.istrans(vi_result) + + # Set the values to something that is out of domain if we're in constrained space. + for vn in keys(vi) + vi_linked = DynamicPPL.setindex!!(vi_linked, -rand(), vn) + end + + retval, vi_linked_result = DynamicPPL.evaluate!!( + model, deepcopy(vi_linked), DefaultContext() + ) + + @test DynamicPPL.getindex_raw(vi_linked, @varname(s)) ≠ retval.s # `s` is unconstrained in original + @test DynamicPPL.getindex_raw(vi_linked_result, @varname(s)) == retval.s # `s` is constrained in result + + # `m` should not be transformed. + @test vi_linked[@varname(m)] == retval.m + @test vi_linked_result[@varname(m)] == retval.m + + # Compare to truth. + retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + model, retval.s, retval.m + ) + + # Realizations in `vi_linked` should all be equal to the unconstrained realization. + @test DynamicPPL.getindex_raw(vi_linked, @varname(s)) ≈ retval_unconstrained.s + @test DynamicPPL.getindex_raw(vi_linked, @varname(m)) ≈ retval_unconstrained.m + + # The resulting varinfo should hold the correct logp. + lp = getlogp(vi_linked_result) + @test lp ≈ lp_true + end + end end From 3ffeef15042175c7a4e99eb0d1c45bca4dde1422 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Nov 2022 12:29:17 +0000 Subject: [PATCH 208/221] added impl of maybe_invlink_before_eval!! for VarInfo --- src/varinfo.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/varinfo.jl b/src/varinfo.jl index f662884bf..14f8c6fd9 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -388,6 +388,9 @@ function settrans!!(vi::VarInfo, trans::Bool) return vi end +settrans!!(vi::VarInfo, trans::NoTransformation) = settrans!!(vi, false) +settrans!!(vi::VarInfo, trans::AbstractTransformation) = settrans!!(vi, true) + """ syms(vi::VarInfo) @@ -797,6 +800,14 @@ function invlink!!(::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, m return vi end +function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, model::Model) + # Because `VarInfo` does not contain any information about what the transformation + # other than whether or not it has actually been transformed, the best we can do + # is just assume that `default_transformation` is the correct one if `istrans(vi)`. + t = istrans(vi) ? default_transformation(model, vi) : NoTransformation() + return maybe_invlink_before_eval!!(t, vi, context, model) +end + """ invlink!(vi::VarInfo, spl::AbstractSampler) From 8ed91ecb97246f5771d34d859bc10e15a8b73cea Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Nov 2022 12:29:51 +0000 Subject: [PATCH 209/221] fixed bug in invlink!! for StaticBijector --- src/abstract_varinfo.jl | 2 +- src/simple_varinfo.jl | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 651113ec0..38c25a8dc 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -429,7 +429,7 @@ function invlink!!( y = vi[spl] x, logjac = with_logabsdet_jacobian(b, y) - lp_new = getlogp(vi) - logjac + lp_new = getlogp(vi) + logjac vi_new = setlogp!!(unflatten(vi, spl, x), lp_new) return settrans!!(vi_new, NoTransformation()) end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index a0766374d..119015915 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -654,7 +654,7 @@ function link!!( model::Model, ) # TODO: Make sure that `spl` is respected. - b = t.bijector + b = inverse(t.bijector) x = vi.values y, logjac = with_logabsdet_jacobian(b, x) lp_new = getlogp(vi) - logjac @@ -670,10 +670,9 @@ function invlink!!( ) # TODO: Make sure that `spl` is respected. b = t.bijector - ib = inverse(b) y = vi.values - x, logjac = with_logabsdet_jacobian(ib, y) - lp_new = getlogp(vi) - logjac + x, logjac = with_logabsdet_jacobian(b, y) + lp_new = getlogp(vi) + logjac vi_new = setlogp!!(Setfield.@set(vi.values = x), lp_new) return settrans!!(vi_new, NoTransformation()) end From 8ddfb4c425cf23d700ad108adb7c62001605c365 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Nov 2022 12:30:05 +0000 Subject: [PATCH 210/221] added maybe_invlink_before_eval!! impl for ThreadSafeVarInfo --- src/abstract_varinfo.jl | 9 +++++++-- src/threadsafe.jl | 11 +++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 38c25a8dc..c9441ebb9 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -493,9 +493,14 @@ function maybe_invlink_before_eval!!( return maybe_invlink_before_eval!!(transformation(vi), vi, context, model) end function maybe_invlink_before_eval!!( - t::AbstractTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model + ::NoTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model ) - # Default behavior is to _not_ transform. + return vi +end +function maybe_invlink_before_eval!!( + ::DynamicTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model +) + # `DynamicTransformation` is meant to _not_ do the transformation statically, hence we do nothing. return vi end function maybe_invlink_before_eval!!( diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 32a6df4af..85ad0e23e 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -93,6 +93,17 @@ function invlink!!( return invlink!!(t, vi.varinfo, spl, model) end +function maybe_invlink_before_eval!!( + vi::ThreadSafeVarInfo, context::AbstractContext, model::Model +) + # Defer to the wrapped `AbstractVarInfo` object. + # NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the `getlogp(vi.varinfo)` + # hence the log-absdet-jacobian term will correctly be included in the `getlogp(vi)`. + return Setfield.@set vi.varinfo = maybe_invlink_before_eval!!( + vi.varinfo, context, model + ) +end + # `getindex` getindex(vi::ThreadSafeVarInfo, ::Colon) = getindex(vi.varinfo, Colon()) getindex(vi::ThreadSafeVarInfo, vn::VarName) = getindex(vi.varinfo, vn) From 3246cf477f689def60fadb5e8670d7171008ebd8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Nov 2022 12:41:41 +0000 Subject: [PATCH 211/221] fixed bug in doctests --- src/varinfo.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/varinfo.jl b/src/varinfo.jl index 14f8c6fd9..db8738e9e 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -630,6 +630,9 @@ end # Functions defined only for UntypedVarInfo Base.keys(vi::UntypedVarInfo) = keys(vi.metadata.idcs) +# HACK: Necessary to avoid returning `Any[]` which won't dispatch correctly +# on other methods in the codebase which requires `Vector{<:VarName}`. +Base.keys(vi::TypedVarInfo{<:NamedTuple{()}}) = VarName[] @generated function Base.keys(vi::TypedVarInfo{<:NamedTuple{names}}) where {names} expr = Expr(:call) push!(expr.args, :vcat) From 74b5d93cb8c3165afed37e914702a08ca9a34eb2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Nov 2022 11:32:33 +0000 Subject: [PATCH 212/221] relax constraint on istrans --- src/abstract_varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index c9441ebb9..526fce3cf 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -351,7 +351,7 @@ If `vns` is provided, then only check if this/these varname(s) are transformed. the variables. """ istrans(vi::AbstractVarInfo) = istrans(vi, collect(keys(vi))) -function istrans(vi::AbstractVarInfo, vns::AbstractVector{<:VarName}) +function istrans(vi::AbstractVarInfo, vns::AbstractVector) return !isempty(vns) && all(Base.Fix1(istrans, vi), vns) end From c6264e50932c4688573a1fa14bbddbd5ede3d6b8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Nov 2022 11:33:06 +0000 Subject: [PATCH 213/221] fixed unflatten for Dict to respect the original type --- src/utils.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index b9056873a..8f076efee 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -468,7 +468,8 @@ function unflatten(original::NamedTuple{names}, x::AbstractVector) where {names} return NamedTuple{names}(unflatten(values(original), x)) end function unflatten(original::AbstractDict, x::AbstractVector) - return Dict(zip(keys(original), unflatten(collect(values(original)), x))) + D = ConstructionBase.constructorof(typeof(original)) + return D(zip(keys(original), unflatten(collect(values(original)), x))) end # TODO: Move `getvalue` and `hasvalue` to AbstractPPL.jl. From da04c7bfc98c70bee19d6bd07c3044964d8f3550 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Nov 2022 11:33:19 +0000 Subject: [PATCH 214/221] suggest using OrderedDict instead of Dict in docstrings --- src/simple_varinfo.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 119015915..d1d637d27 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -571,8 +571,8 @@ julia> # Using a `NamedTuple`. logjoint(demo([1.0]), (m = 100.0, )) -9902.33787706641 -julia> # Using a `Dict`. - logjoint(demo([1.0]), Dict(@varname(m) => 100.0)) +julia> # Using a `OrderedDict`. + logjoint(demo([1.0]), OrderedDict(@varname(m) => 100.0)) -9902.33787706641 julia> # Truth. @@ -603,8 +603,8 @@ julia> # Using a `NamedTuple`. logprior(demo([1.0]), (m = 100.0, )) -5000.918938533205 -julia> # Using a `Dict`. - logprior(demo([1.0]), Dict(@varname(m) => 100.0)) +julia> # Using a `OrderedDict`. + logprior(demo([1.0]), OrderedDict(@varname(m) => 100.0)) -5000.918938533205 julia> # Truth. @@ -635,8 +635,8 @@ julia> # Using a `NamedTuple`. loglikelihood(demo([1.0]), (m = 100.0, )) -4901.418938533205 -julia> # Using a `Dict`. - loglikelihood(demo([1.0]), Dict(@varname(m) => 100.0)) +julia> # Using a `OrderedDict`. + loglikelihood(demo([1.0]), OrderedDict(@varname(m) => 100.0)) -4901.418938533205 julia> # Truth. From 41fd89a538ec815723ab5ee643aeb4ca9cab14b9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Nov 2022 11:33:34 +0000 Subject: [PATCH 215/221] fixed doctest --- src/model.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index 7af98e945..09f0d36ad 100644 --- a/src/model.jl +++ b/src/model.jl @@ -458,7 +458,7 @@ julia> conditioned(cm).var"a.m" 1.0 julia> keys(VarInfo(cm)) # <= no variables are sampled -Any[] +VarName[] ``` """ conditioned(model::Model) = conditioned(model.context) From e7b8b1063b1fe5e3b698b975774c91c5e3f4b612 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Nov 2022 12:09:14 +0000 Subject: [PATCH 216/221] added docs for unflatten for varinfos --- docs/src/api.md | 1 + src/abstract_varinfo.jl | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/docs/src/api.md b/docs/src/api.md index 87c604ffb..1096a62b0 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -200,6 +200,7 @@ DynamicPPL.maybe_invlink_before_eval!! #### Utils ```@docs +DynamicPPL.unflatten DynamicPPL.tonamedtuple ``` diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 526fce3cf..e77f03023 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -518,6 +518,16 @@ function _default_sampler(::IsParent, context::AbstractContext) end # Utilities +""" + unflatten(vi::AbstractVarInfo[, spl::AbstractSampler], x::AbstractVector) + +Return a new instance of `vi` with the values of `x` assigned to the variables. + +If `spl` is provided, `x` is assumed to be realizations only for variables related +to `spl`. +""" +function unflatten end + """ tonamedtuple(vi::AbstractVarInfo) From 1c1b6ed36c08819208cc97db348be5a5d785dd9b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 3 Nov 2022 11:34:49 +0000 Subject: [PATCH 217/221] added comment to explain settrans!! for VarInfo --- src/varinfo.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/varinfo.jl b/src/varinfo.jl index db8738e9e..f8dfbbfca 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -389,6 +389,11 @@ function settrans!!(vi::VarInfo, trans::Bool) end settrans!!(vi::VarInfo, trans::NoTransformation) = settrans!!(vi, false) +# HACK: This is necessary to make something like `link!!(transformation, vi, model)` +# work properly, which will transform the variables according to `transformation` +# and then call `settrans!!(vi, transformation)`. An alternative would be to add +# the `transformation` to the `VarInfo` object, but at the moment doesn't seem +# worth it as `VarInfo` has its own way of handling transformations. settrans!!(vi::VarInfo, trans::AbstractTransformation) = settrans!!(vi, true) """ From fbf9e0aafac4cdcb98ede60ebf181ac4e22cb02e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 3 Nov 2022 11:35:38 +0000 Subject: [PATCH 218/221] renamed MaybeThreadSafeVarInfo --- src/context_implementations.jl | 2 +- src/varinfo.jl | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 6b538240e..810600072 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -204,7 +204,7 @@ function assume( sampler::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, - vi::MaybeThreadSafeVarInfo, + vi::VarInfoOrThreadSafeVarInfo, ) if haskey(vi, vn) # Always overwrite the parameters with new ones for `SampleFromUniform`. diff --git a/src/varinfo.jl b/src/varinfo.jl index f8dfbbfca..f4438ad8b 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -103,7 +103,7 @@ struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo end const UntypedVarInfo = VarInfo{<:Metadata} const TypedVarInfo = VarInfo{<:NamedTuple} -const MaybeThreadSafeVarInfo{Tmeta} = Union{ +const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} } @@ -1252,7 +1252,7 @@ end Calls `kernel!(vi, vn, values, keys)` for every `vn` in `vi`. """ -function _apply!(kernel!, vi::MaybeThreadSafeVarInfo, values, keys) +function _apply!(kernel!, vi::VarInfoOrThreadSafeVarInfo, values, keys) keys_strings = map(string, collectmaybe(keys)) num_indices_seen = 0 @@ -1310,7 +1310,7 @@ end end end -function _find_missing_keys(vi::MaybeThreadSafeVarInfo, keys) +function _find_missing_keys(vi::VarInfoOrThreadSafeVarInfo, keys) string_vns = map(string, collectmaybe(Base.keys(vi))) # If `key` isn't subsumed by any element of `string_vns`, it is not present in `vi`. missing_keys = filter(keys) do key @@ -1383,7 +1383,7 @@ function setval!(vi::VarInfo, chains::AbstractChains, sample_idx::Int, chain_idx return setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) end -function _setval_kernel!(vi::MaybeThreadSafeVarInfo, vn::VarName, values, keys) +function _setval_kernel!(vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys) indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) if !isempty(indices) val = reduce(vcat, values[indices]) @@ -1452,19 +1452,19 @@ julia> var_info[@varname(x[1])] # [✓] changed ## See also - [`setval!`](@ref) """ -function setval_and_resample!(vi::MaybeThreadSafeVarInfo, x) +function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, x) return setval_and_resample!(vi, values(x), keys(x)) end -function setval_and_resample!(vi::MaybeThreadSafeVarInfo, values, keys) +function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, values, keys) return _apply!(_setval_and_resample_kernel!, vi, values, keys) end function setval_and_resample!( - vi::MaybeThreadSafeVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int + vi::VarInfoOrThreadSafeVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int ) return setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) end -function _setval_and_resample_kernel!(vi::MaybeThreadSafeVarInfo, vn::VarName, values, keys) +function _setval_and_resample_kernel!(vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys) indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) if !isempty(indices) val = reduce(vcat, values[indices]) From a4279832a6cea1dec37a038e686606d517297745 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 3 Nov 2022 11:39:00 +0000 Subject: [PATCH 219/221] added comment on maybe_inlink_before_eval!! --- src/model.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/model.jl b/src/model.jl index 09f0d36ad..b7e0984c5 100644 --- a/src/model.jl +++ b/src/model.jl @@ -592,6 +592,10 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf ) model.f( model, + # Maybe perform `invlink!!` once prior to evaluation to avoid + # lazy `invlink`-ing of the parameters. This can be useful for + # speeding up computation. See docs for `maybe_invlink_before_eval!!` + # for more information. maybe_invlink_before_eval!!(varinfo, context_new, model), context_new, $(unwrap_args...), From 6b126b8601ac365f660e2f8e16639dabb1c3e99b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 3 Nov 2022 11:41:02 +0000 Subject: [PATCH 220/221] removed unnecessary defs in tests --- test/varinfo.jl | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/test/varinfo.jl b/test/varinfo.jl index 5890310ee..a94de4a29 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,23 +1,3 @@ -# TODO: Should all this go somewhere else? Seems useful for more tests. -short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = short_varinfo_name(vi.varinfo) -short_varinfo_name(::TypedVarInfo) = "TypedVarInfo" -short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" -short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" -short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" - -function update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns) - for vn in vns - vi = DynamicPPL.setindex!!(vi, get(vals, vn), vn) - end - return vi -end - -function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns) - for vn in vns - @test vi[vn] == get(vals, vn) - end -end - @testset "varinfo.jl" begin @testset "TypedVarInfo" begin @model gdemo(x, y) = begin From d715b0c4a842cfe15645e5330c7f7c1c366d37b7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 3 Nov 2022 11:42:13 +0000 Subject: [PATCH 221/221] formatting --- src/varinfo.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index f4438ad8b..6107b869f 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1464,7 +1464,9 @@ function setval_and_resample!( return setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) end -function _setval_and_resample_kernel!(vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys) +function _setval_and_resample_kernel!( + vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys +) indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) if !isempty(indices) val = reduce(vcat, values[indices])