From d05124cc69b4b57fae34c1d10305a202bf568ecb Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Mon, 16 Sep 2024 08:10:29 +0200 Subject: [PATCH 01/37] implement pointwise_logpriors --- src/DynamicPPL.jl | 2 + src/context_implementations.jl | 25 +++- src/logpriors.jl | 223 +++++++++++++++++++++++++++++++++ test/Project.toml | 3 +- test/loglikelihoods.jl | 50 ++++++++ test/runtests.jl | 2 + 6 files changed, 300 insertions(+), 5 deletions(-) create mode 100644 src/logpriors.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index eb027b45b..555744ea3 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -116,6 +116,7 @@ export AbstractVarInfo, logprior, logjoint, pointwise_loglikelihoods, + pointwise_logpriors, condition, decondition, fix, @@ -182,6 +183,7 @@ include("simple_varinfo.jl") include("context_implementations.jl") include("compiler.jl") include("loglikelihoods.jl") +include("logpriors.jl") include("submodel_macro.jl") include("test_utils.jl") include("transforming.jl") diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 13231837f..338749344 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -425,6 +425,15 @@ function dot_assume( var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi::AbstractVarInfo, +) + r, lp, vi = dot_assume_vec(dist, var, vns, vi) + return r, sum(lp), vi +end +function dot_assume_vec( + dist::MultivariateDistribution, + var::AbstractMatrix, + vns::AbstractVector{<:VarName}, + vi::AbstractVarInfo, ) @assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))" # NOTE: We cannot work with `var` here because we might have a model of the form @@ -434,7 +443,7 @@ function dot_assume( # # in which case `var` will have `undef` elements, even if `m` is present in `vi`. r = vi[vns, dist] - lp = sum(zip(vns, eachcol(r))) do (vn, ri) + lp = map(zip(vns, eachcol(r))) do (vn, ri) return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) end return r, lp, vi @@ -455,21 +464,29 @@ function dot_assume( end function dot_assume( + dist::Union{Distribution,AbstractArray{<:Distribution}}, var::AbstractArray, vns::AbstractArray{<:VarName}, vi +) + # possibility to acesss the single logpriors + r, lp, vi = dot_assume_vec(dist, var, vns, vi) + return r, sum(lp), vi +end + +function dot_assume_vec( dist::Distribution, var::AbstractArray, vns::AbstractArray{<:VarName}, vi ) r = getindex.((vi,), vns, (dist,)) - lp = sum(Bijectors.logpdf_with_trans.((dist,), r, istrans.((vi,), vns))) + lp = Bijectors.logpdf_with_trans.((dist,), r, istrans.((vi,), vns)) return r, lp, vi end -function dot_assume( +function dot_assume_vec( dists::AbstractArray{<:Distribution}, var::AbstractArray, vns::AbstractArray{<:VarName}, vi, ) r = getindex.((vi,), vns, dists) - lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))) + lp = Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns)) return r, lp, vi end diff --git a/src/logpriors.jl b/src/logpriors.jl new file mode 100644 index 000000000..cdc00d699 --- /dev/null +++ b/src/logpriors.jl @@ -0,0 +1,223 @@ +# Context version +struct PointwisePriorContext{A,Ctx} <: AbstractContext + logpriors::A + context::Ctx +end + +function PointwisePriorContext( + priors=OrderedDict{VarName,Vector{Float64}}(), + context::AbstractContext=DynamicPPL.PriorContext(), + #OrderedDict{VarName,Vector{Float64}}(),PriorContext()), +) + return PointwisePriorContext{typeof(priors),typeof(context)}( + priors, context + ) +end + +NodeTrait(::PointwisePriorContext) = IsParent() +childcontext(context::PointwisePriorContext) = context.context +function setchildcontext(context::PointwisePriorContext, child) + return PointwisePriorContext(context.logpriors, child) +end + + +function tilde_assume!!(context::PointwisePriorContext, right, vn, vi) + #@info "PointwisePriorContext tilde_assume!! called for $vn" + value, logp, vi = tilde_assume(context, right, vn, vi) + push!(context, vn, logp) + return value, acclogp_assume!!(context, vi, logp) +end + +function dot_tilde_assume!!(context::PointwisePriorContext, right, left, vn, vi) + #@info "PointwisePriorContext dot_tilde_assume!! called for $vn" + # @show vn, left, right, typeof(context).name + value, logps, vi = dot_assume_vec(right, left, vn, vi) + # Need to unwrap the `vn`, i.e. get one `VarName` for each entry in `left`. + #_, _, vns = unwrap_right_left_vns(right, left, vn) + vns = vn + for (vn, logp) in zip(vns, logps) + # Track log-prior value. + push!(context, vn, logp) + end + return value, acclogp!!(vi, sum(logps)), vi +end + +function dot_tilde_assume!!(context::PointwisePriorContext, right, left, vn, vi) + #@info "PointwisePriorContext dot_tilde_assume!! called for $vn" + value, logps, vi = dot_assume_vec(right, left, vn, vi) + # Need to unwrap the `vn`, i.e. get one `VarName` for each entry in `left`? + #_, _, vns = unwrap_right_left_vns(right, left, vn) + vns = vn + for (vn, logp) in zip(vns, logps) + # Track log-prior value. + push!(context, vn, logp) + end + return value, acclogp!!(vi, sum(logps)), vi +end + +function tilde_observe(context::PointwisePriorContext, right, left, vi) + # Since we are evaluating the prior, the log probability of all the observations + # is set to 0. This has the effect of ignoring the likelihood. + return 0.0, vi + #tmp = tilde_observe(context.context, SampleFromPrior(), right, left, vi) + #return tmp +end + +function Base.push!( + context::PointwisePriorContext{<:AbstractDict{VarName,Vector{Float64}}}, + vn::VarName, + logp::Real, +) + lookup = context.logpriors + ℓ = get!(lookup, vn, Float64[]) + return push!(ℓ, logp) +end + +function Base.push!( + context::PointwisePriorContext{<:AbstractDict{VarName,Float64}}, + vn::VarName, + logp::Real, +) + return context.logpriors[vn] = logp +end + +function Base.push!( + context::PointwisePriorContext{<:AbstractDict{String,Vector{Float64}}}, + vn::VarName, + logp::Real, +) + lookup = context.logpriors + ℓ = get!(lookup, string(vn), Float64[]) + return push!(ℓ, logp) +end + +function Base.push!( + context::PointwisePriorContext{<:AbstractDict{String,Float64}}, + vn::VarName, + logp::Real, +) + return context.logpriors[string(vn)] = logp +end + +function Base.push!( + context::PointwisePriorContext{<:AbstractDict{String,Vector{Float64}}}, + vn::String, + logp::Real, +) + lookup = context.logpriors + ℓ = get!(lookup, vn, Float64[]) + return push!(ℓ, logp) +end + +function Base.push!( + context::PointwisePriorContext{<:AbstractDict{String,Float64}}, + vn::String, + logp::Real, +) + return context.logpriors[vn] = logp +end + + +# """ +# pointwise_logpriors(model::Model, chain::Chains, keytype = String) + +# Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` +# with keys corresponding to symbols of the observations, and values being matrices +# of shape `(num_chains, num_samples)`. + +# `keytype` specifies what the type of the keys used in the returned `OrderedDict` are. +# Currently, only `String` and `VarName` are supported. + +# # Notes +# Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ` +# both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an +# *observation*) statements can be implemented in three ways: +# 1. using a `for` loop: +# ```julia +# for i in eachindex(y) +# y[i] ~ Normal(μ, σ) +# end +# ``` +# 2. using `.~`: +# ```julia +# y .~ Normal(μ, σ) +# ``` +# 3. using `MvNormal`: +# ```julia +# y ~ MvNormal(fill(μ, n), σ^2 * I) +# ``` + +# In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables, +# while in (3) `y` will be treated as a _single_ n-dimensional observation. + +# This is important to keep in mind, in particular if the computation is used +# for downstream computations. + +# # Examples +# ## From chain +# ```julia-repl +# julia> using DynamicPPL, Turing + +# julia> @model function demo(xs, y) +# s ~ InverseGamma(2, 3) +# m ~ Normal(0, √s) +# for i in eachindex(xs) +# xs[i] ~ Normal(m, √s) +# end + +# y ~ Normal(m, √s) +# end +# demo (generic function with 1 method) + +# julia> model = demo(randn(3), randn()); + +# julia> chain = sample(model, MH(), 10); + +# julia> pointwise_logpriors(model, chain) +# OrderedDict{String,Array{Float64,2}} with 4 entries: +# "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] +# "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] +# "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] +# "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] + +# julia> pointwise_logpriors(model, chain, String) +# OrderedDict{String,Array{Float64,2}} with 4 entries: +# "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] +# "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] +# "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] +# "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] + +# julia> pointwise_logpriors(model, chain, VarName) +# OrderedDict{VarName,Array{Float64,2}} with 4 entries: +# xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333] +# xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359] +# xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251] +# y => [-1.51265; -0.914129; … ; -1.5499; -1.5499] +# ``` + +# ## Broadcasting +# Note that `x .~ Dist()` will treat `x` as a collection of +# _independent_ observations rather than as a single observation. + +# ```jldoctest; setup = :(using Distributions) +# julia> @model function demo(x) +# x .~ Normal() +# end; + +# julia> m = demo([1.0, ]); + +# julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first(ℓ[@varname(x[1])]) +# -1.4189385332046727 + +# julia> m = demo([1.0; 1.0]); + +# julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) +# (-1.4189385332046727, -1.4189385332046727) +# ``` + +# """ +function pointwise_logpriors(model::Model, varinfo::AbstractVarInfo) + context = PointwisePriorContext() + model(varinfo, context) + return context.logpriors +end diff --git a/test/Project.toml b/test/Project.toml index 13267ee1d..a75909f95 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,6 +14,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -26,10 +27,10 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -Accessors = "0.1" ADTypes = "0.2, 1" AbstractMCMC = "5" AbstractPPL = "0.8.2" +Accessors = "0.1" Bijectors = "0.13" Compat = "4.3.0" Distributions = "0.25" diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index 1075ce333..fbd405aaa 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -22,3 +22,53 @@ @test loglikelihood ≈ loglikelihood_true end end + +() -> begin + m = DynamicPPL.TestUtils.demo_assume_index_observe() + example_values = DynamicPPL.TestUtils.rand_prior_true(m) + vi = VarInfo(m) + for vn in DynamicPPL.TestUtils.varnames(m) + vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) + end + ret_m = first(evaluate!!(m, vi, SamplingContext())) + @test sum(map(k -> sum(ret_m.ld[k]), eachindex(ret_m.ld))) ≈ ret_m.logp + () -> begin + #by_generated_quantities + s = vcat(example_values...) + vnames = ["s[1]", "s[2]", "m[1]", "m[2]"] + stup = (; zip(Symbol.(vnames), s)...) + ret_m = generated_quantities(m, stup) + @test sum(map(k -> sum(ret_m.ld[k]), eachindex(ret_m.ld))) ≈ ret_m.logp + #chn = Chains(reshape(s, 1, : , 1), vnames); + chn = Chains(reshape(s, 1, :, 1)) # causes warning but works + ret_m = @test_logs (:warn,) generated_quantities(m, chn)[1, 1] + @test sum(map(k -> sum(ret_m.ld[k]), eachindex(ret_m.ld))) ≈ ret_m.logp + end +end + +@testset "logpriors.jl" begin + #m = DynamicPPL.TestUtils.DEMO_MODELS[1] + @testset "$(m.f)" for (i, m) in enumerate(DynamicPPL.TestUtils.DEMO_MODELS) + #@show i + example_values = DynamicPPL.TestUtils.rand_prior_true(m) + + # Instantiate a `VarInfo` with the example values. + vi = VarInfo(m) + () -> begin + for vn in DynamicPPL.TestUtils.varnames(m) + global vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) + end + end + for vn in DynamicPPL.TestUtils.varnames(m) + vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) + end + + #chains = sample(m, SampleFromPrior(), 2; progress=false) + + # Compute the pointwise loglikelihoods. + tmp = DynamicPPL.pointwise_logpriors(m, vi) + logp1 = getlogp(vi) + logp = logprior(m, vi) + @test !isfinite(getlogp(vi)) || sum(x -> sum(x), values(tmp)) ≈ logp + end; +end; diff --git a/test/runtests.jl b/test/runtests.jl index aa0883708..eafc0e652 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,6 +22,7 @@ using Pkg using Random using Serialization using Test +using Logging using DynamicPPL: getargs_dottilde, getargs_tilde, Selector @@ -31,6 +32,7 @@ const GROUP = get(ENV, "GROUP", "All") Random.seed!(100) +#include(joinpath(DIRECTORY_DynamicPPL,"test","test_util.jl")) include("test_util.jl") @testset "DynamicPPL.jl" begin From 4f46102805d1dcfafb5eadc40afdc29d1641101b Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Tue, 17 Sep 2024 08:13:29 +0200 Subject: [PATCH 02/37] implement varwise_logpriors --- docs/src/api.md | 11 +++ src/DynamicPPL.jl | 2 + src/logpriors.jl | 13 --- src/logpriors_var.jl | 201 +++++++++++++++++++++++++++++++++++++++++ src/test_utils.jl | 33 +++++++ test/loglikelihoods.jl | 132 +++++++++++++++++++++++---- test/runtests.jl | 2 + 7 files changed, 361 insertions(+), 33 deletions(-) create mode 100644 src/logpriors_var.jl diff --git a/docs/src/api.md b/docs/src/api.md index 97c48316e..38d9ee6b0 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -130,6 +130,17 @@ For a chain of samples, one can compute the pointwise log-likelihoods of each ob pointwise_loglikelihoods ``` +Similarly, one can compute the pointwise log-priors of each sampled random variable +with [`varwise_logpriors`](@ref). +Differently from `pointwise_loglikelihoods` it reports only a +single value for `.~` assignements. +If one needs to access the parts for single indices, one can +reformulate the model to use an explicit loop instead. + +```@docs +varwise_logpriors +``` + For converting a chain into a format that can more easily be fed into a `Model` again, for example using `condition`, you can use [`value_iterator_from_chain`](@ref). ```@docs diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 555744ea3..e69cdd568 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -117,6 +117,7 @@ export AbstractVarInfo, logjoint, pointwise_loglikelihoods, pointwise_logpriors, + varwise_logpriors, condition, decondition, fix, @@ -184,6 +185,7 @@ include("context_implementations.jl") include("compiler.jl") include("loglikelihoods.jl") include("logpriors.jl") +include("logpriors_var.jl") include("submodel_macro.jl") include("test_utils.jl") include("transforming.jl") diff --git a/src/logpriors.jl b/src/logpriors.jl index cdc00d699..1f3d8b3ad 100644 --- a/src/logpriors.jl +++ b/src/logpriors.jl @@ -42,19 +42,6 @@ function dot_tilde_assume!!(context::PointwisePriorContext, right, left, vn, vi) return value, acclogp!!(vi, sum(logps)), vi end -function dot_tilde_assume!!(context::PointwisePriorContext, right, left, vn, vi) - #@info "PointwisePriorContext dot_tilde_assume!! called for $vn" - value, logps, vi = dot_assume_vec(right, left, vn, vi) - # Need to unwrap the `vn`, i.e. get one `VarName` for each entry in `left`? - #_, _, vns = unwrap_right_left_vns(right, left, vn) - vns = vn - for (vn, logp) in zip(vns, logps) - # Track log-prior value. - push!(context, vn, logp) - end - return value, acclogp!!(vi, sum(logps)), vi -end - function tilde_observe(context::PointwisePriorContext, right, left, vi) # Since we are evaluating the prior, the log probability of all the observations # is set to 0. This has the effect of ignoring the likelihood. diff --git a/src/logpriors_var.jl b/src/logpriors_var.jl new file mode 100644 index 000000000..1c534b60a --- /dev/null +++ b/src/logpriors_var.jl @@ -0,0 +1,201 @@ +""" +Context that records logp after tilde_assume!! +for each VarName used by [`varwise_logpriors`](@ref). +""" +struct VarwisePriorContext{A,Ctx} <: AbstractContext + logpriors::A + context::Ctx +end + +function VarwisePriorContext( + logpriors=OrderedDict{Symbol,Float64}(), + context::AbstractContext=DynamicPPL.PriorContext(), + #OrderedDict{Symbol,Vector{Float64}}(),PriorContext()), +) + return VarwisePriorContext{typeof(logpriors),typeof(context)}( + logpriors, context + ) +end + +NodeTrait(::VarwisePriorContext) = IsParent() +childcontext(context::VarwisePriorContext) = context.context +function setchildcontext(context::VarwisePriorContext, child) + return VarwisePriorContext(context.logpriors, child) +end + +function tilde_assume(context::VarwisePriorContext, right, vn, vi) + #@info "VarwisePriorContext tilde_assume!! called for $vn" + value, logp, vi = tilde_assume(context.context, right, vn, vi) + #sym = DynamicPPL.getsym(vn) + new_context = acc_logp!(context, vn, logp) + return value, logp, vi +end + +function dot_tilde_assume(context::VarwisePriorContext, right, left, vn, vi) + #@info "VarwisePriorContext dot_tilde_assume!! called for $vn" + # @show vn, left, right, typeof(context).name + value, logp, vi = dot_tilde_assume(context.context, right, left, vn, vi) + new_context = acc_logp!(context, vn, logp) + return value, logp, vi +end + + +function tilde_observe(context::VarwisePriorContext, right, left, vi) + # Since we are evaluating the prior, the log probability of all the observations + # is set to 0. This has the effect of ignoring the likelihood. + return 0.0, vi + #tmp = tilde_observe(context.context, SampleFromPrior(), right, left, vi) + #return tmp +end + +function acc_logp!(context::VarwisePriorContext, vn::Union{VarName,AbstractVector{<:VarName}}, logp) + #sym = DynamicPPL.getsym(vn) # leads to duplicates + # if vn is a Vector leads to Symbol("VarName{:s, IndexLens{Tuple{Int64}}}[s[1], s[2]]") + sym = Symbol(vn) + context.logpriors[sym] = logp + return (context) +end + + +# """ +# pointwise_logpriors(model::Model, chain::Chains, keytype = String) + +# Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` +# with keys corresponding to symbols of the observations, and values being matrices +# of shape `(num_chains, num_samples)`. + +# `keytype` specifies what the type of the keys used in the returned `OrderedDict` are. +# Currently, only `String` and `VarName` are supported. + +# # Notes +# Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ` +# both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an +# *observation*) statements can be implemented in three ways: +# 1. using a `for` loop: +# ```julia +# for i in eachindex(y) +# y[i] ~ Normal(μ, σ) +# end +# ``` +# 2. using `.~`: +# ```julia +# y .~ Normal(μ, σ) +# ``` +# 3. using `MvNormal`: +# ```julia +# y ~ MvNormal(fill(μ, n), σ^2 * I) +# ``` + +# In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables, +# while in (3) `y` will be treated as a _single_ n-dimensional observation. + +# This is important to keep in mind, in particular if the computation is used +# for downstream computations. + +# # Examples +# ## From chain +# ```julia-repl +# julia> using DynamicPPL, Turing + +# julia> @model function demo(xs, y) +# s ~ InverseGamma(2, 3) +# m ~ Normal(0, √s) +# for i in eachindex(xs) +# xs[i] ~ Normal(m, √s) +# end + +# y ~ Normal(m, √s) +# end +# demo (generic function with 1 method) + +# julia> model = demo(randn(3), randn()); + +# julia> chain = sample(model, MH(), 10); + +# julia> pointwise_logpriors(model, chain) +# OrderedDict{String,Array{Float64,2}} with 4 entries: +# "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] +# "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] +# "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] +# "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] + +# julia> pointwise_logpriors(model, chain, String) +# OrderedDict{String,Array{Float64,2}} with 4 entries: +# "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] +# "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] +# "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] +# "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] + +# julia> pointwise_logpriors(model, chain, VarName) +# OrderedDict{VarName,Array{Float64,2}} with 4 entries: +# xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333] +# xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359] +# xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251] +# y => [-1.51265; -0.914129; … ; -1.5499; -1.5499] +# ``` + +# ## Broadcasting +# Note that `x .~ Dist()` will treat `x` as a collection of +# _independent_ observations rather than as a single observation. + +# ```jldoctest; setup = :(using Distributions) +# julia> @model function demo(x) +# x .~ Normal() +# end; + +# julia> m = demo([1.0, ]); + +# julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first(ℓ[@varname(x[1])]) +# -1.4189385332046727 + +# julia> m = demo([1.0; 1.0]); + +# julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) +# (-1.4189385332046727, -1.4189385332046727) +# ``` + +# """ +function varwise_logpriors( + model::Model, varinfo::AbstractVarInfo, + context::AbstractContext=PriorContext() +) +# top_context = VarwisePriorContext(OrderedDict{Symbol,Float64}(), context) + top_context = VarwisePriorContext(OrderedDict{Symbol,Float64}(), context) + model(varinfo, top_context) + return top_context.logpriors +end + +function varwise_logpriors(model::Model, chain::AbstractChains, + context::AbstractContext=PriorContext(); + top_context::VarwisePriorContext{T} = VarwisePriorContext(OrderedDict{Symbol,Float64}(), context) + ) where T + # pass top-context as keyword to allow adapt Number type of log-prior + get_values = (vi) -> begin + model(vi, top_context) + values(top_context.logpriors) + end + arr = map_model(get_values, model, chain) + par_names = collect(keys(top_context.logpriors)) + return(arr, par_names) +end + +function map_model(get_values, model::Model, chain::AbstractChains) + niters = size(chain, 1) + nchains = size(chain, 3) + vi = VarInfo(model) + iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) + # initialize the array by the first result + (sample_idx, chain_idx), iters2 = Iterators.peel(iters) + setval!(vi, chain, sample_idx, chain_idx) + values1 = get_values(vi) + arr = Array{eltype(values1)}(undef, niters, length(values1), nchains) + arr[sample_idx, :, chain_idx] .= values1 + #(sample_idx, chain_idx), iters3 = Iterators.peel(iters2) + for (sample_idx, chain_idx) in iters2 + # Update the values + setval!(vi, chain, sample_idx, chain_idx) + values_i = get_values(vi) + arr[sample_idx, :, chain_idx] .= values_i + end + return(arr) +end diff --git a/src/test_utils.jl b/src/test_utils.jl index 6f7481c40..c65c75b44 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -1042,4 +1042,37 @@ function test_context_interface(context) end end +""" +Context that multiplies each log-prior by mod +used to test whether varwise_logpriors respects child-context. +""" +struct TestLogModifyingChildContext{T,Ctx} <: DynamicPPL.AbstractContext + mod::T + context::Ctx +end +function TestLogModifyingChildContext( + mod=1.2, + context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext(), + #OrderedDict{VarName,Vector{Float64}}(),PriorContext()), +) + return TestLogModifyingChildContext{typeof(mod),typeof(context)}( + mod, context + ) +end +DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent() +DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context +function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child) + return TestLogModifyingChildContext(context.mod, child) +end +function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, vn, vi) + #@info "TestLogModifyingChildContext tilde_assume!! called for $vn" + value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) + return value, logp*context.mod, vi +end +function DynamicPPL.dot_tilde_assume(context::TestLogModifyingChildContext, right, left, vn, vi) + #@info "TestLogModifyingChildContext dot_tilde_assume!! called for $vn" + value, logp, vi = DynamicPPL.dot_tilde_assume(context.context, right, left, vn, vi) + return value, logp*context.mod, vi +end + end diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index fbd405aaa..dbbdc6245 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -23,28 +23,53 @@ end end -() -> begin - m = DynamicPPL.TestUtils.demo_assume_index_observe() - example_values = DynamicPPL.TestUtils.rand_prior_true(m) - vi = VarInfo(m) - for vn in DynamicPPL.TestUtils.varnames(m) +@testset "pointwise_logpriors" begin + # test equality of the single log-prios (not just the sum) + # by returning them from the model as generated values + @model function exdemo_assume_index_observe( + x=[1.5, 2.0], ::Type{TV}=Vector{Float64} + ) where {TV} + # `assume` with indexing and `observe` + s = TV(undef, length(x)) + for i in eachindex(s) + s[i] ~ InverseGamma(2, 3) + end + m = TV(undef, length(x)) + for i in eachindex(m) + m[i] ~ Normal(0, sqrt(s[i])) + end + x ~ MvNormal(m, Diagonal(s)) + # here also return the log-priors for testing + return (; + s=s, + m=m, + x=x, + logp=getlogp(__varinfo__), + ld=(; + s=logpdf.(Ref(InverseGamma(2, 3)), s), + m=[logpdf(Normal(0, sqrt(s[i])), m[i]) for i in eachindex(m)], + x=logpdf(MvNormal(m, Diagonal(s)), x), + ), + ) + end + mex = exdemo_assume_index_observe() + morig = DynamicPPL.TestUtils.demo_assume_index_observe() + example_values = DynamicPPL.TestUtils.rand_prior_true(morig) + vi = VarInfo(mex) + #Main.@infiltrate_main + for vn in DynamicPPL.TestUtils.varnames(mex) vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) end - ret_m = first(evaluate!!(m, vi, SamplingContext())) - @test sum(map(k -> sum(ret_m.ld[k]), eachindex(ret_m.ld))) ≈ ret_m.logp - () -> begin - #by_generated_quantities - s = vcat(example_values...) - vnames = ["s[1]", "s[2]", "m[1]", "m[2]"] - stup = (; zip(Symbol.(vnames), s)...) - ret_m = generated_quantities(m, stup) - @test sum(map(k -> sum(ret_m.ld[k]), eachindex(ret_m.ld))) ≈ ret_m.logp - #chn = Chains(reshape(s, 1, : , 1), vnames); - chn = Chains(reshape(s, 1, :, 1)) # causes warning but works - ret_m = @test_logs (:warn,) generated_quantities(m, chn)[1, 1] - @test sum(map(k -> sum(ret_m.ld[k]), eachindex(ret_m.ld))) ≈ ret_m.logp + () -> begin # for interactive execution at the repl need the global keyword + for vn in DynamicPPL.TestUtils.varnames(mex) + global vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) + end end -end + ret_m = first(evaluate!!(mex, vi, SamplingContext())) + true_logpriors = ret_m.ld[[:s, :m]] + logpriors = DynamicPPL.pointwise_logpriors(mex, vi) + @test all(vcat(values(logpriors)...) .≈ vcat(true_logpriors...)) +end; @testset "logpriors.jl" begin #m = DynamicPPL.TestUtils.DEMO_MODELS[1] @@ -70,5 +95,72 @@ end logp1 = getlogp(vi) logp = logprior(m, vi) @test !isfinite(getlogp(vi)) || sum(x -> sum(x), values(tmp)) ≈ logp - end; + end +end; + +@testset "logpriors_var.jl" begin + mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2, PriorContext()) + mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) + #m = DynamicPPL.TestUtils.DEMO_MODELS[1] + # m = DynamicPPL.TestUtils.demo_assume_index_observe() # logp at i-level? + @testset "$(m.f)" for (i, m) in enumerate(DynamicPPL.TestUtils.DEMO_MODELS) + #@show i + example_values = DynamicPPL.TestUtils.rand_prior_true(m) + + # Instantiate a `VarInfo` with the example values. + vi = VarInfo(m) + () -> begin + for vn in DynamicPPL.TestUtils.varnames(m) + global vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) + end + end + for vn in DynamicPPL.TestUtils.varnames(m) + vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) + end + + #chains = sample(m, SampleFromPrior(), 2; progress=false) + + # Compute the pointwise loglikelihoods. + logpriors = DynamicPPL.varwise_logpriors(m, vi) + logp1 = getlogp(vi) + logp = logprior(m, vi) + @test !isfinite(logp) || sum(x -> sum(x), values(logpriors)) ≈ logp + # + # test on modifying child-context + logpriors_mod = DynamicPPL.varwise_logpriors(m, vi, mod_ctx2) + logp1 = getlogp(vi) + # Following line assumes no Likelihood contributions + # requires lowest Context to be PriorContext + @test !isfinite(logp1) || sum(x -> sum(x), values(logpriors_mod)) ≈ logp1 # + @test all(values(logpriors_mod) .≈ values(logpriors) .* 1.2 .* 1.4) + end end; + +@testset "logpriors_var chain" begin + @model function demo(xs, y) + s ~ InverseGamma(2, 3) + m ~ Normal(0, √s) + for i in eachindex(xs) + xs[i] ~ Normal(m, √s) + end + y ~ Normal(m, √s) + end + xs_true, y_true = ([0.3290767977680923, 0.038972110187911684, -0.5797496780649221], -0.7321425592768186)#randn(3), randn() + model = demo(xs_true, y_true) + () -> begin + # generate the sample used below + chain = sample(model, MH(), 10) + arr0 = Array(chain) + end + arr0 = [1.8585322626573435 -0.05900855284939967; 1.7304068220366808 -0.6386249100228161; 1.7304068220366808 -0.6386249100228161; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039]; # generated in function above + # split into two chains for testing + arr1 = permutedims(reshape(arr0, 5,2,:),(1,3,2)) + chain = Chains(arr1, [:s, :m]); + tmp1 = varwise_logpriors(model, chain) + tmp = Chains(tmp1...); # can be used to create a Chains object + vi = VarInfo(model) + i_sample, i_chain = (1,2) + DynamicPPL.setval!(vi, chain, i_sample, i_chain) + lp1 = DynamicPPL.varwise_logpriors(model, vi) + @test all(tmp1[1][i_sample,:,i_chain] .≈ values(lp1)) +end; \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index eafc0e652..a321151ab 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,6 +23,8 @@ using Random using Serialization using Test using Logging +using Distributions +using LinearAlgebra # Diagonal using DynamicPPL: getargs_dottilde, getargs_tilde, Selector From c6653b98f3fd86471c954b72a042c6cc240c8e13 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Tue, 17 Sep 2024 08:29:51 +0200 Subject: [PATCH 03/37] remove pointwise_logpriors --- src/DynamicPPL.jl | 2 - src/logpriors.jl | 210 ----------------------------------------- test/loglikelihoods.jl | 75 --------------- 3 files changed, 287 deletions(-) delete mode 100644 src/logpriors.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e69cdd568..774a8010f 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -116,7 +116,6 @@ export AbstractVarInfo, logprior, logjoint, pointwise_loglikelihoods, - pointwise_logpriors, varwise_logpriors, condition, decondition, @@ -184,7 +183,6 @@ include("simple_varinfo.jl") include("context_implementations.jl") include("compiler.jl") include("loglikelihoods.jl") -include("logpriors.jl") include("logpriors_var.jl") include("submodel_macro.jl") include("test_utils.jl") diff --git a/src/logpriors.jl b/src/logpriors.jl deleted file mode 100644 index 1f3d8b3ad..000000000 --- a/src/logpriors.jl +++ /dev/null @@ -1,210 +0,0 @@ -# Context version -struct PointwisePriorContext{A,Ctx} <: AbstractContext - logpriors::A - context::Ctx -end - -function PointwisePriorContext( - priors=OrderedDict{VarName,Vector{Float64}}(), - context::AbstractContext=DynamicPPL.PriorContext(), - #OrderedDict{VarName,Vector{Float64}}(),PriorContext()), -) - return PointwisePriorContext{typeof(priors),typeof(context)}( - priors, context - ) -end - -NodeTrait(::PointwisePriorContext) = IsParent() -childcontext(context::PointwisePriorContext) = context.context -function setchildcontext(context::PointwisePriorContext, child) - return PointwisePriorContext(context.logpriors, child) -end - - -function tilde_assume!!(context::PointwisePriorContext, right, vn, vi) - #@info "PointwisePriorContext tilde_assume!! called for $vn" - value, logp, vi = tilde_assume(context, right, vn, vi) - push!(context, vn, logp) - return value, acclogp_assume!!(context, vi, logp) -end - -function dot_tilde_assume!!(context::PointwisePriorContext, right, left, vn, vi) - #@info "PointwisePriorContext dot_tilde_assume!! called for $vn" - # @show vn, left, right, typeof(context).name - value, logps, vi = dot_assume_vec(right, left, vn, vi) - # Need to unwrap the `vn`, i.e. get one `VarName` for each entry in `left`. - #_, _, vns = unwrap_right_left_vns(right, left, vn) - vns = vn - for (vn, logp) in zip(vns, logps) - # Track log-prior value. - push!(context, vn, logp) - end - return value, acclogp!!(vi, sum(logps)), vi -end - -function tilde_observe(context::PointwisePriorContext, right, left, vi) - # Since we are evaluating the prior, the log probability of all the observations - # is set to 0. This has the effect of ignoring the likelihood. - return 0.0, vi - #tmp = tilde_observe(context.context, SampleFromPrior(), right, left, vi) - #return tmp -end - -function Base.push!( - context::PointwisePriorContext{<:AbstractDict{VarName,Vector{Float64}}}, - vn::VarName, - logp::Real, -) - lookup = context.logpriors - ℓ = get!(lookup, vn, Float64[]) - return push!(ℓ, logp) -end - -function Base.push!( - context::PointwisePriorContext{<:AbstractDict{VarName,Float64}}, - vn::VarName, - logp::Real, -) - return context.logpriors[vn] = logp -end - -function Base.push!( - context::PointwisePriorContext{<:AbstractDict{String,Vector{Float64}}}, - vn::VarName, - logp::Real, -) - lookup = context.logpriors - ℓ = get!(lookup, string(vn), Float64[]) - return push!(ℓ, logp) -end - -function Base.push!( - context::PointwisePriorContext{<:AbstractDict{String,Float64}}, - vn::VarName, - logp::Real, -) - return context.logpriors[string(vn)] = logp -end - -function Base.push!( - context::PointwisePriorContext{<:AbstractDict{String,Vector{Float64}}}, - vn::String, - logp::Real, -) - lookup = context.logpriors - ℓ = get!(lookup, vn, Float64[]) - return push!(ℓ, logp) -end - -function Base.push!( - context::PointwisePriorContext{<:AbstractDict{String,Float64}}, - vn::String, - logp::Real, -) - return context.logpriors[vn] = logp -end - - -# """ -# pointwise_logpriors(model::Model, chain::Chains, keytype = String) - -# Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` -# with keys corresponding to symbols of the observations, and values being matrices -# of shape `(num_chains, num_samples)`. - -# `keytype` specifies what the type of the keys used in the returned `OrderedDict` are. -# Currently, only `String` and `VarName` are supported. - -# # Notes -# Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ` -# both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an -# *observation*) statements can be implemented in three ways: -# 1. using a `for` loop: -# ```julia -# for i in eachindex(y) -# y[i] ~ Normal(μ, σ) -# end -# ``` -# 2. using `.~`: -# ```julia -# y .~ Normal(μ, σ) -# ``` -# 3. using `MvNormal`: -# ```julia -# y ~ MvNormal(fill(μ, n), σ^2 * I) -# ``` - -# In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables, -# while in (3) `y` will be treated as a _single_ n-dimensional observation. - -# This is important to keep in mind, in particular if the computation is used -# for downstream computations. - -# # Examples -# ## From chain -# ```julia-repl -# julia> using DynamicPPL, Turing - -# julia> @model function demo(xs, y) -# s ~ InverseGamma(2, 3) -# m ~ Normal(0, √s) -# for i in eachindex(xs) -# xs[i] ~ Normal(m, √s) -# end - -# y ~ Normal(m, √s) -# end -# demo (generic function with 1 method) - -# julia> model = demo(randn(3), randn()); - -# julia> chain = sample(model, MH(), 10); - -# julia> pointwise_logpriors(model, chain) -# OrderedDict{String,Array{Float64,2}} with 4 entries: -# "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] -# "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] -# "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] -# "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] - -# julia> pointwise_logpriors(model, chain, String) -# OrderedDict{String,Array{Float64,2}} with 4 entries: -# "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] -# "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] -# "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] -# "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] - -# julia> pointwise_logpriors(model, chain, VarName) -# OrderedDict{VarName,Array{Float64,2}} with 4 entries: -# xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333] -# xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359] -# xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251] -# y => [-1.51265; -0.914129; … ; -1.5499; -1.5499] -# ``` - -# ## Broadcasting -# Note that `x .~ Dist()` will treat `x` as a collection of -# _independent_ observations rather than as a single observation. - -# ```jldoctest; setup = :(using Distributions) -# julia> @model function demo(x) -# x .~ Normal() -# end; - -# julia> m = demo([1.0, ]); - -# julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first(ℓ[@varname(x[1])]) -# -1.4189385332046727 - -# julia> m = demo([1.0; 1.0]); - -# julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) -# (-1.4189385332046727, -1.4189385332046727) -# ``` - -# """ -function pointwise_logpriors(model::Model, varinfo::AbstractVarInfo) - context = PointwisePriorContext() - model(varinfo, context) - return context.logpriors -end diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index dbbdc6245..b511a03d7 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -23,81 +23,6 @@ end end -@testset "pointwise_logpriors" begin - # test equality of the single log-prios (not just the sum) - # by returning them from the model as generated values - @model function exdemo_assume_index_observe( - x=[1.5, 2.0], ::Type{TV}=Vector{Float64} - ) where {TV} - # `assume` with indexing and `observe` - s = TV(undef, length(x)) - for i in eachindex(s) - s[i] ~ InverseGamma(2, 3) - end - m = TV(undef, length(x)) - for i in eachindex(m) - m[i] ~ Normal(0, sqrt(s[i])) - end - x ~ MvNormal(m, Diagonal(s)) - # here also return the log-priors for testing - return (; - s=s, - m=m, - x=x, - logp=getlogp(__varinfo__), - ld=(; - s=logpdf.(Ref(InverseGamma(2, 3)), s), - m=[logpdf(Normal(0, sqrt(s[i])), m[i]) for i in eachindex(m)], - x=logpdf(MvNormal(m, Diagonal(s)), x), - ), - ) - end - mex = exdemo_assume_index_observe() - morig = DynamicPPL.TestUtils.demo_assume_index_observe() - example_values = DynamicPPL.TestUtils.rand_prior_true(morig) - vi = VarInfo(mex) - #Main.@infiltrate_main - for vn in DynamicPPL.TestUtils.varnames(mex) - vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) - end - () -> begin # for interactive execution at the repl need the global keyword - for vn in DynamicPPL.TestUtils.varnames(mex) - global vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) - end - end - ret_m = first(evaluate!!(mex, vi, SamplingContext())) - true_logpriors = ret_m.ld[[:s, :m]] - logpriors = DynamicPPL.pointwise_logpriors(mex, vi) - @test all(vcat(values(logpriors)...) .≈ vcat(true_logpriors...)) -end; - -@testset "logpriors.jl" begin - #m = DynamicPPL.TestUtils.DEMO_MODELS[1] - @testset "$(m.f)" for (i, m) in enumerate(DynamicPPL.TestUtils.DEMO_MODELS) - #@show i - example_values = DynamicPPL.TestUtils.rand_prior_true(m) - - # Instantiate a `VarInfo` with the example values. - vi = VarInfo(m) - () -> begin - for vn in DynamicPPL.TestUtils.varnames(m) - global vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) - end - end - for vn in DynamicPPL.TestUtils.varnames(m) - vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) - end - - #chains = sample(m, SampleFromPrior(), 2; progress=false) - - # Compute the pointwise loglikelihoods. - tmp = DynamicPPL.pointwise_logpriors(m, vi) - logp1 = getlogp(vi) - logp = logprior(m, vi) - @test !isfinite(getlogp(vi)) || sum(x -> sum(x), values(tmp)) ≈ logp - end -end; - @testset "logpriors_var.jl" begin mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2, PriorContext()) mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) From 216d50cb7343da0a0a1680d49e5e8c79e5344707 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Tue, 17 Sep 2024 12:59:03 +0200 Subject: [PATCH 04/37] revert dot_assume to not explicitly resolve components of sum --- src/context_implementations.jl | 25 ++++--------------------- 1 file changed, 4 insertions(+), 21 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 338749344..13231837f 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -425,15 +425,6 @@ function dot_assume( var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi::AbstractVarInfo, -) - r, lp, vi = dot_assume_vec(dist, var, vns, vi) - return r, sum(lp), vi -end -function dot_assume_vec( - dist::MultivariateDistribution, - var::AbstractMatrix, - vns::AbstractVector{<:VarName}, - vi::AbstractVarInfo, ) @assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))" # NOTE: We cannot work with `var` here because we might have a model of the form @@ -443,7 +434,7 @@ function dot_assume_vec( # # in which case `var` will have `undef` elements, even if `m` is present in `vi`. r = vi[vns, dist] - lp = map(zip(vns, eachcol(r))) do (vn, ri) + lp = sum(zip(vns, eachcol(r))) do (vn, ri) return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) end return r, lp, vi @@ -464,29 +455,21 @@ function dot_assume( end function dot_assume( - dist::Union{Distribution,AbstractArray{<:Distribution}}, var::AbstractArray, vns::AbstractArray{<:VarName}, vi -) - # possibility to acesss the single logpriors - r, lp, vi = dot_assume_vec(dist, var, vns, vi) - return r, sum(lp), vi -end - -function dot_assume_vec( dist::Distribution, var::AbstractArray, vns::AbstractArray{<:VarName}, vi ) r = getindex.((vi,), vns, (dist,)) - lp = Bijectors.logpdf_with_trans.((dist,), r, istrans.((vi,), vns)) + lp = sum(Bijectors.logpdf_with_trans.((dist,), r, istrans.((vi,), vns))) return r, lp, vi end -function dot_assume_vec( +function dot_assume( dists::AbstractArray{<:Distribution}, var::AbstractArray, vns::AbstractArray{<:VarName}, vi, ) r = getindex.((vi,), vns, dists) - lp = Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns)) + lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))) return r, lp, vi end From fd8d3b244cb861eae2850a35415b62cf47b7994a Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Wed, 18 Sep 2024 10:09:40 +0200 Subject: [PATCH 05/37] docstring varwise_logpriores use loop for prior in example Unfortunately cannot make it a jldoctest, because relies on Turing for sampling --- src/logpriors_var.jl | 154 +++++++++++++---------------------------- src/test_utils.jl | 14 +++- test/loglikelihoods.jl | 31 +++++---- 3 files changed, 77 insertions(+), 122 deletions(-) diff --git a/src/logpriors_var.jl b/src/logpriors_var.jl index 1c534b60a..1860d8ccd 100644 --- a/src/logpriors_var.jl +++ b/src/logpriors_var.jl @@ -40,13 +40,8 @@ function dot_tilde_assume(context::VarwisePriorContext, right, left, vn, vi) end -function tilde_observe(context::VarwisePriorContext, right, left, vi) - # Since we are evaluating the prior, the log probability of all the observations - # is set to 0. This has the effect of ignoring the likelihood. - return 0.0, vi - #tmp = tilde_observe(context.context, SampleFromPrior(), right, left, vi) - #return tmp -end +tilde_observe(context::VarwisePriorContext, right, left, vi) = 0, vi +dot_tilde_observe(::VarwisePriorContext, right, left, vi) = 0, vi function acc_logp!(context::VarwisePriorContext, vn::Union{VarName,AbstractVector{<:VarName}}, logp) #sym = DynamicPPL.getsym(vn) # leads to duplicates @@ -56,105 +51,52 @@ function acc_logp!(context::VarwisePriorContext, vn::Union{VarName,AbstractVecto return (context) end - -# """ -# pointwise_logpriors(model::Model, chain::Chains, keytype = String) - -# Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` -# with keys corresponding to symbols of the observations, and values being matrices -# of shape `(num_chains, num_samples)`. - -# `keytype` specifies what the type of the keys used in the returned `OrderedDict` are. -# Currently, only `String` and `VarName` are supported. - -# # Notes -# Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ` -# both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an -# *observation*) statements can be implemented in three ways: -# 1. using a `for` loop: -# ```julia -# for i in eachindex(y) -# y[i] ~ Normal(μ, σ) -# end -# ``` -# 2. using `.~`: -# ```julia -# y .~ Normal(μ, σ) -# ``` -# 3. using `MvNormal`: -# ```julia -# y ~ MvNormal(fill(μ, n), σ^2 * I) -# ``` - -# In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables, -# while in (3) `y` will be treated as a _single_ n-dimensional observation. - -# This is important to keep in mind, in particular if the computation is used -# for downstream computations. - -# # Examples -# ## From chain -# ```julia-repl -# julia> using DynamicPPL, Turing - -# julia> @model function demo(xs, y) -# s ~ InverseGamma(2, 3) -# m ~ Normal(0, √s) -# for i in eachindex(xs) -# xs[i] ~ Normal(m, √s) -# end - -# y ~ Normal(m, √s) -# end -# demo (generic function with 1 method) - -# julia> model = demo(randn(3), randn()); - -# julia> chain = sample(model, MH(), 10); - -# julia> pointwise_logpriors(model, chain) -# OrderedDict{String,Array{Float64,2}} with 4 entries: -# "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] -# "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] -# "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] -# "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] - -# julia> pointwise_logpriors(model, chain, String) -# OrderedDict{String,Array{Float64,2}} with 4 entries: -# "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] -# "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] -# "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] -# "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] - -# julia> pointwise_logpriors(model, chain, VarName) -# OrderedDict{VarName,Array{Float64,2}} with 4 entries: -# xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333] -# xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359] -# xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251] -# y => [-1.51265; -0.914129; … ; -1.5499; -1.5499] -# ``` - -# ## Broadcasting -# Note that `x .~ Dist()` will treat `x` as a collection of -# _independent_ observations rather than as a single observation. - -# ```jldoctest; setup = :(using Distributions) -# julia> @model function demo(x) -# x .~ Normal() -# end; - -# julia> m = demo([1.0, ]); - -# julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first(ℓ[@varname(x[1])]) -# -1.4189385332046727 - -# julia> m = demo([1.0; 1.0]); - -# julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) -# (-1.4189385332046727, -1.4189385332046727) -# ``` - -# """ +""" + varwise_logpriors(model::Model, chain::Chains; context) + +Runs `model` on each sample in `chain` returning a tuple `(values, var_names)` +with var_names corresponding to symbols of the prior components, and values being +array of shape `(num_samples, num_components, num_chains)`. + +`context` specifies child context that handles computation of log-priors. + +# Example +```julia; setup = :(using Distributions) +using DynamicPPL, Turing + +@model function demo(x, ::Type{TV}=Vector{Float64}) where {TV} + s ~ InverseGamma(2, 3) + m = TV(undef, length(x)) + for i in eachindex(x) + m[i] ~ Normal(0, √s) + end + x ~ MvNormal(m, √s) + end + +model = demo(randn(3), randn()); + +chain = sample(model, MH(), 10); + +lp = varwise_logpriors(model, chain) +# Can be used to construct a new Chains object +#lpc = MCMCChains(varwise_logpriors(model, chain)...) + +# got a logdensity for each parameter prior +(but fewer if used `.~` assignments, see below) +lp[2] == names(chain, :parameters) + +# for each sample in the Chains object +size(lp[1])[[1,3]] == size(chain)[[1,3]] +``` + +# Broadcasting +Note that `m .~ Dist()` will treat `m` as a collection of +_independent_ prior rather than as a single prior, +but `varwise_logpriors` returns the single +sum of log-likelihood of components of `m` only. +If one needs the log-density of the components, one needs to rewrite +the model with an explicit loop. +""" function varwise_logpriors( model::Model, varinfo::AbstractVarInfo, context::AbstractContext=PriorContext() diff --git a/src/test_utils.jl b/src/test_utils.jl index c65c75b44..9db8c37fd 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -1059,7 +1059,10 @@ function TestLogModifyingChildContext( mod, context ) end -DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent() +# Samplers call leafcontext(model.context) when evaluating log-densities +# Hence, in order to be used need to say that its a leaf-context +#DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent() +DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsLeaf() DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child) return TestLogModifyingChildContext(context.mod, child) @@ -1074,5 +1077,14 @@ function DynamicPPL.dot_tilde_assume(context::TestLogModifyingChildContext, righ value, logp, vi = DynamicPPL.dot_tilde_assume(context.context, right, left, vn, vi) return value, logp*context.mod, vi end +function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi) + value, logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) + return value, logp*context.mod, vi +end +function DynamicPPL.dot_tilde_observe(context::TestLogModifyingChildContext, right, left, vi) + return DynamicPPL.dot_tilde_observe(context.context, right, left, vi) + return value, logp*context.mod, vi +end + end diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index b511a03d7..17185c4ad 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -28,6 +28,7 @@ end mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) #m = DynamicPPL.TestUtils.DEMO_MODELS[1] # m = DynamicPPL.TestUtils.demo_assume_index_observe() # logp at i-level? + # m = DynamicPPL.TestUtils.demo_assume_dot_observe() # failing test? @testset "$(m.f)" for (i, m) in enumerate(DynamicPPL.TestUtils.DEMO_MODELS) #@show i example_values = DynamicPPL.TestUtils.rand_prior_true(m) @@ -53,7 +54,8 @@ end # # test on modifying child-context logpriors_mod = DynamicPPL.varwise_logpriors(m, vi, mod_ctx2) - logp1 = getlogp(vi) + logp1 = getlogp(vi) + #logp_mod = logprior(m, vi) # uses prior context but not mod_ctx2 # Following line assumes no Likelihood contributions # requires lowest Context to be PriorContext @test !isfinite(logp1) || sum(x -> sum(x), values(logpriors_mod)) ≈ logp1 # @@ -62,25 +64,24 @@ end end; @testset "logpriors_var chain" begin - @model function demo(xs, y) + @model function demo(x, ::Type{TV}=Vector{Float64}) where {TV} s ~ InverseGamma(2, 3) - m ~ Normal(0, √s) - for i in eachindex(xs) - xs[i] ~ Normal(m, √s) + m = TV(undef, length(x)) + for i in eachindex(x) + m[i] ~ Normal(0, √s) end - y ~ Normal(m, √s) - end - xs_true, y_true = ([0.3290767977680923, 0.038972110187911684, -0.5797496780649221], -0.7321425592768186)#randn(3), randn() - model = demo(xs_true, y_true) + x ~ MvNormal(m, √s) + end + x_true = [0.3290767977680923, 0.038972110187911684, -0.5797496780649221] + model = demo(x_true) () -> begin # generate the sample used below - chain = sample(model, MH(), 10) - arr0 = Array(chain) + chain = sample(model, MH(), MCMCThreads(), 10, 2) + arr0 = stack(Array(chain, append_chains=false)) + @show(arr0); end - arr0 = [1.8585322626573435 -0.05900855284939967; 1.7304068220366808 -0.6386249100228161; 1.7304068220366808 -0.6386249100228161; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039]; # generated in function above - # split into two chains for testing - arr1 = permutedims(reshape(arr0, 5,2,:),(1,3,2)) - chain = Chains(arr1, [:s, :m]); + arr0 = [5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 0.9199555480151707 -0.1304320097505629 1.0669120062696917 -0.05253734412139093; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183;;; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 2.5409470583244933 1.7838744695696407 0.7013562890105632 -3.0843947804314658; 0.8296370582311665 1.5360702767879642 -1.5964695255693102 0.16928084806166913; 2.6246697053824954 0.8096845024785173 -1.2621822861663752 1.1414885535466166; 1.1304261861894538 0.7325784741344005 -1.1866016911837542 -0.1639319562090826; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 0.9838526141898173 -0.20198797220982412 2.0569535882007006 -1.1560724118010939] + chain = Chains(arr0, [:s, Symbol("m[1]"), Symbol("m[2]"), Symbol("m[3]")]); tmp1 = varwise_logpriors(model, chain) tmp = Chains(tmp1...); # can be used to create a Chains object vi = VarInfo(model) From 5842656154a5b2f9a0377c45a4d4438933971a11 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Thu, 19 Sep 2024 10:54:05 +0200 Subject: [PATCH 06/37] integrate pointwise_loglikelihoods and varwise_logpriors by pointwise_densities --- src/DynamicPPL.jl | 6 +- src/deprecated.jl | 9 ++ src/logpriors_var.jl | 143 ------------------ ...kelihoods.jl => pointwise_logdensities.jl} | 128 +++++++++++----- src/test_utils.jl | 9 +- test/contexts.jl | 4 +- test/deprecated.jl | 29 ++++ ...ikelihoods.jl => pointwise_logdensitiesjl} | 96 ++++++------ test/runtests.jl | 4 +- 9 files changed, 184 insertions(+), 244 deletions(-) create mode 100644 src/deprecated.jl delete mode 100644 src/logpriors_var.jl rename src/{loglikelihoods.jl => pointwise_logdensities.jl} (59%) create mode 100644 test/deprecated.jl rename test/{loglikelihoods.jl => pointwise_logdensitiesjl} (65%) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 774a8010f..9d870c9e8 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -116,7 +116,7 @@ export AbstractVarInfo, logprior, logjoint, pointwise_loglikelihoods, - varwise_logpriors, + pointwise_logdensities, condition, decondition, fix, @@ -182,8 +182,7 @@ include("varinfo.jl") include("simple_varinfo.jl") include("context_implementations.jl") include("compiler.jl") -include("loglikelihoods.jl") -include("logpriors_var.jl") +include("pointwise_logdensities.jl") include("submodel_macro.jl") include("test_utils.jl") include("transforming.jl") @@ -191,6 +190,7 @@ include("logdensityfunction.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") +include("deprecated.jl") include("debug_utils.jl") using .DebugUtils diff --git a/src/deprecated.jl b/src/deprecated.jl new file mode 100644 index 000000000..09e1ac84d --- /dev/null +++ b/src/deprecated.jl @@ -0,0 +1,9 @@ +# https://invenia.github.io/blog/2022/06/17/deprecating-in-julia/ + +Base.@deprecate pointwise_loglikelihoods(model::Model, chain, keytype) pointwise_logdensities( + model::Model, LikelihoodContext(), chain, keytype) + +Base.@deprecate pointwise_loglikelihoods( + model::Model, varinfo::AbstractVarInfo) pointwise_logdensities( + model::Model, varinfo, LikelihoodContext()) + diff --git a/src/logpriors_var.jl b/src/logpriors_var.jl deleted file mode 100644 index 1860d8ccd..000000000 --- a/src/logpriors_var.jl +++ /dev/null @@ -1,143 +0,0 @@ -""" -Context that records logp after tilde_assume!! -for each VarName used by [`varwise_logpriors`](@ref). -""" -struct VarwisePriorContext{A,Ctx} <: AbstractContext - logpriors::A - context::Ctx -end - -function VarwisePriorContext( - logpriors=OrderedDict{Symbol,Float64}(), - context::AbstractContext=DynamicPPL.PriorContext(), - #OrderedDict{Symbol,Vector{Float64}}(),PriorContext()), -) - return VarwisePriorContext{typeof(logpriors),typeof(context)}( - logpriors, context - ) -end - -NodeTrait(::VarwisePriorContext) = IsParent() -childcontext(context::VarwisePriorContext) = context.context -function setchildcontext(context::VarwisePriorContext, child) - return VarwisePriorContext(context.logpriors, child) -end - -function tilde_assume(context::VarwisePriorContext, right, vn, vi) - #@info "VarwisePriorContext tilde_assume!! called for $vn" - value, logp, vi = tilde_assume(context.context, right, vn, vi) - #sym = DynamicPPL.getsym(vn) - new_context = acc_logp!(context, vn, logp) - return value, logp, vi -end - -function dot_tilde_assume(context::VarwisePriorContext, right, left, vn, vi) - #@info "VarwisePriorContext dot_tilde_assume!! called for $vn" - # @show vn, left, right, typeof(context).name - value, logp, vi = dot_tilde_assume(context.context, right, left, vn, vi) - new_context = acc_logp!(context, vn, logp) - return value, logp, vi -end - - -tilde_observe(context::VarwisePriorContext, right, left, vi) = 0, vi -dot_tilde_observe(::VarwisePriorContext, right, left, vi) = 0, vi - -function acc_logp!(context::VarwisePriorContext, vn::Union{VarName,AbstractVector{<:VarName}}, logp) - #sym = DynamicPPL.getsym(vn) # leads to duplicates - # if vn is a Vector leads to Symbol("VarName{:s, IndexLens{Tuple{Int64}}}[s[1], s[2]]") - sym = Symbol(vn) - context.logpriors[sym] = logp - return (context) -end - -""" - varwise_logpriors(model::Model, chain::Chains; context) - -Runs `model` on each sample in `chain` returning a tuple `(values, var_names)` -with var_names corresponding to symbols of the prior components, and values being -array of shape `(num_samples, num_components, num_chains)`. - -`context` specifies child context that handles computation of log-priors. - -# Example -```julia; setup = :(using Distributions) -using DynamicPPL, Turing - -@model function demo(x, ::Type{TV}=Vector{Float64}) where {TV} - s ~ InverseGamma(2, 3) - m = TV(undef, length(x)) - for i in eachindex(x) - m[i] ~ Normal(0, √s) - end - x ~ MvNormal(m, √s) - end - -model = demo(randn(3), randn()); - -chain = sample(model, MH(), 10); - -lp = varwise_logpriors(model, chain) -# Can be used to construct a new Chains object -#lpc = MCMCChains(varwise_logpriors(model, chain)...) - -# got a logdensity for each parameter prior -(but fewer if used `.~` assignments, see below) -lp[2] == names(chain, :parameters) - -# for each sample in the Chains object -size(lp[1])[[1,3]] == size(chain)[[1,3]] -``` - -# Broadcasting -Note that `m .~ Dist()` will treat `m` as a collection of -_independent_ prior rather than as a single prior, -but `varwise_logpriors` returns the single -sum of log-likelihood of components of `m` only. -If one needs the log-density of the components, one needs to rewrite -the model with an explicit loop. -""" -function varwise_logpriors( - model::Model, varinfo::AbstractVarInfo, - context::AbstractContext=PriorContext() -) -# top_context = VarwisePriorContext(OrderedDict{Symbol,Float64}(), context) - top_context = VarwisePriorContext(OrderedDict{Symbol,Float64}(), context) - model(varinfo, top_context) - return top_context.logpriors -end - -function varwise_logpriors(model::Model, chain::AbstractChains, - context::AbstractContext=PriorContext(); - top_context::VarwisePriorContext{T} = VarwisePriorContext(OrderedDict{Symbol,Float64}(), context) - ) where T - # pass top-context as keyword to allow adapt Number type of log-prior - get_values = (vi) -> begin - model(vi, top_context) - values(top_context.logpriors) - end - arr = map_model(get_values, model, chain) - par_names = collect(keys(top_context.logpriors)) - return(arr, par_names) -end - -function map_model(get_values, model::Model, chain::AbstractChains) - niters = size(chain, 1) - nchains = size(chain, 3) - vi = VarInfo(model) - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - # initialize the array by the first result - (sample_idx, chain_idx), iters2 = Iterators.peel(iters) - setval!(vi, chain, sample_idx, chain_idx) - values1 = get_values(vi) - arr = Array{eltype(values1)}(undef, niters, length(values1), nchains) - arr[sample_idx, :, chain_idx] .= values1 - #(sample_idx, chain_idx), iters3 = Iterators.peel(iters2) - for (sample_idx, chain_idx) in iters2 - # Update the values - setval!(vi, chain, sample_idx, chain_idx) - values_i = get_values(vi) - arr[sample_idx, :, chain_idx] .= values_i - end - return(arr) -end diff --git a/src/loglikelihoods.jl b/src/pointwise_logdensities.jl similarity index 59% rename from src/loglikelihoods.jl rename to src/pointwise_logdensities.jl index 227a70889..73d3e5c0c 100644 --- a/src/loglikelihoods.jl +++ b/src/pointwise_logdensities.jl @@ -1,83 +1,83 @@ # Context version -struct PointwiseLikelihoodContext{A,Ctx} <: AbstractContext - loglikelihoods::A +struct PointwiseLogdensityContext{A,Ctx} <: AbstractContext + logdensities::A context::Ctx end -function PointwiseLikelihoodContext( +function PointwiseLogdensityContext( likelihoods=OrderedDict{VarName,Vector{Float64}}(), - context::AbstractContext=LikelihoodContext(), + context::AbstractContext=DefaultContext(), ) - return PointwiseLikelihoodContext{typeof(likelihoods),typeof(context)}( + return PointwiseLogdensityContext{typeof(likelihoods),typeof(context)}( likelihoods, context ) end -NodeTrait(::PointwiseLikelihoodContext) = IsParent() -childcontext(context::PointwiseLikelihoodContext) = context.context -function setchildcontext(context::PointwiseLikelihoodContext, child) - return PointwiseLikelihoodContext(context.loglikelihoods, child) +NodeTrait(::PointwiseLogdensityContext) = IsParent() +childcontext(context::PointwiseLogdensityContext) = context.context +function setchildcontext(context::PointwiseLogdensityContext, child) + return PointwiseLogdensityContext(context.logdensities, child) end function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{VarName,Vector{Float64}}}, + context::PointwiseLogdensityContext{<:AbstractDict{VarName,Vector{Float64}}}, vn::VarName, logp::Real, ) - lookup = context.loglikelihoods + lookup = context.logdensities ℓ = get!(lookup, vn, Float64[]) return push!(ℓ, logp) end function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{VarName,Float64}}, + context::PointwiseLogdensityContext{<:AbstractDict{VarName,Float64}}, vn::VarName, logp::Real, ) - return context.loglikelihoods[vn] = logp + return context.logdensities[vn] = logp end function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{String,Vector{Float64}}}, + context::PointwiseLogdensityContext{<:AbstractDict{String,Vector{Float64}}}, vn::VarName, logp::Real, ) - lookup = context.loglikelihoods + lookup = context.logdensities ℓ = get!(lookup, string(vn), Float64[]) return push!(ℓ, logp) end function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{String,Float64}}, + context::PointwiseLogdensityContext{<:AbstractDict{String,Float64}}, vn::VarName, logp::Real, ) - return context.loglikelihoods[string(vn)] = logp + return context.logdensities[string(vn)] = logp end function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{String,Vector{Float64}}}, + context::PointwiseLogdensityContext{<:AbstractDict{String,Vector{Float64}}}, vn::String, logp::Real, ) - lookup = context.loglikelihoods + lookup = context.logdensities ℓ = get!(lookup, vn, Float64[]) return push!(ℓ, logp) end function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{String,Float64}}, + context::PointwiseLogdensityContext{<:AbstractDict{String,Float64}}, vn::String, logp::Real, ) - return context.loglikelihoods[vn] = logp + return context.logdensities[vn] = logp end -function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vi) +function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi) # Defer literal `observe` to child-context. return tilde_observe!!(context.context, right, left, vi) end -function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, vi) +function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi) # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. # we have to intercept the call to `tilde_observe!`. logp, vi = tilde_observe(context.context, right, left, vi) @@ -88,11 +88,11 @@ function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, v return left, acclogp!!(vi, logp) end -function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vi) +function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi) # Defer literal `observe` to child-context. return dot_tilde_observe!!(context.context, right, left, vi) end -function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, vi) +function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi) # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. # we have to intercept the call to `dot_tilde_observe!`. @@ -129,8 +129,49 @@ function _pointwise_tilde_observe( end end +function tilde_assume(context::PointwiseLogdensityContext, right, vn, vi) + #@info "PointwiseLogdensityContext tilde_assume!! called for $vn" + value, logp, vi = tilde_assume(context.context, right, vn, vi) + #sym = DynamicPPL.getsym(vn) + new_context = acc_logp!(context, vn, logp) + return value, logp, vi +end + +function dot_tilde_assume(context::PointwiseLogdensityContext, right, left, vn, vi) + #@info "PointwiseLogdensityContext dot_tilde_assume!! called for $vn" + # @show vn, left, right, typeof(context).name + value, logp, vi = dot_tilde_assume(context.context, right, left, vn, vi) + new_context = acc_logp!(context, vn, logp) + return value, logp, vi +end + +function acc_logp!(context::PointwiseLogdensityContext, vn::VarName, logp) + push!(context, vn, logp) + return (context) +end + +function acc_logp!(context::PointwiseLogdensityContext, vns::AbstractVector{<:VarName}, logp) + # construct a new VarName from given sequence of VarName + # assume that all items in vns have an IndexLens optic + indices = tuplejoin(map(vn -> getoptic(vn).indices, vns)...) + vn = VarName(first(vns), Accessors.IndexLens(indices)) + push!(context, vn, logp) + return (context) +end + +#https://discourse.julialang.org/t/efficient-tuple-concatenation/5398/8 +@inline tuplejoin(x) = x +@inline tuplejoin(x, y) = (x..., y...) +@inline tuplejoin(x, y, z...) = (x..., tuplejoin(y, z...)...) + +() -> begin + # code that generates julia-repl in docstring below + # using DynamicPPL, Turing + # TODO when Turing version that is compatible with DynamicPPL 0.29 becomes available +end + """ - pointwise_loglikelihoods(model::Model, chain::Chains, keytype = String) + pointwise_logdensities(model::Model, chain::Chains, keytype = String) Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` with keys corresponding to symbols of the observations, and values being matrices @@ -184,21 +225,21 @@ julia> model = demo(randn(3), randn()); julia> chain = sample(model, MH(), 10); -julia> pointwise_loglikelihoods(model, chain) +julia> pointwise_logdensities(model, chain) OrderedDict{String,Array{Float64,2}} with 4 entries: "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] -julia> pointwise_loglikelihoods(model, chain, String) +julia> pointwise_logdensities(model, chain, String) OrderedDict{String,Array{Float64,2}} with 4 entries: "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] -julia> pointwise_loglikelihoods(model, chain, VarName) +julia> pointwise_logdensities(model, chain, VarName) OrderedDict{VarName,Array{Float64,2}} with 4 entries: xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333] xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359] @@ -217,20 +258,21 @@ julia> @model function demo(x) julia> m = demo([1.0, ]); -julia> ℓ = pointwise_loglikelihoods(m, VarInfo(m)); first(ℓ[@varname(x[1])]) +julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first(ℓ[@varname(x[1])]) -1.4189385332046727 julia> m = demo([1.0; 1.0]); -julia> ℓ = pointwise_loglikelihoods(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) +julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) (-1.4189385332046727, -1.4189385332046727) ``` """ -function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T} +function pointwise_logdensities(model::Model, chain, + context::AbstractContext=DefaultContext(), keytype::Type{T}=String) where {T} # Get the data by executing the model once vi = VarInfo(model) - context = PointwiseLikelihoodContext(OrderedDict{T,Vector{Float64}}()) + point_context = PointwiseLogdensityContext(OrderedDict{T,Vector{Float64}}(), context) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) for (sample_idx, chain_idx) in iters @@ -238,20 +280,24 @@ function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) setval!(vi, chain, sample_idx, chain_idx) # Execute model - model(vi, context) + model(vi, point_context) end niters = size(chain, 1) nchains = size(chain, 3) - loglikelihoods = OrderedDict( + logdensities = OrderedDict( varname => reshape(logliks, niters, nchains) for - (varname, logliks) in context.loglikelihoods + (varname, logliks) in point_context.logdensities ) - return loglikelihoods + return logdensities end -function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) - context = PointwiseLikelihoodContext(OrderedDict{VarName,Vector{Float64}}()) - model(varinfo, context) - return context.loglikelihoods +function pointwise_logdensities(model::Model, + varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext()) + point_context = PointwiseLogdensityContext( + OrderedDict{VarName,Vector{Float64}}(), context) + model(varinfo, point_context) + return point_context.logdensities end + + diff --git a/src/test_utils.jl b/src/test_utils.jl index 9db8c37fd..85dfa71f4 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -1078,12 +1078,13 @@ function DynamicPPL.dot_tilde_assume(context::TestLogModifyingChildContext, righ return value, logp*context.mod, vi end function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi) - value, logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) - return value, logp*context.mod, vi + # @info "called tilde_observe TestLogModifyingChildContext for left=$left, right=$right" + logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) + return logp*context.mod, vi end function DynamicPPL.dot_tilde_observe(context::TestLogModifyingChildContext, right, left, vi) - return DynamicPPL.dot_tilde_observe(context.context, right, left, vi) - return value, logp*context.mod, vi + logp, vi = DynamicPPL.dot_tilde_observe(context.context, right, left, vi) + return logp*context.mod, vi end diff --git a/test/contexts.jl b/test/contexts.jl index 11e2c99b7..4ec9ff945 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -8,7 +8,7 @@ using DynamicPPL: NodeTrait, IsLeaf, IsParent, - PointwiseLikelihoodContext, + PointwiseLogdensityContext, contextual_isassumption, ConditionContext, hasconditioned, @@ -67,7 +67,7 @@ end SamplingContext(), MiniBatchContext(DefaultContext(), 0.0), PrefixContext{:x}(DefaultContext()), - PointwiseLikelihoodContext(), + PointwiseLogdensityContext(), ConditionContext((x=1.0,)), ConditionContext((x=1.0,), ParentContext(ConditionContext((y=2.0,)))), ConditionContext((x=1.0,), PrefixContext{:a}(ConditionContext((var"a.y"=2.0,)))), diff --git a/test/deprecated.jl b/test/deprecated.jl new file mode 100644 index 000000000..24fb3a55e --- /dev/null +++ b/test/deprecated.jl @@ -0,0 +1,29 @@ +@testset "loglikelihoods.jl" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + example_values = DynamicPPL.TestUtils.rand_prior_true(m) + + # Instantiate a `VarInfo` with the example values. + vi = VarInfo(m) + for vn in DynamicPPL.TestUtils.varnames(m) + vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) + end + + # Compute the pointwise loglikelihoods. + lls = pointwise_loglikelihoods(m, vi) + loglikelihood = sum(sum, values(lls)) + + #if isempty(lls) + if loglikelihood ≈ 0.0 #isempty(lls) + # One of the models with literal observations, so we just skip. + # TODO: Think of better way to detect this special case + continue + end + + loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(m, example_values...) + + #priors = + + @test loglikelihood ≈ loglikelihood_true + end +end + diff --git a/test/loglikelihoods.jl b/test/pointwise_logdensitiesjl similarity index 65% rename from test/loglikelihoods.jl rename to test/pointwise_logdensitiesjl index 17185c4ad..5214bf5a1 100644 --- a/test/loglikelihoods.jl +++ b/test/pointwise_logdensitiesjl @@ -1,41 +1,16 @@ -@testset "loglikelihoods.jl" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - example_values = DynamicPPL.TestUtils.rand_prior_true(m) - - # Instantiate a `VarInfo` with the example values. - vi = VarInfo(m) - for vn in DynamicPPL.TestUtils.varnames(m) - vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) - end - - # Compute the pointwise loglikelihoods. - lls = pointwise_loglikelihoods(m, vi) - - if isempty(lls) - # One of the models with literal observations, so we just skip. - continue - end - - loglikelihood = sum(sum, values(lls)) - loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(m, example_values...) - - @test loglikelihood ≈ loglikelihood_true - end -end - -@testset "logpriors_var.jl" begin - mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2, PriorContext()) +@testset "logdensities_likelihoods.jl" begin + likelihood_context = LikelihoodContext() + prior_context = PriorContext() + mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2) mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) #m = DynamicPPL.TestUtils.DEMO_MODELS[1] - # m = DynamicPPL.TestUtils.demo_assume_index_observe() # logp at i-level? - # m = DynamicPPL.TestUtils.demo_assume_dot_observe() # failing test? @testset "$(m.f)" for (i, m) in enumerate(DynamicPPL.TestUtils.DEMO_MODELS) #@show i example_values = DynamicPPL.TestUtils.rand_prior_true(m) # Instantiate a `VarInfo` with the example values. vi = VarInfo(m) - () -> begin + () -> begin # when interactively debugging, need the global keyword for vn in DynamicPPL.TestUtils.varnames(m) global vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) end @@ -44,26 +19,45 @@ end vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) end - #chains = sample(m, SampleFromPrior(), 2; progress=false) + loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(m, example_values...) + logp_true = logprior(m, vi) # Compute the pointwise loglikelihoods. - logpriors = DynamicPPL.varwise_logpriors(m, vi) - logp1 = getlogp(vi) - logp = logprior(m, vi) - @test !isfinite(logp) || sum(x -> sum(x), values(logpriors)) ≈ logp - # - # test on modifying child-context - logpriors_mod = DynamicPPL.varwise_logpriors(m, vi, mod_ctx2) - logp1 = getlogp(vi) - #logp_mod = logprior(m, vi) # uses prior context but not mod_ctx2 - # Following line assumes no Likelihood contributions - # requires lowest Context to be PriorContext - @test !isfinite(logp1) || sum(x -> sum(x), values(logpriors_mod)) ≈ logp1 # - @test all(values(logpriors_mod) .≈ values(logpriors) .* 1.2 .* 1.4) + lls = pointwise_logdensities(m, vi, likelihood_context) + #lls2 = pointwise_loglikelihoods(m, vi) + loglikelihood = sum(sum, values(lls)) + if loglikelihood ≈ 0.0 #isempty(lls) + # One of the models with literal observations, so we just skip. + # TODO: Think of better way to detect this special case + loglikelihood_true = 0.0 + end + @test loglikelihood ≈ loglikelihood_true + + # Compute the pointwise logdensities of the priors. + lps_prior = pointwise_logdensities(m, vi, prior_context) + logp = sum(sum, values(lps_prior)) + if false # isempty(lps_prior) + # One of the models with only observations so we just skip. + else + logp1 = getlogp(vi) + @test !isfinite(logp_true) || logp ≈ logp_true + end + + # Compute both likelihood and logdensity of prior + # using the default DefaultContex + lps = pointwise_logdensities(m, vi) + logp = sum(sum, values(lps)) + @test logp ≈ (logp_true + loglikelihood_true) + + # Test that modifications of Setup are picked up + lps = pointwise_logdensities(m, vi, mod_ctx2) + logp = sum(sum, values(lps)) + @test logp ≈ (logp_true + loglikelihood_true) * 1.2 * 1.4 end -end; +end + -@testset "logpriors_var chain" begin +@testset "pointwise_logdensities chain" begin @model function demo(x, ::Type{TV}=Vector{Float64}) where {TV} s ~ InverseGamma(2, 3) m = TV(undef, length(x)) @@ -82,11 +76,13 @@ end; end arr0 = [5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 0.9199555480151707 -0.1304320097505629 1.0669120062696917 -0.05253734412139093; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183;;; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 2.5409470583244933 1.7838744695696407 0.7013562890105632 -3.0843947804314658; 0.8296370582311665 1.5360702767879642 -1.5964695255693102 0.16928084806166913; 2.6246697053824954 0.8096845024785173 -1.2621822861663752 1.1414885535466166; 1.1304261861894538 0.7325784741344005 -1.1866016911837542 -0.1639319562090826; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 0.9838526141898173 -0.20198797220982412 2.0569535882007006 -1.1560724118010939] chain = Chains(arr0, [:s, Symbol("m[1]"), Symbol("m[2]"), Symbol("m[3]")]); - tmp1 = varwise_logpriors(model, chain) - tmp = Chains(tmp1...); # can be used to create a Chains object + tmp1 = pointwise_logdensities(model, chain) vi = VarInfo(model) i_sample, i_chain = (1,2) DynamicPPL.setval!(vi, chain, i_sample, i_chain) - lp1 = DynamicPPL.varwise_logpriors(model, vi) - @test all(tmp1[1][i_sample,:,i_chain] .≈ values(lp1)) + lp1 = DynamicPPL.pointwise_logdensities(model, vi) + # k = first(keys(lp1)) + for k in keys(lp1) + @test tmp1[string(k)][i_sample,i_chain] .≈ lp1[k][1] + end end; \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index a321151ab..6de0cb7fe 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -57,9 +57,11 @@ include("test_util.jl") include("serialization.jl") - include("loglikelihoods.jl") + include("pointwise_logdensitiesjl") include("lkj.jl") + + include("deprecated.jl") end @testset "compat" begin From 18beb57bd37126c534157cf1aeb1795ad810fd14 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Sat, 21 Sep 2024 13:02:20 +0200 Subject: [PATCH 07/37] record single prior components by forwarding dot_tilde_assume to tilde_assume --- src/pointwise_logdensities.jl | 54 +++++++++++-------- ...gdensitiesjl => pointwise_logdensities.jl} | 2 +- test/runtests.jl | 2 +- 3 files changed, 35 insertions(+), 23 deletions(-) rename test/{pointwise_logdensitiesjl => pointwise_logdensities.jl} (99%) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 73d3e5c0c..000b83d32 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -132,37 +132,49 @@ end function tilde_assume(context::PointwiseLogdensityContext, right, vn, vi) #@info "PointwiseLogdensityContext tilde_assume!! called for $vn" value, logp, vi = tilde_assume(context.context, right, vn, vi) - #sym = DynamicPPL.getsym(vn) - new_context = acc_logp!(context, vn, logp) + push!(context, vn, logp) return value, logp, vi end -function dot_tilde_assume(context::PointwiseLogdensityContext, right, left, vn, vi) - #@info "PointwiseLogdensityContext dot_tilde_assume!! called for $vn" - # @show vn, left, right, typeof(context).name - value, logp, vi = dot_tilde_assume(context.context, right, left, vn, vi) - new_context = acc_logp!(context, vn, logp) +function dot_tilde_assume(context::PointwiseLogdensityContext, right, left, vns, vi) + #@info "PointwiseLogdensityContext dot_tilde_assume called for $vns" + value, logp, vi_new = dot_tilde_assume(context.context, right, left, vns, vi) + # dispatch recording of log-densities based on type of right + logps = record_dot_tilde_assume(context, right, left, vns, vi, logp) + sum(logps) ≈ logp || error("Expected sum of individual logp equal origina, but differed sum($(join(logps, ","))) != $logp_orig") return value, logp, vi end -function acc_logp!(context::PointwiseLogdensityContext, vn::VarName, logp) - push!(context, vn, logp) - return (context) +function record_dot_tilde_assume(context::PointwiseLogdensityContext, right::UnivariateDistribution, left, vns, vi, logp) + # forward to tilde_assume for each variable + map(vns) do vn + value_i, logp_i, vi_i = tilde_assume(context, right, vn, vi) + logp_i + end end -function acc_logp!(context::PointwiseLogdensityContext, vns::AbstractVector{<:VarName}, logp) - # construct a new VarName from given sequence of VarName - # assume that all items in vns have an IndexLens optic - indices = tuplejoin(map(vn -> getoptic(vn).indices, vns)...) - vn = VarName(first(vns), Accessors.IndexLens(indices)) - push!(context, vn, logp) - return (context) +function record_dot_tilde_assume(context::PointwiseLogdensityContext, rights::AbstractVector{<:Distribution}, left, vns, vi, logp) + # forward to tilde_assume for each variable and distribution + logps = map(vns, rights) do vn, right + # use current context to record vn + value_i, logp_i, vi_i = tilde_assume(context, right, vn, vi) + logp_i + end end -#https://discourse.julialang.org/t/efficient-tuple-concatenation/5398/8 -@inline tuplejoin(x) = x -@inline tuplejoin(x, y) = (x..., y...) -@inline tuplejoin(x, y, z...) = (x..., tuplejoin(y, z...)...) +function record_dot_tilde_assume(context::PointwiseLogdensityContext, right::MultivariateDistribution, left, vns, vi, logp) + #@info "PointwiseLogdensityContext record_dot_tilde_assume multivariate called for $vns" + # For multivariate distribution on the right there is only a single density. + # Need to construct a combined VarName. + # Assume that all vns have an IndexLens with a Colon at the first position + # and a single number at the second position. + indices = map(vn -> getoptic(vn).indices[2], vns) + indices_combined = (:,indices) + #indices = tuplejoin(map(vn -> getoptic(vn).indices[2], vns)...) + vn = VarName(first(vns), Accessors.IndexLens(indices_combined)) + push!(context, vn, logp) + return logp +end () -> begin # code that generates julia-repl in docstring below diff --git a/test/pointwise_logdensitiesjl b/test/pointwise_logdensities.jl similarity index 99% rename from test/pointwise_logdensitiesjl rename to test/pointwise_logdensities.jl index 5214bf5a1..77522edb3 100644 --- a/test/pointwise_logdensitiesjl +++ b/test/pointwise_logdensities.jl @@ -3,7 +3,7 @@ prior_context = PriorContext() mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2) mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) - #m = DynamicPPL.TestUtils.DEMO_MODELS[1] + #m = DynamicPPL.TestUtils.DEMO_MODELS[12] @testset "$(m.f)" for (i, m) in enumerate(DynamicPPL.TestUtils.DEMO_MODELS) #@show i example_values = DynamicPPL.TestUtils.rand_prior_true(m) diff --git a/test/runtests.jl b/test/runtests.jl index 6de0cb7fe..3632879f5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -57,7 +57,7 @@ include("test_util.jl") include("serialization.jl") - include("pointwise_logdensitiesjl") + include("pointwise_logdensities.jl") include("lkj.jl") From d9945d72c0d4db75b209ff5244098019466a2fab Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Sun, 22 Sep 2024 07:00:22 +0200 Subject: [PATCH 08/37] forward dot_tilde_assume to tilde_assume for Multivariate --- src/pointwise_logdensities.jl | 24 +++----------- src/test_utils.jl | 57 ++++++++++++++++++++++++++++++++++ test/pointwise_logdensities.jl | 7 +++-- 3 files changed, 66 insertions(+), 22 deletions(-) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 000b83d32..834a17b8d 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -98,7 +98,7 @@ function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, v # We want to treat `.~` as a collection of independent observations, # hence we need the `logp` for each of them. Broadcasting the univariate - # `tilde_obseve` does exactly this. + # `tilde_observe` does exactly this. logps = _pointwise_tilde_observe(context.context, right, left, vi) # Need to unwrap the `vn`, i.e. get one `VarName` for each entry in `left`. @@ -129,8 +129,8 @@ function _pointwise_tilde_observe( end end -function tilde_assume(context::PointwiseLogdensityContext, right, vn, vi) - #@info "PointwiseLogdensityContext tilde_assume!! called for $vn" +function tilde_assume(context::PointwiseLogdensityContext, right::Distribution, vn, vi) + #@info "PointwiseLogdensityContext tilde_assume called for $vn" value, logp, vi = tilde_assume(context.context, right, vn, vi) push!(context, vn, logp) return value, logp, vi @@ -145,7 +145,7 @@ function dot_tilde_assume(context::PointwiseLogdensityContext, right, left, vns, return value, logp, vi end -function record_dot_tilde_assume(context::PointwiseLogdensityContext, right::UnivariateDistribution, left, vns, vi, logp) +function record_dot_tilde_assume(context::PointwiseLogdensityContext, right::Distribution, left, vns, vi, logp) # forward to tilde_assume for each variable map(vns) do vn value_i, logp_i, vi_i = tilde_assume(context, right, vn, vi) @@ -155,27 +155,13 @@ end function record_dot_tilde_assume(context::PointwiseLogdensityContext, rights::AbstractVector{<:Distribution}, left, vns, vi, logp) # forward to tilde_assume for each variable and distribution - logps = map(vns, rights) do vn, right + map(vns, rights) do vn, right # use current context to record vn value_i, logp_i, vi_i = tilde_assume(context, right, vn, vi) logp_i end end -function record_dot_tilde_assume(context::PointwiseLogdensityContext, right::MultivariateDistribution, left, vns, vi, logp) - #@info "PointwiseLogdensityContext record_dot_tilde_assume multivariate called for $vns" - # For multivariate distribution on the right there is only a single density. - # Need to construct a combined VarName. - # Assume that all vns have an IndexLens with a Colon at the first position - # and a single number at the second position. - indices = map(vn -> getoptic(vn).indices[2], vns) - indices_combined = (:,indices) - #indices = tuplejoin(map(vn -> getoptic(vn).indices[2], vns)...) - vn = VarName(first(vns), Accessors.IndexLens(indices_combined)) - push!(context, vn, logp) - return logp -end - () -> begin # code that generates julia-repl in docstring below # using DynamicPPL, Turing diff --git a/src/test_utils.jl b/src/test_utils.jl index 85dfa71f4..03b4b1fd7 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -667,6 +667,62 @@ function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix) return [@varname(s[:, 1], true), @varname(s[:, 2], true), @varname(m)] end +# model with truly two 3 columns in of d=2 (rather than ProductMatrix of single factor) to explore dot_tilde_assume with MultivariateDistribution +@model function demo_dot_assume_matrix_dot_observe_matrix2( + x=transpose([1.5 2.0;1.6 2.1;1.45 2.05]), ::Type{TV}=Array{Float64} +) where {TV} + d = size(x,1) + n = size(x,2) + s = TV(undef, d, n) + # for i in 1:n + # s[:,i] ~ product_distribution([InverseGamma(2, 3) for _ in 1:d]) + # end + s .~ product_distribution([InverseGamma(2, 3) for _ in 1:d]) + m = TV(undef, d, n) + Sigma_x = Diagonal(s[:,1]) + for i in 1:n + diag_s = Diagonal(s[:,i]) + m[:,i] ~ MvNormal(zeros(d), diag_s) + x[:,i] ~ MvNormal(m[:,i], Sigma_x) + end + + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) +end +function logprior_true( + model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix2)}, s, m +) + n = size(model.args.x,1) + d = size(model.args.x,2) + logd = map(1:d) do i_d + s_vec = Diagonal(s[:,i_d]) + loglikelihood(InverseGamma(2, 3), s_vec) + + logpdf(MvNormal(zeros(n), Diagonal(s_vec)), m[:,i_d]) + end + return sum(logd) +end +function loglikelihood_true( + model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix2)}, s, m +) + n = size(model.args.x,2) + d = size(model.args.x,1) + s_vec = Diagonal(s[:,1]) + logd = map(1:n) do i + loglikelihood(MvNormal(m[:,i], Diagonal(s_vec)), model.args.x[:,i]) + end + return sum(logd) +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix2)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end +function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix2)}) + s = m = zeros(2, 3) # used for varname concretization only + return [@varname(s[:, 1], true), @varname(s[:, 2], true), @varname(s[:, 3], true), + @varname(m[:,1], true), @varname(m[:,2], true), @varname(m[:,3], true)] +end + + @model function demo_assume_matrix_dot_observe_matrix( x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64} ) where {TV} @@ -748,6 +804,7 @@ const MultivariateAssumeDemoModels = Union{ Model{typeof(demo_dot_assume_observe_submodel)}, Model{typeof(demo_dot_assume_dot_observe_matrix)}, Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, + Model{typeof(demo_dot_assume_matrix_dot_observe_matrix2)}, } function posterior_mean(model::MultivariateAssumeDemoModels) # Get some containers to fill. diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 77522edb3..b4b2df646 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -4,6 +4,7 @@ mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2) mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) #m = DynamicPPL.TestUtils.DEMO_MODELS[12] + #m = model = DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix2() @testset "$(m.f)" for (i, m) in enumerate(DynamicPPL.TestUtils.DEMO_MODELS) #@show i example_values = DynamicPPL.TestUtils.rand_prior_true(m) @@ -25,13 +26,13 @@ # Compute the pointwise loglikelihoods. lls = pointwise_logdensities(m, vi, likelihood_context) #lls2 = pointwise_loglikelihoods(m, vi) - loglikelihood = sum(sum, values(lls)) - if loglikelihood ≈ 0.0 #isempty(lls) + loglikelihood_sum = sum(sum, values(lls)) + if loglikelihood_sum ≈ 0.0 #isempty(lls) # One of the models with literal observations, so we just skip. # TODO: Think of better way to detect this special case loglikelihood_true = 0.0 end - @test loglikelihood ≈ loglikelihood_true + @test loglikelihood_sum ≈ loglikelihood_true # Compute the pointwise logdensities of the priors. lps_prior = pointwise_logdensities(m, vi, prior_context) From 656a75793a671b625d6a0b0774114b5b4a756e07 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Tue, 24 Sep 2024 07:42:01 +0200 Subject: [PATCH 09/37] avoid recording prior components on leaf-prior-context and avoid recording likelihoods when invoked with leaf-Likelihood context --- src/pointwise_logdensities.jl | 33 ++++++++++++++++++++++++++++----- src/test_utils.jl | 6 ++---- test/deprecated.jl | 5 ++--- test/pointwise_logdensities.jl | 28 +++++++++++++--------------- 4 files changed, 45 insertions(+), 27 deletions(-) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 834a17b8d..f045f9ae2 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -73,11 +73,26 @@ function Base.push!( return context.logdensities[vn] = logp end + +function _include_prior(context::PointwiseLogdensityContext) + return leafcontext(context) isa Union{PriorContext,DefaultContext} +end +function _include_likelihood(context::PointwiseLogdensityContext) + return leafcontext(context) isa Union{LikelihoodContext,DefaultContext} +end + + + function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi) # Defer literal `observe` to child-context. return tilde_observe!!(context.context, right, left, vi) end function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi) + # Completely defer to child context if we are not tracking likelihoods. + if !(_include_likelihood(context)) + return tilde_observe!!(context.context, right, left, vn, vi) + end + # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. # we have to intercept the call to `tilde_observe!`. logp, vi = tilde_observe(context.context, right, left, vi) @@ -93,6 +108,11 @@ function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, v return dot_tilde_observe!!(context.context, right, left, vi) end function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi) + # Completely defer to child context if we are not tracking likelihoods. + if !(_include_likelihood(context)) + return dot_tilde_observe!!(context.context, right, left, vn, vi) + end + # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. # we have to intercept the call to `dot_tilde_observe!`. @@ -132,16 +152,19 @@ end function tilde_assume(context::PointwiseLogdensityContext, right::Distribution, vn, vi) #@info "PointwiseLogdensityContext tilde_assume called for $vn" value, logp, vi = tilde_assume(context.context, right, vn, vi) - push!(context, vn, logp) + if _include_prior(context) + push!(context, vn, logp) + end return value, logp, vi end function dot_tilde_assume(context::PointwiseLogdensityContext, right, left, vns, vi) #@info "PointwiseLogdensityContext dot_tilde_assume called for $vns" value, logp, vi_new = dot_tilde_assume(context.context, right, left, vns, vi) - # dispatch recording of log-densities based on type of right - logps = record_dot_tilde_assume(context, right, left, vns, vi, logp) - sum(logps) ≈ logp || error("Expected sum of individual logp equal origina, but differed sum($(join(logps, ","))) != $logp_orig") + if _include_prior(context) + logps = record_dot_tilde_assume(context, right, left, vns, vi, logp) + sum(logps) ≈ logp || error("Expected sum of individual logp equal origina, but differed sum($(join(logps, ","))) != $logp_orig") + end return value, logp, vi end @@ -172,7 +195,7 @@ end pointwise_logdensities(model::Model, chain::Chains, keytype = String) Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` -with keys corresponding to symbols of the observations, and values being matrices +with keys corresponding to symbols of the variables, and values being matrices of shape `(num_chains, num_samples)`. `keytype` specifies what the type of the keys used in the returned `OrderedDict` are. diff --git a/src/test_utils.jl b/src/test_utils.jl index 03b4b1fd7..c6ed5cd0c 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -1116,10 +1116,8 @@ function TestLogModifyingChildContext( mod, context ) end -# Samplers call leafcontext(model.context) when evaluating log-densities -# Hence, in order to be used need to say that its a leaf-context -#DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent() -DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsLeaf() + +DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent() DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child) return TestLogModifyingChildContext(context.mod, child) diff --git a/test/deprecated.jl b/test/deprecated.jl index 24fb3a55e..322029ec8 100644 --- a/test/deprecated.jl +++ b/test/deprecated.jl @@ -10,15 +10,14 @@ # Compute the pointwise loglikelihoods. lls = pointwise_loglikelihoods(m, vi) - loglikelihood = sum(sum, values(lls)) #if isempty(lls) - if loglikelihood ≈ 0.0 #isempty(lls) + if isempty(lls) # One of the models with literal observations, so we just skip. - # TODO: Think of better way to detect this special case continue end + loglikelihood = sum(sum, values(lls)) loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(m, example_values...) #priors = diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index b4b2df646..d91416a1f 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -3,9 +3,12 @@ prior_context = PriorContext() mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2) mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) - #m = DynamicPPL.TestUtils.DEMO_MODELS[12] + #m = DynamicPPL.TestUtils.DEMO_MODELS[1] #m = model = DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix2() - @testset "$(m.f)" for (i, m) in enumerate(DynamicPPL.TestUtils.DEMO_MODELS) + demo_models = ( + DynamicPPL.TestUtils.DEMO_MODELS..., + DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix2()) + @testset "$(m.f)" for (i, m) in enumerate(demo_models) #@show i example_values = DynamicPPL.TestUtils.rand_prior_true(m) @@ -26,23 +29,19 @@ # Compute the pointwise loglikelihoods. lls = pointwise_logdensities(m, vi, likelihood_context) #lls2 = pointwise_loglikelihoods(m, vi) - loglikelihood_sum = sum(sum, values(lls)) - if loglikelihood_sum ≈ 0.0 #isempty(lls) + if isempty(lls) # One of the models with literal observations, so we just skip. - # TODO: Think of better way to detect this special case loglikelihood_true = 0.0 + else + loglikelihood_sum = sum(sum, values(lls)) + @test loglikelihood_sum ≈ loglikelihood_true end - @test loglikelihood_sum ≈ loglikelihood_true # Compute the pointwise logdensities of the priors. lps_prior = pointwise_logdensities(m, vi, prior_context) logp = sum(sum, values(lps_prior)) - if false # isempty(lps_prior) - # One of the models with only observations so we just skip. - else - logp1 = getlogp(vi) - @test !isfinite(logp_true) || logp ≈ logp_true - end + logp1 = getlogp(vi) + @test !isfinite(logp_true) || logp ≈ logp_true # Compute both likelihood and logdensity of prior # using the default DefaultContex @@ -57,7 +56,6 @@ end end - @testset "pointwise_logdensities chain" begin @model function demo(x, ::Type{TV}=Vector{Float64}) where {TV} s ~ InverseGamma(2, 3) @@ -73,9 +71,9 @@ end # generate the sample used below chain = sample(model, MH(), MCMCThreads(), 10, 2) arr0 = stack(Array(chain, append_chains=false)) - @show(arr0); + @show(arr0[1:2,:,:]); end - arr0 = [5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 0.9199555480151707 -0.1304320097505629 1.0669120062696917 -0.05253734412139093; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183;;; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 2.5409470583244933 1.7838744695696407 0.7013562890105632 -3.0843947804314658; 0.8296370582311665 1.5360702767879642 -1.5964695255693102 0.16928084806166913; 2.6246697053824954 0.8096845024785173 -1.2621822861663752 1.1414885535466166; 1.1304261861894538 0.7325784741344005 -1.1866016911837542 -0.1639319562090826; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 0.9838526141898173 -0.20198797220982412 2.0569535882007006 -1.1560724118010939] + arr0[1:2, :, :] = [5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317;;; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497] chain = Chains(arr0, [:s, Symbol("m[1]"), Symbol("m[2]"), Symbol("m[3]")]); tmp1 = pointwise_logdensities(model, chain) vi = VarInfo(model) From 7aa9ebe203e17a3f214edd937f2e838ba26fb4b2 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Tue, 24 Sep 2024 08:17:46 +0200 Subject: [PATCH 10/37] undeprecate pointwise_loglikelihoods and implement pointwise_prior_logdensities mostly taken from #669 --- src/DynamicPPL.jl | 4 +-- src/deprecated.jl | 9 ------ src/pointwise_logdensities.jl | 59 ++++++++++++++++++++++++++++++++++ test/deprecated.jl | 28 ---------------- test/pointwise_logdensities.jl | 5 +-- test/runtests.jl | 2 -- 6 files changed, 64 insertions(+), 43 deletions(-) delete mode 100644 src/deprecated.jl delete mode 100644 test/deprecated.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 9d870c9e8..de304fd28 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -115,8 +115,9 @@ export AbstractVarInfo, # Convenience functions logprior, logjoint, - pointwise_loglikelihoods, + pointwise_prior_logdensities, pointwise_logdensities, + pointwise_loglikelihoods, condition, decondition, fix, @@ -190,7 +191,6 @@ include("logdensityfunction.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") -include("deprecated.jl") include("debug_utils.jl") using .DebugUtils diff --git a/src/deprecated.jl b/src/deprecated.jl deleted file mode 100644 index 09e1ac84d..000000000 --- a/src/deprecated.jl +++ /dev/null @@ -1,9 +0,0 @@ -# https://invenia.github.io/blog/2022/06/17/deprecating-in-julia/ - -Base.@deprecate pointwise_loglikelihoods(model::Model, chain, keytype) pointwise_logdensities( - model::Model, LikelihoodContext(), chain, keytype) - -Base.@deprecate pointwise_loglikelihoods( - model::Model, varinfo::AbstractVarInfo) pointwise_logdensities( - model::Model, varinfo, LikelihoodContext()) - diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index f045f9ae2..9faca254a 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -322,3 +322,62 @@ function pointwise_logdensities(model::Model, end + + +""" + pointwise_loglikelihoods(model, chain[, keytype, context]) +Compute the pointwise log-likelihoods of the model given the chain. +This is the same as `pointwise_logdensities(model, chain, context)`, but only +including the likelihood terms. +See also: [`pointwise_logdensities`](@ref). +""" +function pointwise_loglikelihoods( + model::Model, + chain, + keytype::Type{T}=String, + context::AbstractContext=LikelihoodContext(), +) where {T} + if !(leafcontext(context) isa LikelihoodContext) + throw(ArgumentError("Leaf context should be a LikelihoodContext")) + end + + return pointwise_logdensities(model, chain, T, context) +end + +function pointwise_loglikelihoods( + model::Model, varinfo::AbstractVarInfo, context::AbstractContext=LikelihoodContext() +) + if !(leafcontext(context) isa LikelihoodContext) + throw(ArgumentError("Leaf context should be a LikelihoodContext")) + end + + return pointwise_logdensities(model, varinfo, context) +end + +""" + pointwise_prior_logdensities(model, chain[, keytype, context]) +Compute the pointwise log-prior-densities of the model given the chain. +This is the same as `pointwise_logdensities(model, chain, context)`, but only +including the prior terms. +See also: [`pointwise_logdensities`](@ref). +""" +function pointwise_prior_logdensities( + model::Model, chain, keytype::Type{T}=String, context::AbstractContext=PriorContext() +) where {T} + if !(leafcontext(context) isa PriorContext) + throw(ArgumentError("Leaf context should be a PriorContext")) + end + + return pointwise_logdensities(model, chain, T, context) +end + +function pointwise_prior_logdensities( + model::Model, varinfo::AbstractVarInfo, context::AbstractContext=PriorContext() +) + if !(leafcontext(context) isa PriorContext) + throw(ArgumentError("Leaf context should be a PriorContext")) + end + + return pointwise_logdensities(model, varinfo, context) +end + diff --git a/test/deprecated.jl b/test/deprecated.jl deleted file mode 100644 index 322029ec8..000000000 --- a/test/deprecated.jl +++ /dev/null @@ -1,28 +0,0 @@ -@testset "loglikelihoods.jl" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - example_values = DynamicPPL.TestUtils.rand_prior_true(m) - - # Instantiate a `VarInfo` with the example values. - vi = VarInfo(m) - for vn in DynamicPPL.TestUtils.varnames(m) - vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) - end - - # Compute the pointwise loglikelihoods. - lls = pointwise_loglikelihoods(m, vi) - - #if isempty(lls) - if isempty(lls) - # One of the models with literal observations, so we just skip. - continue - end - - loglikelihood = sum(sum, values(lls)) - loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(m, example_values...) - - #priors = - - @test loglikelihood ≈ loglikelihood_true - end -end - diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index d91416a1f..d53c194dc 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -27,7 +27,7 @@ logp_true = logprior(m, vi) # Compute the pointwise loglikelihoods. - lls = pointwise_logdensities(m, vi, likelihood_context) + lls = pointwise_loglikelihoods(m, vi) #lls2 = pointwise_loglikelihoods(m, vi) if isempty(lls) # One of the models with literal observations, so we just skip. @@ -38,7 +38,7 @@ end # Compute the pointwise logdensities of the priors. - lps_prior = pointwise_logdensities(m, vi, prior_context) + lps_prior = pointwise_prior_logdensities(m, vi) logp = sum(sum, values(lps_prior)) logp1 = getlogp(vi) @test !isfinite(logp_true) || logp ≈ logp_true @@ -56,6 +56,7 @@ end end + @testset "pointwise_logdensities chain" begin @model function demo(x, ::Type{TV}=Vector{Float64}) where {TV} s ~ InverseGamma(2, 3) diff --git a/test/runtests.jl b/test/runtests.jl index 3632879f5..c78e8f941 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -60,8 +60,6 @@ include("test_util.jl") include("pointwise_logdensities.jl") include("lkj.jl") - - include("deprecated.jl") end @testset "compat" begin From 2f67c5b3691b36078998cfccda33b735975fb11a Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Tue, 24 Sep 2024 08:41:52 +0200 Subject: [PATCH 11/37] drop vi instead of re-compute vi bgctw first forwared dot_tilde_assume to get a correct vi and then recomputed it for recording component prior densities. Replaced this by the Hack of torfjelde that completely drops vi and recombines the value, so that assume is called only once for each varName, --- src/pointwise_logdensities.jl | 98 +++++++++++++++++++++++++--------- test/pointwise_logdensities.jl | 6 +-- 2 files changed, 76 insertions(+), 28 deletions(-) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 9faca254a..aaf2128aa 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -149,40 +149,86 @@ function _pointwise_tilde_observe( end end -function tilde_assume(context::PointwiseLogdensityContext, right::Distribution, vn, vi) - #@info "PointwiseLogdensityContext tilde_assume called for $vn" +# function tilde_assume(context::PointwiseLogdensityContext, right::Distribution, vn, vi) +# #@info "PointwiseLogdensityContext tilde_assume called for $vn" +# value, logp, vi = tilde_assume(context.context, right, vn, vi) +# if _include_prior(context) +# push!(context, vn, logp) +# end +# return value, logp, vi +# end + +# function dot_tilde_assume(context::PointwiseLogdensityContext, right, left, vns, vi) +# #@info "PointwiseLogdensityContext dot_tilde_assume called for $vns" +# value, logp, vi_new = dot_tilde_assume(context.context, right, left, vns, vi) +# if _include_prior(context) +# logps = record_dot_tilde_assume(context, right, left, vns, vi, logp) +# sum(logps) ≈ logp || error("Expected sum of individual logp equal origina, but differed sum($(join(logps, ","))) != $logp_orig") +# end +# return value, logp, vi +# end + +# function record_dot_tilde_assume(context::PointwiseLogdensityContext, right::Distribution, left, vns, vi, logp) +# # forward to tilde_assume for each variable +# map(vns) do vn +# value_i, logp_i, vi_i = tilde_assume(context, right, vn, vi) +# logp_i +# end +# end + +# function record_dot_tilde_assume(context::PointwiseLogdensityContext, rights::AbstractVector{<:Distribution}, left, vns, vi, logp) +# # forward to tilde_assume for each variable and distribution +# map(vns, rights) do vn, right +# # use current context to record vn +# value_i, logp_i, vi_i = tilde_assume(context, right, vn, vi) +# logp_i +# end +# end + +function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi) + !_include_prior(context) && return(tilde_assume!!(context.context, right, vn, vi)) value, logp, vi = tilde_assume(context.context, right, vn, vi) - if _include_prior(context) - push!(context, vn, logp) - end - return value, logp, vi + # Track loglikelihood value. + push!(context, vn, logp) + return value, acclogp!!(vi, logp) end -function dot_tilde_assume(context::PointwiseLogdensityContext, right, left, vns, vi) - #@info "PointwiseLogdensityContext dot_tilde_assume called for $vns" - value, logp, vi_new = dot_tilde_assume(context.context, right, left, vns, vi) - if _include_prior(context) - logps = record_dot_tilde_assume(context, right, left, vns, vi, logp) - sum(logps) ≈ logp || error("Expected sum of individual logp equal origina, but differed sum($(join(logps, ","))) != $logp_orig") +function dot_tilde_assume!!(context::PointwiseLogdensityContext, right, left, vns, vi) + !_include_prior(context) && return( + dot_tilde_assume!!(context.context, right, left, vns, vi)) + value, logps = _pointwise_tilde_assume(context, right, left, vns, vi) + # Track loglikelihood values. + for (vn, logp) in zip(vns, logps) + push!(context, vn, logp) end - return value, logp, vi + return value, acclogp!!(vi, sum(logps)) end -function record_dot_tilde_assume(context::PointwiseLogdensityContext, right::Distribution, left, vns, vi, logp) - # forward to tilde_assume for each variable - map(vns) do vn - value_i, logp_i, vi_i = tilde_assume(context, right, vn, vi) - logp_i +function _pointwise_tilde_assume(context, right, left, vns, vi) + # We need to drop the `vi` returned. + values_and_logps = broadcast(right, left, vns) do r, l, vn + # HACK(torfjelde): This drops the `vi` returned, which means the `vi` is not updated + # in case of immutable varinfos. But a) atm we're only using mutable varinfos for this, + # and b) even if the variables aren't stored in the vi correctly, we're not going to use + # this vi for anything downstream anyways, i.e. I don't see a case where this would matter + # for this particular use case. + val, logp, _ = tilde_assume(context, r, vn, vi) + return val, logp end + return map(first, values_and_logps), map(last, values_and_logps) end - -function record_dot_tilde_assume(context::PointwiseLogdensityContext, rights::AbstractVector{<:Distribution}, left, vns, vi, logp) - # forward to tilde_assume for each variable and distribution - map(vns, rights) do vn, right - # use current context to record vn - value_i, logp_i, vi_i = tilde_assume(context, right, vn, vi) - logp_i +function _pointwise_tilde_assume( + context, right::MultivariateDistribution, left::AbstractMatrix, vns, vi +) + # We need to drop the `vi` returned. + values_and_logps = map(eachcol(left), vns) do l, vn + val, logp, _ = tilde_assume(context, right, vn, vi) + return val, logp end + # HACK(torfjelde): Due to the way we handle `.~`, we should use `recombine` to stay consistent. + # But this also means that we need to first flatten the entire `values` component before recombining. + values = recombine(right, mapreduce(vec ∘ first, vcat, values_and_logps), length(vns)) + return values, map(last, values_and_logps) end () -> begin @@ -326,6 +372,7 @@ end """ pointwise_loglikelihoods(model, chain[, keytype, context]) + Compute the pointwise log-likelihoods of the model given the chain. This is the same as `pointwise_logdensities(model, chain, context)`, but only including the likelihood terms. @@ -356,6 +403,7 @@ end """ pointwise_prior_logdensities(model, chain[, keytype, context]) + Compute the pointwise log-prior-densities of the model given the chain. This is the same as `pointwise_logdensities(model, chain, context)`, but only including the prior terms. diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index d53c194dc..f3d9cbac8 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -1,6 +1,4 @@ @testset "logdensities_likelihoods.jl" begin - likelihood_context = LikelihoodContext() - prior_context = PriorContext() mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2) mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) #m = DynamicPPL.TestUtils.DEMO_MODELS[1] @@ -28,6 +26,7 @@ # Compute the pointwise loglikelihoods. lls = pointwise_loglikelihoods(m, vi) + @test :s ∉ getsym.(keys(lls)) #lls2 = pointwise_loglikelihoods(m, vi) if isempty(lls) # One of the models with literal observations, so we just skip. @@ -39,6 +38,7 @@ # Compute the pointwise logdensities of the priors. lps_prior = pointwise_prior_logdensities(m, vi) + @test :x ∉ getsym.(keys(lps_prior)) logp = sum(sum, values(lps_prior)) logp1 = getlogp(vi) @test !isfinite(logp_true) || logp ≈ logp_true @@ -74,7 +74,7 @@ end arr0 = stack(Array(chain, append_chains=false)) @show(arr0[1:2,:,:]); end - arr0[1:2, :, :] = [5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317;;; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497] + arr0 = [5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317;;; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497] chain = Chains(arr0, [:s, Symbol("m[1]"), Symbol("m[2]"), Symbol("m[3]")]); tmp1 = pointwise_logdensities(model, chain) vi = VarInfo(model) From 9dfb9ed866731db3a7e8941a90bcd2705f027e9e Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Tue, 24 Sep 2024 08:58:41 +0200 Subject: [PATCH 12/37] include docstrings of pointwise_logdensities pointwise_prior_logdensities int api.md docu --- docs/src/api.md | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 38d9ee6b0..41cdf5843 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -124,22 +124,16 @@ Return values of the model function for a collection of samples can be obtained generated_quantities ``` -For a chain of samples, one can compute the pointwise log-likelihoods of each observed random variable with [`pointwise_loglikelihoods`](@ref). +For a chain of samples, one can compute the pointwise log-likelihoods of each observed random variable with [`pointwise_loglikelihoods`](@ref). Similarly, the log-densities of the priors using +[`pointwise_prior_logdensities`](@ref) or both, i.e. all variables, using +[`pointwise_logdensities`](@ref). ```@docs +pointwise_logdensities pointwise_loglikelihoods +pointwise_prior_logdensities ``` -Similarly, one can compute the pointwise log-priors of each sampled random variable -with [`varwise_logpriors`](@ref). -Differently from `pointwise_loglikelihoods` it reports only a -single value for `.~` assignements. -If one needs to access the parts for single indices, one can -reformulate the model to use an explicit loop instead. - -```@docs -varwise_logpriors -``` For converting a chain into a format that can more easily be fed into a `Model` again, for example using `condition`, you can use [`value_iterator_from_chain`](@ref). From c1939e0d9a70e80285c92484d994c90f7472c425 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Wed, 25 Sep 2024 07:27:20 +0200 Subject: [PATCH 13/37] Update src/pointwise_logdensities.jl remove commented code Co-authored-by: Tor Erlend Fjelde --- src/pointwise_logdensities.jl | 35 ----------------------------------- 1 file changed, 35 deletions(-) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index aaf2128aa..150b2a7ac 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -149,41 +149,6 @@ function _pointwise_tilde_observe( end end -# function tilde_assume(context::PointwiseLogdensityContext, right::Distribution, vn, vi) -# #@info "PointwiseLogdensityContext tilde_assume called for $vn" -# value, logp, vi = tilde_assume(context.context, right, vn, vi) -# if _include_prior(context) -# push!(context, vn, logp) -# end -# return value, logp, vi -# end - -# function dot_tilde_assume(context::PointwiseLogdensityContext, right, left, vns, vi) -# #@info "PointwiseLogdensityContext dot_tilde_assume called for $vns" -# value, logp, vi_new = dot_tilde_assume(context.context, right, left, vns, vi) -# if _include_prior(context) -# logps = record_dot_tilde_assume(context, right, left, vns, vi, logp) -# sum(logps) ≈ logp || error("Expected sum of individual logp equal origina, but differed sum($(join(logps, ","))) != $logp_orig") -# end -# return value, logp, vi -# end - -# function record_dot_tilde_assume(context::PointwiseLogdensityContext, right::Distribution, left, vns, vi, logp) -# # forward to tilde_assume for each variable -# map(vns) do vn -# value_i, logp_i, vi_i = tilde_assume(context, right, vn, vi) -# logp_i -# end -# end - -# function record_dot_tilde_assume(context::PointwiseLogdensityContext, rights::AbstractVector{<:Distribution}, left, vns, vi, logp) -# # forward to tilde_assume for each variable and distribution -# map(vns, rights) do vn, right -# # use current context to record vn -# value_i, logp_i, vi_i = tilde_assume(context, right, vn, vi) -# logp_i -# end -# end function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi) !_include_prior(context) && return(tilde_assume!!(context.context, right, vn, vi)) From 790be1da1d9eadcbf06e888bb28f24b0120b5ced Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Wed, 25 Sep 2024 07:27:58 +0200 Subject: [PATCH 14/37] Update src/pointwise_logdensities.jl remove commented code Co-authored-by: Tor Erlend Fjelde --- src/pointwise_logdensities.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 150b2a7ac..9ac6df755 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -196,12 +196,6 @@ function _pointwise_tilde_assume( return values, map(last, values_and_logps) end -() -> begin - # code that generates julia-repl in docstring below - # using DynamicPPL, Turing - # TODO when Turing version that is compatible with DynamicPPL 0.29 becomes available -end - """ pointwise_logdensities(model::Model, chain::Chains, keytype = String) From 426df38c627992d1a19678f67160e37645c8649a Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Wed, 25 Sep 2024 07:38:13 +0200 Subject: [PATCH 15/37] Update test/pointwise_logdensities.jl rename m to model Co-authored-by: Tor Erlend Fjelde --- test/pointwise_logdensities.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index f3d9cbac8..423ca1bc4 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -37,7 +37,7 @@ end # Compute the pointwise logdensities of the priors. - lps_prior = pointwise_prior_logdensities(m, vi) + lps_prior = pointwise_prior_logdensities(model, vi) @test :x ∉ getsym.(keys(lps_prior)) logp = sum(sum, values(lps_prior)) logp1 = getlogp(vi) From c32bf3b0e0ffca53b1acaf9545316976ef61a47e Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Wed, 25 Sep 2024 07:42:52 +0200 Subject: [PATCH 16/37] Update test/pointwise_logdensities.jl remove unused code Co-authored-by: Tor Erlend Fjelde --- test/pointwise_logdensities.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 423ca1bc4..6cccede34 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -40,7 +40,6 @@ lps_prior = pointwise_prior_logdensities(model, vi) @test :x ∉ getsym.(keys(lps_prior)) logp = sum(sum, values(lps_prior)) - logp1 = getlogp(vi) @test !isfinite(logp_true) || logp ≈ logp_true # Compute both likelihood and logdensity of prior From 6213249d27be6a106c6b82142cae17501b1639df Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Wed, 25 Sep 2024 07:43:23 +0200 Subject: [PATCH 17/37] Update test/pointwise_logdensities.jl rename m to model Co-authored-by: Tor Erlend Fjelde --- test/pointwise_logdensities.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 6cccede34..9864d87c3 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -44,7 +44,7 @@ # Compute both likelihood and logdensity of prior # using the default DefaultContex - lps = pointwise_logdensities(m, vi) + lps = pointwise_logdensities(model, vi) logp = sum(sum, values(lps)) @test logp ≈ (logp_true + loglikelihood_true) From 3551b385f77c397ad7f8d7453ad789898672dc5a Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Wed, 25 Sep 2024 07:43:46 +0200 Subject: [PATCH 18/37] Update test/pointwise_logdensities.jl rename m to model Co-authored-by: Tor Erlend Fjelde --- test/pointwise_logdensities.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 9864d87c3..4de499f00 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -49,7 +49,7 @@ @test logp ≈ (logp_true + loglikelihood_true) # Test that modifications of Setup are picked up - lps = pointwise_logdensities(m, vi, mod_ctx2) + lps = pointwise_logdensities(model, vi, mod_ctx2) logp = sum(sum, values(lps)) @test logp ≈ (logp_true + loglikelihood_true) * 1.2 * 1.4 end From 95c892b531a62dd0de753fa9216a2d2a744d3839 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Wed, 25 Sep 2024 08:19:08 +0200 Subject: [PATCH 19/37] Update src/test_utils.jl remove old code Co-authored-by: Tor Erlend Fjelde --- src/test_utils.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index c6ed5cd0c..5059fe94f 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -1110,7 +1110,6 @@ end function TestLogModifyingChildContext( mod=1.2, context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext(), - #OrderedDict{VarName,Vector{Float64}}(),PriorContext()), ) return TestLogModifyingChildContext{typeof(mod),typeof(context)}( mod, context From a7a7e70794fefd991368149e7919337eed709493 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Wed, 25 Sep 2024 08:39:45 +0200 Subject: [PATCH 20/37] rename m to model --- test/pointwise_logdensities.jl | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 4de499f00..26b9fe1bd 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -1,33 +1,32 @@ @testset "logdensities_likelihoods.jl" begin mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2) mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) - #m = DynamicPPL.TestUtils.DEMO_MODELS[1] - #m = model = DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix2() + #model = DynamicPPL.TestUtils.DEMO_MODELS[1] + #model = model = DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix2() demo_models = ( DynamicPPL.TestUtils.DEMO_MODELS..., DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix2()) - @testset "$(m.f)" for (i, m) in enumerate(demo_models) + @testset "$(model.f)" for (i, model) in enumerate(demo_models) #@show i - example_values = DynamicPPL.TestUtils.rand_prior_true(m) + example_values = DynamicPPL.TestUtils.rand_prior_true(model) # Instantiate a `VarInfo` with the example values. - vi = VarInfo(m) + vi = VarInfo(model) () -> begin # when interactively debugging, need the global keyword - for vn in DynamicPPL.TestUtils.varnames(m) + for vn in DynamicPPL.TestUtils.varnames(model) global vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) end end - for vn in DynamicPPL.TestUtils.varnames(m) + for vn in DynamicPPL.TestUtils.varnames(model) vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) end - loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(m, example_values...) - logp_true = logprior(m, vi) + loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(model, example_values...) + logp_true = logprior(model, vi) # Compute the pointwise loglikelihoods. - lls = pointwise_loglikelihoods(m, vi) + lls = pointwise_loglikelihoods(model, vi) @test :s ∉ getsym.(keys(lls)) - #lls2 = pointwise_loglikelihoods(m, vi) if isempty(lls) # One of the models with literal observations, so we just skip. loglikelihood_true = 0.0 From 1653aba6cfc409d742f255a0f2ad3ffeb1c3a755 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Wed, 25 Sep 2024 08:40:15 +0200 Subject: [PATCH 21/37] JuliaFormatter --- docs/src/api.md | 3 +- src/DynamicPPL.jl | 2 +- src/pointwise_logdensities.jl | 33 +++++++---------- src/test_utils.jl | 67 +++++++++++++++++++--------------- test/pointwise_logdensities.jl | 33 ++++++++++------- 5 files changed, 72 insertions(+), 66 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 41cdf5843..156b51e03 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -125,7 +125,7 @@ generated_quantities ``` For a chain of samples, one can compute the pointwise log-likelihoods of each observed random variable with [`pointwise_loglikelihoods`](@ref). Similarly, the log-densities of the priors using -[`pointwise_prior_logdensities`](@ref) or both, i.e. all variables, using +[`pointwise_prior_logdensities`](@ref) or both, i.e. all variables, using [`pointwise_logdensities`](@ref). ```@docs @@ -134,7 +134,6 @@ pointwise_loglikelihoods pointwise_prior_logdensities ``` - For converting a chain into a format that can more easily be fed into a `Model` again, for example using `condition`, you can use [`value_iterator_from_chain`](@ref). ```@docs diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index de304fd28..777c770d4 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -115,7 +115,7 @@ export AbstractVarInfo, # Convenience functions logprior, logjoint, - pointwise_prior_logdensities, + pointwise_prior_logdensities, pointwise_logdensities, pointwise_loglikelihoods, condition, diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 9ac6df755..9d6e70109 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -73,7 +73,6 @@ function Base.push!( return context.logdensities[vn] = logp end - function _include_prior(context::PointwiseLogdensityContext) return leafcontext(context) isa Union{PriorContext,DefaultContext} end @@ -81,18 +80,16 @@ function _include_likelihood(context::PointwiseLogdensityContext) return leafcontext(context) isa Union{LikelihoodContext,DefaultContext} end - - function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi) # Defer literal `observe` to child-context. return tilde_observe!!(context.context, right, left, vi) end function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi) - # Completely defer to child context if we are not tracking likelihoods. + # Completely defer to child context if we are not tracking likelihoods. if !(_include_likelihood(context)) return tilde_observe!!(context.context, right, left, vn, vi) - end - + end + # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. # we have to intercept the call to `tilde_observe!`. logp, vi = tilde_observe(context.context, right, left, vi) @@ -149,9 +146,8 @@ function _pointwise_tilde_observe( end end - function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi) - !_include_prior(context) && return(tilde_assume!!(context.context, right, vn, vi)) + !_include_prior(context) && return (tilde_assume!!(context.context, right, vn, vi)) value, logp, vi = tilde_assume(context.context, right, vn, vi) # Track loglikelihood value. push!(context, vn, logp) @@ -159,8 +155,8 @@ function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi) end function dot_tilde_assume!!(context::PointwiseLogdensityContext, right, left, vns, vi) - !_include_prior(context) && return( - dot_tilde_assume!!(context.context, right, left, vns, vi)) + !_include_prior(context) && + return (dot_tilde_assume!!(context.context, right, left, vns, vi)) value, logps = _pointwise_tilde_assume(context, right, left, vns, vi) # Track loglikelihood values. for (vn, logp) in zip(vns, logps) @@ -294,8 +290,9 @@ julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ``` """ -function pointwise_logdensities(model::Model, chain, - context::AbstractContext=DefaultContext(), keytype::Type{T}=String) where {T} +function pointwise_logdensities( + model::Model, chain, context::AbstractContext=DefaultContext(), keytype::Type{T}=String +) where {T} # Get the data by executing the model once vi = VarInfo(model) point_context = PointwiseLogdensityContext(OrderedDict{T,Vector{Float64}}(), context) @@ -318,17 +315,16 @@ function pointwise_logdensities(model::Model, chain, return logdensities end -function pointwise_logdensities(model::Model, - varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext()) +function pointwise_logdensities( + model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext() +) point_context = PointwiseLogdensityContext( - OrderedDict{VarName,Vector{Float64}}(), context) + OrderedDict{VarName,Vector{Float64}}(), context + ) model(varinfo, point_context) return point_context.logdensities end - - - """ pointwise_loglikelihoods(model, chain[, keytype, context]) @@ -387,4 +383,3 @@ function pointwise_prior_logdensities( return pointwise_logdensities(model, varinfo, context) end - diff --git a/src/test_utils.jl b/src/test_utils.jl index c6ed5cd0c..3715db3a4 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -667,23 +667,24 @@ function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix) return [@varname(s[:, 1], true), @varname(s[:, 2], true), @varname(m)] end -# model with truly two 3 columns in of d=2 (rather than ProductMatrix of single factor) to explore dot_tilde_assume with MultivariateDistribution +# model with truly 3 columns (rather than ProductMatrix of single factor, d=1) +# to explore dot_tilde_assume with MultivariateDistribution @model function demo_dot_assume_matrix_dot_observe_matrix2( - x=transpose([1.5 2.0;1.6 2.1;1.45 2.05]), ::Type{TV}=Array{Float64} + x=transpose([1.5 2.0; 1.6 2.1; 1.45 2.05]), ::Type{TV}=Array{Float64} ) where {TV} - d = size(x,1) - n = size(x,2) + d = size(x, 1) + n = size(x, 2) s = TV(undef, d, n) # for i in 1:n # s[:,i] ~ product_distribution([InverseGamma(2, 3) for _ in 1:d]) # end s .~ product_distribution([InverseGamma(2, 3) for _ in 1:d]) m = TV(undef, d, n) - Sigma_x = Diagonal(s[:,1]) + Sigma_x = Diagonal(s[:, 1]) for i in 1:n - diag_s = Diagonal(s[:,i]) - m[:,i] ~ MvNormal(zeros(d), diag_s) - x[:,i] ~ MvNormal(m[:,i], Sigma_x) + diag_s = Diagonal(s[:, i]) + m[:, i] ~ MvNormal(zeros(d), diag_s) + x[:, i] ~ MvNormal(m[:, i], Sigma_x) end return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) @@ -691,23 +692,23 @@ end function logprior_true( model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix2)}, s, m ) - n = size(model.args.x,1) - d = size(model.args.x,2) + n = size(model.args.x, 1) + d = size(model.args.x, 2) logd = map(1:d) do i_d - s_vec = Diagonal(s[:,i_d]) + s_vec = Diagonal(s[:, i_d]) loglikelihood(InverseGamma(2, 3), s_vec) + - logpdf(MvNormal(zeros(n), Diagonal(s_vec)), m[:,i_d]) + logpdf(MvNormal(zeros(n), Diagonal(s_vec)), m[:, i_d]) end return sum(logd) end function loglikelihood_true( model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix2)}, s, m ) - n = size(model.args.x,2) - d = size(model.args.x,1) - s_vec = Diagonal(s[:,1]) + n = size(model.args.x, 2) + d = size(model.args.x, 1) + s_vec = Diagonal(s[:, 1]) logd = map(1:n) do i - loglikelihood(MvNormal(m[:,i], Diagonal(s_vec)), model.args.x[:,i]) + loglikelihood(MvNormal(m[:, i], Diagonal(s_vec)), model.args.x[:, i]) end return sum(logd) end @@ -718,11 +719,16 @@ function logprior_true_with_logabsdet_jacobian( end function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix2)}) s = m = zeros(2, 3) # used for varname concretization only - return [@varname(s[:, 1], true), @varname(s[:, 2], true), @varname(s[:, 3], true), - @varname(m[:,1], true), @varname(m[:,2], true), @varname(m[:,3], true)] + return [ + @varname(s[:, 1], true), + @varname(s[:, 2], true), + @varname(s[:, 3], true), + @varname(m[:, 1], true), + @varname(m[:, 2], true), + @varname(m[:, 3], true) + ] end - @model function demo_assume_matrix_dot_observe_matrix( x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64} ) where {TV} @@ -1110,11 +1116,9 @@ end function TestLogModifyingChildContext( mod=1.2, context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext(), - #OrderedDict{VarName,Vector{Float64}}(),PriorContext()), + #OrderedDict{VarName,Vector{Float64}}(),PriorContext()), ) - return TestLogModifyingChildContext{typeof(mod),typeof(context)}( - mod, context - ) + return TestLogModifyingChildContext{typeof(mod),typeof(context)}(mod, context) end DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent() @@ -1125,22 +1129,25 @@ end function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, vn, vi) #@info "TestLogModifyingChildContext tilde_assume!! called for $vn" value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) - return value, logp*context.mod, vi + return value, logp * context.mod, vi end -function DynamicPPL.dot_tilde_assume(context::TestLogModifyingChildContext, right, left, vn, vi) +function DynamicPPL.dot_tilde_assume( + context::TestLogModifyingChildContext, right, left, vn, vi +) #@info "TestLogModifyingChildContext dot_tilde_assume!! called for $vn" value, logp, vi = DynamicPPL.dot_tilde_assume(context.context, right, left, vn, vi) - return value, logp*context.mod, vi + return value, logp * context.mod, vi end function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi) # @info "called tilde_observe TestLogModifyingChildContext for left=$left, right=$right" logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) - return logp*context.mod, vi + return logp * context.mod, vi end -function DynamicPPL.dot_tilde_observe(context::TestLogModifyingChildContext, right, left, vi) +function DynamicPPL.dot_tilde_observe( + context::TestLogModifyingChildContext, right, left, vi +) logp, vi = DynamicPPL.dot_tilde_observe(context.context, right, left, vi) - return logp*context.mod, vi + return logp * context.mod, vi end - end diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 26b9fe1bd..0b76bce00 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -4,8 +4,9 @@ #model = DynamicPPL.TestUtils.DEMO_MODELS[1] #model = model = DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix2() demo_models = ( - DynamicPPL.TestUtils.DEMO_MODELS..., - DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix2()) + DynamicPPL.TestUtils.DEMO_MODELS..., + DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix2(), + ) @testset "$(model.f)" for (i, model) in enumerate(demo_models) #@show i example_values = DynamicPPL.TestUtils.rand_prior_true(model) @@ -21,7 +22,9 @@ vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) end - loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(model, example_values...) + loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true( + model, example_values... + ) logp_true = logprior(model, vi) # Compute the pointwise loglikelihoods. @@ -54,7 +57,6 @@ end end - @testset "pointwise_logdensities chain" begin @model function demo(x, ::Type{TV}=Vector{Float64}) where {TV} s ~ InverseGamma(2, 3) @@ -62,25 +64,28 @@ end for i in eachindex(x) m[i] ~ Normal(0, √s) end - x ~ MvNormal(m, √s) - end + return x ~ MvNormal(m, √s) + end x_true = [0.3290767977680923, 0.038972110187911684, -0.5797496780649221] model = demo(x_true) () -> begin # generate the sample used below chain = sample(model, MH(), MCMCThreads(), 10, 2) - arr0 = stack(Array(chain, append_chains=false)) - @show(arr0[1:2,:,:]); + arr0 = stack(Array(chain; append_chains=false)) + @show(arr0[1:2, :, :]) end - arr0 = [5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317;;; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497] - chain = Chains(arr0, [:s, Symbol("m[1]"), Symbol("m[2]"), Symbol("m[3]")]); + arr0 = [ + 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317;;; + 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497 + ] + chain = Chains(arr0, [:s, Symbol("m[1]"), Symbol("m[2]"), Symbol("m[3]")]) tmp1 = pointwise_logdensities(model, chain) vi = VarInfo(model) - i_sample, i_chain = (1,2) + i_sample, i_chain = (1, 2) DynamicPPL.setval!(vi, chain, i_sample, i_chain) - lp1 = DynamicPPL.pointwise_logdensities(model, vi) + lp1 = DynamicPPL.pointwise_logdensities(model, vi) # k = first(keys(lp1)) for k in keys(lp1) - @test tmp1[string(k)][i_sample,i_chain] .≈ lp1[k][1] + @test tmp1[string(k)][i_sample, i_chain] .≈ lp1[k][1] end -end; \ No newline at end of file +end; From a99eab41459bb1ef9fe1b368504b6a6e86c38cbc Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Thu, 26 Sep 2024 08:34:33 +0200 Subject: [PATCH 22/37] Update test/runtests.jl remove interactive code Co-authored-by: Tor Erlend Fjelde --- test/runtests.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index c78e8f941..918ab61d2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,7 +34,6 @@ const GROUP = get(ENV, "GROUP", "All") Random.seed!(100) -#include(joinpath(DIRECTORY_DynamicPPL,"test","test_util.jl")) include("test_util.jl") @testset "DynamicPPL.jl" begin From 64ce63a10037f63e7d659f21d6708c72dfee04dc Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Thu, 26 Sep 2024 08:53:41 +0200 Subject: [PATCH 23/37] remove demo_dot_assume_matrix_dot_observe_matrix2 testcase testing higher dimensions better left for other PR --- src/test_utils.jl | 63 ---------------------------------- test/pointwise_logdensities.jl | 14 +------- test/runtests.jl | 2 +- 3 files changed, 2 insertions(+), 77 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index c4769a1b9..6c0971f8f 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -667,68 +667,6 @@ function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix) return [@varname(s[:, 1], true), @varname(s[:, 2], true), @varname(m)] end -# model with truly 3 columns (rather than ProductMatrix of single factor, d=1) -# to explore dot_tilde_assume with MultivariateDistribution -@model function demo_dot_assume_matrix_dot_observe_matrix2( - x=transpose([1.5 2.0; 1.6 2.1; 1.45 2.05]), ::Type{TV}=Array{Float64} -) where {TV} - d = size(x, 1) - n = size(x, 2) - s = TV(undef, d, n) - # for i in 1:n - # s[:,i] ~ product_distribution([InverseGamma(2, 3) for _ in 1:d]) - # end - s .~ product_distribution([InverseGamma(2, 3) for _ in 1:d]) - m = TV(undef, d, n) - Sigma_x = Diagonal(s[:, 1]) - for i in 1:n - diag_s = Diagonal(s[:, i]) - m[:, i] ~ MvNormal(zeros(d), diag_s) - x[:, i] ~ MvNormal(m[:, i], Sigma_x) - end - - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) -end -function logprior_true( - model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix2)}, s, m -) - n = size(model.args.x, 1) - d = size(model.args.x, 2) - logd = map(1:d) do i_d - s_vec = Diagonal(s[:, i_d]) - loglikelihood(InverseGamma(2, 3), s_vec) + - logpdf(MvNormal(zeros(n), Diagonal(s_vec)), m[:, i_d]) - end - return sum(logd) -end -function loglikelihood_true( - model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix2)}, s, m -) - n = size(model.args.x, 2) - d = size(model.args.x, 1) - s_vec = Diagonal(s[:, 1]) - logd = map(1:n) do i - loglikelihood(MvNormal(m[:, i], Diagonal(s_vec)), model.args.x[:, i]) - end - return sum(logd) -end -function logprior_true_with_logabsdet_jacobian( - model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix2)}, s, m -) - return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) -end -function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix2)}) - s = m = zeros(2, 3) # used for varname concretization only - return [ - @varname(s[:, 1], true), - @varname(s[:, 2], true), - @varname(s[:, 3], true), - @varname(m[:, 1], true), - @varname(m[:, 2], true), - @varname(m[:, 3], true) - ] -end - @model function demo_assume_matrix_dot_observe_matrix( x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64} ) where {TV} @@ -810,7 +748,6 @@ const MultivariateAssumeDemoModels = Union{ Model{typeof(demo_dot_assume_observe_submodel)}, Model{typeof(demo_dot_assume_dot_observe_matrix)}, Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, - Model{typeof(demo_dot_assume_matrix_dot_observe_matrix2)}, } function posterior_mean(model::MultivariateAssumeDemoModels) # Get some containers to fill. diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 0b76bce00..58373f2f8 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -1,23 +1,11 @@ @testset "logdensities_likelihoods.jl" begin mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2) mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) - #model = DynamicPPL.TestUtils.DEMO_MODELS[1] - #model = model = DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix2() - demo_models = ( - DynamicPPL.TestUtils.DEMO_MODELS..., - DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix2(), - ) - @testset "$(model.f)" for (i, model) in enumerate(demo_models) - #@show i + @testset "$(model.f)" for (i, model) in enumerate(DynamicPPL.TestUtils.DEMO_MODELS) example_values = DynamicPPL.TestUtils.rand_prior_true(model) # Instantiate a `VarInfo` with the example values. vi = VarInfo(model) - () -> begin # when interactively debugging, need the global keyword - for vn in DynamicPPL.TestUtils.varnames(model) - global vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) - end - end for vn in DynamicPPL.TestUtils.varnames(model) vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) end diff --git a/test/runtests.jl b/test/runtests.jl index 918ab61d2..65bc2c9a0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,7 +26,7 @@ using Logging using Distributions using LinearAlgebra # Diagonal -using DynamicPPL: getargs_dottilde, getargs_tilde, Selector +using DynamicPPL: getargs_dottilde, getargs_tilde, Selector, getsym const DIRECTORY_DynamicPPL = dirname(dirname(pathof(DynamicPPL))) const DIRECTORY_Turing_tests = joinpath(DIRECTORY_DynamicPPL, "test", "turing") From 456115c8004aca61584b0b8cfaa7b9ac784d33e6 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Thu, 26 Sep 2024 08:54:16 +0200 Subject: [PATCH 24/37] ignore local interactive development code --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 198907c73..7a42d083a 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ .DS_Store Manifest.toml **.~undo-tree~ +tmp/interactive.jl From 222529ad05936f494ec245595f54409a3a079aad Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Thu, 26 Sep 2024 08:57:05 +0200 Subject: [PATCH 25/37] ignore temporary directory holding local interactive development code --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 7a42d083a..5db129e27 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,4 @@ .DS_Store Manifest.toml **.~undo-tree~ -tmp/interactive.jl +/tmp From 17b251ae477c38fa41be37d2cebff09b09756542 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Thu, 26 Sep 2024 12:12:26 +0200 Subject: [PATCH 26/37] Apply suggestions from code review: clean up comments and Imports Co-authored-by: Tor Erlend Fjelde --- src/test_utils.jl | 3 --- test/pointwise_logdensities.jl | 10 +++++----- test/runtests.jl | 3 +-- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 6c0971f8f..8489f2684 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -1062,19 +1062,16 @@ function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child return TestLogModifyingChildContext(context.mod, child) end function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, vn, vi) - #@info "TestLogModifyingChildContext tilde_assume!! called for $vn" value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) return value, logp * context.mod, vi end function DynamicPPL.dot_tilde_assume( context::TestLogModifyingChildContext, right, left, vn, vi ) - #@info "TestLogModifyingChildContext dot_tilde_assume!! called for $vn" value, logp, vi = DynamicPPL.dot_tilde_assume(context.context, right, left, vn, vi) return value, logp * context.mod, vi end function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi) - # @info "called tilde_observe TestLogModifyingChildContext for left=$left, right=$right" logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) return logp * context.mod, vi end diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 58373f2f8..c94ad3173 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -1,7 +1,7 @@ @testset "logdensities_likelihoods.jl" begin mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2) mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) - @testset "$(model.f)" for (i, model) in enumerate(DynamicPPL.TestUtils.DEMO_MODELS) + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS example_values = DynamicPPL.TestUtils.rand_prior_true(model) # Instantiate a `VarInfo` with the example values. @@ -17,9 +17,9 @@ # Compute the pointwise loglikelihoods. lls = pointwise_loglikelihoods(model, vi) - @test :s ∉ getsym.(keys(lls)) + @test [:x] == unique(DynamicPPL.getsym.(keys(lls))) if isempty(lls) - # One of the models with literal observations, so we just skip. + # One of the models with literal observations, so we'll set this to 0 for subsequent comparisons. loglikelihood_true = 0.0 else loglikelihood_sum = sum(sum, values(lls)) @@ -28,9 +28,9 @@ # Compute the pointwise logdensities of the priors. lps_prior = pointwise_prior_logdensities(model, vi) - @test :x ∉ getsym.(keys(lps_prior)) + @test :x ∉ DynamicPPL.getsym.(keys(lps_prior)) logp = sum(sum, values(lps_prior)) - @test !isfinite(logp_true) || logp ≈ logp_true + @test logp ≈ logp_true # Compute both likelihood and logdensity of prior # using the default DefaultContex diff --git a/test/runtests.jl b/test/runtests.jl index 65bc2c9a0..bc8e54b39 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,11 +22,10 @@ using Pkg using Random using Serialization using Test -using Logging using Distributions using LinearAlgebra # Diagonal -using DynamicPPL: getargs_dottilde, getargs_tilde, Selector, getsym +using DynamicPPL: getargs_dottilde, getargs_tilde, Selector const DIRECTORY_DynamicPPL = dirname(dirname(pathof(DynamicPPL))) const DIRECTORY_Turing_tests = joinpath(DIRECTORY_DynamicPPL, "test", "turing") From 7e990f0a03239366e579624fec78d659a0eba08a Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Thu, 26 Sep 2024 12:14:52 +0200 Subject: [PATCH 27/37] Apply suggestions from code review: change test of applying to chains on already used model Co-authored-by: Tor Erlend Fjelde --- test/Project.toml | 1 - test/pointwise_logdensities.jl | 49 +++++++++++++++------------------- 2 files changed, 21 insertions(+), 29 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index a75909f95..2b4b23df5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,7 +14,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" -Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index c94ad3173..6843d235b 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -46,34 +46,27 @@ end @testset "pointwise_logdensities chain" begin - @model function demo(x, ::Type{TV}=Vector{Float64}) where {TV} - s ~ InverseGamma(2, 3) - m = TV(undef, length(x)) - for i in eachindex(x) - m[i] ~ Normal(0, √s) - end - return x ~ MvNormal(m, √s) - end - x_true = [0.3290767977680923, 0.038972110187911684, -0.5797496780649221] - model = demo(x_true) - () -> begin - # generate the sample used below - chain = sample(model, MH(), MCMCThreads(), 10, 2) - arr0 = stack(Array(chain; append_chains=false)) - @show(arr0[1:2, :, :]) - end - arr0 = [ - 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317;;; - 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497 + # We'll just test one, since `pointwise_logdensities(::Model, ::AbstractVarInfo)` is tested extensively, + # and this is what is used to implement `pointwise_logdensities(::Model, ::Chains)`. This test suite is just + # to ensure that we don't accidentally break the the version on `Chains`. + model = DynamicPPL.TestUtils.demo_dot_assume_dot_observe() + # FIXME(torfjelde): Make use of `varname_and_value_leaves` once we've introduced + # an impl of this for containers. + vns = DynamicPPL.TestUtils.varnames(model) + # Get some random `NamedTuple` samples from the prior. + vals = [DynamicPPL.TestUtils.rand_prior_true(model) for _ = 1:5] + # Concatenate the vector representations and create a `Chains` from it. + vals_arr = reduce(hcat, (mapreduce(DynamicPPL.tovec, vcat, values(nt) for nt in vals)) + chain = Chains(permutedims(vals_arr), map(Symbol, vns)) + logjoints_pointwise = pointwise_logdensities(model, chain) + # Get the sum of the logjoints for each of the iterations. + logjoints = [ + sum(logjoints_pointwise[vn][idx] for vn in vns) + for idx = 1:5 ] - chain = Chains(arr0, [:s, Symbol("m[1]"), Symbol("m[2]"), Symbol("m[3]")]) - tmp1 = pointwise_logdensities(model, chain) - vi = VarInfo(model) - i_sample, i_chain = (1, 2) - DynamicPPL.setval!(vi, chain, i_sample, i_chain) - lp1 = DynamicPPL.pointwise_logdensities(model, vi) - # k = first(keys(lp1)) - for k in keys(lp1) - @test tmp1[string(k)][i_sample, i_chain] .≈ lp1[k][1] + for (val, logp) in zip(vals, logjoints) + # Compare true logjoint with the one obtained from `pointwise_logdensities`. + logjoint_true = DynamicPPL.TestUtils.logjoint_true(model, val...) + @test logp ≈ logjoint_true end end; From 8706f680070f5cad990551e5c8fae102fa5d8f75 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Thu, 26 Sep 2024 12:24:36 +0200 Subject: [PATCH 28/37] fix test on names in likelihood components to work with literal models --- test/pointwise_logdensities.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 6843d235b..914ce1d13 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -17,11 +17,11 @@ # Compute the pointwise loglikelihoods. lls = pointwise_loglikelihoods(model, vi) - @test [:x] == unique(DynamicPPL.getsym.(keys(lls))) if isempty(lls) # One of the models with literal observations, so we'll set this to 0 for subsequent comparisons. loglikelihood_true = 0.0 else + @test [:x] == unique(DynamicPPL.getsym.(keys(lls))) loglikelihood_sum = sum(sum, values(lls)) @test loglikelihood_sum ≈ loglikelihood_true end From 073a325bc998ccc1bcee0bed0897c75e7801a89e Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Thu, 26 Sep 2024 12:36:31 +0200 Subject: [PATCH 29/37] try to fix testset pointwise_logdensities chain --- test/pointwise_logdensities.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 914ce1d13..e796c89ef 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -56,12 +56,12 @@ end # Get some random `NamedTuple` samples from the prior. vals = [DynamicPPL.TestUtils.rand_prior_true(model) for _ = 1:5] # Concatenate the vector representations and create a `Chains` from it. - vals_arr = reduce(hcat, (mapreduce(DynamicPPL.tovec, vcat, values(nt) for nt in vals)) - chain = Chains(permutedims(vals_arr), map(Symbol, vns)) + vals_arr = reduce(hcat, mapreduce(DynamicPPL.tovec, vcat, values(nt)) for nt in vals) + chain = Chains(permutedims(vals_arr), map(Symbol, vns)); logjoints_pointwise = pointwise_logdensities(model, chain) # Get the sum of the logjoints for each of the iterations. logjoints = [ - sum(logjoints_pointwise[vn][idx] for vn in vns) + sum(logjoints_pointwise[string(vn)][idx] for vn in vns) for idx = 1:5 ] for (val, logp) in zip(vals, logjoints) From 23e17118629aa8025d3f10d4b2624cc0c6738bb0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 26 Sep 2024 12:38:36 +0100 Subject: [PATCH 30/37] Update test/pointwise_logdensities.jl --- test/pointwise_logdensities.jl | 43 ++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index e796c89ef..a422965ef 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -52,21 +52,50 @@ end model = DynamicPPL.TestUtils.demo_dot_assume_dot_observe() # FIXME(torfjelde): Make use of `varname_and_value_leaves` once we've introduced # an impl of this for containers. + # NOTE(torfjelde): This only returns the varnames of the _random_ variables, i.e. excl. observed. vns = DynamicPPL.TestUtils.varnames(model) # Get some random `NamedTuple` samples from the prior. - vals = [DynamicPPL.TestUtils.rand_prior_true(model) for _ = 1:5] + num_iters = 3 + vals = [DynamicPPL.TestUtils.rand_prior_true(model) for _ in 1:num_iters] # Concatenate the vector representations and create a `Chains` from it. vals_arr = reduce(hcat, mapreduce(DynamicPPL.tovec, vcat, values(nt)) for nt in vals) - chain = Chains(permutedims(vals_arr), map(Symbol, vns)); + chain = Chains(permutedims(vals_arr), map(Symbol, vns)) + + # Compute the different pointwise logdensities. logjoints_pointwise = pointwise_logdensities(model, chain) + logpriors_pointwise = pointwise_prior_logdensities(model, chain) + loglikelihoods_pointwise = pointwise_loglikelihoods(model, chain) + + # Check that they contain the correct variables. + @test all(string(vn) in keys(logjoints_pointwise) for vn in vns) + @test all(string(vn) in keys(logpriors_pointwise) for vn in vns) + @test !any(Base.Fix2(startswith, "x"), keys(logpriors_pointwise)) + @test !any(string(vn) in keys(loglikelihoods_pointwise) for vn in vns) + @test all(Base.Fix2(startswith, "x"), keys(loglikelihoods_pointwise)) + # Get the sum of the logjoints for each of the iterations. logjoints = [ - sum(logjoints_pointwise[string(vn)][idx] for vn in vns) - for idx = 1:5 + sum(logjoints_pointwise[vn][idx] for vn in keys(logjoints_pointwise)) for + idx in 1:num_iters + ] + logpriors = [ + sum(logpriors_pointwise[vn][idx] for vn in keys(logpriors_pointwise)) for + idx in 1:num_iters ] - for (val, logp) in zip(vals, logjoints) + loglikelihoods = [ + sum(loglikelihoods_pointwise[vn][idx] for vn in keys(loglikelihoods_pointwise)) for + idx in 1:num_iters + ] + + for (val, logjoint, logprior, loglikelihood) in + zip(vals, logjoints, logpriors, loglikelihoods) # Compare true logjoint with the one obtained from `pointwise_logdensities`. logjoint_true = DynamicPPL.TestUtils.logjoint_true(model, val...) - @test logp ≈ logjoint_true + logprior_true = DynamicPPL.TestUtils.logprior_true(model, val...) + loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(model, val...) + + @test logjoint ≈ logjoint_true + @test logprior ≈ logprior_true + @test loglikelihood ≈ loglikelihood_true end -end; +end From 34ae4f8ab68cd39c154d2559f92cceb785f1bc73 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 26 Sep 2024 12:40:27 +0100 Subject: [PATCH 31/37] Update .gitignore --- .gitignore | 1 - 1 file changed, 1 deletion(-) diff --git a/.gitignore b/.gitignore index 5db129e27..198907c73 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,3 @@ .DS_Store Manifest.toml **.~undo-tree~ -/tmp From 777624a5aadab4b06a7974d6567cc96ea3b3074d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 26 Sep 2024 12:50:31 +0100 Subject: [PATCH 32/37] Formtating --- test/pointwise_logdensities.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index a422965ef..93b7c59be 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -21,7 +21,7 @@ # One of the models with literal observations, so we'll set this to 0 for subsequent comparisons. loglikelihood_true = 0.0 else - @test [:x] == unique(DynamicPPL.getsym.(keys(lls))) + @test [:x] == unique(DynamicPPL.getsym.(keys(lls))) loglikelihood_sum = sum(sum, values(lls)) @test loglikelihood_sum ≈ loglikelihood_true end From 4864e6003569b4e7f80e2585450a61160b4fa400 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 26 Sep 2024 13:12:33 +0100 Subject: [PATCH 33/37] Fixed tests --- src/pointwise_logdensities.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 9d6e70109..de4646746 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -291,7 +291,7 @@ julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], """ function pointwise_logdensities( - model::Model, chain, context::AbstractContext=DefaultContext(), keytype::Type{T}=String + model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext() ) where {T} # Get the data by executing the model once vi = VarInfo(model) From 4d3b0c0d990de1f54afcb6710d169919a5e86c94 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 27 Sep 2024 14:38:20 +0100 Subject: [PATCH 34/37] Updated docs for `pointwise_logdensities` + made it a doctest not dependent on Turing.jl --- src/pointwise_logdensities.jl | 51 +++++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index de4646746..676149b73 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -229,8 +229,8 @@ for downstream computations. # Examples ## From chain -```julia-repl -julia> using DynamicPPL, Turing +```jldoctest pointwise-logdensities-chains; setup=:(using Distributions) +julia> using MCMCChains julia> @model function demo(xs, y) s ~ InverseGamma(2, 3) @@ -241,32 +241,43 @@ julia> @model function demo(xs, y) y ~ Normal(m, √s) end -demo (generic function with 1 method) +demo (generic function with 2 methods) -julia> model = demo(randn(3), randn()); +julia> # Example observations. + model = demo([1.0, 2.0, 3.0], [4.0]); -julia> chain = sample(model, MH(), 10); +julia> # A chain with 3 iterations. + chain = Chains( + reshape(1.:6., 3, 2), + [:s, :m] + ); julia> pointwise_logdensities(model, chain) -OrderedDict{String,Array{Float64,2}} with 4 entries: - "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] - "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] - "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] - "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] +OrderedDict{String, Matrix{Float64}} with 6 entries: + "s" => [-0.802775; -1.38222; -2.09861;;] + "m" => [-8.91894; -7.51551; -7.46824;;] + "xs[1]" => [-5.41894; -5.26551; -5.63491;;] + "xs[2]" => [-2.91894; -3.51551; -4.13491;;] + "xs[3]" => [-1.41894; -2.26551; -2.96824;;] + "y" => [-0.918939; -1.51551; -2.13491;;] julia> pointwise_logdensities(model, chain, String) -OrderedDict{String,Array{Float64,2}} with 4 entries: - "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] - "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] - "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] - "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] +OrderedDict{String, Matrix{Float64}} with 6 entries: + "s" => [-0.802775; -1.38222; -2.09861;;] + "m" => [-8.91894; -7.51551; -7.46824;;] + "xs[1]" => [-5.41894; -5.26551; -5.63491;;] + "xs[2]" => [-2.91894; -3.51551; -4.13491;;] + "xs[3]" => [-1.41894; -2.26551; -2.96824;;] + "y" => [-0.918939; -1.51551; -2.13491;;] julia> pointwise_logdensities(model, chain, VarName) -OrderedDict{VarName,Array{Float64,2}} with 4 entries: - xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333] - xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359] - xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251] - y => [-1.51265; -0.914129; … ; -1.5499; -1.5499] +OrderedDict{VarName, Matrix{Float64}} with 6 entries: + s => [-0.802775; -1.38222; -2.09861;;] + m => [-8.91894; -7.51551; -7.46824;;] + xs[1] => [-5.41894; -5.26551; -5.63491;;] + xs[2] => [-2.91894; -3.51551; -4.13491;;] + xs[3] => [-1.41894; -2.26551; -2.96824;;] + y => [-0.918939; -1.51551; -2.13491;;] ``` ## Broadcasting From e54fa4e1384b3777246d72960c9d1fc3a67cd56b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 27 Sep 2024 14:42:51 +0100 Subject: [PATCH 35/37] Bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 700f040c7..95ce8cde2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.29.1" +version = "0.29.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From bcd82a92a26f1ca4ad340a67ce542fe461602579 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 27 Sep 2024 17:28:03 +0100 Subject: [PATCH 36/37] Remove blank line from `@model` in doctest to see if that fixes the parsing issues --- src/pointwise_logdensities.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 676149b73..47b969e6c 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -238,7 +238,6 @@ julia> @model function demo(xs, y) for i in eachindex(xs) xs[i] ~ Normal(m, √s) end - y ~ Normal(m, √s) end demo (generic function with 2 methods) From cff0941c4c00bb6759bc6b70a29dce971e74699e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 29 Sep 2024 09:22:14 +0100 Subject: [PATCH 37/37] Added doctest filter to handle the `;;]` at the end of lines for matrices --- test/runtests.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index bc8e54b39..b9a1d92bd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -95,6 +95,9 @@ include("test_util.jl") # Errors from macros sometimes result in `LoadError: LoadError:` # rather than `LoadError:`, depending on Julia version. r"ERROR: (LoadError:\s)+", + # Older versions do not have `;;]` but instead just `]` at end of the line + # => need to treat `;;]` and `]` as the same, i.e. ignore them if at the end of a line + r"(;;){0,1}\]$"m, ] doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters) end