From 8b916f8640280b9883a587d6214155eae2c75892 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 21 Jul 2022 12:19:50 +0000 Subject: [PATCH] Perform invlinking in assume rather than implicitly in getindex (#360) Currently, in `assume`, etc., `invlink` is called implicitly in `getindex` using the distribution extracted from `vi`. This has a couple of drawbacks: 1. We can only use the distribution for a particular `vn` stored in `vi` obtained during the initial run. This means that we can't even run models where the distributions has dynamic domains, i.e. the domain of a particular random variable is dependent on the realizations of other random variables. 2. We have to store the distribution for each `vn` in `vi`. This was fine when we only had `VarInfo` because we also need it for other functionality, but this is not the case in `SimpleVarInfo` (nor will it be). So. In this PR we introduce a `getindex_raw` which is `getindex` but without `invlink` if it's already linked, and uses this within `assume`, etc. where we now use the distributions that are passed to `assume` rather than those stored in `vi`. E.g. the following now works: ``` julia julia> @model demo() = x ~ InverseGamma(2, 3) demo (generic function with 2 methods) julia> vi = SimpleVarInfo((x = 10.0, ), true) SimpleVarInfo((x = 10.0,), 0.0, true) julia> _, vi = DynamicPPL.evaluate!!(model, vi, DefaultContext()) (22026.465794806718, SimpleVarInfo{NamedTuple{(:x,), Tuple{Float64}}, Float64}((x = 10.0,), -17.80291162245307, true)) ``` Co-authored-by: Hong Ge --- Project.toml | 6 +- docs/src/api.md | 22 +- src/DynamicPPL.jl | 2 + src/context_implementations.jl | 94 +++-- src/distribution_wrappers.jl | 31 +- src/simple_varinfo.jl | 230 +++++++++--- src/test_utils.jl | 661 ++++++++++++++++++++++++++------- src/threadsafe.jl | 16 + src/utils.jl | 9 + src/varinfo.jl | 89 +++-- test/Project.toml | 2 +- test/contexts.jl | 26 +- test/distribution_wrappers.jl | 2 +- test/loglikelihoods.jl | 26 +- test/model.jl | 16 + test/simple_varinfo.jl | 114 ++++-- test/turing/Project.toml | 2 + test/varinfo.jl | 43 +++ 18 files changed, 1062 insertions(+), 329 deletions(-) diff --git a/Project.toml b/Project.toml index 726a74343..bfa13d956 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.19.3" +version = "0.19.4" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -8,7 +8,9 @@ AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -22,7 +24,9 @@ AbstractPPL = "0.5.1" BangBang = "0.3" Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9, 0.10" ChainRulesCore = "0.9.7, 0.10, 1" +ConstructionBase = "1" Distributions = "0.23.8, 0.24, 0.25" +DocStringExtensions = "0.8" MacroTools = "0.5.6" Setfield = "0.7.1, 0.8" ZygoteRules = "0.2" diff --git a/docs/src/api.md b/docs/src/api.md index 133b86e9b..809e6c49e 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -103,8 +103,14 @@ NamedDist DynamicPPL provides several demo models and helpers for testing samplers in the `DynamicPPL.TestUtils` submodule. ```@docs -DynamicPPL.TestUtils.test_sampler_demo_models +DynamicPPL.TestUtils.test_sampler +DynamicPPL.TestUtils.test_sampler_on_demo_models DynamicPPL.TestUtils.test_sampler_continuous +DynamicPPL.TestUtils.marginal_mean_of_samples +``` + +```@docs +DynamicPPL.TestUtils.DEMO_MODELS ``` For every demo model, one can define the true log prior, log likelihood, and log joint probabilities. @@ -115,6 +121,20 @@ DynamicPPL.TestUtils.loglikelihood_true DynamicPPL.TestUtils.logjoint_true ``` +And in the case where the model includes constrained variables, it can also be useful to define + +```@docs +DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian +DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian +``` + +Finally, the following methods can also be of use: + +```@docs +DynamicPPL.TestUtils.varnames +DynamicPPL.TestUtils.posterior_mean +``` + ## Advanced ### Variable names diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 4a44fa869..86d4e0def 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -12,6 +12,8 @@ using MacroTools: MacroTools using Setfield: Setfield using ZygoteRules: ZygoteRules +using DocStringExtensions + using Random: Random import Base: diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 20c4af446..6271b5d8c 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -55,7 +55,7 @@ end function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi) if haskey(context.vars, getsym(vn)) vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return tilde_assume(PriorContext(), right, vn, vi) end @@ -64,7 +64,7 @@ function tilde_assume( ) if haskey(context.vars, getsym(vn)) vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return tilde_assume(rng, PriorContext(), sampler, right, vn, vi) end @@ -72,7 +72,7 @@ end function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi) if haskey(context.vars, getsym(vn)) vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return tilde_assume(LikelihoodContext(), right, vn, vi) end @@ -86,7 +86,7 @@ function tilde_assume( ) if haskey(context.vars, getsym(vn)) vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi) end @@ -194,7 +194,7 @@ end # fallback without sampler function assume(dist::Distribution, vn::VarName, vi) - r = vi[vn] + r = vi[vn, dist] return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi end @@ -211,16 +211,21 @@ function assume( if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") unset_flag!(vi, vn, "del") r = init(rng, dist, sampler) - vi[vn] = vectorize(dist, r) - settrans!(vi, false, vn) + vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r)) setorder!(vi, vn, get_num_produce(vi)) else - r = vi[vn] + # Otherwise we just extract it. + r = vi[vn, dist] end else r = init(rng, dist, sampler) - push!!(vi, vn, r, dist, sampler) - settrans!(vi, false, vn) + if istrans(vi) + push!!(vi, vn, link(dist, r), dist, sampler) + # By default `push!!` sets the transformed flag to `false`. + settrans!!(vi, true, vn) + else + push!!(vi, vn, r, dist, sampler) + end end return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi @@ -286,7 +291,7 @@ function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left, var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) + settrans!!.((vi,), false, _vns) dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi) else dot_tilde_assume(LikelihoodContext(), right, left, vn, vi) @@ -305,19 +310,20 @@ function dot_tilde_assume( var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) + settrans!!.((vi,), false, _vns) dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi) else dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi) end end + function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi) - return dot_assume(NoDist.(right), left, vn, vi) + return dot_assume(nodist(right), left, vn, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, vi ) - return dot_assume(rng, sampler, NoDist.(right), vn, left, vi) + return dot_assume(rng, sampler, nodist(right), vn, left, vi) end # `PriorContext` @@ -326,7 +332,7 @@ function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) + settrans!!.((vi,), false, _vns) dot_tilde_assume(PriorContext(), _right, _left, _vns, vi) else dot_tilde_assume(PriorContext(), right, left, vn, vi) @@ -345,7 +351,7 @@ function dot_tilde_assume( var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) + settrans!!.((vi,), false, _vns) dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi) else dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi) @@ -383,14 +389,14 @@ function dot_assume( vns::AbstractVector{<:VarName}, vi::AbstractVarInfo, ) - @assert length(dist) == size(var, 1) + @assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))" # NOTE: We cannot work with `var` here because we might have a model of the form # # m = Vector{Float64}(undef, n) # m .~ Normal() # # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - r = vi[vns] + r = vi[vns, dist] lp = sum(zip(vns, eachcol(r))) do (vn, ri) return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) end @@ -412,19 +418,21 @@ function dot_assume( end function dot_assume( - dists::Union{Distribution,AbstractArray{<:Distribution}}, + dist::Distribution, var::AbstractArray, vns::AbstractArray{<:VarName}, vi +) + r = getindex.((vi,), vns, (dist,)) + lp = sum(Bijectors.logpdf_with_trans.((dist,), r, istrans.((vi,), vns))) + return r, lp, vi +end + +function dot_assume( + dists::AbstractArray{<:Distribution}, var::AbstractArray, vns::AbstractArray{<:VarName}, vi, ) - # NOTE: We cannot work with `var` here because we might have a model of the form - # - # m = Vector{Float64}(undef, n) - # m .~ Normal() - # - # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - r = reshape(vi[vec(vns)], size(vns)) - lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) + r = getindex.((vi,), vns, dists) + lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))) return r, lp, vi end @@ -438,7 +446,7 @@ function dot_assume( ) r = get_and_set_val!(rng, vi, vns, dists, spl) # Make sure `r` is not a matrix for multivariate distributions - lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) + lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))) return r, lp, vi end function dot_assume(rng, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any) @@ -462,19 +470,23 @@ function get_and_set_val!( r = init(rng, dist, spl, n) for i in 1:n vn = vns[i] - vi[vn] = vectorize(dist, r[:, i]) - settrans!(vi, false, vn) + vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r[:, i])) setorder!(vi, vn, get_num_produce(vi)) end else - r = vi[vns] + r = vi[vns, dist] end else r = init(rng, dist, spl, n) for i in 1:n vn = vns[i] - push!!(vi, vn, r[:, i], dist, spl) - settrans!(vi, false, vn) + if istrans(vi) + push!!(vi, vn, Bijectors.link(dist, r[:, i]), dist, spl) + # `push!!` sets the trans-flag to `false` by default. + settrans!!(vi, true, vn) + else + push!!(vi, vn, r[:, i], dist, spl) + end end end return r @@ -496,12 +508,13 @@ function get_and_set_val!( for i in eachindex(vns) vn = vns[i] dist = dists isa AbstractArray ? dists[i] : dists - vi[vn] = vectorize(dist, r[i]) - settrans!(vi, false, vn) + vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r[i])) setorder!(vi, vn, get_num_produce(vi)) end else - r = reshape(vi[vec(vns)], size(vns)) + # r = reshape(vi[vec(vns)], size(vns)) + r_raw = getindex_raw(vi, vec(vns)) + r = maybe_invlink.((vi,), vns, dists, reshape(r_raw, size(vns))) end else f = (vn, dist) -> init(rng, dist, spl) @@ -511,8 +524,13 @@ function get_and_set_val!( # 1. Figure out the broadcast size and use a `foreach`. # 2. Define an anonymous function which returns `nothing`, which # we then broadcast. This will allocate a vector of `nothing` though. - push!!.(Ref(vi), vns, r, dists, Ref(spl)) - settrans!.(Ref(vi), false, vns) + if istrans(vi) + push!!.((vi,), vns, link.((vi,), vns, dists, r), dists, (spl,)) + # `push!!` sets the trans-flag to `false` by default. + settrans!!.((vi,), true, vns) + else + push!!.((vi,), vns, r, dists, (spl,)) + end end return r end diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index 4045cc089..d8968a68e 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -13,6 +13,9 @@ end NamedDist(dist::Distribution, name::Symbol) = NamedDist(dist, VarName{name}()) +Base.length(dist::NamedDist) = Base.length(dist.dist) +Base.size(dist::NamedDist) = Base.size(dist.dist) + Distributions.logpdf(dist::NamedDist, x::Real) = Distributions.logpdf(dist.dist, x) function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real}) return Distributions.logpdf(dist.dist, x) @@ -24,12 +27,20 @@ function Distributions.loglikelihood(dist::NamedDist, x::AbstractArray{<:Real}) return Distributions.loglikelihood(dist.dist, x) end +Bijectors.bijector(d::NamedDist) = Bijectors.bijector(d.dist) + struct NoDist{variate,support,Td<:Distribution{variate,support}} <: Distribution{variate,support} dist::Td end NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name) +nodist(dist::Distribution) = NoDist(dist) +nodist(dists::AbstractArray) = nodist.(dists) + +Base.length(dist::NoDist) = Base.length(dist.dist) +Base.size(dist::NoDist) = Base.size(dist.dist) + Distributions.rand(rng::Random.AbstractRNG, d::NoDist) = rand(rng, d.dist) Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0 Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0 @@ -40,9 +51,21 @@ Distributions.logpdf(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0 Distributions.minimum(d::NoDist) = minimum(d.dist) Distributions.maximum(d::NoDist) = maximum(d.dist) -Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real) = 0 -Bijectors.logpdf_with_trans(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0 -function Bijectors.logpdf_with_trans(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}) +Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real, ::Bool) = 0 +function Bijectors.logpdf_with_trans( + d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}, ::Bool +) + return 0 +end +function Bijectors.logpdf_with_trans( + d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}, ::Bool +) return zeros(Int, size(x, 2)) end -Bijectors.logpdf_with_trans(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0 +function Bijectors.logpdf_with_trans( + d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}, ::Bool +) + return 0 +end + +Bijectors.bijector(d::NoDist) = Bijectors.bijector(d.dist) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 5cecda4b2..5b9edefdf 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -1,11 +1,19 @@ +abstract type AbstractTransformation end + +struct NoTransformation <: AbstractTransformation end +struct DefaultTransformation <: AbstractTransformation end + """ - SimpleVarInfo{NT,T} <: AbstractVarInfo + $(TYPEDEF) A simple wrapper of the parameters with a `logp` field for accumulation of the logdensity. Currently only implemented for `NT<:NamedTuple` and `NT<:Dict`. +# Fields +$(FIELDS) + # Notes The major differences between this and `TypedVarInfo` are: 1. `SimpleVarInfo` does not require linearization. @@ -16,7 +24,7 @@ The major differences between this and `TypedVarInfo` are: # Examples ## General usage -```jldoctest; setup=:(using Distributions) +```jldoctest simplevarinfo-general; setup=:(using Distributions) julia> using StableRNGs julia> @model function demo() @@ -78,6 +86,60 @@ ERROR: KeyError: key x[1:2] not found [...] ``` +You can also sample in _transformed_ space: + +```jldoctest simplevarinfo-general +julia> @model demo_constrained() = x ~ Exponential() +demo_constrained (generic function with 2 methods) + +julia> m = demo_constrained(); + +julia> _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo(), ctx); + +julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞ +1.8632965762164932 + +julia> _, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(), true), ctx); + +julia> vi[@varname(x)] # (✓) -∞ < x < ∞ +-0.21080155351918753 + +julia> xs = [last(DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(), true), ctx))[@varname(x)] for i = 1:10]; + +julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! +true + +julia> # And with `Dict` of course! + _, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true), ctx); + +julia> vi[@varname(x)] # (✓) -∞ < x < ∞ +0.6225185067787314 + +julia> xs = [last(DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(), true), ctx))[@varname(x)] for i = 1:10]; + +julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! +true +``` + +Evaluation in transformed space of course also works: + +```jldoctest simplevarinfo-general +julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true) +Transformed SimpleVarInfo((x = -1.0,), 0.0) + +julia> # (✓) Positive probability mass on negative numbers! + getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) +-1.3678794411714423 + +julia> # While if we forget to make indicate that it's transformed: + vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false) +SimpleVarInfo((x = -1.0,), 0.0) + +julia> # (✓) No probability mass on negative numbers! + getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) +-Inf +``` + ## Indexing Using `NamedTuple` as underlying storage. @@ -126,14 +188,26 @@ ERROR: type NamedTuple has no field b [...] ``` """ -struct SimpleVarInfo{NT,T} <: AbstractVarInfo +struct SimpleVarInfo{NT,T,C<:AbstractTransformation} <: AbstractVarInfo + "underlying representation of the realization represented" values::NT + "holds the accumulated log-probability" logp::T + "represents whether it assumes variables to be transformed" + transformation::C end -SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) -SimpleVarInfo{T}(; kwargs...) where {T<:Real} = SimpleVarInfo{T}(NamedTuple(kwargs)) -SimpleVarInfo(; kwargs...) = SimpleVarInfo{Float64}(NamedTuple(kwargs)) +SimpleVarInfo(values, logp) = SimpleVarInfo(values, logp, NoTransformation()) + +function SimpleVarInfo{T}(θ) where {T<:Real} + return SimpleVarInfo(θ, zero(T)) +end +function SimpleVarInfo{T}(; kwargs...) where {T<:Real} + return SimpleVarInfo{T}(NamedTuple(kwargs)) +end +function SimpleVarInfo(; kwargs...) + return SimpleVarInfo{Float64}(NamedTuple(kwargs)) +end SimpleVarInfo(θ) = SimpleVarInfo{Float64}(θ) # Constructor from `Model`. @@ -158,8 +232,8 @@ function BangBang.empty!!(vi::SimpleVarInfo) end getlogp(vi::SimpleVarInfo) = vi.logp -setlogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.values, logp) -acclogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.values, getlogp(vi) + logp) +setlogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = logp +acclogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = getlogp(vi) + logp """ keys(vi::SimpleVarInfo) @@ -167,6 +241,7 @@ acclogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.values, getlogp(vi) + logp Return an iterator of keys present in `vi`. """ Base.keys(vi::SimpleVarInfo) = keys(vi.values) +Base.keys(vi::SimpleVarInfo{<:NamedTuple}) = map(k -> VarName{k}(), keys(vi.values)) function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) vi.logp[] = logp @@ -179,10 +254,24 @@ function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) end function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo) + if !(svi.transformation isa NoTransformation) + print(io, "Transformed ") + end + return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ")") end # `NamedTuple` +function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution) + return maybe_invlink(vi, vn, dist, getindex(vi, vn)) +end +function Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution) + vals_linked = mapreduce(vcat, vns) do vn + getindex(vi, vn, dist) + end + return reconstruct(dist, vals_linked, length(vns)) +end + Base.getindex(vi::SimpleVarInfo, vn::VarName) = get(vi.values, vn) # `Dict` @@ -221,9 +310,23 @@ Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(Base.Fix1(getinde Base.getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.values Base.getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.values + # TODO: Should we do better? Base.getindex(vi::SimpleVarInfo, spl::Sampler) = vi.values +# Since we don't perform any transformations in `getindex` for `SimpleVarInfo` +# we simply call `getindex` in `getindex_raw`. +getindex_raw(vi::SimpleVarInfo, vn::VarName) = vi[vn] +function getindex_raw(vi::SimpleVarInfo, vn::VarName, dist::Distribution) + return reconstruct(dist, getindex_raw(vi, vn)) +end +getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}) = vi[vns] +function getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution) + # `reconstruct` expects a flattened `Vector` regardless of the type of `dist`, so we `vcat` everything. + vals = mapreduce(Base.Fix1(getindex_raw, vi), vcat, vns) + return reconstruct(dist, vals, length(vns)) +end + Base.haskey(vi::SimpleVarInfo, vn::VarName) = _haskey(vi.values, vn) function _haskey(nt::NamedTuple, vn::VarName) # LHS: Ensure that `nt` indeed has the property we want. @@ -259,7 +362,7 @@ end function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName) # For `NamedTuple` we treat the symbol in `vn` as the _property_ to set. - return SimpleVarInfo(set!!(vi.values, vn, val), vi.logp) + return Setfield.@set vi.values = set!!(vi.values, vn, val) end # TODO: Specialize to handle certain cases, e.g. a collection of `VarName` with @@ -290,7 +393,7 @@ function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName vn_key = VarName(vn, keylens) BangBang.setindex!!(dict, set!!(dict[vn_key], child, val), vn_key) end - return SimpleVarInfo(dict_new, vi.logp) + return Setfield.@set vi.values = dict_new end # `NamedTuple` @@ -325,8 +428,8 @@ function BangBang.push!!( return vi end -const SimpleOrThreadSafeSimple{T,V} = Union{ - SimpleVarInfo{T,V},ThreadSafeVarInfo{<:SimpleVarInfo{T,V}} +const SimpleOrThreadSafeSimple{T,V,C} = Union{ + SimpleVarInfo{T,V,C},ThreadSafeVarInfo{<:SimpleVarInfo{T,V,C}} } # Necessary for `matchingvalue` to work properly. @@ -337,79 +440,90 @@ function Base.eltype( end # Context implementations -function assume(dist::Distribution, vn::VarName, vi::SimpleOrThreadSafeSimple) - left = vi[vn] - return left, Distributions.loglikelihood(dist, left), vi -end - +# NOTE: Evaluations, i.e. those without `rng` are shared with other +# implementations of `AbstractVarInfo`. function assume( rng::Random.AbstractRNG, - sampler::SampleFromPrior, + sampler::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, vi::SimpleOrThreadSafeSimple, ) value = init(rng, dist, sampler) - vi = BangBang.push!!(vi, vn, value, dist, sampler) - return value, Distributions.loglikelihood(dist, value), vi -end - -function dot_assume( - dist::MultivariateDistribution, - var::AbstractMatrix, - vns::AbstractVector{<:VarName}, - vi::SimpleOrThreadSafeSimple, -) - @assert length(dist) == size(var, 1) - # NOTE: We cannot work with `var` here because we might have a model of the form - # - # m = Vector{Float64}(undef, n) - # m .~ Normal() - # - # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - value = vi[vns] - lp = sum(zip(vns, eachcol(value))) do (vn, val) - return Distributions.logpdf(dist, val) - end - return value, lp, vi + # Transform if we're working in unconstrained space. + value_raw = maybe_link(vi, vn, dist, value) + vi = BangBang.push!!(vi, vn, value_raw, dist, sampler) + return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi end function dot_assume( + rng, + spl::Union{SampleFromPrior,SampleFromUniform}, dists::Union{Distribution,AbstractArray{<:Distribution}}, - var::AbstractArray, vns::AbstractArray{<:VarName}, + var::AbstractArray, vi::SimpleOrThreadSafeSimple, ) - # NOTE: We cannot work with `var` here because we might have a model of the form - # - # m = Vector{Float64}(undef, n) - # m .~ Normal() - # - # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - value = vi[vns] - lp = sum(Distributions.logpdf.(dists, value)) + f = (vn, dist) -> init(rng, dist, spl) + value = f.(vns, dists) + + # Transform if we're working in transformed space. + value_raw = if dists isa Distribution + maybe_link.((vi,), vns, (dists,), value) + else + maybe_link.((vi,), vns, dists, value) + end + + # Update `vi` + vi = BangBang.setindex!!(vi, value_raw, vns) + + # Compute logp. + lp = sum(Bijectors.logpdf_with_trans.(dists, value, istrans.((vi,), vns))) return value, lp, vi end function dot_assume( rng, spl::Union{SampleFromPrior,SampleFromUniform}, - dists::Union{Distribution,AbstractArray{<:Distribution}}, - vns::AbstractArray{<:VarName}, - var::AbstractArray, + dist::MultivariateDistribution, + vns::AbstractVector{<:VarName}, + var::AbstractMatrix, vi::SimpleOrThreadSafeSimple, ) - f = (vn, dist) -> init(rng, dist, spl) - value = f.(vns, dists) - vi = BangBang.setindex!!(vi, value, vns) - lp = sum(Distributions.logpdf.(dists, value)) + @assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))" + + # r = get_and_set_val!(rng, vi, vns, dist, spl) + n = length(vns) + value = init(rng, dist, spl, n) + + # Update `vi`. + for (vn, val) in zip(vns, eachcol(value)) + val_linked = maybe_link(vi, vn, dist, val) + vi = BangBang.setindex!!(vi, val_linked, vn) + end + + # Compute logp. + lp = sum(Bijectors.logpdf_with_trans(dist, value, istrans(vi))) return value, lp, vi end # HACK: Allows us to re-use the implementation of `dot_tilde`, etc. for literals. increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing -settrans!(vi::SimpleOrThreadSafeSimple, trans::Bool, vn::VarName) = nothing -istrans(::SimpleVarInfo, vn::VarName) = false + +# NOTE: We don't implement `settrans!!(vi, trans, vn)`. +function settrans!!(vi::SimpleVarInfo, trans) + return settrans!!(vi, trans ? DefaultTransformation() : NoTransformation()) +end +function settrans!!(vi::SimpleVarInfo, transformation::AbstractTransformation) + return Setfield.@set vi.transformation = transformation +end +function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) + return Setfield.@set vi.varinfo = settrans!!(vi.varinfo, trans) +end + +istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) +istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi) +istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) """ values_as(varinfo[, Type]) diff --git a/src/test_utils.jl b/src/test_utils.jl index ca0fabc9a..ef314fa91 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -6,10 +6,35 @@ using LinearAlgebra using Distributions using Test +using Random: Random +using Bijectors: Bijectors +using Setfield: Setfield + """ - logprior_true(model, θ) + varname_leaves(vn::VarName, val) -Return the `logprior` of `model` for `θ`. +Return iterator over all varnames that are represented by `vn` on `val`, +e.g. `varname_leaves(@varname(x), rand(2))` results in an iterator over `[@varname(x[1]), @varname(x[2])]`. +""" +varname_leaves(vn::VarName, val::Real) = [vn] +function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) + return ( + VarName(vn, DynamicPPL.getlens(vn) ∘ Setfield.IndexLens(Tuple(I))) for + I in CartesianIndices(val) + ) +end +function varname_leaves(vn::VarName, val::AbstractArray) + return Iterators.flatten( + varname_leaves( + VarName(vn, DynamicPPL.getlens(vn) ∘ Setfield.IndexLens(Tuple(I))), val[I] + ) for I in CartesianIndices(val) + ) +end + +""" + logprior_true(model, args...) + +Return the `logprior` of `model` for `args`. This should generally be implemented by hand for every specific `model`. @@ -18,9 +43,9 @@ See also: [`logjoint_true`](@ref), [`loglikelihood_true`](@ref). function logprior_true end """ - loglikelihood_true(model, θ) + loglikelihood_true(model, args...) -Return the `loglikelihood` of `model` for `θ`. +Return the `loglikelihood` of `model` for `args`. This should generally be implemented by hand for every specific `model`. @@ -29,11 +54,11 @@ See also: [`logjoint_true`](@ref), [`logprior_true`](@ref). function loglikelihood_true end """ - logjoint_true(model, θ) + logjoint_true(model, args...) -Return the `logjoint` of `model` for `θ`. +Return the `logjoint` of `model` for `args`. -Defaults to `logprior_true(model, θ) + loglikelihood_true(model, θ)`. +Defaults to `logprior_true(model, args...) + loglikelihood_true(model, args..)`. This should generally be implemented by hand for every specific `model` so that the returned value can be used as a ground-truth for testing things like: @@ -49,202 +74,561 @@ function logjoint_true(model::Model, args...) return logprior_true(model, args...) + loglikelihood_true(model, args...) end -# A collection of models for which the mean-of-means for the posterior should -# be same. +""" + logjoint_true_with_logabsdet_jacobian(model::Model, args...) + +Return a tuple `(args_unconstrained, logjoint)` of `model` for `args`. + +Unlike [`logjoint_true`](@ref), the returned logjoint computation includes the +log-absdet-jacobian adjustment, thus computing logjoint for the unconstrained variables. + +Note that `args` are assumed be in the support of `model`, while `args_unconstrained` +will be unconstrained. + +This should generally not be implemented directly, instead one should implement +[`logprior_true_with_logabsdet_jacobian`](@ref) for a given `model`. + +See also: [`logjoint_true`](@ref), [`logprior_true_with_logabsdet_jacobian`](@ref). +""" +function logjoint_true_with_logabsdet_jacobian(model::Model, args...) + args_unconstrained, lp = logprior_true_with_logabsdet_jacobian(model, args...) + return args_unconstrained, lp + loglikelihood_true(model, args...) +end + +""" + logprior_true_with_logabsdet_jacobian(model::Model, args...) + +Return a tuple `(args_unconstrained, logprior_unconstrained)` of `model` for `args...`. + +Unlike [`logprior_true`](@ref), the returned logprior computation includes the +log-absdet-jacobian adjustment, thus computing logprior for the unconstrained variables. + +Note that `args` are assumed be in the support of `model`, while `args_unconstrained` +will be unconstrained. + +See also: [`logprior_true`](@ref). +""" +function logprior_true_with_logabsdet_jacobian end + +""" + varnames(model::Model) + +Return a collection of `VarName` as they are expected to appear in the model. + +Even though it is recommended to implement this by hand for a particular `Model`, +a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. +""" +function varnames(model::Model) + return collect( + keys(last(DynamicPPL.evaluate!!(model, SimpleVarInfo(Dict()), SamplingContext()))) + ) +end + +""" + posterior_mean(model::Model) + +Return a `NamedTuple` compatible with `varnames(model)` where the values represent +the posterior mean under `model`. + +"Compatible" means that a `varname` from `varnames(model)` can be used to extract the +corresponding value using `get`, e.g. `get(posterior_mean(model), varname)`. +""" +function posterior_mean end + +""" + demo_dynamic_constraint() + +A model with variables `m` and `x` with `x` having support depending on `m`. +""" +@model function demo_dynamic_constraint() + m ~ Normal() + x ~ truncated(Normal(), m, Inf) + + return (m=m, x=x) +end + +function logprior_true(model::Model{typeof(demo_dynamic_constraint)}, m, x) + return logpdf(Normal(), m) + logpdf(truncated(Normal(), m, Inf), x) +end +function loglikelihood_true(model::Model{typeof(demo_dynamic_constraint)}, m, x) + return zero(float(eltype(m))) +end +function varnames(model::Model{typeof(demo_dynamic_constraint)}) + return [@varname(m), @varname(x)] +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dynamic_constraint)}, m, x +) + b_x = Bijectors.bijector(truncated(Normal(), m, Inf)) + x_unconstrained, Δlogp = Bijectors.with_logabsdet_jacobian(b_x, x) + return (m=m, x=x_unconstrained), logprior_true(model, m, x) - Δlogp +end + +# A collection of models for which the posterior should be "similar". +# Some utility methods for these. +function _demo_logprior_true_with_logabsdet_jacobian(model, s, m) + b = Bijectors.bijector(InverseGamma(2, 3)) + s_unconstrained = b.(s) + Δlogp = sum(Base.Fix1(Bijectors.logabsdetjac, b), s) + return (s=s_unconstrained, m=m), logprior_true(model, s, m) - Δlogp +end + @model function demo_dot_assume_dot_observe( - x=[10.0, 10.0], ::Type{TV}=Vector{Float64} + x=[1.5, 2.0], ::Type{TV}=Vector{Float64} ) where {TV} # `dot_assume` and `observe` + s = TV(undef, length(x)) m = TV(undef, length(x)) - m .~ Normal() - x ~ MvNormal(m, 0.25 * I) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + s .~ InverseGamma(2, 3) + m .~ Normal.(0, sqrt.(s)) + + x ~ MvNormal(m, Diagonal(s)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) +end +function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe)}, s, m) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe)}, m) - return loglikelihood(Normal(), m) +function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe)}, s, m) + return loglikelihood(MvNormal(m, Diagonal(s)), model.args.x) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe)}, m) - return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_dot_observe)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end +function varnames(model::Model{typeof(demo_dot_assume_dot_observe)}) + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end @model function demo_assume_index_observe( - x=[10.0, 10.0], ::Type{TV}=Vector{Float64} + x=[1.5, 2.0], ::Type{TV}=Vector{Float64} ) where {TV} # `assume` with indexing and `observe` + s = TV(undef, length(x)) + for i in eachindex(s) + s[i] ~ InverseGamma(2, 3) + end m = TV(undef, length(x)) for i in eachindex(m) - m[i] ~ Normal() + m[i] ~ Normal(0, sqrt(s[i])) end - x ~ MvNormal(m, 0.25 * I) + x ~ MvNormal(m, Diagonal(s)) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_index_observe)}, m) - return loglikelihood(Normal(), m) +function logprior_true(model::Model{typeof(demo_assume_index_observe)}, s, m) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_assume_index_observe)}, m) - return logpdf(MvNormal(m, 0.25 * I), model.args.x) +function loglikelihood_true(model::Model{typeof(demo_assume_index_observe)}, s, m) + return logpdf(MvNormal(m, Diagonal(s)), model.args.x) +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_index_observe)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end +function varnames(model::Model{typeof(demo_assume_index_observe)}) + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end -@model function demo_assume_multivariate_observe(x=[10.0, 10.0]) +@model function demo_assume_multivariate_observe(x=[1.5, 2.0]) # Multivariate `assume` and `observe` - m ~ MvNormal(zero(x), I) - x ~ MvNormal(m, 0.25 * I) + s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) + m ~ MvNormal(zero(x), Diagonal(s)) + x ~ MvNormal(m, Diagonal(s)) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) +end +function logprior_true(model::Model{typeof(demo_assume_multivariate_observe)}, s, m) + s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) + m_dist = MvNormal(zero(model.args.x), Diagonal(s)) + return logpdf(s_dist, s) + logpdf(m_dist, m) end -function logprior_true(model::Model{typeof(demo_assume_multivariate_observe)}, m) - return logpdf(MvNormal(zero(model.args.x), I), m) +function loglikelihood_true(model::Model{typeof(demo_assume_multivariate_observe)}, s, m) + return logpdf(MvNormal(m, Diagonal(s)), model.args.x) +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_multivariate_observe)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function loglikelihood_true(model::Model{typeof(demo_assume_multivariate_observe)}, m) - return logpdf(MvNormal(m, 0.25 * I), model.args.x) +function varnames(model::Model{typeof(demo_assume_multivariate_observe)}) + return [@varname(s), @varname(m)] end @model function demo_dot_assume_observe_index( - x=[10.0, 10.0], ::Type{TV}=Vector{Float64} + x=[1.5, 2.0], ::Type{TV}=Vector{Float64} ) where {TV} # `dot_assume` and `observe` with indexing + s = TV(undef, length(x)) + s .~ InverseGamma(2, 3) m = TV(undef, length(x)) - m .~ Normal() + m .~ Normal.(0, sqrt.(s)) for i in eachindex(x) - x[i] ~ Normal(m[i], 0.5) + x[i] ~ Normal(m[i], sqrt(s[i])) end - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_observe_index)}, m) - return loglikelihood(Normal(), m) +function logprior_true(model::Model{typeof(demo_dot_assume_observe_index)}, s, m) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index)}, m) - return sum(logpdf.(Normal.(m, 0.5), model.args.x)) +function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index)}, s, m) + return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_observe_index)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end +function varnames(model::Model{typeof(demo_dot_assume_observe_index)}) + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end # Using vector of `length` 1 here so the posterior of `m` is the same # as the others. -@model function demo_assume_dot_observe(x=[10.0]) +@model function demo_assume_dot_observe(x=[1.5, 2.0]) # `assume` and `dot_observe` - m ~ Normal() - x .~ Normal(m, 0.5) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x .~ Normal(m, sqrt(s)) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_dot_observe)}, m) - return logpdf(Normal(), m) +function logprior_true(model::Model{typeof(demo_assume_dot_observe)}, s, m) + return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) end -function loglikelihood_true(model::Model{typeof(demo_assume_dot_observe)}, m) - return sum(logpdf.(Normal.(m, 0.5), model.args.x)) +function loglikelihood_true(model::Model{typeof(demo_assume_dot_observe)}, s, m) + return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_dot_observe)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end +function varnames(model::Model{typeof(demo_assume_dot_observe)}) + return [@varname(s), @varname(m)] end @model function demo_assume_observe_literal() # `assume` and literal `observe` - m ~ MvNormal(zeros(2), I) - [10.0, 10.0] ~ MvNormal(m, 0.25 * I) + s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) + m ~ MvNormal(zeros(2), Diagonal(s)) + [1.5, 2.0] ~ MvNormal(m, Diagonal(s)) - return (; m=m, x=[10.0, 10.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, m) - return logpdf(MvNormal(zeros(2), I), m) +function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) + s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) + m_dist = MvNormal(zeros(2), Diagonal(s)) + return logpdf(s_dist, s) + logpdf(m_dist, m) end -function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, m) - return logpdf(MvNormal(m, 0.25 * I), [10.0, 10.0]) +function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) + return logpdf(MvNormal(m, Diagonal(s)), [1.5, 2.0]) +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_observe_literal)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end +function varnames(model::Model{typeof(demo_assume_observe_literal)}) + return [@varname(s), @varname(m)] end @model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and literal `observe` with indexing + s = TV(undef, 2) m = TV(undef, 2) - m .~ Normal() - for i in eachindex(m) - 10.0 ~ Normal(m[i], 0.5) - end + s .~ InverseGamma(2, 3) + m .~ Normal.(0, sqrt.(s)) + + 1.5 ~ Normal(m[1], sqrt(s[1])) + 2.0 ~ Normal(m[2], sqrt(s[2])) - return (; m=m, x=fill(10.0, length(m)), logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, m) - return loglikelihood(Normal(), m) +function logprior_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, s, m) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, m) - return sum(logpdf.(Normal.(m, 0.5), fill(10.0, length(m)))) +function loglikelihood_true( + model::Model{typeof(demo_dot_assume_observe_index_literal)}, s, m +) + return sum(logpdf.(Normal.(m, sqrt.(s)), [1.5, 2.0])) +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_observe_index_literal)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end +function varnames(model::Model{typeof(demo_dot_assume_observe_index_literal)}) + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end @model function demo_assume_literal_dot_observe() # `assume` and literal `dot_observe` - m ~ Normal() - [10.0] .~ Normal(m, 0.5) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + [1.5, 2.0] .~ Normal(m, sqrt(s)) - return (; m=m, x=[10.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) +end +function logprior_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m) + return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) +end +function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m) + return loglikelihood(Normal(m, sqrt(s)), [1.5, 2.0]) end -function logprior_true(model::Model{typeof(demo_assume_literal_dot_observe)}, m) - return logpdf(Normal(), m) +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_literal_dot_observe)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, m) - return logpdf(Normal(m, 0.5), 10.0) +function varnames(model::Model{typeof(demo_assume_literal_dot_observe)}) + return [@varname(s), @varname(m)] end @model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} + s = TV(undef, 2) + s .~ InverseGamma(2, 3) m = TV(undef, 2) - m .~ Normal() + m .~ Normal.(0, sqrt.(s)) - return m + return s, m end @model function demo_assume_submodel_observe_index_literal() # Submodel prior - @submodel m = _prior_dot_assume() - for i in eachindex(m) - 10.0 ~ Normal(m[i], 0.5) - end + @submodel s, m = _prior_dot_assume() + 1.5 ~ Normal(m[1], sqrt(s[1])) + 2.0 ~ Normal(m[2], sqrt(s[2])) - return (; m=m, x=[10.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_submodel_observe_index_literal)}, m) - return loglikelihood(Normal(), m) +function logprior_true( + model::Model{typeof(demo_assume_submodel_observe_index_literal)}, s, m +) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end function loglikelihood_true( - model::Model{typeof(demo_assume_submodel_observe_index_literal)}, m + model::Model{typeof(demo_assume_submodel_observe_index_literal)}, s, m ) - return sum(logpdf.(Normal.(m, 0.5), 10.0)) + return sum(logpdf.(Normal.(m, sqrt.(s)), [1.5, 2.0])) +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_submodel_observe_index_literal)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end +function varnames(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end -@model function _likelihood_dot_observe(m, x) - return x ~ MvNormal(m, 0.25 * I) +@model function _likelihood_mltivariate_observe(s, m, x) + return x ~ MvNormal(m, Diagonal(s)) end @model function demo_dot_assume_observe_submodel( - x=[10.0, 10.0], ::Type{TV}=Vector{Float64} + x=[1.5, 2.0], ::Type{TV}=Vector{Float64} ) where {TV} + s = TV(undef, length(x)) + s .~ InverseGamma(2, 3) m = TV(undef, length(x)) - m .~ Normal() + m .~ Normal.(0, sqrt.(s)) # Submodel likelihood - @submodel _likelihood_dot_observe(m, x) + @submodel _likelihood_mltivariate_observe(s, m, x) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, m) - return loglikelihood(Normal(), m) +function logprior_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, s, m) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, m) - return logpdf(MvNormal(m, 0.25 * I), model.args.x) +function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, s, m) + return logpdf(MvNormal(m, Diagonal(s)), model.args.x) +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_observe_submodel)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end +function varnames(model::Model{typeof(demo_dot_assume_observe_submodel)}) + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end @model function demo_dot_assume_dot_observe_matrix( - x=fill(10.0, 2, 1), ::Type{TV}=Vector{Float64} + x=transpose([1.5 2.0;]), ::Type{TV}=Vector{Float64} ) where {TV} + s = TV(undef, length(x)) + s .~ InverseGamma(2, 3) m = TV(undef, length(x)) - m .~ Normal() + m .~ Normal.(0, sqrt.(s)) # Dotted observe for `Matrix`. - x .~ MvNormal(m, 0.25 * I) + x .~ MvNormal(m, Diagonal(s)) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) +end +function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, m) - return loglikelihood(Normal(), m) +function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m) + return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, m) - return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) +function varnames(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end +@model function demo_dot_assume_matrix_dot_observe_matrix( + x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64} +) where {TV} + n = length(x) + d = length(x) ÷ 2 + s = TV(undef, d, 2) + s .~ product_distribution([InverseGamma(2, 3) for _ in 1:d]) + s_vec = vec(s) + m ~ MvNormal(zeros(n), Diagonal(s_vec)) + + # Dotted observe for `Matrix`. + x .~ MvNormal(m, Diagonal(s_vec)) + + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) +end +function logprior_true( + model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m +) + n = length(model.args.x) + s_vec = vec(s) + return loglikelihood(InverseGamma(2, 3), s_vec) + + logpdf(MvNormal(zeros(n), Diagonal(s_vec)), m) +end +function loglikelihood_true( + model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m +) + return loglikelihood(MvNormal(m, Diagonal(vec(s))), model.args.x) +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end +function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) + return [@varname(s[:, 1]), @varname(s[:, 2]), @varname(m)] +end + +const DemoModels = Union{ + Model{typeof(demo_dot_assume_dot_observe)}, + Model{typeof(demo_assume_index_observe)}, + Model{typeof(demo_assume_multivariate_observe)}, + Model{typeof(demo_dot_assume_observe_index)}, + Model{typeof(demo_assume_dot_observe)}, + Model{typeof(demo_assume_literal_dot_observe)}, + Model{typeof(demo_assume_observe_literal)}, + Model{typeof(demo_dot_assume_observe_index_literal)}, + Model{typeof(demo_assume_submodel_observe_index_literal)}, + Model{typeof(demo_dot_assume_observe_submodel)}, + Model{typeof(demo_dot_assume_dot_observe_matrix)}, + Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, +} + +# We require demo models to have explict impleentations of `rand` since we want +# these to be considered as ground truth. +function Random.rand(rng::Random.AbstractRNG, ::Type{NamedTuple}, model::DemoModels) + return error("demo models requires explicit implementation of rand") +end + +const UnivariateAssumeDemoModels = Union{ + Model{typeof(demo_assume_dot_observe)},Model{typeof(demo_assume_literal_dot_observe)} +} +function posterior_mean(model::UnivariateAssumeDemoModels) + return (s=49 / 24, m=7 / 6) +end +function Random.rand( + rng::Random.AbstractRNG, ::Type{NamedTuple}, model::UnivariateAssumeDemoModels +) + s = rand(rng, InverseGamma(2, 3)) + m = rand(rng, Normal(0, sqrt(s))) + + return (s=s, m=m) +end + +const MultivariateAssumeDemoModels = Union{ + Model{typeof(demo_dot_assume_dot_observe)}, + Model{typeof(demo_assume_index_observe)}, + Model{typeof(demo_assume_multivariate_observe)}, + Model{typeof(demo_dot_assume_observe_index)}, + Model{typeof(demo_assume_observe_literal)}, + Model{typeof(demo_dot_assume_observe_index_literal)}, + Model{typeof(demo_assume_submodel_observe_index_literal)}, + Model{typeof(demo_dot_assume_observe_submodel)}, + Model{typeof(demo_dot_assume_dot_observe_matrix)}, + Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, +} +function posterior_mean(model::MultivariateAssumeDemoModels) + # Get some containers to fill. + vals = Random.rand(model) + + vals.s[1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s[2] = 8 / 3 + vals.m[2] = 1 + + return vals +end +function Random.rand( + rng::Random.AbstractRNG, ::Type{NamedTuple}, model::MultivariateAssumeDemoModels +) + # Get template values from `model`. + retval = model(rng) + vals = (s=retval.s, m=retval.m) + # Fill containers with realizations from prior. + for i in LinearIndices(vals.s) + vals.s[i] = rand(rng, InverseGamma(2, 3)) + vals.m[i] = rand(rng, Normal(0, sqrt(vals.s[i]))) + end + + return vals +end + +""" +A collection of models corresponding to the posterior distribution defined by +the generative process + + s ~ InverseGamma(2, 3) + m ~ Normal(0, √s) + 1.5 ~ Normal(m, √s) + 2.0 ~ Normal(m, √s) + +or by + + s[1] ~ InverseGamma(2, 3) + s[2] ~ InverseGamma(2, 3) + m[1] ~ Normal(0, √s) + m[2] ~ Normal(0, √s) + 1.5 ~ Normal(m[1], √s[1]) + 2.0 ~ Normal(m[2], √s[2]) + +These are examples of a Normal-InverseGamma conjugate prior with Normal likelihood, +for which the posterior is known in closed form. + +In particular, for the univariate model (the former one): + + mean(s) == 49 / 24 + mean(m) == 7 / 6 + +And for the multivariate one (the latter one): + + mean(s[1]) == 19 / 8 + mean(m[1]) == 3 / 4 + mean(s[2]) == 8 / 3 + mean(m[2]) == 1 + +""" const DEMO_MODELS = ( demo_dot_assume_dot_observe(), demo_assume_index_observe(), @@ -257,64 +641,79 @@ const DEMO_MODELS = ( demo_assume_submodel_observe_index_literal(), demo_dot_assume_observe_submodel(), demo_dot_assume_dot_observe_matrix(), + demo_dot_assume_matrix_dot_observe_matrix(), ) -# TODO: Is this really the best/most convenient "default" test method? """ - test_sampler_demo_models(meanfunction, sampler, args...; kwargs...) + marginal_mean_of_samples(chain, varname) -Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`. +Return the mean of variable represented by `varname` in `chain`. +""" +marginal_mean_of_samples(chain, varname) = mean(Array(chain[Symbol(varname)])) + +""" + test_sampler(models, sampler, args...; kwargs...) + +Test that `sampler` produces correct marginal posterior means on each model in `models`. -In short, this method iterators through `demo_models`, calls `AbstractMCMC.sample` on the -`model` and `sampler` to produce a `chain`, and then checks `meanfunction(chain)` against `target` -provided in `kwargs...`. +In short, this method iterates through `models`, calls `AbstractMCMC.sample` on the +`model` and `sampler` to produce a `chain`, and then checks `marginal_mean_of_samples(chain, vn)` +for every (leaf) varname `vn` against the corresponding value returned by +[`posterior_mean`](@ref) for each model. + +To change how comparison is done for a particular `chain` type, one can overload +[`marginal_mean_of_samples`](@ref) for the corresponding type. # Arguments -- `meanfunction`: A callable which computes the mean of the marginal means from the - chain resulting from the `sample` call. +- `models`: A collection of instaces of [`DynamicPPL.Model`](@ref) to test on. - `sampler`: The `AbstractMCMC.AbstractSampler` to test. - `args...`: Arguments forwarded to `sample`. # Keyword arguments -- `target`: Value to compare result of `meanfunction(chain)` to. - `atol=1e-1`: Absolute tolerance used in `@test`. - `rtol=1e-3`: Relative tolerance used in `@test`. - `kwargs...`: Keyword arguments forwarded to `sample`. """ -function test_sampler_demo_models( - meanfunction, - sampler::AbstractMCMC.AbstractSampler, - args...; - target=8.0, - atol=1e-1, - rtol=1e-3, - kwargs..., +function test_sampler( + models, sampler::AbstractMCMC.AbstractSampler, args...; atol=1e-1, rtol=1e-3, kwargs... ) - @testset "$(nameof(typeof(sampler))) on $(nameof(m))" for model in DEMO_MODELS + @testset "$(typeof(sampler)) on $(nameof(model))" for model in models chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) - μ = meanfunction(chain) - @test μ ≈ target atol = atol rtol = rtol + target_values = posterior_mean(model) + for vn in varnames(model) + # We want to compare elementwise which can be achieved by + # extracting the leaves of the `VarName` and the corresponding value. + for vn_leaf in varname_leaves(vn, get(target_values, vn)) + target_value = get(target_values, vn_leaf) + chain_mean_value = marginal_mean_of_samples(chain, vn_leaf) + @test chain_mean_value ≈ target_value atol = atol rtol = rtol + end + end end end """ - test_sampler_continuous([meanfunction, ]sampler, args...; kwargs...) + test_sampler_on_demo_models(meanfunction, sampler, args...; kwargs...) -Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`. +Test `sampler` on every model in [`DEMO_MODELS`](@ref). -As of right now, this is just an alias for [`test_sampler_demo_models`](@ref). +This is just a proxy for `test_sampler(meanfunction, DEMO_MODELS, sampler, args...; kwargs...)`. """ -function test_sampler_continuous( - meanfunction, sampler::AbstractMCMC.AbstractSampler, args...; kwargs... +function test_sampler_on_demo_models( + sampler::AbstractMCMC.AbstractSampler, args...; kwargs... ) - return test_sampler_demo_models(meanfunction, sampler, args...; kwargs...) + return test_sampler(DEMO_MODELS, sampler, args...; kwargs...) end +""" + test_sampler_continuous(sampler, args...; kwargs...) + +Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`. + +As of right now, this is just an alias for [`test_sampler_on_demo_models`](@ref). +""" function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...) - # Default for `MCMCChains.Chains`. - return test_sampler_continuous(sampler, args...; kwargs...) do chain - mean(Array(chain)) - end + return test_sampler_on_demo_models(sampler, args...; kwargs...) end end diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 6f020a352..7c7dd13ac 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -62,8 +62,24 @@ islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl getindex(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex(vi.varinfo, spl) getindex(vi::ThreadSafeVarInfo, spl::SampleFromPrior) = getindex(vi.varinfo, spl) getindex(vi::ThreadSafeVarInfo, spl::SampleFromUniform) = getindex(vi.varinfo, spl) + getindex(vi::ThreadSafeVarInfo, vn::VarName) = getindex(vi.varinfo, vn) +function getindex(vi::ThreadSafeVarInfo, vn::VarName, dist::Distribution) + return getindex(vi.varinfo, vn, dist) +end getindex(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) = getindex(vi.varinfo, vns) +function getindex(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}, dist::Distribution) + return getindex(vi.varinfo, vns, dist) +end + +getindex_raw(vi::ThreadSafeVarInfo, vn::VarName) = getindex_raw(vi.varinfo, vn) +function getindex_raw(vi::ThreadSafeVarInfo, vn::VarName, dist::Distribution) + return getindex_raw(vi.varinfo, vn, dist) +end +getindex_raw(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) = getindex_raw(vi.varinfo, vns) +function getindex_raw(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}, dist::Distribution) + return getindex_raw(vi.varinfo, vns, dist) +end function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::AbstractSampler) return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) diff --git a/src/utils.jl b/src/utils.jl index 821eba38e..ac9222818 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -183,6 +183,7 @@ vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r)) # otherwise we will have error for MatrixDistribution. # Note this is not the case for MultivariateDistribution so I guess this might be lack of # support for some types related to matrices (like PDMat). +reconstruct(d::UnivariateDistribution, val::Real) = val reconstruct(d::Distribution, val::AbstractVector) = reconstruct(size(d), val) reconstruct(::Tuple{}, val::AbstractVector) = val[1] reconstruct(s::NTuple{1}, val::AbstractVector) = copy(val) @@ -417,3 +418,11 @@ function splitlens(condition, lens) return current_parent, current_child, condition(current_parent) end + +# HACK(torfjelde): Avoids type-instability in `dot_assume` for `SimpleVarInfo`. +function BangBang.possible( + ::typeof(BangBang._setindex!), ::C, ::T, ::Colon, ::Integer +) where {C<:AbstractMatrix,T<:AbstractVector} + return BangBang.implements(setindex!, C) && + promote_type(eltype(C), eltype(T)) <: eltype(C) +end diff --git a/src/varinfo.jl b/src/varinfo.jl index 9ce0414d6..22728ba9a 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -358,12 +358,31 @@ Return the set of sampler selectors associated with `vn` in `vi`. getgid(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).gids[getidx(vi, vn)] """ - settrans!(vi::VarInfo, trans::Bool, vn::VarName) + settrans!!(vi::VarInfo, trans::Bool, vn::VarName) Set the `trans` flag value of `vn` in `vi`. """ -function settrans!(vi::AbstractVarInfo, trans::Bool, vn::VarName) - return trans ? set_flag!(vi, vn, "trans") : unset_flag!(vi, vn, "trans") +function settrans!!(vi::AbstractVarInfo, trans::Bool, vn::VarName) + if trans + set_flag!(vi, vn, "trans") + else + unset_flag!(vi, vn, "trans") + end + + return vi +end + +""" + settrans!!(vi::AbstractVarInfo, trans) + +Return new instance of `vi` but with `istrans(vi, trans)` now evaluating to `true`. +""" +function settrans!!(vi::VarInfo, trans::Bool) + for vn in keys(vi) + settrans!!(vi, trans, vn) + end + + return vi end """ @@ -638,6 +657,14 @@ function setgid!(vi::VarInfo, gid::Selector, vn::VarName) return push!(getmetadata(vi, vn).gids[getidx(vi, vn)], gid) end +""" + istrans(vi::AbstractVarInfo) + +Return `true` if `vi` is working in unconstrained space, and `false` +if `vi` is assuming realizations to be in support of the corresponding distributions. +""" +istrans(vi::AbstractVarInfo) = false # `VarInfo` works in constrained space by default. + """ istrans(vi::VarInfo, vn::VarName) @@ -645,6 +672,9 @@ Return true if `vn`'s values in `vi` are transformed to Euclidean space, and fal they are in the support of `vn`'s distribution. """ istrans(vi::AbstractVarInfo, vn::VarName) = is_flagged(vi, vn, "trans") +function istrans(vi::AbstractVarInfo, vns::AbstractVector{<:VarName}) + return all(Base.Fix1(istrans, vi), vns) +end """ getlogp(vi::VarInfo) @@ -749,7 +779,7 @@ function link!(vi::UntypedVarInfo, spl::Sampler) vectorize(dist, Bijectors.link(dist, reconstruct(dist, getval(vi, vn)))), vn, ) - settrans!(vi, true, vn) + settrans!!(vi, true, vn) end else @warn("[DynamicPPL] attempt to link a linked vi") @@ -785,7 +815,7 @@ end ), vn, ) - settrans!(vi, true, vn) + settrans!!(vi, true, vn) end else @warn("[DynamicPPL] attempt to link a linked vi") @@ -816,7 +846,7 @@ function invlink!(vi::UntypedVarInfo, spl::AbstractSampler) vectorize(dist, Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn)))), vn, ) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end else @warn("[DynamicPPL] attempt to invlink an invlinked vi") @@ -854,7 +884,7 @@ end ), vn, ) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end else @warn("[DynamicPPL] attempt to invlink an invlinked vi") @@ -866,6 +896,9 @@ end return expr end +maybe_link(vi, vn, dist, val) = istrans(vi, vn) ? Bijectors.link(dist, val) : val +maybe_invlink(vi, vn, dist, val) = istrans(vi, vn) ? Bijectors.invlink(dist, val) : val + """ islinked(vi::VarInfo, spl::Union{Sampler, SampleFromPrior}) @@ -904,23 +937,37 @@ distribution(s). If the value(s) is (are) transformed to the Euclidean space, it is (they are) transformed back. """ -function getindex(vi::AbstractVarInfo, vn::VarName) +getindex(vi::AbstractVarInfo, vn::VarName) = getindex(vi, vn, getdist(vi, vn)) +function getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution) @assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" - dist = getdist(vi, vn) - return if istrans(vi, vn) - Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn))) - else - reconstruct(dist, getval(vi, vn)) - end + val = getindex_raw(vi, vn, dist) + return maybe_invlink(vi, vn, dist, val) end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) + # FIXME(torfjelde): Using `getdist(vi, first(vns))` won't be correct in cases + # such as `x .~ [Normal(), Exponential()]`. + # BUT we also can't fix this here because this will lead to "incorrect" + # behavior if `vns` arose from something like `x .~ MvNormal(zeros(2), I)`, + # where by "incorrect" we mean there exists pieces of code expecting this behavior. + return getindex(vi, vns, getdist(vi, first(vns))) +end +function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" - dist = getdist(vi, vns[1]) - return if istrans(vi, vns[1]) - Bijectors.invlink(dist, reconstruct(dist, getval(vi, vns), length(vns))) - else - reconstruct(dist, getval(vi, vns), length(vns)) + vals_linked = mapreduce(vcat, vns) do vn + getindex(vi, vn, dist) end + return reconstruct(dist, vals_linked, length(vns)) +end + +getindex_raw(vi::AbstractVarInfo, vn::VarName) = getindex_raw(vi, vn, getdist(vi, vn)) +function getindex_raw(vi::AbstractVarInfo, vn::VarName, dist::Distribution) + return reconstruct(dist, getval(vi, vn)) +end +function getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}) + return getindex_raw(vi, vns, getdist(vi, first(vns))) +end +function getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) + return reconstruct(dist, getval(vi, vns), length(vns)) end """ @@ -1411,7 +1458,7 @@ function _setval_kernel!(vi::VarInfo, vn::VarName, values, keys) if !isempty(indices) val = reduce(vcat, values[indices]) setval!(vi, val, vn) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return indices @@ -1492,7 +1539,7 @@ function _setval_and_resample_kernel!(vi::VarInfo, vn::VarName, values, keys) if !isempty(indices) val = reduce(vcat, values[indices]) setval!(vi, val, vn) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) else # Ensures that we'll resample the variable corresponding to `vn` if we run # the model on `vi` again. diff --git a/test/Project.toml b/test/Project.toml index 6a3726474..edba29e5a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -23,7 +23,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" AbstractMCMC = "2.1, 3.0, 4" AbstractPPL = "0.5.1" Bijectors = "0.9.5, 0.10" -Distributions = "0.25" +Distributions = "<0.25.65" DistributionsAD = "0.6.3" Documenter = "0.26.1, 0.27" ForwardDiff = "0.10.12" diff --git a/test/contexts.jl b/test/contexts.jl index 65629afec..edcf5d0f3 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -57,26 +57,6 @@ function remove_prefix(vn::VarName) ) end -""" - varnames(vn::VarName, val) - -Return iterator over all varnames that are represented by `vn` on `val`, -e.g. `varnames(@varname(x), rand(2))` results in an iterator over `[@varname(x[1]), @varname(x[2])]`. -""" -varnames(vn::VarName, val::Real) = [vn] -function varnames(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) - return ( - VarName(vn, getlens(vn) ∘ Setfield.IndexLens(Tuple(I))) for - I in CartesianIndices(val) - ) -end -function varnames(vn::VarName, val::AbstractArray) - return Iterators.flatten( - varnames(VarName(vn, getlens(vn) ∘ Setfield.IndexLens(Tuple(I))), val[I]) for - I in CartesianIndices(val) - ) -end - @testset "contexts.jl" begin child_contexts = [DefaultContext(), PriorContext(), LikelihoodContext()] @@ -185,7 +165,8 @@ end vn_without_prefix = remove_prefix(vn) # Let's check elementwise. - for vn_child in varnames(vn_without_prefix, val) + for vn_child in + DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) if get(val, getlens(vn_child)) === missing @test contextual_isassumption(context, vn_child) else @@ -217,7 +198,8 @@ end # `ConditionContext` with the conditioned variable. vn_without_prefix = remove_prefix(vn) - for vn_child in varnames(vn_without_prefix, val) + for vn_child in + DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) # `vn_child` should be in `context`. @test hasvalue_nested(context, vn_child) # Value should be the same as extracted above. diff --git a/test/distribution_wrappers.jl b/test/distribution_wrappers.jl index 350ce6014..8bb692783 100644 --- a/test/distribution_wrappers.jl +++ b/test/distribution_wrappers.jl @@ -9,5 +9,5 @@ @test minimum(nd) == -Inf @test maximum(nd) == Inf @test logpdf(nd, 15.0) == 0 - @test Bijectors.logpdf_with_trans(nd, 0) == 0 + @test Bijectors.logpdf_with_trans(nd, 30, true) == 0 end diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index 4d5003f03..b390997af 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -1,15 +1,14 @@ @testset "loglikelihoods.jl" begin - for m in DynamicPPL.TestUtils.DEMO_MODELS - vi = VarInfo(m) + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + example_values = rand(NamedTuple, m) - vns = vi.metadata.m.vns - if length(vns) == 1 && length(vi[vns[1]]) == 1 - # Only have one latent variable. - DynamicPPL.setval!(vi, [1.0], ["m"]) - else - DynamicPPL.setval!(vi, [1.0, 1.0], ["m[1]", "m[2]"]) + # Instantiate a `VarInfo` with the example values. + vi = VarInfo(m) + for vn in DynamicPPL.TestUtils.varnames(m) + vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) end + # Compute the pointwise loglikelihoods. lls = pointwise_loglikelihoods(m, vi) if isempty(lls) @@ -17,14 +16,9 @@ continue end - loglikelihood = if length(keys(lls)) == 1 && length(m.args.x) == 1 - # Only have one observation, so we need to double it - # for comparison with other models. - 2 * sum(lls[first(keys(lls))]) - else - sum(sum, values(lls)) - end + loglikelihood = sum(sum, values(lls)) + loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(m, example_values...) - @test loglikelihood ≈ -324.45158270528947 + @test loglikelihood ≈ loglikelihood_true end end diff --git a/test/model.jl b/test/model.jl index 78ef27a6e..8eabeeba9 100644 --- a/test/model.jl +++ b/test/model.jl @@ -112,6 +112,22 @@ end @test !any(map(x -> x isa DynamicPPL.AbstractVarInfo, call_retval)) end + @testset "Dynamic constraints" begin + model = DynamicPPL.TestUtils.demo_dynamic_constraint() + vi = VarInfo(model) + spl = SampleFromPrior() + link!(vi, spl) + + for i in 1:10 + # Sample with large variations. + r_raw = randn(length(vi[spl])) * 10 + vi[spl] = r_raw + @test vi[@varname(m)] == r_raw[1] + @test vi[@varname(x)] != r_raw[2] + model(vi) + end + end + @testset "rand" begin model = gdemo_default diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 7e9346450..6a8c545ca 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -62,68 +62,112 @@ DynamicPPL.TestUtils.DEMO_MODELS # We might need to pre-allocate for the variable `m`, so we need # to see whether this is the case. - m = model().m - svi_nt = if m isa AbstractArray - SimpleVarInfo((m=similar(m),)) - else - SimpleVarInfo() - end + svi_nt = SimpleVarInfo(rand(NamedTuple, model)) svi_dict = SimpleVarInfo(VarInfo(model), Dict) - @testset "$(nameof(typeof(svi.values)))" for svi in (svi_nt, svi_dict) + @testset "$(nameof(typeof(DynamicPPL.values_as(svi))))" for svi in ( + svi_nt, + svi_dict, + DynamicPPL.settrans!!(svi_nt, true), + DynamicPPL.settrans!!(svi_dict, true), + ) # Random seed is set in each `@testset`, so we need to sample # a new realization for `m` here. - m = model().m + retval = model() ### Sampling ### # Sample a new varinfo! _, svi_new = DynamicPPL.evaluate!!(model, svi, SamplingContext()) - # If the `m[1]` varname doesn't exist, this is a univariate model. - # TODO: Find a better way of dealing with this that is not dependent - # on knowledge of internals of `model`. - isunivariate = !haskey(svi_new, @varname(m[1])) - # Realization for `m` should be different wp. 1. - if isunivariate - @test svi_new[@varname(m)] != m - else - @test svi_new[@varname(m[1])] != m[1] - @test svi_new[@varname(m[2])] != m[2] + for vn in DynamicPPL.TestUtils.varnames(model) + @test svi_new[vn] != get(retval, vn) end + # Logjoint should be non-zero wp. 1. @test getlogp(svi_new) != 0 ### Evaluation ### - # Sample some random testing values. - m_eval = if m isa AbstractArray - randn!(similar(m)) + values_eval_constrained = rand(NamedTuple, model) + if DynamicPPL.istrans(svi) + _values_prior, logpri_true = DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian( + model, values_eval_constrained... + ) + values_eval, logπ_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + model, values_eval_constrained... + ) + # Make sure that these two computation paths provide the same + # transformed values. + @test values_eval == _values_prior else - randn(eltype(m)) + logpri_true = DynamicPPL.TestUtils.logprior_true( + model, values_eval_constrained... + ) + logπ_true = DynamicPPL.TestUtils.logjoint_true( + model, values_eval_constrained... + ) + values_eval = values_eval_constrained end + # No logabsdet-jacobian correction needed for the likelihood. + loglik_true = DynamicPPL.TestUtils.loglikelihood_true( + model, values_eval_constrained... + ) + # Update the realizations in `svi_new`. - svi_eval = if isunivariate - DynamicPPL.setindex!!(svi_new, m_eval, @varname(m)) - else - DynamicPPL.setindex!!(svi_new, m_eval, [@varname(m[1]), @varname(m[2])]) + svi_eval = svi_new + for vn in DynamicPPL.TestUtils.varnames(model) + svi_eval = DynamicPPL.setindex!!(svi_eval, get(values_eval, vn), vn) end + # Reset the logp field. svi_eval = DynamicPPL.resetlogp!!(svi_eval) # Compute `logjoint` using the varinfo. logπ = logjoint(model, svi_eval) - # Extract the parameters from `svi_eval`. - m_vi = if isunivariate - svi_eval[@varname(m)] - else - svi_eval[[@varname(m[1]), @varname(m[2])]] + logpri = logprior(model, svi_eval) + loglik = loglikelihood(model, svi_eval) + + # Values should not have changed. + for vn in DynamicPPL.TestUtils.varnames(model) + @test svi_eval[vn] == get(values_eval, vn) end - # These should not have changed. - @test m_vi == m_eval - # Compute the true `logjoint` and compare. - logπ_true = DynamicPPL.TestUtils.logjoint_true(model, m_vi) + + # Compare log-probability computations. + @test logpri ≈ logpri_true + @test loglik ≈ loglik_true @test logπ ≈ logπ_true end end + + @testset "Dynamic constraints" begin + model = DynamicPPL.TestUtils.demo_dynamic_constraint() + + # Initialize. + svi = DynamicPPL.settrans!!(SimpleVarInfo(), true) + svi = last(DynamicPPL.evaluate!!(model, svi, SamplingContext())) + + # Sample with large variations in unconstrained space. + for i in 1:10 + for vn in keys(svi) + svi = DynamicPPL.setindex!!(svi, 10 * randn(), vn) + end + retval, svi = DynamicPPL.evaluate!!(model, svi, DefaultContext()) + @test retval.m == svi[@varname(m)] # `m` is unconstrained + @test retval.x ≠ svi[@varname(x)] # `x` is constrained depending on `m` + + retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + model, retval.m, retval.x + ) + + # Realizations from model should all be equal to the unconstrained realization. + for vn in DynamicPPL.TestUtils.varnames(model) + @test get(retval_unconstrained, vn) ≈ svi[vn] rtol = 1e-6 + end + + # `getlogp` should be equal to the logjoint with log-absdet-jac correction. + lp = getlogp(svi) + @test lp ≈ lp_true + end + end end diff --git a/test/turing/Project.toml b/test/turing/Project.toml index 41a873172..e8713dae2 100644 --- a/test/turing/Project.toml +++ b/test/turing/Project.toml @@ -1,10 +1,12 @@ [deps] +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] +Distributions = "<0.25.65" DynamicPPL = "0.19" Turing = "0.21" julia = "1.3" diff --git a/test/varinfo.jl b/test/varinfo.jl index 2f7816024..32c90bf47 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -271,4 +271,47 @@ DynamicPPL.setval_and_resample!(vi, vi.metadata.x.vals, ks) @test vals_prev == vi.metadata.x.vals end + + @testset "istrans" begin + @model demo_constrained() = x ~ truncated(Normal(), 0, Inf) + model = demo_constrained() + vn = @varname(x) + dist = truncated(Normal(), 0, Inf) + + ### `VarInfo` + # Need to run once since we can't specify that we want to _sample_ + # in the unconstrained space for `VarInfo` without having `vn` + # present in the `varinfo`. + ## `UntypedVarInfo` + vi = VarInfo() + vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + vi = DynamicPPL.settrans!!(vi, true, vn) + # Sample in unconstrained space. + vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + x = Bijectors.invlink(dist, DynamicPPL.getindex_raw(vi, vn)) + @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + + ## `TypedVarInfo` + vi = VarInfo(model) + vi = DynamicPPL.settrans!!(vi, true, vn) + # Sample in unconstrained space. + vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + x = Bijectors.invlink(dist, DynamicPPL.getindex_raw(vi, vn)) + @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + + ### `SimpleVarInfo` + ## `SimpleVarInfo{<:NamedTuple}` + vi = DynamicPPL.settrans!!(SimpleVarInfo(), true) + # Sample in unconstrained space. + vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + x = Bijectors.invlink(dist, DynamicPPL.getindex_raw(vi, vn)) + @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + + ## `SimpleVarInfo{<:Dict}` + vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true) + # Sample in unconstrained space. + vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + x = Bijectors.invlink(dist, DynamicPPL.getindex_raw(vi, vn)) + @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + end end