diff --git a/docs/src/api.md b/docs/src/api.md index 97c48316e..38d9ee6b0 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -130,6 +130,17 @@ For a chain of samples, one can compute the pointwise log-likelihoods of each ob pointwise_loglikelihoods ``` +Similarly, one can compute the pointwise log-priors of each sampled random variable +with [`varwise_logpriors`](@ref). +Differently from `pointwise_loglikelihoods` it reports only a +single value for `.~` assignements. +If one needs to access the parts for single indices, one can +reformulate the model to use an explicit loop instead. + +```@docs +varwise_logpriors +``` + For converting a chain into a format that can more easily be fed into a `Model` again, for example using `condition`, you can use [`value_iterator_from_chain`](@ref). ```@docs diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index eb027b45b..14b66ee36 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -116,6 +116,7 @@ export AbstractVarInfo, logprior, logjoint, pointwise_loglikelihoods, + pointwise_logdensities, condition, decondition, fix, @@ -181,7 +182,7 @@ include("varinfo.jl") include("simple_varinfo.jl") include("context_implementations.jl") include("compiler.jl") -include("loglikelihoods.jl") +include("pointwise_logdensities.jl") include("submodel_macro.jl") include("test_utils.jl") include("transforming.jl") diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl deleted file mode 100644 index 227a70889..000000000 --- a/src/loglikelihoods.jl +++ /dev/null @@ -1,257 +0,0 @@ -# Context version -struct PointwiseLikelihoodContext{A,Ctx} <: AbstractContext - loglikelihoods::A - context::Ctx -end - -function PointwiseLikelihoodContext( - likelihoods=OrderedDict{VarName,Vector{Float64}}(), - context::AbstractContext=LikelihoodContext(), -) - return PointwiseLikelihoodContext{typeof(likelihoods),typeof(context)}( - likelihoods, context - ) -end - -NodeTrait(::PointwiseLikelihoodContext) = IsParent() -childcontext(context::PointwiseLikelihoodContext) = context.context -function setchildcontext(context::PointwiseLikelihoodContext, child) - return PointwiseLikelihoodContext(context.loglikelihoods, child) -end - -function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{VarName,Vector{Float64}}}, - vn::VarName, - logp::Real, -) - lookup = context.loglikelihoods - ℓ = get!(lookup, vn, Float64[]) - return push!(ℓ, logp) -end - -function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{VarName,Float64}}, - vn::VarName, - logp::Real, -) - return context.loglikelihoods[vn] = logp -end - -function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{String,Vector{Float64}}}, - vn::VarName, - logp::Real, -) - lookup = context.loglikelihoods - ℓ = get!(lookup, string(vn), Float64[]) - return push!(ℓ, logp) -end - -function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{String,Float64}}, - vn::VarName, - logp::Real, -) - return context.loglikelihoods[string(vn)] = logp -end - -function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{String,Vector{Float64}}}, - vn::String, - logp::Real, -) - lookup = context.loglikelihoods - ℓ = get!(lookup, vn, Float64[]) - return push!(ℓ, logp) -end - -function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{String,Float64}}, - vn::String, - logp::Real, -) - return context.loglikelihoods[vn] = logp -end - -function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vi) - # Defer literal `observe` to child-context. - return tilde_observe!!(context.context, right, left, vi) -end -function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, vi) - # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. - # we have to intercept the call to `tilde_observe!`. - logp, vi = tilde_observe(context.context, right, left, vi) - - # Track loglikelihood value. - push!(context, vn, logp) - - return left, acclogp!!(vi, logp) -end - -function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vi) - # Defer literal `observe` to child-context. - return dot_tilde_observe!!(context.context, right, left, vi) -end -function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, vi) - # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. - # we have to intercept the call to `dot_tilde_observe!`. - - # We want to treat `.~` as a collection of independent observations, - # hence we need the `logp` for each of them. Broadcasting the univariate - # `tilde_obseve` does exactly this. - logps = _pointwise_tilde_observe(context.context, right, left, vi) - - # Need to unwrap the `vn`, i.e. get one `VarName` for each entry in `left`. - _, _, vns = unwrap_right_left_vns(right, left, vn) - for (vn, logp) in zip(vns, logps) - # Track loglikelihood value. - push!(context, vn, logp) - end - - return left, acclogp!!(vi, sum(logps)) -end - -# FIXME: This is really not a good approach since it needs to stay in sync with -# the `dot_assume` implementations, but as things are _right now_ this is the best we can do. -function _pointwise_tilde_observe(context, right, left, vi) - # We need to drop the `vi` returned. - return broadcast(right, left) do r, l - return first(tilde_observe(context, r, l, vi)) - end -end - -function _pointwise_tilde_observe( - context, right::MultivariateDistribution, left::AbstractMatrix, vi::AbstractVarInfo -) - # We need to drop the `vi` returned. - return map(eachcol(left)) do l - return first(tilde_observe(context, right, l, vi)) - end -end - -""" - pointwise_loglikelihoods(model::Model, chain::Chains, keytype = String) - -Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` -with keys corresponding to symbols of the observations, and values being matrices -of shape `(num_chains, num_samples)`. - -`keytype` specifies what the type of the keys used in the returned `OrderedDict` are. -Currently, only `String` and `VarName` are supported. - -# Notes -Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ` -both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an -*observation*) statements can be implemented in three ways: -1. using a `for` loop: -```julia -for i in eachindex(y) - y[i] ~ Normal(μ, σ) -end -``` -2. using `.~`: -```julia -y .~ Normal(μ, σ) -``` -3. using `MvNormal`: -```julia -y ~ MvNormal(fill(μ, n), σ^2 * I) -``` - -In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables, -while in (3) `y` will be treated as a _single_ n-dimensional observation. - -This is important to keep in mind, in particular if the computation is used -for downstream computations. - -# Examples -## From chain -```julia-repl -julia> using DynamicPPL, Turing - -julia> @model function demo(xs, y) - s ~ InverseGamma(2, 3) - m ~ Normal(0, √s) - for i in eachindex(xs) - xs[i] ~ Normal(m, √s) - end - - y ~ Normal(m, √s) - end -demo (generic function with 1 method) - -julia> model = demo(randn(3), randn()); - -julia> chain = sample(model, MH(), 10); - -julia> pointwise_loglikelihoods(model, chain) -OrderedDict{String,Array{Float64,2}} with 4 entries: - "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] - "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] - "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] - "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] - -julia> pointwise_loglikelihoods(model, chain, String) -OrderedDict{String,Array{Float64,2}} with 4 entries: - "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] - "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] - "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] - "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] - -julia> pointwise_loglikelihoods(model, chain, VarName) -OrderedDict{VarName,Array{Float64,2}} with 4 entries: - xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333] - xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359] - xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251] - y => [-1.51265; -0.914129; … ; -1.5499; -1.5499] -``` - -## Broadcasting -Note that `x .~ Dist()` will treat `x` as a collection of -_independent_ observations rather than as a single observation. - -```jldoctest; setup = :(using Distributions) -julia> @model function demo(x) - x .~ Normal() - end; - -julia> m = demo([1.0, ]); - -julia> ℓ = pointwise_loglikelihoods(m, VarInfo(m)); first(ℓ[@varname(x[1])]) --1.4189385332046727 - -julia> m = demo([1.0; 1.0]); - -julia> ℓ = pointwise_loglikelihoods(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) -(-1.4189385332046727, -1.4189385332046727) -``` - -""" -function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T} - # Get the data by executing the model once - vi = VarInfo(model) - context = PointwiseLikelihoodContext(OrderedDict{T,Vector{Float64}}()) - - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - for (sample_idx, chain_idx) in iters - # Update the values - setval!(vi, chain, sample_idx, chain_idx) - - # Execute model - model(vi, context) - end - - niters = size(chain, 1) - nchains = size(chain, 3) - loglikelihoods = OrderedDict( - varname => reshape(logliks, niters, nchains) for - (varname, logliks) in context.loglikelihoods - ) - return loglikelihoods -end - -function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) - context = PointwiseLikelihoodContext(OrderedDict{VarName,Vector{Float64}}()) - model(varinfo, context) - return context.loglikelihoods -end diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl new file mode 100644 index 000000000..0b2435c16 --- /dev/null +++ b/src/pointwise_logdensities.jl @@ -0,0 +1,402 @@ +# Context version +struct PointwiseLogdensityContext{A,Ctx} <: AbstractContext + logdensities::A + context::Ctx +end + +function PointwiseLogdensityContext( + likelihoods=OrderedDict{VarName,Vector{Float64}}(), + context::AbstractContext=DefaultContext(), +) + return PointwiseLogdensityContext{typeof(likelihoods),typeof(context)}( + likelihoods, context + ) +end + +NodeTrait(::PointwiseLogdensityContext) = IsParent() +childcontext(context::PointwiseLogdensityContext) = context.context +function setchildcontext(context::PointwiseLogdensityContext, child) + return PointwiseLogdensityContext(context.logdensities, child) +end + +function _include_prior(context::PointwiseLogdensityContext) + return leafcontext(context) isa Union{PriorContext,DefaultContext} +end +function _include_likelihood(context::PointwiseLogdensityContext) + return leafcontext(context) isa Union{LikelihoodContext,DefaultContext} +end + +function Base.push!( + context::PointwiseLogdensityContext{<:AbstractDict{VarName,Vector{Float64}}}, + vn::VarName, + logp::Real, +) + lookup = context.logdensities + ℓ = get!(lookup, vn, Float64[]) + return push!(ℓ, logp) +end + +function Base.push!( + context::PointwiseLogdensityContext{<:AbstractDict{VarName,Float64}}, + vn::VarName, + logp::Real, +) + return context.logdensities[vn] = logp +end + +function Base.push!( + context::PointwiseLogdensityContext{<:AbstractDict{String,Vector{Float64}}}, + vn::VarName, + logp::Real, +) + lookup = context.logdensities + ℓ = get!(lookup, string(vn), Float64[]) + return push!(ℓ, logp) +end + +function Base.push!( + context::PointwiseLogdensityContext{<:AbstractDict{String,Float64}}, + vn::VarName, + logp::Real, +) + return context.logdensities[string(vn)] = logp +end + +function Base.push!( + context::PointwiseLogdensityContext{<:AbstractDict{String,Vector{Float64}}}, + vn::String, + logp::Real, +) + lookup = context.logdensities + ℓ = get!(lookup, vn, Float64[]) + return push!(ℓ, logp) +end + +function Base.push!( + context::PointwiseLogdensityContext{<:AbstractDict{String,Float64}}, + vn::String, + logp::Real, +) + return context.logdensities[vn] = logp +end + +function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi) + # Defer literal `observe` to child-context. + return tilde_observe!!(context.context, right, left, vi) +end +function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi) + # Completely defer to child context if we are not tracking likelihoods. + if !(_include_likelihood(context)) + return tilde_observe!!(context.context, right, left, vn, vi) + end + + # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. + # we have to intercept the call to `tilde_observe!`. + logp, vi = tilde_observe(context.context, right, left, vi) + + # Track loglikelihood value. + push!(context, vn, logp) + + return left, acclogp!!(vi, logp) +end + +function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi) + # Defer literal `observe` to child-context. + return dot_tilde_observe!!(context.context, right, left, vi) +end +function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi) + # Completely defer to child context if we are not tracking likelihoods. + if !(_include_likelihood(context)) + return dot_tilde_observe!!(context.context, right, left, vn, vi) + end + + # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. + # we have to intercept the call to `dot_tilde_observe!`. + + # We want to treat `.~` as a collection of independent observations, + # hence we need the `logp` for each of them. Broadcasting the univariate + # `tilde_obseve` does exactly this. + logps = _pointwise_tilde_observe(context.context, right, left, vi) + + # Need to unwrap the `vn`, i.e. get one `VarName` for each entry in `left`. + _, _, vns = unwrap_right_left_vns(right, left, vn) + for (vn, logp) in zip(vns, logps) + # Track loglikelihood value. + push!(context, vn, logp) + end + + return left, acclogp!!(vi, sum(logps)) +end + +# FIXME: This is really not a good approach since it needs to stay in sync with +# the `dot_assume` implementations, but as things are _right now_ this is the best we can do. +function _pointwise_tilde_observe(context, right, left, vi) + # We need to drop the `vi` returned. + return broadcast(right, left) do r, l + return first(tilde_observe(context, r, l, vi)) + end +end + +function _pointwise_tilde_observe( + context, right::MultivariateDistribution, left::AbstractMatrix, vi::AbstractVarInfo +) + # We need to drop the `vi` returned. + return map(eachcol(left)) do l + return first(tilde_observe(context, right, l, vi)) + end +end + +function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi) + # Completely defer to child context if we are not tracking prior densities. + _include_prior(context) || return tilde_assume!!(context.context, right, vn, vi) + + # Otherwise, capture the return values. + value, logp, vi = tilde_assume(context.context, right, vn, vi) + # Track loglikelihood value. + push!(context, vn, logp) + + return value, acclogp!!(vi, logp) +end + +function dot_tilde_assume!!(context::PointwiseLogdensityContext, right, left, vns, vi) + # Completely defer to child context if we are not tracking prior densities. + if !(_include_prior(context)) + return dot_tilde_assume!!(context.context, right, left, vns, vi) + end + + value, logps = _pointwise_tilde_assume(context, right, left, vns, vi) + # Track loglikelihood values. + for (vn, logp) in zip(vns, logps) + push!(context, vn, logp) + end + return value, acclogp!!(vi, sum(logps)) +end + +function _pointwise_tilde_assume(context, right, left, vns, vi) + # We need to drop the `vi` returned. + values_and_logps = broadcast(right, left, vns) do r, l, vn + # HACK(torfjelde): This drops the `vi` returned, which means the `vi` is not updated + # in case of immutable varinfos. But a) atm we're only using mutable varinfos for this, + # and b) even if the variables aren't stored in the vi correctly, we're not going to use + # this vi for anything downstream anyways, i.e. I don't see a case where this would matter + # for this particular use case. + val, logp, _ = tilde_assume(context, r, vn, vi) + return val, logp + end + return map(first, values_and_logps), map(last, values_and_logps) +end + +function _pointwise_tilde_assume( + context, right::MultivariateDistribution, left::AbstractMatrix, vns, vi +) + # We need to drop the `vi` returned. + values_and_logps = map(eachcol(left), vns) do l, vn + val, logp, _ = tilde_assume(context, right, vn, vi) + return val, logp + end + # HACK(torfjelde): Due to the way we handle `.~`, we should use `recombine` to stay consistent. + # But this also means that we need to first flatten the entire `values` component before recombining. + values = recombine(right, mapreduce(vec ∘ first, vcat, values_and_logps), length(vns)) + return values, map(last, values_and_logps) +end + +""" + pointwise_logdensities(model::Model, chain::Chains[, keytype::Type, context::AbstractContext]) + +Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` +with keys corresponding to symbols of the variables, and values being matrices +of shape `(num_chains, num_samples)`. + +# Arguments +- `model`: the `Model` to run. +- `chain`: the `Chains` to run the model on. +- `keytype`: the type of the keys used in the returned `OrderedDict` are. + Currently, only `String` and `VarName` are supported. +- `context`: the context to use when running the model. Default: `DefaultContext`. + The [`leafcontext`](@ref) is used to decide which variables to include. + +# Notes +Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ` +both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an +*observation*) statements can be implemented in three ways: +1. using a `for` loop: +```julia +for i in eachindex(y) + y[i] ~ Normal(μ, σ) +end +``` +2. using `.~`: +```julia +y .~ Normal(μ, σ) +``` +3. using `MvNormal`: +```julia +y ~ MvNormal(fill(μ, n), σ^2 * I) +``` + +In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables, +while in (3) `y` will be treated as a _single_ n-dimensional observation. + +This is important to keep in mind, in particular if the computation is used +for downstream computations. + +# Examples +## From chain +```julia-repl +julia> using DynamicPPL, Turing + +julia> @model function demo(xs, y) + s ~ InverseGamma(2, 3) + m ~ Normal(0, √s) + for i in eachindex(xs) + xs[i] ~ Normal(m, √s) + end + + y ~ Normal(m, √s) + end +demo (generic function with 1 method) + +julia> model = demo(randn(3), randn()); + +julia> chain = sample(model, MH(), 10); + +julia> pointwise_logdensities(model, chain) +OrderedDict{String,Array{Float64,2}} with 4 entries: + "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] + "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] + "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] + "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] + +julia> pointwise_logdensities(model, chain, String) +OrderedDict{String,Array{Float64,2}} with 4 entries: + "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] + "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] + "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] + "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] + +julia> pointwise_logdensities(model, chain, VarName) +OrderedDict{VarName,Array{Float64,2}} with 4 entries: + xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333] + xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359] + xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251] + y => [-1.51265; -0.914129; … ; -1.5499; -1.5499] +``` + +## Broadcasting +Note that `x .~ Dist()` will treat `x` as a collection of +_independent_ observations rather than as a single observation. + +```jldoctest; setup = :(using Distributions) +julia> @model function demo(x) + x .~ Normal() + end; + +julia> m = demo([1.0, ]); + +julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first(ℓ[@varname(x[1])]) +-1.4189385332046727 + +julia> m = demo([1.0; 1.0]); + +julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) +(-1.4189385332046727, -1.4189385332046727) +``` + +""" +function pointwise_logdensities( + model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext() +) where {T} + # Get the data by executing the model once + vi = VarInfo(model) + point_context = PointwiseLogdensityContext(OrderedDict{T,Vector{Float64}}(), context) + + iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) + for (sample_idx, chain_idx) in iters + # Update the values + setval!(vi, chain, sample_idx, chain_idx) + + # Execute model + model(vi, point_context) + end + + niters = size(chain, 1) + nchains = size(chain, 3) + logdensities = OrderedDict( + varname => reshape(logliks, niters, nchains) for + (varname, logliks) in point_context.logdensities + ) + return logdensities +end + +function pointwise_logdensities( + model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext() +) + point_context = PointwiseLogdensityContext( + OrderedDict{VarName,Vector{Float64}}(), context + ) + model(varinfo, point_context) + return point_context.logdensities +end + +""" + pointwise_loglikelihoods(model, chain[, keytype, context]) + +Compute the pointwise log-likelihoods of the model given the chain. + +This is the same as `pointwise_logdensities(model, chain, context)`, but only +including the likelihood terms. + +See also: [`pointwise_logdensities`](@ref). +""" +function pointwise_loglikelihoods( + model::Model, + chain, + keytype::Type{T}=String, + context::AbstractContext=LikelihoodContext(), +) where {T} + if !(leafcontext(context) isa LikelihoodContext) + throw(ArgumentError("Leaf context should be a LikelihoodContext")) + end + + return pointwise_logdensities(model, chain, T, context) +end + +function pointwise_loglikelihoods( + model::Model, varinfo::AbstractVarInfo, context::AbstractContext=LikelihoodContext() +) + if !(leafcontext(context) isa LikelihoodContext) + throw(ArgumentError("Leaf context should be a LikelihoodContext")) + end + + return pointwise_logdensities(model, varinfo, context) +end + +""" + pointwise_prior_logdensities(model, chain[, keytype, context]) + +Compute the pointwise log-prior-densities of the model given the chain. + +This is the same as `pointwise_logdensities(model, chain, context)`, but only +including the prior terms. + +See also: [`pointwise_logdensities`](@ref). +""" +function pointwise_prior_logdensities( + model::Model, chain, keytype::Type{T}=String, context::AbstractContext=PriorContext() +) where {T} + if !(leafcontext(context) isa PriorContext) + throw(ArgumentError("Leaf context should be a PriorContext")) + end + + return pointwise_logdensities(model, chain, T, context) +end + +function pointwise_prior_logdensities( + model::Model, varinfo::AbstractVarInfo, context::AbstractContext=PriorContext() +) + if !(leafcontext(context) isa PriorContext) + throw(ArgumentError("Leaf context should be a PriorContext")) + end + + return pointwise_logdensities(model, varinfo, context) +end diff --git a/src/test_utils.jl b/src/test_utils.jl index 6f7481c40..85dfa71f4 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -1042,4 +1042,50 @@ function test_context_interface(context) end end +""" +Context that multiplies each log-prior by mod +used to test whether varwise_logpriors respects child-context. +""" +struct TestLogModifyingChildContext{T,Ctx} <: DynamicPPL.AbstractContext + mod::T + context::Ctx +end +function TestLogModifyingChildContext( + mod=1.2, + context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext(), + #OrderedDict{VarName,Vector{Float64}}(),PriorContext()), +) + return TestLogModifyingChildContext{typeof(mod),typeof(context)}( + mod, context + ) +end +# Samplers call leafcontext(model.context) when evaluating log-densities +# Hence, in order to be used need to say that its a leaf-context +#DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent() +DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsLeaf() +DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context +function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child) + return TestLogModifyingChildContext(context.mod, child) +end +function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, vn, vi) + #@info "TestLogModifyingChildContext tilde_assume!! called for $vn" + value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) + return value, logp*context.mod, vi +end +function DynamicPPL.dot_tilde_assume(context::TestLogModifyingChildContext, right, left, vn, vi) + #@info "TestLogModifyingChildContext dot_tilde_assume!! called for $vn" + value, logp, vi = DynamicPPL.dot_tilde_assume(context.context, right, left, vn, vi) + return value, logp*context.mod, vi +end +function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi) + # @info "called tilde_observe TestLogModifyingChildContext for left=$left, right=$right" + logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) + return logp*context.mod, vi +end +function DynamicPPL.dot_tilde_observe(context::TestLogModifyingChildContext, right, left, vi) + logp, vi = DynamicPPL.dot_tilde_observe(context.context, right, left, vi) + return logp*context.mod, vi +end + + end diff --git a/test/Project.toml b/test/Project.toml index 13267ee1d..a75909f95 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,6 +14,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -26,10 +27,10 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -Accessors = "0.1" ADTypes = "0.2, 1" AbstractMCMC = "5" AbstractPPL = "0.8.2" +Accessors = "0.1" Bijectors = "0.13" Compat = "4.3.0" Distributions = "0.25" diff --git a/test/contexts.jl b/test/contexts.jl index 11e2c99b7..4ec9ff945 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -8,7 +8,7 @@ using DynamicPPL: NodeTrait, IsLeaf, IsParent, - PointwiseLikelihoodContext, + PointwiseLogdensityContext, contextual_isassumption, ConditionContext, hasconditioned, @@ -67,7 +67,7 @@ end SamplingContext(), MiniBatchContext(DefaultContext(), 0.0), PrefixContext{:x}(DefaultContext()), - PointwiseLikelihoodContext(), + PointwiseLogdensityContext(), ConditionContext((x=1.0,)), ConditionContext((x=1.0,), ParentContext(ConditionContext((y=2.0,)))), ConditionContext((x=1.0,), PrefixContext{:a}(ConditionContext((var"a.y"=2.0,)))), diff --git a/test/loglikelihoods.jl b/test/deprecated.jl similarity index 90% rename from test/loglikelihoods.jl rename to test/deprecated.jl index 1075ce333..f5c400691 100644 --- a/test/loglikelihoods.jl +++ b/test/deprecated.jl @@ -13,12 +13,15 @@ if isempty(lls) # One of the models with literal observations, so we just skip. + # TODO: Think of better way to detect this special case continue end loglikelihood = sum(sum, values(lls)) loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(m, example_values...) + #priors = + @test loglikelihood ≈ loglikelihood_true end end diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl new file mode 100644 index 000000000..5214bf5a1 --- /dev/null +++ b/test/pointwise_logdensities.jl @@ -0,0 +1,88 @@ +@testset "logdensities_likelihoods.jl" begin + likelihood_context = LikelihoodContext() + prior_context = PriorContext() + mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2) + mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) + #m = DynamicPPL.TestUtils.DEMO_MODELS[1] + @testset "$(m.f)" for (i, m) in enumerate(DynamicPPL.TestUtils.DEMO_MODELS) + #@show i + example_values = DynamicPPL.TestUtils.rand_prior_true(m) + + # Instantiate a `VarInfo` with the example values. + vi = VarInfo(m) + () -> begin # when interactively debugging, need the global keyword + for vn in DynamicPPL.TestUtils.varnames(m) + global vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) + end + end + for vn in DynamicPPL.TestUtils.varnames(m) + vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) + end + + loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(m, example_values...) + logp_true = logprior(m, vi) + + # Compute the pointwise loglikelihoods. + lls = pointwise_logdensities(m, vi, likelihood_context) + #lls2 = pointwise_loglikelihoods(m, vi) + loglikelihood = sum(sum, values(lls)) + if loglikelihood ≈ 0.0 #isempty(lls) + # One of the models with literal observations, so we just skip. + # TODO: Think of better way to detect this special case + loglikelihood_true = 0.0 + end + @test loglikelihood ≈ loglikelihood_true + + # Compute the pointwise logdensities of the priors. + lps_prior = pointwise_logdensities(m, vi, prior_context) + logp = sum(sum, values(lps_prior)) + if false # isempty(lps_prior) + # One of the models with only observations so we just skip. + else + logp1 = getlogp(vi) + @test !isfinite(logp_true) || logp ≈ logp_true + end + + # Compute both likelihood and logdensity of prior + # using the default DefaultContex + lps = pointwise_logdensities(m, vi) + logp = sum(sum, values(lps)) + @test logp ≈ (logp_true + loglikelihood_true) + + # Test that modifications of Setup are picked up + lps = pointwise_logdensities(m, vi, mod_ctx2) + logp = sum(sum, values(lps)) + @test logp ≈ (logp_true + loglikelihood_true) * 1.2 * 1.4 + end +end + + +@testset "pointwise_logdensities chain" begin + @model function demo(x, ::Type{TV}=Vector{Float64}) where {TV} + s ~ InverseGamma(2, 3) + m = TV(undef, length(x)) + for i in eachindex(x) + m[i] ~ Normal(0, √s) + end + x ~ MvNormal(m, √s) + end + x_true = [0.3290767977680923, 0.038972110187911684, -0.5797496780649221] + model = demo(x_true) + () -> begin + # generate the sample used below + chain = sample(model, MH(), MCMCThreads(), 10, 2) + arr0 = stack(Array(chain, append_chains=false)) + @show(arr0); + end + arr0 = [5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 0.9199555480151707 -0.1304320097505629 1.0669120062696917 -0.05253734412139093; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183;;; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 2.5409470583244933 1.7838744695696407 0.7013562890105632 -3.0843947804314658; 0.8296370582311665 1.5360702767879642 -1.5964695255693102 0.16928084806166913; 2.6246697053824954 0.8096845024785173 -1.2621822861663752 1.1414885535466166; 1.1304261861894538 0.7325784741344005 -1.1866016911837542 -0.1639319562090826; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 0.9838526141898173 -0.20198797220982412 2.0569535882007006 -1.1560724118010939] + chain = Chains(arr0, [:s, Symbol("m[1]"), Symbol("m[2]"), Symbol("m[3]")]); + tmp1 = pointwise_logdensities(model, chain) + vi = VarInfo(model) + i_sample, i_chain = (1,2) + DynamicPPL.setval!(vi, chain, i_sample, i_chain) + lp1 = DynamicPPL.pointwise_logdensities(model, vi) + # k = first(keys(lp1)) + for k in keys(lp1) + @test tmp1[string(k)][i_sample,i_chain] .≈ lp1[k][1] + end +end; \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index aa0883708..6de0cb7fe 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,6 +22,9 @@ using Pkg using Random using Serialization using Test +using Logging +using Distributions +using LinearAlgebra # Diagonal using DynamicPPL: getargs_dottilde, getargs_tilde, Selector @@ -31,6 +34,7 @@ const GROUP = get(ENV, "GROUP", "All") Random.seed!(100) +#include(joinpath(DIRECTORY_DynamicPPL,"test","test_util.jl")) include("test_util.jl") @testset "DynamicPPL.jl" begin @@ -53,9 +57,11 @@ include("test_util.jl") include("serialization.jl") - include("loglikelihoods.jl") + include("pointwise_logdensitiesjl") include("lkj.jl") + + include("deprecated.jl") end @testset "compat" begin