From d05124cc69b4b57fae34c1d10305a202bf568ecb Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Mon, 16 Sep 2024 08:10:29 +0200 Subject: [PATCH 01/14] implement pointwise_logpriors --- src/DynamicPPL.jl | 2 + src/context_implementations.jl | 25 +++- src/logpriors.jl | 223 +++++++++++++++++++++++++++++++++ test/Project.toml | 3 +- test/loglikelihoods.jl | 50 ++++++++ test/runtests.jl | 2 + 6 files changed, 300 insertions(+), 5 deletions(-) create mode 100644 src/logpriors.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index eb027b45b..555744ea3 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -116,6 +116,7 @@ export AbstractVarInfo, logprior, logjoint, pointwise_loglikelihoods, + pointwise_logpriors, condition, decondition, fix, @@ -182,6 +183,7 @@ include("simple_varinfo.jl") include("context_implementations.jl") include("compiler.jl") include("loglikelihoods.jl") +include("logpriors.jl") include("submodel_macro.jl") include("test_utils.jl") include("transforming.jl") diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 13231837f..338749344 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -425,6 +425,15 @@ function dot_assume( var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi::AbstractVarInfo, +) + r, lp, vi = dot_assume_vec(dist, var, vns, vi) + return r, sum(lp), vi +end +function dot_assume_vec( + dist::MultivariateDistribution, + var::AbstractMatrix, + vns::AbstractVector{<:VarName}, + vi::AbstractVarInfo, ) @assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))" # NOTE: We cannot work with `var` here because we might have a model of the form @@ -434,7 +443,7 @@ function dot_assume( # # in which case `var` will have `undef` elements, even if `m` is present in `vi`. r = vi[vns, dist] - lp = sum(zip(vns, eachcol(r))) do (vn, ri) + lp = map(zip(vns, eachcol(r))) do (vn, ri) return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) end return r, lp, vi @@ -455,21 +464,29 @@ function dot_assume( end function dot_assume( + dist::Union{Distribution,AbstractArray{<:Distribution}}, var::AbstractArray, vns::AbstractArray{<:VarName}, vi +) + # possibility to acesss the single logpriors + r, lp, vi = dot_assume_vec(dist, var, vns, vi) + return r, sum(lp), vi +end + +function dot_assume_vec( dist::Distribution, var::AbstractArray, vns::AbstractArray{<:VarName}, vi ) r = getindex.((vi,), vns, (dist,)) - lp = sum(Bijectors.logpdf_with_trans.((dist,), r, istrans.((vi,), vns))) + lp = Bijectors.logpdf_with_trans.((dist,), r, istrans.((vi,), vns)) return r, lp, vi end -function dot_assume( +function dot_assume_vec( dists::AbstractArray{<:Distribution}, var::AbstractArray, vns::AbstractArray{<:VarName}, vi, ) r = getindex.((vi,), vns, dists) - lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))) + lp = Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns)) return r, lp, vi end diff --git a/src/logpriors.jl b/src/logpriors.jl new file mode 100644 index 000000000..cdc00d699 --- /dev/null +++ b/src/logpriors.jl @@ -0,0 +1,223 @@ +# Context version +struct PointwisePriorContext{A,Ctx} <: AbstractContext + logpriors::A + context::Ctx +end + +function PointwisePriorContext( + priors=OrderedDict{VarName,Vector{Float64}}(), + context::AbstractContext=DynamicPPL.PriorContext(), + #OrderedDict{VarName,Vector{Float64}}(),PriorContext()), +) + return PointwisePriorContext{typeof(priors),typeof(context)}( + priors, context + ) +end + +NodeTrait(::PointwisePriorContext) = IsParent() +childcontext(context::PointwisePriorContext) = context.context +function setchildcontext(context::PointwisePriorContext, child) + return PointwisePriorContext(context.logpriors, child) +end + + +function tilde_assume!!(context::PointwisePriorContext, right, vn, vi) + #@info "PointwisePriorContext tilde_assume!! called for $vn" + value, logp, vi = tilde_assume(context, right, vn, vi) + push!(context, vn, logp) + return value, acclogp_assume!!(context, vi, logp) +end + +function dot_tilde_assume!!(context::PointwisePriorContext, right, left, vn, vi) + #@info "PointwisePriorContext dot_tilde_assume!! called for $vn" + # @show vn, left, right, typeof(context).name + value, logps, vi = dot_assume_vec(right, left, vn, vi) + # Need to unwrap the `vn`, i.e. get one `VarName` for each entry in `left`. + #_, _, vns = unwrap_right_left_vns(right, left, vn) + vns = vn + for (vn, logp) in zip(vns, logps) + # Track log-prior value. + push!(context, vn, logp) + end + return value, acclogp!!(vi, sum(logps)), vi +end + +function dot_tilde_assume!!(context::PointwisePriorContext, right, left, vn, vi) + #@info "PointwisePriorContext dot_tilde_assume!! called for $vn" + value, logps, vi = dot_assume_vec(right, left, vn, vi) + # Need to unwrap the `vn`, i.e. get one `VarName` for each entry in `left`? + #_, _, vns = unwrap_right_left_vns(right, left, vn) + vns = vn + for (vn, logp) in zip(vns, logps) + # Track log-prior value. + push!(context, vn, logp) + end + return value, acclogp!!(vi, sum(logps)), vi +end + +function tilde_observe(context::PointwisePriorContext, right, left, vi) + # Since we are evaluating the prior, the log probability of all the observations + # is set to 0. This has the effect of ignoring the likelihood. + return 0.0, vi + #tmp = tilde_observe(context.context, SampleFromPrior(), right, left, vi) + #return tmp +end + +function Base.push!( + context::PointwisePriorContext{<:AbstractDict{VarName,Vector{Float64}}}, + vn::VarName, + logp::Real, +) + lookup = context.logpriors + ℓ = get!(lookup, vn, Float64[]) + return push!(ℓ, logp) +end + +function Base.push!( + context::PointwisePriorContext{<:AbstractDict{VarName,Float64}}, + vn::VarName, + logp::Real, +) + return context.logpriors[vn] = logp +end + +function Base.push!( + context::PointwisePriorContext{<:AbstractDict{String,Vector{Float64}}}, + vn::VarName, + logp::Real, +) + lookup = context.logpriors + ℓ = get!(lookup, string(vn), Float64[]) + return push!(ℓ, logp) +end + +function Base.push!( + context::PointwisePriorContext{<:AbstractDict{String,Float64}}, + vn::VarName, + logp::Real, +) + return context.logpriors[string(vn)] = logp +end + +function Base.push!( + context::PointwisePriorContext{<:AbstractDict{String,Vector{Float64}}}, + vn::String, + logp::Real, +) + lookup = context.logpriors + ℓ = get!(lookup, vn, Float64[]) + return push!(ℓ, logp) +end + +function Base.push!( + context::PointwisePriorContext{<:AbstractDict{String,Float64}}, + vn::String, + logp::Real, +) + return context.logpriors[vn] = logp +end + + +# """ +# pointwise_logpriors(model::Model, chain::Chains, keytype = String) + +# Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` +# with keys corresponding to symbols of the observations, and values being matrices +# of shape `(num_chains, num_samples)`. + +# `keytype` specifies what the type of the keys used in the returned `OrderedDict` are. +# Currently, only `String` and `VarName` are supported. + +# # Notes +# Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ` +# both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an +# *observation*) statements can be implemented in three ways: +# 1. using a `for` loop: +# ```julia +# for i in eachindex(y) +# y[i] ~ Normal(μ, σ) +# end +# ``` +# 2. using `.~`: +# ```julia +# y .~ Normal(μ, σ) +# ``` +# 3. using `MvNormal`: +# ```julia +# y ~ MvNormal(fill(μ, n), σ^2 * I) +# ``` + +# In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables, +# while in (3) `y` will be treated as a _single_ n-dimensional observation. + +# This is important to keep in mind, in particular if the computation is used +# for downstream computations. + +# # Examples +# ## From chain +# ```julia-repl +# julia> using DynamicPPL, Turing + +# julia> @model function demo(xs, y) +# s ~ InverseGamma(2, 3) +# m ~ Normal(0, √s) +# for i in eachindex(xs) +# xs[i] ~ Normal(m, √s) +# end + +# y ~ Normal(m, √s) +# end +# demo (generic function with 1 method) + +# julia> model = demo(randn(3), randn()); + +# julia> chain = sample(model, MH(), 10); + +# julia> pointwise_logpriors(model, chain) +# OrderedDict{String,Array{Float64,2}} with 4 entries: +# "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] +# "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] +# "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] +# "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] + +# julia> pointwise_logpriors(model, chain, String) +# OrderedDict{String,Array{Float64,2}} with 4 entries: +# "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] +# "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] +# "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] +# "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] + +# julia> pointwise_logpriors(model, chain, VarName) +# OrderedDict{VarName,Array{Float64,2}} with 4 entries: +# xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333] +# xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359] +# xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251] +# y => [-1.51265; -0.914129; … ; -1.5499; -1.5499] +# ``` + +# ## Broadcasting +# Note that `x .~ Dist()` will treat `x` as a collection of +# _independent_ observations rather than as a single observation. + +# ```jldoctest; setup = :(using Distributions) +# julia> @model function demo(x) +# x .~ Normal() +# end; + +# julia> m = demo([1.0, ]); + +# julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first(ℓ[@varname(x[1])]) +# -1.4189385332046727 + +# julia> m = demo([1.0; 1.0]); + +# julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) +# (-1.4189385332046727, -1.4189385332046727) +# ``` + +# """ +function pointwise_logpriors(model::Model, varinfo::AbstractVarInfo) + context = PointwisePriorContext() + model(varinfo, context) + return context.logpriors +end diff --git a/test/Project.toml b/test/Project.toml index 13267ee1d..a75909f95 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,6 +14,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -26,10 +27,10 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -Accessors = "0.1" ADTypes = "0.2, 1" AbstractMCMC = "5" AbstractPPL = "0.8.2" +Accessors = "0.1" Bijectors = "0.13" Compat = "4.3.0" Distributions = "0.25" diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index 1075ce333..fbd405aaa 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -22,3 +22,53 @@ @test loglikelihood ≈ loglikelihood_true end end + +() -> begin + m = DynamicPPL.TestUtils.demo_assume_index_observe() + example_values = DynamicPPL.TestUtils.rand_prior_true(m) + vi = VarInfo(m) + for vn in DynamicPPL.TestUtils.varnames(m) + vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) + end + ret_m = first(evaluate!!(m, vi, SamplingContext())) + @test sum(map(k -> sum(ret_m.ld[k]), eachindex(ret_m.ld))) ≈ ret_m.logp + () -> begin + #by_generated_quantities + s = vcat(example_values...) + vnames = ["s[1]", "s[2]", "m[1]", "m[2]"] + stup = (; zip(Symbol.(vnames), s)...) + ret_m = generated_quantities(m, stup) + @test sum(map(k -> sum(ret_m.ld[k]), eachindex(ret_m.ld))) ≈ ret_m.logp + #chn = Chains(reshape(s, 1, : , 1), vnames); + chn = Chains(reshape(s, 1, :, 1)) # causes warning but works + ret_m = @test_logs (:warn,) generated_quantities(m, chn)[1, 1] + @test sum(map(k -> sum(ret_m.ld[k]), eachindex(ret_m.ld))) ≈ ret_m.logp + end +end + +@testset "logpriors.jl" begin + #m = DynamicPPL.TestUtils.DEMO_MODELS[1] + @testset "$(m.f)" for (i, m) in enumerate(DynamicPPL.TestUtils.DEMO_MODELS) + #@show i + example_values = DynamicPPL.TestUtils.rand_prior_true(m) + + # Instantiate a `VarInfo` with the example values. + vi = VarInfo(m) + () -> begin + for vn in DynamicPPL.TestUtils.varnames(m) + global vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) + end + end + for vn in DynamicPPL.TestUtils.varnames(m) + vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) + end + + #chains = sample(m, SampleFromPrior(), 2; progress=false) + + # Compute the pointwise loglikelihoods. + tmp = DynamicPPL.pointwise_logpriors(m, vi) + logp1 = getlogp(vi) + logp = logprior(m, vi) + @test !isfinite(getlogp(vi)) || sum(x -> sum(x), values(tmp)) ≈ logp + end; +end; diff --git a/test/runtests.jl b/test/runtests.jl index aa0883708..eafc0e652 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,6 +22,7 @@ using Pkg using Random using Serialization using Test +using Logging using DynamicPPL: getargs_dottilde, getargs_tilde, Selector @@ -31,6 +32,7 @@ const GROUP = get(ENV, "GROUP", "All") Random.seed!(100) +#include(joinpath(DIRECTORY_DynamicPPL,"test","test_util.jl")) include("test_util.jl") @testset "DynamicPPL.jl" begin From 4f46102805d1dcfafb5eadc40afdc29d1641101b Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Tue, 17 Sep 2024 08:13:29 +0200 Subject: [PATCH 02/14] implement varwise_logpriors --- docs/src/api.md | 11 +++ src/DynamicPPL.jl | 2 + src/logpriors.jl | 13 --- src/logpriors_var.jl | 201 +++++++++++++++++++++++++++++++++++++++++ src/test_utils.jl | 33 +++++++ test/loglikelihoods.jl | 132 +++++++++++++++++++++++---- test/runtests.jl | 2 + 7 files changed, 361 insertions(+), 33 deletions(-) create mode 100644 src/logpriors_var.jl diff --git a/docs/src/api.md b/docs/src/api.md index 97c48316e..38d9ee6b0 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -130,6 +130,17 @@ For a chain of samples, one can compute the pointwise log-likelihoods of each ob pointwise_loglikelihoods ``` +Similarly, one can compute the pointwise log-priors of each sampled random variable +with [`varwise_logpriors`](@ref). +Differently from `pointwise_loglikelihoods` it reports only a +single value for `.~` assignements. +If one needs to access the parts for single indices, one can +reformulate the model to use an explicit loop instead. + +```@docs +varwise_logpriors +``` + For converting a chain into a format that can more easily be fed into a `Model` again, for example using `condition`, you can use [`value_iterator_from_chain`](@ref). ```@docs diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 555744ea3..e69cdd568 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -117,6 +117,7 @@ export AbstractVarInfo, logjoint, pointwise_loglikelihoods, pointwise_logpriors, + varwise_logpriors, condition, decondition, fix, @@ -184,6 +185,7 @@ include("context_implementations.jl") include("compiler.jl") include("loglikelihoods.jl") include("logpriors.jl") +include("logpriors_var.jl") include("submodel_macro.jl") include("test_utils.jl") include("transforming.jl") diff --git a/src/logpriors.jl b/src/logpriors.jl index cdc00d699..1f3d8b3ad 100644 --- a/src/logpriors.jl +++ b/src/logpriors.jl @@ -42,19 +42,6 @@ function dot_tilde_assume!!(context::PointwisePriorContext, right, left, vn, vi) return value, acclogp!!(vi, sum(logps)), vi end -function dot_tilde_assume!!(context::PointwisePriorContext, right, left, vn, vi) - #@info "PointwisePriorContext dot_tilde_assume!! called for $vn" - value, logps, vi = dot_assume_vec(right, left, vn, vi) - # Need to unwrap the `vn`, i.e. get one `VarName` for each entry in `left`? - #_, _, vns = unwrap_right_left_vns(right, left, vn) - vns = vn - for (vn, logp) in zip(vns, logps) - # Track log-prior value. - push!(context, vn, logp) - end - return value, acclogp!!(vi, sum(logps)), vi -end - function tilde_observe(context::PointwisePriorContext, right, left, vi) # Since we are evaluating the prior, the log probability of all the observations # is set to 0. This has the effect of ignoring the likelihood. diff --git a/src/logpriors_var.jl b/src/logpriors_var.jl new file mode 100644 index 000000000..1c534b60a --- /dev/null +++ b/src/logpriors_var.jl @@ -0,0 +1,201 @@ +""" +Context that records logp after tilde_assume!! +for each VarName used by [`varwise_logpriors`](@ref). +""" +struct VarwisePriorContext{A,Ctx} <: AbstractContext + logpriors::A + context::Ctx +end + +function VarwisePriorContext( + logpriors=OrderedDict{Symbol,Float64}(), + context::AbstractContext=DynamicPPL.PriorContext(), + #OrderedDict{Symbol,Vector{Float64}}(),PriorContext()), +) + return VarwisePriorContext{typeof(logpriors),typeof(context)}( + logpriors, context + ) +end + +NodeTrait(::VarwisePriorContext) = IsParent() +childcontext(context::VarwisePriorContext) = context.context +function setchildcontext(context::VarwisePriorContext, child) + return VarwisePriorContext(context.logpriors, child) +end + +function tilde_assume(context::VarwisePriorContext, right, vn, vi) + #@info "VarwisePriorContext tilde_assume!! called for $vn" + value, logp, vi = tilde_assume(context.context, right, vn, vi) + #sym = DynamicPPL.getsym(vn) + new_context = acc_logp!(context, vn, logp) + return value, logp, vi +end + +function dot_tilde_assume(context::VarwisePriorContext, right, left, vn, vi) + #@info "VarwisePriorContext dot_tilde_assume!! called for $vn" + # @show vn, left, right, typeof(context).name + value, logp, vi = dot_tilde_assume(context.context, right, left, vn, vi) + new_context = acc_logp!(context, vn, logp) + return value, logp, vi +end + + +function tilde_observe(context::VarwisePriorContext, right, left, vi) + # Since we are evaluating the prior, the log probability of all the observations + # is set to 0. This has the effect of ignoring the likelihood. + return 0.0, vi + #tmp = tilde_observe(context.context, SampleFromPrior(), right, left, vi) + #return tmp +end + +function acc_logp!(context::VarwisePriorContext, vn::Union{VarName,AbstractVector{<:VarName}}, logp) + #sym = DynamicPPL.getsym(vn) # leads to duplicates + # if vn is a Vector leads to Symbol("VarName{:s, IndexLens{Tuple{Int64}}}[s[1], s[2]]") + sym = Symbol(vn) + context.logpriors[sym] = logp + return (context) +end + + +# """ +# pointwise_logpriors(model::Model, chain::Chains, keytype = String) + +# Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` +# with keys corresponding to symbols of the observations, and values being matrices +# of shape `(num_chains, num_samples)`. + +# `keytype` specifies what the type of the keys used in the returned `OrderedDict` are. +# Currently, only `String` and `VarName` are supported. + +# # Notes +# Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ` +# both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an +# *observation*) statements can be implemented in three ways: +# 1. using a `for` loop: +# ```julia +# for i in eachindex(y) +# y[i] ~ Normal(μ, σ) +# end +# ``` +# 2. using `.~`: +# ```julia +# y .~ Normal(μ, σ) +# ``` +# 3. using `MvNormal`: +# ```julia +# y ~ MvNormal(fill(μ, n), σ^2 * I) +# ``` + +# In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables, +# while in (3) `y` will be treated as a _single_ n-dimensional observation. + +# This is important to keep in mind, in particular if the computation is used +# for downstream computations. + +# # Examples +# ## From chain +# ```julia-repl +# julia> using DynamicPPL, Turing + +# julia> @model function demo(xs, y) +# s ~ InverseGamma(2, 3) +# m ~ Normal(0, √s) +# for i in eachindex(xs) +# xs[i] ~ Normal(m, √s) +# end + +# y ~ Normal(m, √s) +# end +# demo (generic function with 1 method) + +# julia> model = demo(randn(3), randn()); + +# julia> chain = sample(model, MH(), 10); + +# julia> pointwise_logpriors(model, chain) +# OrderedDict{String,Array{Float64,2}} with 4 entries: +# "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] +# "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] +# "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] +# "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] + +# julia> pointwise_logpriors(model, chain, String) +# OrderedDict{String,Array{Float64,2}} with 4 entries: +# "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] +# "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] +# "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] +# "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] + +# julia> pointwise_logpriors(model, chain, VarName) +# OrderedDict{VarName,Array{Float64,2}} with 4 entries: +# xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333] +# xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359] +# xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251] +# y => [-1.51265; -0.914129; … ; -1.5499; -1.5499] +# ``` + +# ## Broadcasting +# Note that `x .~ Dist()` will treat `x` as a collection of +# _independent_ observations rather than as a single observation. + +# ```jldoctest; setup = :(using Distributions) +# julia> @model function demo(x) +# x .~ Normal() +# end; + +# julia> m = demo([1.0, ]); + +# julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first(ℓ[@varname(x[1])]) +# -1.4189385332046727 + +# julia> m = demo([1.0; 1.0]); + +# julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) +# (-1.4189385332046727, -1.4189385332046727) +# ``` + +# """ +function varwise_logpriors( + model::Model, varinfo::AbstractVarInfo, + context::AbstractContext=PriorContext() +) +# top_context = VarwisePriorContext(OrderedDict{Symbol,Float64}(), context) + top_context = VarwisePriorContext(OrderedDict{Symbol,Float64}(), context) + model(varinfo, top_context) + return top_context.logpriors +end + +function varwise_logpriors(model::Model, chain::AbstractChains, + context::AbstractContext=PriorContext(); + top_context::VarwisePriorContext{T} = VarwisePriorContext(OrderedDict{Symbol,Float64}(), context) + ) where T + # pass top-context as keyword to allow adapt Number type of log-prior + get_values = (vi) -> begin + model(vi, top_context) + values(top_context.logpriors) + end + arr = map_model(get_values, model, chain) + par_names = collect(keys(top_context.logpriors)) + return(arr, par_names) +end + +function map_model(get_values, model::Model, chain::AbstractChains) + niters = size(chain, 1) + nchains = size(chain, 3) + vi = VarInfo(model) + iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) + # initialize the array by the first result + (sample_idx, chain_idx), iters2 = Iterators.peel(iters) + setval!(vi, chain, sample_idx, chain_idx) + values1 = get_values(vi) + arr = Array{eltype(values1)}(undef, niters, length(values1), nchains) + arr[sample_idx, :, chain_idx] .= values1 + #(sample_idx, chain_idx), iters3 = Iterators.peel(iters2) + for (sample_idx, chain_idx) in iters2 + # Update the values + setval!(vi, chain, sample_idx, chain_idx) + values_i = get_values(vi) + arr[sample_idx, :, chain_idx] .= values_i + end + return(arr) +end diff --git a/src/test_utils.jl b/src/test_utils.jl index 6f7481c40..c65c75b44 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -1042,4 +1042,37 @@ function test_context_interface(context) end end +""" +Context that multiplies each log-prior by mod +used to test whether varwise_logpriors respects child-context. +""" +struct TestLogModifyingChildContext{T,Ctx} <: DynamicPPL.AbstractContext + mod::T + context::Ctx +end +function TestLogModifyingChildContext( + mod=1.2, + context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext(), + #OrderedDict{VarName,Vector{Float64}}(),PriorContext()), +) + return TestLogModifyingChildContext{typeof(mod),typeof(context)}( + mod, context + ) +end +DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent() +DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context +function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child) + return TestLogModifyingChildContext(context.mod, child) +end +function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, vn, vi) + #@info "TestLogModifyingChildContext tilde_assume!! called for $vn" + value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) + return value, logp*context.mod, vi +end +function DynamicPPL.dot_tilde_assume(context::TestLogModifyingChildContext, right, left, vn, vi) + #@info "TestLogModifyingChildContext dot_tilde_assume!! called for $vn" + value, logp, vi = DynamicPPL.dot_tilde_assume(context.context, right, left, vn, vi) + return value, logp*context.mod, vi +end + end diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index fbd405aaa..dbbdc6245 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -23,28 +23,53 @@ end end -() -> begin - m = DynamicPPL.TestUtils.demo_assume_index_observe() - example_values = DynamicPPL.TestUtils.rand_prior_true(m) - vi = VarInfo(m) - for vn in DynamicPPL.TestUtils.varnames(m) +@testset "pointwise_logpriors" begin + # test equality of the single log-prios (not just the sum) + # by returning them from the model as generated values + @model function exdemo_assume_index_observe( + x=[1.5, 2.0], ::Type{TV}=Vector{Float64} + ) where {TV} + # `assume` with indexing and `observe` + s = TV(undef, length(x)) + for i in eachindex(s) + s[i] ~ InverseGamma(2, 3) + end + m = TV(undef, length(x)) + for i in eachindex(m) + m[i] ~ Normal(0, sqrt(s[i])) + end + x ~ MvNormal(m, Diagonal(s)) + # here also return the log-priors for testing + return (; + s=s, + m=m, + x=x, + logp=getlogp(__varinfo__), + ld=(; + s=logpdf.(Ref(InverseGamma(2, 3)), s), + m=[logpdf(Normal(0, sqrt(s[i])), m[i]) for i in eachindex(m)], + x=logpdf(MvNormal(m, Diagonal(s)), x), + ), + ) + end + mex = exdemo_assume_index_observe() + morig = DynamicPPL.TestUtils.demo_assume_index_observe() + example_values = DynamicPPL.TestUtils.rand_prior_true(morig) + vi = VarInfo(mex) + #Main.@infiltrate_main + for vn in DynamicPPL.TestUtils.varnames(mex) vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) end - ret_m = first(evaluate!!(m, vi, SamplingContext())) - @test sum(map(k -> sum(ret_m.ld[k]), eachindex(ret_m.ld))) ≈ ret_m.logp - () -> begin - #by_generated_quantities - s = vcat(example_values...) - vnames = ["s[1]", "s[2]", "m[1]", "m[2]"] - stup = (; zip(Symbol.(vnames), s)...) - ret_m = generated_quantities(m, stup) - @test sum(map(k -> sum(ret_m.ld[k]), eachindex(ret_m.ld))) ≈ ret_m.logp - #chn = Chains(reshape(s, 1, : , 1), vnames); - chn = Chains(reshape(s, 1, :, 1)) # causes warning but works - ret_m = @test_logs (:warn,) generated_quantities(m, chn)[1, 1] - @test sum(map(k -> sum(ret_m.ld[k]), eachindex(ret_m.ld))) ≈ ret_m.logp + () -> begin # for interactive execution at the repl need the global keyword + for vn in DynamicPPL.TestUtils.varnames(mex) + global vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) + end end -end + ret_m = first(evaluate!!(mex, vi, SamplingContext())) + true_logpriors = ret_m.ld[[:s, :m]] + logpriors = DynamicPPL.pointwise_logpriors(mex, vi) + @test all(vcat(values(logpriors)...) .≈ vcat(true_logpriors...)) +end; @testset "logpriors.jl" begin #m = DynamicPPL.TestUtils.DEMO_MODELS[1] @@ -70,5 +95,72 @@ end logp1 = getlogp(vi) logp = logprior(m, vi) @test !isfinite(getlogp(vi)) || sum(x -> sum(x), values(tmp)) ≈ logp - end; + end +end; + +@testset "logpriors_var.jl" begin + mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2, PriorContext()) + mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) + #m = DynamicPPL.TestUtils.DEMO_MODELS[1] + # m = DynamicPPL.TestUtils.demo_assume_index_observe() # logp at i-level? + @testset "$(m.f)" for (i, m) in enumerate(DynamicPPL.TestUtils.DEMO_MODELS) + #@show i + example_values = DynamicPPL.TestUtils.rand_prior_true(m) + + # Instantiate a `VarInfo` with the example values. + vi = VarInfo(m) + () -> begin + for vn in DynamicPPL.TestUtils.varnames(m) + global vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) + end + end + for vn in DynamicPPL.TestUtils.varnames(m) + vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) + end + + #chains = sample(m, SampleFromPrior(), 2; progress=false) + + # Compute the pointwise loglikelihoods. + logpriors = DynamicPPL.varwise_logpriors(m, vi) + logp1 = getlogp(vi) + logp = logprior(m, vi) + @test !isfinite(logp) || sum(x -> sum(x), values(logpriors)) ≈ logp + # + # test on modifying child-context + logpriors_mod = DynamicPPL.varwise_logpriors(m, vi, mod_ctx2) + logp1 = getlogp(vi) + # Following line assumes no Likelihood contributions + # requires lowest Context to be PriorContext + @test !isfinite(logp1) || sum(x -> sum(x), values(logpriors_mod)) ≈ logp1 # + @test all(values(logpriors_mod) .≈ values(logpriors) .* 1.2 .* 1.4) + end end; + +@testset "logpriors_var chain" begin + @model function demo(xs, y) + s ~ InverseGamma(2, 3) + m ~ Normal(0, √s) + for i in eachindex(xs) + xs[i] ~ Normal(m, √s) + end + y ~ Normal(m, √s) + end + xs_true, y_true = ([0.3290767977680923, 0.038972110187911684, -0.5797496780649221], -0.7321425592768186)#randn(3), randn() + model = demo(xs_true, y_true) + () -> begin + # generate the sample used below + chain = sample(model, MH(), 10) + arr0 = Array(chain) + end + arr0 = [1.8585322626573435 -0.05900855284939967; 1.7304068220366808 -0.6386249100228161; 1.7304068220366808 -0.6386249100228161; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039]; # generated in function above + # split into two chains for testing + arr1 = permutedims(reshape(arr0, 5,2,:),(1,3,2)) + chain = Chains(arr1, [:s, :m]); + tmp1 = varwise_logpriors(model, chain) + tmp = Chains(tmp1...); # can be used to create a Chains object + vi = VarInfo(model) + i_sample, i_chain = (1,2) + DynamicPPL.setval!(vi, chain, i_sample, i_chain) + lp1 = DynamicPPL.varwise_logpriors(model, vi) + @test all(tmp1[1][i_sample,:,i_chain] .≈ values(lp1)) +end; \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index eafc0e652..a321151ab 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,6 +23,8 @@ using Random using Serialization using Test using Logging +using Distributions +using LinearAlgebra # Diagonal using DynamicPPL: getargs_dottilde, getargs_tilde, Selector From c6653b98f3fd86471c954b72a042c6cc240c8e13 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Tue, 17 Sep 2024 08:29:51 +0200 Subject: [PATCH 03/14] remove pointwise_logpriors --- src/DynamicPPL.jl | 2 - src/logpriors.jl | 210 ----------------------------------------- test/loglikelihoods.jl | 75 --------------- 3 files changed, 287 deletions(-) delete mode 100644 src/logpriors.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e69cdd568..774a8010f 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -116,7 +116,6 @@ export AbstractVarInfo, logprior, logjoint, pointwise_loglikelihoods, - pointwise_logpriors, varwise_logpriors, condition, decondition, @@ -184,7 +183,6 @@ include("simple_varinfo.jl") include("context_implementations.jl") include("compiler.jl") include("loglikelihoods.jl") -include("logpriors.jl") include("logpriors_var.jl") include("submodel_macro.jl") include("test_utils.jl") diff --git a/src/logpriors.jl b/src/logpriors.jl deleted file mode 100644 index 1f3d8b3ad..000000000 --- a/src/logpriors.jl +++ /dev/null @@ -1,210 +0,0 @@ -# Context version -struct PointwisePriorContext{A,Ctx} <: AbstractContext - logpriors::A - context::Ctx -end - -function PointwisePriorContext( - priors=OrderedDict{VarName,Vector{Float64}}(), - context::AbstractContext=DynamicPPL.PriorContext(), - #OrderedDict{VarName,Vector{Float64}}(),PriorContext()), -) - return PointwisePriorContext{typeof(priors),typeof(context)}( - priors, context - ) -end - -NodeTrait(::PointwisePriorContext) = IsParent() -childcontext(context::PointwisePriorContext) = context.context -function setchildcontext(context::PointwisePriorContext, child) - return PointwisePriorContext(context.logpriors, child) -end - - -function tilde_assume!!(context::PointwisePriorContext, right, vn, vi) - #@info "PointwisePriorContext tilde_assume!! called for $vn" - value, logp, vi = tilde_assume(context, right, vn, vi) - push!(context, vn, logp) - return value, acclogp_assume!!(context, vi, logp) -end - -function dot_tilde_assume!!(context::PointwisePriorContext, right, left, vn, vi) - #@info "PointwisePriorContext dot_tilde_assume!! called for $vn" - # @show vn, left, right, typeof(context).name - value, logps, vi = dot_assume_vec(right, left, vn, vi) - # Need to unwrap the `vn`, i.e. get one `VarName` for each entry in `left`. - #_, _, vns = unwrap_right_left_vns(right, left, vn) - vns = vn - for (vn, logp) in zip(vns, logps) - # Track log-prior value. - push!(context, vn, logp) - end - return value, acclogp!!(vi, sum(logps)), vi -end - -function tilde_observe(context::PointwisePriorContext, right, left, vi) - # Since we are evaluating the prior, the log probability of all the observations - # is set to 0. This has the effect of ignoring the likelihood. - return 0.0, vi - #tmp = tilde_observe(context.context, SampleFromPrior(), right, left, vi) - #return tmp -end - -function Base.push!( - context::PointwisePriorContext{<:AbstractDict{VarName,Vector{Float64}}}, - vn::VarName, - logp::Real, -) - lookup = context.logpriors - ℓ = get!(lookup, vn, Float64[]) - return push!(ℓ, logp) -end - -function Base.push!( - context::PointwisePriorContext{<:AbstractDict{VarName,Float64}}, - vn::VarName, - logp::Real, -) - return context.logpriors[vn] = logp -end - -function Base.push!( - context::PointwisePriorContext{<:AbstractDict{String,Vector{Float64}}}, - vn::VarName, - logp::Real, -) - lookup = context.logpriors - ℓ = get!(lookup, string(vn), Float64[]) - return push!(ℓ, logp) -end - -function Base.push!( - context::PointwisePriorContext{<:AbstractDict{String,Float64}}, - vn::VarName, - logp::Real, -) - return context.logpriors[string(vn)] = logp -end - -function Base.push!( - context::PointwisePriorContext{<:AbstractDict{String,Vector{Float64}}}, - vn::String, - logp::Real, -) - lookup = context.logpriors - ℓ = get!(lookup, vn, Float64[]) - return push!(ℓ, logp) -end - -function Base.push!( - context::PointwisePriorContext{<:AbstractDict{String,Float64}}, - vn::String, - logp::Real, -) - return context.logpriors[vn] = logp -end - - -# """ -# pointwise_logpriors(model::Model, chain::Chains, keytype = String) - -# Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` -# with keys corresponding to symbols of the observations, and values being matrices -# of shape `(num_chains, num_samples)`. - -# `keytype` specifies what the type of the keys used in the returned `OrderedDict` are. -# Currently, only `String` and `VarName` are supported. - -# # Notes -# Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ` -# both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an -# *observation*) statements can be implemented in three ways: -# 1. using a `for` loop: -# ```julia -# for i in eachindex(y) -# y[i] ~ Normal(μ, σ) -# end -# ``` -# 2. using `.~`: -# ```julia -# y .~ Normal(μ, σ) -# ``` -# 3. using `MvNormal`: -# ```julia -# y ~ MvNormal(fill(μ, n), σ^2 * I) -# ``` - -# In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables, -# while in (3) `y` will be treated as a _single_ n-dimensional observation. - -# This is important to keep in mind, in particular if the computation is used -# for downstream computations. - -# # Examples -# ## From chain -# ```julia-repl -# julia> using DynamicPPL, Turing - -# julia> @model function demo(xs, y) -# s ~ InverseGamma(2, 3) -# m ~ Normal(0, √s) -# for i in eachindex(xs) -# xs[i] ~ Normal(m, √s) -# end - -# y ~ Normal(m, √s) -# end -# demo (generic function with 1 method) - -# julia> model = demo(randn(3), randn()); - -# julia> chain = sample(model, MH(), 10); - -# julia> pointwise_logpriors(model, chain) -# OrderedDict{String,Array{Float64,2}} with 4 entries: -# "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] -# "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] -# "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] -# "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] - -# julia> pointwise_logpriors(model, chain, String) -# OrderedDict{String,Array{Float64,2}} with 4 entries: -# "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] -# "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] -# "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] -# "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] - -# julia> pointwise_logpriors(model, chain, VarName) -# OrderedDict{VarName,Array{Float64,2}} with 4 entries: -# xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333] -# xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359] -# xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251] -# y => [-1.51265; -0.914129; … ; -1.5499; -1.5499] -# ``` - -# ## Broadcasting -# Note that `x .~ Dist()` will treat `x` as a collection of -# _independent_ observations rather than as a single observation. - -# ```jldoctest; setup = :(using Distributions) -# julia> @model function demo(x) -# x .~ Normal() -# end; - -# julia> m = demo([1.0, ]); - -# julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first(ℓ[@varname(x[1])]) -# -1.4189385332046727 - -# julia> m = demo([1.0; 1.0]); - -# julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) -# (-1.4189385332046727, -1.4189385332046727) -# ``` - -# """ -function pointwise_logpriors(model::Model, varinfo::AbstractVarInfo) - context = PointwisePriorContext() - model(varinfo, context) - return context.logpriors -end diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index dbbdc6245..b511a03d7 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -23,81 +23,6 @@ end end -@testset "pointwise_logpriors" begin - # test equality of the single log-prios (not just the sum) - # by returning them from the model as generated values - @model function exdemo_assume_index_observe( - x=[1.5, 2.0], ::Type{TV}=Vector{Float64} - ) where {TV} - # `assume` with indexing and `observe` - s = TV(undef, length(x)) - for i in eachindex(s) - s[i] ~ InverseGamma(2, 3) - end - m = TV(undef, length(x)) - for i in eachindex(m) - m[i] ~ Normal(0, sqrt(s[i])) - end - x ~ MvNormal(m, Diagonal(s)) - # here also return the log-priors for testing - return (; - s=s, - m=m, - x=x, - logp=getlogp(__varinfo__), - ld=(; - s=logpdf.(Ref(InverseGamma(2, 3)), s), - m=[logpdf(Normal(0, sqrt(s[i])), m[i]) for i in eachindex(m)], - x=logpdf(MvNormal(m, Diagonal(s)), x), - ), - ) - end - mex = exdemo_assume_index_observe() - morig = DynamicPPL.TestUtils.demo_assume_index_observe() - example_values = DynamicPPL.TestUtils.rand_prior_true(morig) - vi = VarInfo(mex) - #Main.@infiltrate_main - for vn in DynamicPPL.TestUtils.varnames(mex) - vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) - end - () -> begin # for interactive execution at the repl need the global keyword - for vn in DynamicPPL.TestUtils.varnames(mex) - global vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) - end - end - ret_m = first(evaluate!!(mex, vi, SamplingContext())) - true_logpriors = ret_m.ld[[:s, :m]] - logpriors = DynamicPPL.pointwise_logpriors(mex, vi) - @test all(vcat(values(logpriors)...) .≈ vcat(true_logpriors...)) -end; - -@testset "logpriors.jl" begin - #m = DynamicPPL.TestUtils.DEMO_MODELS[1] - @testset "$(m.f)" for (i, m) in enumerate(DynamicPPL.TestUtils.DEMO_MODELS) - #@show i - example_values = DynamicPPL.TestUtils.rand_prior_true(m) - - # Instantiate a `VarInfo` with the example values. - vi = VarInfo(m) - () -> begin - for vn in DynamicPPL.TestUtils.varnames(m) - global vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) - end - end - for vn in DynamicPPL.TestUtils.varnames(m) - vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) - end - - #chains = sample(m, SampleFromPrior(), 2; progress=false) - - # Compute the pointwise loglikelihoods. - tmp = DynamicPPL.pointwise_logpriors(m, vi) - logp1 = getlogp(vi) - logp = logprior(m, vi) - @test !isfinite(getlogp(vi)) || sum(x -> sum(x), values(tmp)) ≈ logp - end -end; - @testset "logpriors_var.jl" begin mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2, PriorContext()) mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) From 216d50cb7343da0a0a1680d49e5e8c79e5344707 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Tue, 17 Sep 2024 12:59:03 +0200 Subject: [PATCH 04/14] revert dot_assume to not explicitly resolve components of sum --- src/context_implementations.jl | 25 ++++--------------------- 1 file changed, 4 insertions(+), 21 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 338749344..13231837f 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -425,15 +425,6 @@ function dot_assume( var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi::AbstractVarInfo, -) - r, lp, vi = dot_assume_vec(dist, var, vns, vi) - return r, sum(lp), vi -end -function dot_assume_vec( - dist::MultivariateDistribution, - var::AbstractMatrix, - vns::AbstractVector{<:VarName}, - vi::AbstractVarInfo, ) @assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))" # NOTE: We cannot work with `var` here because we might have a model of the form @@ -443,7 +434,7 @@ function dot_assume_vec( # # in which case `var` will have `undef` elements, even if `m` is present in `vi`. r = vi[vns, dist] - lp = map(zip(vns, eachcol(r))) do (vn, ri) + lp = sum(zip(vns, eachcol(r))) do (vn, ri) return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) end return r, lp, vi @@ -464,29 +455,21 @@ function dot_assume( end function dot_assume( - dist::Union{Distribution,AbstractArray{<:Distribution}}, var::AbstractArray, vns::AbstractArray{<:VarName}, vi -) - # possibility to acesss the single logpriors - r, lp, vi = dot_assume_vec(dist, var, vns, vi) - return r, sum(lp), vi -end - -function dot_assume_vec( dist::Distribution, var::AbstractArray, vns::AbstractArray{<:VarName}, vi ) r = getindex.((vi,), vns, (dist,)) - lp = Bijectors.logpdf_with_trans.((dist,), r, istrans.((vi,), vns)) + lp = sum(Bijectors.logpdf_with_trans.((dist,), r, istrans.((vi,), vns))) return r, lp, vi end -function dot_assume_vec( +function dot_assume( dists::AbstractArray{<:Distribution}, var::AbstractArray, vns::AbstractArray{<:VarName}, vi, ) r = getindex.((vi,), vns, dists) - lp = Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns)) + lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))) return r, lp, vi end From fd8d3b244cb861eae2850a35415b62cf47b7994a Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Wed, 18 Sep 2024 10:09:40 +0200 Subject: [PATCH 05/14] docstring varwise_logpriores use loop for prior in example Unfortunately cannot make it a jldoctest, because relies on Turing for sampling --- src/logpriors_var.jl | 154 +++++++++++++---------------------------- src/test_utils.jl | 14 +++- test/loglikelihoods.jl | 31 +++++---- 3 files changed, 77 insertions(+), 122 deletions(-) diff --git a/src/logpriors_var.jl b/src/logpriors_var.jl index 1c534b60a..1860d8ccd 100644 --- a/src/logpriors_var.jl +++ b/src/logpriors_var.jl @@ -40,13 +40,8 @@ function dot_tilde_assume(context::VarwisePriorContext, right, left, vn, vi) end -function tilde_observe(context::VarwisePriorContext, right, left, vi) - # Since we are evaluating the prior, the log probability of all the observations - # is set to 0. This has the effect of ignoring the likelihood. - return 0.0, vi - #tmp = tilde_observe(context.context, SampleFromPrior(), right, left, vi) - #return tmp -end +tilde_observe(context::VarwisePriorContext, right, left, vi) = 0, vi +dot_tilde_observe(::VarwisePriorContext, right, left, vi) = 0, vi function acc_logp!(context::VarwisePriorContext, vn::Union{VarName,AbstractVector{<:VarName}}, logp) #sym = DynamicPPL.getsym(vn) # leads to duplicates @@ -56,105 +51,52 @@ function acc_logp!(context::VarwisePriorContext, vn::Union{VarName,AbstractVecto return (context) end - -# """ -# pointwise_logpriors(model::Model, chain::Chains, keytype = String) - -# Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` -# with keys corresponding to symbols of the observations, and values being matrices -# of shape `(num_chains, num_samples)`. - -# `keytype` specifies what the type of the keys used in the returned `OrderedDict` are. -# Currently, only `String` and `VarName` are supported. - -# # Notes -# Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ` -# both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an -# *observation*) statements can be implemented in three ways: -# 1. using a `for` loop: -# ```julia -# for i in eachindex(y) -# y[i] ~ Normal(μ, σ) -# end -# ``` -# 2. using `.~`: -# ```julia -# y .~ Normal(μ, σ) -# ``` -# 3. using `MvNormal`: -# ```julia -# y ~ MvNormal(fill(μ, n), σ^2 * I) -# ``` - -# In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables, -# while in (3) `y` will be treated as a _single_ n-dimensional observation. - -# This is important to keep in mind, in particular if the computation is used -# for downstream computations. - -# # Examples -# ## From chain -# ```julia-repl -# julia> using DynamicPPL, Turing - -# julia> @model function demo(xs, y) -# s ~ InverseGamma(2, 3) -# m ~ Normal(0, √s) -# for i in eachindex(xs) -# xs[i] ~ Normal(m, √s) -# end - -# y ~ Normal(m, √s) -# end -# demo (generic function with 1 method) - -# julia> model = demo(randn(3), randn()); - -# julia> chain = sample(model, MH(), 10); - -# julia> pointwise_logpriors(model, chain) -# OrderedDict{String,Array{Float64,2}} with 4 entries: -# "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] -# "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] -# "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] -# "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] - -# julia> pointwise_logpriors(model, chain, String) -# OrderedDict{String,Array{Float64,2}} with 4 entries: -# "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] -# "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] -# "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] -# "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] - -# julia> pointwise_logpriors(model, chain, VarName) -# OrderedDict{VarName,Array{Float64,2}} with 4 entries: -# xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333] -# xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359] -# xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251] -# y => [-1.51265; -0.914129; … ; -1.5499; -1.5499] -# ``` - -# ## Broadcasting -# Note that `x .~ Dist()` will treat `x` as a collection of -# _independent_ observations rather than as a single observation. - -# ```jldoctest; setup = :(using Distributions) -# julia> @model function demo(x) -# x .~ Normal() -# end; - -# julia> m = demo([1.0, ]); - -# julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first(ℓ[@varname(x[1])]) -# -1.4189385332046727 - -# julia> m = demo([1.0; 1.0]); - -# julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) -# (-1.4189385332046727, -1.4189385332046727) -# ``` - -# """ +""" + varwise_logpriors(model::Model, chain::Chains; context) + +Runs `model` on each sample in `chain` returning a tuple `(values, var_names)` +with var_names corresponding to symbols of the prior components, and values being +array of shape `(num_samples, num_components, num_chains)`. + +`context` specifies child context that handles computation of log-priors. + +# Example +```julia; setup = :(using Distributions) +using DynamicPPL, Turing + +@model function demo(x, ::Type{TV}=Vector{Float64}) where {TV} + s ~ InverseGamma(2, 3) + m = TV(undef, length(x)) + for i in eachindex(x) + m[i] ~ Normal(0, √s) + end + x ~ MvNormal(m, √s) + end + +model = demo(randn(3), randn()); + +chain = sample(model, MH(), 10); + +lp = varwise_logpriors(model, chain) +# Can be used to construct a new Chains object +#lpc = MCMCChains(varwise_logpriors(model, chain)...) + +# got a logdensity for each parameter prior +(but fewer if used `.~` assignments, see below) +lp[2] == names(chain, :parameters) + +# for each sample in the Chains object +size(lp[1])[[1,3]] == size(chain)[[1,3]] +``` + +# Broadcasting +Note that `m .~ Dist()` will treat `m` as a collection of +_independent_ prior rather than as a single prior, +but `varwise_logpriors` returns the single +sum of log-likelihood of components of `m` only. +If one needs the log-density of the components, one needs to rewrite +the model with an explicit loop. +""" function varwise_logpriors( model::Model, varinfo::AbstractVarInfo, context::AbstractContext=PriorContext() diff --git a/src/test_utils.jl b/src/test_utils.jl index c65c75b44..9db8c37fd 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -1059,7 +1059,10 @@ function TestLogModifyingChildContext( mod, context ) end -DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent() +# Samplers call leafcontext(model.context) when evaluating log-densities +# Hence, in order to be used need to say that its a leaf-context +#DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent() +DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsLeaf() DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child) return TestLogModifyingChildContext(context.mod, child) @@ -1074,5 +1077,14 @@ function DynamicPPL.dot_tilde_assume(context::TestLogModifyingChildContext, righ value, logp, vi = DynamicPPL.dot_tilde_assume(context.context, right, left, vn, vi) return value, logp*context.mod, vi end +function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi) + value, logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) + return value, logp*context.mod, vi +end +function DynamicPPL.dot_tilde_observe(context::TestLogModifyingChildContext, right, left, vi) + return DynamicPPL.dot_tilde_observe(context.context, right, left, vi) + return value, logp*context.mod, vi +end + end diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index b511a03d7..17185c4ad 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -28,6 +28,7 @@ end mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) #m = DynamicPPL.TestUtils.DEMO_MODELS[1] # m = DynamicPPL.TestUtils.demo_assume_index_observe() # logp at i-level? + # m = DynamicPPL.TestUtils.demo_assume_dot_observe() # failing test? @testset "$(m.f)" for (i, m) in enumerate(DynamicPPL.TestUtils.DEMO_MODELS) #@show i example_values = DynamicPPL.TestUtils.rand_prior_true(m) @@ -53,7 +54,8 @@ end # # test on modifying child-context logpriors_mod = DynamicPPL.varwise_logpriors(m, vi, mod_ctx2) - logp1 = getlogp(vi) + logp1 = getlogp(vi) + #logp_mod = logprior(m, vi) # uses prior context but not mod_ctx2 # Following line assumes no Likelihood contributions # requires lowest Context to be PriorContext @test !isfinite(logp1) || sum(x -> sum(x), values(logpriors_mod)) ≈ logp1 # @@ -62,25 +64,24 @@ end end; @testset "logpriors_var chain" begin - @model function demo(xs, y) + @model function demo(x, ::Type{TV}=Vector{Float64}) where {TV} s ~ InverseGamma(2, 3) - m ~ Normal(0, √s) - for i in eachindex(xs) - xs[i] ~ Normal(m, √s) + m = TV(undef, length(x)) + for i in eachindex(x) + m[i] ~ Normal(0, √s) end - y ~ Normal(m, √s) - end - xs_true, y_true = ([0.3290767977680923, 0.038972110187911684, -0.5797496780649221], -0.7321425592768186)#randn(3), randn() - model = demo(xs_true, y_true) + x ~ MvNormal(m, √s) + end + x_true = [0.3290767977680923, 0.038972110187911684, -0.5797496780649221] + model = demo(x_true) () -> begin # generate the sample used below - chain = sample(model, MH(), 10) - arr0 = Array(chain) + chain = sample(model, MH(), MCMCThreads(), 10, 2) + arr0 = stack(Array(chain, append_chains=false)) + @show(arr0); end - arr0 = [1.8585322626573435 -0.05900855284939967; 1.7304068220366808 -0.6386249100228161; 1.7304068220366808 -0.6386249100228161; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039]; # generated in function above - # split into two chains for testing - arr1 = permutedims(reshape(arr0, 5,2,:),(1,3,2)) - chain = Chains(arr1, [:s, :m]); + arr0 = [5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 0.9199555480151707 -0.1304320097505629 1.0669120062696917 -0.05253734412139093; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183;;; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 2.5409470583244933 1.7838744695696407 0.7013562890105632 -3.0843947804314658; 0.8296370582311665 1.5360702767879642 -1.5964695255693102 0.16928084806166913; 2.6246697053824954 0.8096845024785173 -1.2621822861663752 1.1414885535466166; 1.1304261861894538 0.7325784741344005 -1.1866016911837542 -0.1639319562090826; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 0.9838526141898173 -0.20198797220982412 2.0569535882007006 -1.1560724118010939] + chain = Chains(arr0, [:s, Symbol("m[1]"), Symbol("m[2]"), Symbol("m[3]")]); tmp1 = varwise_logpriors(model, chain) tmp = Chains(tmp1...); # can be used to create a Chains object vi = VarInfo(model) From 5842656154a5b2f9a0377c45a4d4438933971a11 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Thu, 19 Sep 2024 10:54:05 +0200 Subject: [PATCH 06/14] integrate pointwise_loglikelihoods and varwise_logpriors by pointwise_densities --- src/DynamicPPL.jl | 6 +- src/deprecated.jl | 9 ++ src/logpriors_var.jl | 143 ------------------ ...kelihoods.jl => pointwise_logdensities.jl} | 128 +++++++++++----- src/test_utils.jl | 9 +- test/contexts.jl | 4 +- test/deprecated.jl | 29 ++++ ...ikelihoods.jl => pointwise_logdensitiesjl} | 96 ++++++------ test/runtests.jl | 4 +- 9 files changed, 184 insertions(+), 244 deletions(-) create mode 100644 src/deprecated.jl delete mode 100644 src/logpriors_var.jl rename src/{loglikelihoods.jl => pointwise_logdensities.jl} (59%) create mode 100644 test/deprecated.jl rename test/{loglikelihoods.jl => pointwise_logdensitiesjl} (65%) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 774a8010f..9d870c9e8 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -116,7 +116,7 @@ export AbstractVarInfo, logprior, logjoint, pointwise_loglikelihoods, - varwise_logpriors, + pointwise_logdensities, condition, decondition, fix, @@ -182,8 +182,7 @@ include("varinfo.jl") include("simple_varinfo.jl") include("context_implementations.jl") include("compiler.jl") -include("loglikelihoods.jl") -include("logpriors_var.jl") +include("pointwise_logdensities.jl") include("submodel_macro.jl") include("test_utils.jl") include("transforming.jl") @@ -191,6 +190,7 @@ include("logdensityfunction.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") +include("deprecated.jl") include("debug_utils.jl") using .DebugUtils diff --git a/src/deprecated.jl b/src/deprecated.jl new file mode 100644 index 000000000..09e1ac84d --- /dev/null +++ b/src/deprecated.jl @@ -0,0 +1,9 @@ +# https://invenia.github.io/blog/2022/06/17/deprecating-in-julia/ + +Base.@deprecate pointwise_loglikelihoods(model::Model, chain, keytype) pointwise_logdensities( + model::Model, LikelihoodContext(), chain, keytype) + +Base.@deprecate pointwise_loglikelihoods( + model::Model, varinfo::AbstractVarInfo) pointwise_logdensities( + model::Model, varinfo, LikelihoodContext()) + diff --git a/src/logpriors_var.jl b/src/logpriors_var.jl deleted file mode 100644 index 1860d8ccd..000000000 --- a/src/logpriors_var.jl +++ /dev/null @@ -1,143 +0,0 @@ -""" -Context that records logp after tilde_assume!! -for each VarName used by [`varwise_logpriors`](@ref). -""" -struct VarwisePriorContext{A,Ctx} <: AbstractContext - logpriors::A - context::Ctx -end - -function VarwisePriorContext( - logpriors=OrderedDict{Symbol,Float64}(), - context::AbstractContext=DynamicPPL.PriorContext(), - #OrderedDict{Symbol,Vector{Float64}}(),PriorContext()), -) - return VarwisePriorContext{typeof(logpriors),typeof(context)}( - logpriors, context - ) -end - -NodeTrait(::VarwisePriorContext) = IsParent() -childcontext(context::VarwisePriorContext) = context.context -function setchildcontext(context::VarwisePriorContext, child) - return VarwisePriorContext(context.logpriors, child) -end - -function tilde_assume(context::VarwisePriorContext, right, vn, vi) - #@info "VarwisePriorContext tilde_assume!! called for $vn" - value, logp, vi = tilde_assume(context.context, right, vn, vi) - #sym = DynamicPPL.getsym(vn) - new_context = acc_logp!(context, vn, logp) - return value, logp, vi -end - -function dot_tilde_assume(context::VarwisePriorContext, right, left, vn, vi) - #@info "VarwisePriorContext dot_tilde_assume!! called for $vn" - # @show vn, left, right, typeof(context).name - value, logp, vi = dot_tilde_assume(context.context, right, left, vn, vi) - new_context = acc_logp!(context, vn, logp) - return value, logp, vi -end - - -tilde_observe(context::VarwisePriorContext, right, left, vi) = 0, vi -dot_tilde_observe(::VarwisePriorContext, right, left, vi) = 0, vi - -function acc_logp!(context::VarwisePriorContext, vn::Union{VarName,AbstractVector{<:VarName}}, logp) - #sym = DynamicPPL.getsym(vn) # leads to duplicates - # if vn is a Vector leads to Symbol("VarName{:s, IndexLens{Tuple{Int64}}}[s[1], s[2]]") - sym = Symbol(vn) - context.logpriors[sym] = logp - return (context) -end - -""" - varwise_logpriors(model::Model, chain::Chains; context) - -Runs `model` on each sample in `chain` returning a tuple `(values, var_names)` -with var_names corresponding to symbols of the prior components, and values being -array of shape `(num_samples, num_components, num_chains)`. - -`context` specifies child context that handles computation of log-priors. - -# Example -```julia; setup = :(using Distributions) -using DynamicPPL, Turing - -@model function demo(x, ::Type{TV}=Vector{Float64}) where {TV} - s ~ InverseGamma(2, 3) - m = TV(undef, length(x)) - for i in eachindex(x) - m[i] ~ Normal(0, √s) - end - x ~ MvNormal(m, √s) - end - -model = demo(randn(3), randn()); - -chain = sample(model, MH(), 10); - -lp = varwise_logpriors(model, chain) -# Can be used to construct a new Chains object -#lpc = MCMCChains(varwise_logpriors(model, chain)...) - -# got a logdensity for each parameter prior -(but fewer if used `.~` assignments, see below) -lp[2] == names(chain, :parameters) - -# for each sample in the Chains object -size(lp[1])[[1,3]] == size(chain)[[1,3]] -``` - -# Broadcasting -Note that `m .~ Dist()` will treat `m` as a collection of -_independent_ prior rather than as a single prior, -but `varwise_logpriors` returns the single -sum of log-likelihood of components of `m` only. -If one needs the log-density of the components, one needs to rewrite -the model with an explicit loop. -""" -function varwise_logpriors( - model::Model, varinfo::AbstractVarInfo, - context::AbstractContext=PriorContext() -) -# top_context = VarwisePriorContext(OrderedDict{Symbol,Float64}(), context) - top_context = VarwisePriorContext(OrderedDict{Symbol,Float64}(), context) - model(varinfo, top_context) - return top_context.logpriors -end - -function varwise_logpriors(model::Model, chain::AbstractChains, - context::AbstractContext=PriorContext(); - top_context::VarwisePriorContext{T} = VarwisePriorContext(OrderedDict{Symbol,Float64}(), context) - ) where T - # pass top-context as keyword to allow adapt Number type of log-prior - get_values = (vi) -> begin - model(vi, top_context) - values(top_context.logpriors) - end - arr = map_model(get_values, model, chain) - par_names = collect(keys(top_context.logpriors)) - return(arr, par_names) -end - -function map_model(get_values, model::Model, chain::AbstractChains) - niters = size(chain, 1) - nchains = size(chain, 3) - vi = VarInfo(model) - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - # initialize the array by the first result - (sample_idx, chain_idx), iters2 = Iterators.peel(iters) - setval!(vi, chain, sample_idx, chain_idx) - values1 = get_values(vi) - arr = Array{eltype(values1)}(undef, niters, length(values1), nchains) - arr[sample_idx, :, chain_idx] .= values1 - #(sample_idx, chain_idx), iters3 = Iterators.peel(iters2) - for (sample_idx, chain_idx) in iters2 - # Update the values - setval!(vi, chain, sample_idx, chain_idx) - values_i = get_values(vi) - arr[sample_idx, :, chain_idx] .= values_i - end - return(arr) -end diff --git a/src/loglikelihoods.jl b/src/pointwise_logdensities.jl similarity index 59% rename from src/loglikelihoods.jl rename to src/pointwise_logdensities.jl index 227a70889..73d3e5c0c 100644 --- a/src/loglikelihoods.jl +++ b/src/pointwise_logdensities.jl @@ -1,83 +1,83 @@ # Context version -struct PointwiseLikelihoodContext{A,Ctx} <: AbstractContext - loglikelihoods::A +struct PointwiseLogdensityContext{A,Ctx} <: AbstractContext + logdensities::A context::Ctx end -function PointwiseLikelihoodContext( +function PointwiseLogdensityContext( likelihoods=OrderedDict{VarName,Vector{Float64}}(), - context::AbstractContext=LikelihoodContext(), + context::AbstractContext=DefaultContext(), ) - return PointwiseLikelihoodContext{typeof(likelihoods),typeof(context)}( + return PointwiseLogdensityContext{typeof(likelihoods),typeof(context)}( likelihoods, context ) end -NodeTrait(::PointwiseLikelihoodContext) = IsParent() -childcontext(context::PointwiseLikelihoodContext) = context.context -function setchildcontext(context::PointwiseLikelihoodContext, child) - return PointwiseLikelihoodContext(context.loglikelihoods, child) +NodeTrait(::PointwiseLogdensityContext) = IsParent() +childcontext(context::PointwiseLogdensityContext) = context.context +function setchildcontext(context::PointwiseLogdensityContext, child) + return PointwiseLogdensityContext(context.logdensities, child) end function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{VarName,Vector{Float64}}}, + context::PointwiseLogdensityContext{<:AbstractDict{VarName,Vector{Float64}}}, vn::VarName, logp::Real, ) - lookup = context.loglikelihoods + lookup = context.logdensities ℓ = get!(lookup, vn, Float64[]) return push!(ℓ, logp) end function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{VarName,Float64}}, + context::PointwiseLogdensityContext{<:AbstractDict{VarName,Float64}}, vn::VarName, logp::Real, ) - return context.loglikelihoods[vn] = logp + return context.logdensities[vn] = logp end function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{String,Vector{Float64}}}, + context::PointwiseLogdensityContext{<:AbstractDict{String,Vector{Float64}}}, vn::VarName, logp::Real, ) - lookup = context.loglikelihoods + lookup = context.logdensities ℓ = get!(lookup, string(vn), Float64[]) return push!(ℓ, logp) end function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{String,Float64}}, + context::PointwiseLogdensityContext{<:AbstractDict{String,Float64}}, vn::VarName, logp::Real, ) - return context.loglikelihoods[string(vn)] = logp + return context.logdensities[string(vn)] = logp end function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{String,Vector{Float64}}}, + context::PointwiseLogdensityContext{<:AbstractDict{String,Vector{Float64}}}, vn::String, logp::Real, ) - lookup = context.loglikelihoods + lookup = context.logdensities ℓ = get!(lookup, vn, Float64[]) return push!(ℓ, logp) end function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{String,Float64}}, + context::PointwiseLogdensityContext{<:AbstractDict{String,Float64}}, vn::String, logp::Real, ) - return context.loglikelihoods[vn] = logp + return context.logdensities[vn] = logp end -function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vi) +function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi) # Defer literal `observe` to child-context. return tilde_observe!!(context.context, right, left, vi) end -function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, vi) +function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi) # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. # we have to intercept the call to `tilde_observe!`. logp, vi = tilde_observe(context.context, right, left, vi) @@ -88,11 +88,11 @@ function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, v return left, acclogp!!(vi, logp) end -function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vi) +function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi) # Defer literal `observe` to child-context. return dot_tilde_observe!!(context.context, right, left, vi) end -function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, vi) +function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi) # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. # we have to intercept the call to `dot_tilde_observe!`. @@ -129,8 +129,49 @@ function _pointwise_tilde_observe( end end +function tilde_assume(context::PointwiseLogdensityContext, right, vn, vi) + #@info "PointwiseLogdensityContext tilde_assume!! called for $vn" + value, logp, vi = tilde_assume(context.context, right, vn, vi) + #sym = DynamicPPL.getsym(vn) + new_context = acc_logp!(context, vn, logp) + return value, logp, vi +end + +function dot_tilde_assume(context::PointwiseLogdensityContext, right, left, vn, vi) + #@info "PointwiseLogdensityContext dot_tilde_assume!! called for $vn" + # @show vn, left, right, typeof(context).name + value, logp, vi = dot_tilde_assume(context.context, right, left, vn, vi) + new_context = acc_logp!(context, vn, logp) + return value, logp, vi +end + +function acc_logp!(context::PointwiseLogdensityContext, vn::VarName, logp) + push!(context, vn, logp) + return (context) +end + +function acc_logp!(context::PointwiseLogdensityContext, vns::AbstractVector{<:VarName}, logp) + # construct a new VarName from given sequence of VarName + # assume that all items in vns have an IndexLens optic + indices = tuplejoin(map(vn -> getoptic(vn).indices, vns)...) + vn = VarName(first(vns), Accessors.IndexLens(indices)) + push!(context, vn, logp) + return (context) +end + +#https://discourse.julialang.org/t/efficient-tuple-concatenation/5398/8 +@inline tuplejoin(x) = x +@inline tuplejoin(x, y) = (x..., y...) +@inline tuplejoin(x, y, z...) = (x..., tuplejoin(y, z...)...) + +() -> begin + # code that generates julia-repl in docstring below + # using DynamicPPL, Turing + # TODO when Turing version that is compatible with DynamicPPL 0.29 becomes available +end + """ - pointwise_loglikelihoods(model::Model, chain::Chains, keytype = String) + pointwise_logdensities(model::Model, chain::Chains, keytype = String) Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` with keys corresponding to symbols of the observations, and values being matrices @@ -184,21 +225,21 @@ julia> model = demo(randn(3), randn()); julia> chain = sample(model, MH(), 10); -julia> pointwise_loglikelihoods(model, chain) +julia> pointwise_logdensities(model, chain) OrderedDict{String,Array{Float64,2}} with 4 entries: "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] -julia> pointwise_loglikelihoods(model, chain, String) +julia> pointwise_logdensities(model, chain, String) OrderedDict{String,Array{Float64,2}} with 4 entries: "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] -julia> pointwise_loglikelihoods(model, chain, VarName) +julia> pointwise_logdensities(model, chain, VarName) OrderedDict{VarName,Array{Float64,2}} with 4 entries: xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333] xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359] @@ -217,20 +258,21 @@ julia> @model function demo(x) julia> m = demo([1.0, ]); -julia> ℓ = pointwise_loglikelihoods(m, VarInfo(m)); first(ℓ[@varname(x[1])]) +julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first(ℓ[@varname(x[1])]) -1.4189385332046727 julia> m = demo([1.0; 1.0]); -julia> ℓ = pointwise_loglikelihoods(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) +julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) (-1.4189385332046727, -1.4189385332046727) ``` """ -function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T} +function pointwise_logdensities(model::Model, chain, + context::AbstractContext=DefaultContext(), keytype::Type{T}=String) where {T} # Get the data by executing the model once vi = VarInfo(model) - context = PointwiseLikelihoodContext(OrderedDict{T,Vector{Float64}}()) + point_context = PointwiseLogdensityContext(OrderedDict{T,Vector{Float64}}(), context) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) for (sample_idx, chain_idx) in iters @@ -238,20 +280,24 @@ function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) setval!(vi, chain, sample_idx, chain_idx) # Execute model - model(vi, context) + model(vi, point_context) end niters = size(chain, 1) nchains = size(chain, 3) - loglikelihoods = OrderedDict( + logdensities = OrderedDict( varname => reshape(logliks, niters, nchains) for - (varname, logliks) in context.loglikelihoods + (varname, logliks) in point_context.logdensities ) - return loglikelihoods + return logdensities end -function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) - context = PointwiseLikelihoodContext(OrderedDict{VarName,Vector{Float64}}()) - model(varinfo, context) - return context.loglikelihoods +function pointwise_logdensities(model::Model, + varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext()) + point_context = PointwiseLogdensityContext( + OrderedDict{VarName,Vector{Float64}}(), context) + model(varinfo, point_context) + return point_context.logdensities end + + diff --git a/src/test_utils.jl b/src/test_utils.jl index 9db8c37fd..85dfa71f4 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -1078,12 +1078,13 @@ function DynamicPPL.dot_tilde_assume(context::TestLogModifyingChildContext, righ return value, logp*context.mod, vi end function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi) - value, logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) - return value, logp*context.mod, vi + # @info "called tilde_observe TestLogModifyingChildContext for left=$left, right=$right" + logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) + return logp*context.mod, vi end function DynamicPPL.dot_tilde_observe(context::TestLogModifyingChildContext, right, left, vi) - return DynamicPPL.dot_tilde_observe(context.context, right, left, vi) - return value, logp*context.mod, vi + logp, vi = DynamicPPL.dot_tilde_observe(context.context, right, left, vi) + return logp*context.mod, vi end diff --git a/test/contexts.jl b/test/contexts.jl index 11e2c99b7..4ec9ff945 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -8,7 +8,7 @@ using DynamicPPL: NodeTrait, IsLeaf, IsParent, - PointwiseLikelihoodContext, + PointwiseLogdensityContext, contextual_isassumption, ConditionContext, hasconditioned, @@ -67,7 +67,7 @@ end SamplingContext(), MiniBatchContext(DefaultContext(), 0.0), PrefixContext{:x}(DefaultContext()), - PointwiseLikelihoodContext(), + PointwiseLogdensityContext(), ConditionContext((x=1.0,)), ConditionContext((x=1.0,), ParentContext(ConditionContext((y=2.0,)))), ConditionContext((x=1.0,), PrefixContext{:a}(ConditionContext((var"a.y"=2.0,)))), diff --git a/test/deprecated.jl b/test/deprecated.jl new file mode 100644 index 000000000..24fb3a55e --- /dev/null +++ b/test/deprecated.jl @@ -0,0 +1,29 @@ +@testset "loglikelihoods.jl" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + example_values = DynamicPPL.TestUtils.rand_prior_true(m) + + # Instantiate a `VarInfo` with the example values. + vi = VarInfo(m) + for vn in DynamicPPL.TestUtils.varnames(m) + vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) + end + + # Compute the pointwise loglikelihoods. + lls = pointwise_loglikelihoods(m, vi) + loglikelihood = sum(sum, values(lls)) + + #if isempty(lls) + if loglikelihood ≈ 0.0 #isempty(lls) + # One of the models with literal observations, so we just skip. + # TODO: Think of better way to detect this special case + continue + end + + loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(m, example_values...) + + #priors = + + @test loglikelihood ≈ loglikelihood_true + end +end + diff --git a/test/loglikelihoods.jl b/test/pointwise_logdensitiesjl similarity index 65% rename from test/loglikelihoods.jl rename to test/pointwise_logdensitiesjl index 17185c4ad..5214bf5a1 100644 --- a/test/loglikelihoods.jl +++ b/test/pointwise_logdensitiesjl @@ -1,41 +1,16 @@ -@testset "loglikelihoods.jl" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - example_values = DynamicPPL.TestUtils.rand_prior_true(m) - - # Instantiate a `VarInfo` with the example values. - vi = VarInfo(m) - for vn in DynamicPPL.TestUtils.varnames(m) - vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) - end - - # Compute the pointwise loglikelihoods. - lls = pointwise_loglikelihoods(m, vi) - - if isempty(lls) - # One of the models with literal observations, so we just skip. - continue - end - - loglikelihood = sum(sum, values(lls)) - loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(m, example_values...) - - @test loglikelihood ≈ loglikelihood_true - end -end - -@testset "logpriors_var.jl" begin - mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2, PriorContext()) +@testset "logdensities_likelihoods.jl" begin + likelihood_context = LikelihoodContext() + prior_context = PriorContext() + mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2) mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) #m = DynamicPPL.TestUtils.DEMO_MODELS[1] - # m = DynamicPPL.TestUtils.demo_assume_index_observe() # logp at i-level? - # m = DynamicPPL.TestUtils.demo_assume_dot_observe() # failing test? @testset "$(m.f)" for (i, m) in enumerate(DynamicPPL.TestUtils.DEMO_MODELS) #@show i example_values = DynamicPPL.TestUtils.rand_prior_true(m) # Instantiate a `VarInfo` with the example values. vi = VarInfo(m) - () -> begin + () -> begin # when interactively debugging, need the global keyword for vn in DynamicPPL.TestUtils.varnames(m) global vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) end @@ -44,26 +19,45 @@ end vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) end - #chains = sample(m, SampleFromPrior(), 2; progress=false) + loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(m, example_values...) + logp_true = logprior(m, vi) # Compute the pointwise loglikelihoods. - logpriors = DynamicPPL.varwise_logpriors(m, vi) - logp1 = getlogp(vi) - logp = logprior(m, vi) - @test !isfinite(logp) || sum(x -> sum(x), values(logpriors)) ≈ logp - # - # test on modifying child-context - logpriors_mod = DynamicPPL.varwise_logpriors(m, vi, mod_ctx2) - logp1 = getlogp(vi) - #logp_mod = logprior(m, vi) # uses prior context but not mod_ctx2 - # Following line assumes no Likelihood contributions - # requires lowest Context to be PriorContext - @test !isfinite(logp1) || sum(x -> sum(x), values(logpriors_mod)) ≈ logp1 # - @test all(values(logpriors_mod) .≈ values(logpriors) .* 1.2 .* 1.4) + lls = pointwise_logdensities(m, vi, likelihood_context) + #lls2 = pointwise_loglikelihoods(m, vi) + loglikelihood = sum(sum, values(lls)) + if loglikelihood ≈ 0.0 #isempty(lls) + # One of the models with literal observations, so we just skip. + # TODO: Think of better way to detect this special case + loglikelihood_true = 0.0 + end + @test loglikelihood ≈ loglikelihood_true + + # Compute the pointwise logdensities of the priors. + lps_prior = pointwise_logdensities(m, vi, prior_context) + logp = sum(sum, values(lps_prior)) + if false # isempty(lps_prior) + # One of the models with only observations so we just skip. + else + logp1 = getlogp(vi) + @test !isfinite(logp_true) || logp ≈ logp_true + end + + # Compute both likelihood and logdensity of prior + # using the default DefaultContex + lps = pointwise_logdensities(m, vi) + logp = sum(sum, values(lps)) + @test logp ≈ (logp_true + loglikelihood_true) + + # Test that modifications of Setup are picked up + lps = pointwise_logdensities(m, vi, mod_ctx2) + logp = sum(sum, values(lps)) + @test logp ≈ (logp_true + loglikelihood_true) * 1.2 * 1.4 end -end; +end + -@testset "logpriors_var chain" begin +@testset "pointwise_logdensities chain" begin @model function demo(x, ::Type{TV}=Vector{Float64}) where {TV} s ~ InverseGamma(2, 3) m = TV(undef, length(x)) @@ -82,11 +76,13 @@ end; end arr0 = [5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 0.9199555480151707 -0.1304320097505629 1.0669120062696917 -0.05253734412139093; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183;;; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 2.5409470583244933 1.7838744695696407 0.7013562890105632 -3.0843947804314658; 0.8296370582311665 1.5360702767879642 -1.5964695255693102 0.16928084806166913; 2.6246697053824954 0.8096845024785173 -1.2621822861663752 1.1414885535466166; 1.1304261861894538 0.7325784741344005 -1.1866016911837542 -0.1639319562090826; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 0.9838526141898173 -0.20198797220982412 2.0569535882007006 -1.1560724118010939] chain = Chains(arr0, [:s, Symbol("m[1]"), Symbol("m[2]"), Symbol("m[3]")]); - tmp1 = varwise_logpriors(model, chain) - tmp = Chains(tmp1...); # can be used to create a Chains object + tmp1 = pointwise_logdensities(model, chain) vi = VarInfo(model) i_sample, i_chain = (1,2) DynamicPPL.setval!(vi, chain, i_sample, i_chain) - lp1 = DynamicPPL.varwise_logpriors(model, vi) - @test all(tmp1[1][i_sample,:,i_chain] .≈ values(lp1)) + lp1 = DynamicPPL.pointwise_logdensities(model, vi) + # k = first(keys(lp1)) + for k in keys(lp1) + @test tmp1[string(k)][i_sample,i_chain] .≈ lp1[k][1] + end end; \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index a321151ab..6de0cb7fe 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -57,9 +57,11 @@ include("test_util.jl") include("serialization.jl") - include("loglikelihoods.jl") + include("pointwise_logdensitiesjl") include("lkj.jl") + + include("deprecated.jl") end @testset "compat" begin From d54bfdcdc7590cb892a6a3fad2728b7c386997f8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 20 Sep 2024 18:45:51 +0100 Subject: [PATCH 07/14] Replaced `acc_logp!` in favour of something similar to the `_pointwise_tilde_observe` method --- src/pointwise_logdensities.jl | 76 +++++++++++++++++------------------ 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 73d3e5c0c..ec1ac011a 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -129,45 +129,44 @@ function _pointwise_tilde_observe( end end -function tilde_assume(context::PointwiseLogdensityContext, right, vn, vi) - #@info "PointwiseLogdensityContext tilde_assume!! called for $vn" +function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi) value, logp, vi = tilde_assume(context.context, right, vn, vi) - #sym = DynamicPPL.getsym(vn) - new_context = acc_logp!(context, vn, logp) - return value, logp, vi -end + # Track loglikelihood value. + push!(context, vn, logp) -function dot_tilde_assume(context::PointwiseLogdensityContext, right, left, vn, vi) - #@info "PointwiseLogdensityContext dot_tilde_assume!! called for $vn" - # @show vn, left, right, typeof(context).name - value, logp, vi = dot_tilde_assume(context.context, right, left, vn, vi) - new_context = acc_logp!(context, vn, logp) - return value, logp, vi + return value, acclogp!!(vi, logp) end -function acc_logp!(context::PointwiseLogdensityContext, vn::VarName, logp) - push!(context, vn, logp) - return (context) +function dot_tilde_assume!!(context::PointwiseLogdensityContext, right, left, vns, vi) + value, logps = _pointwise_tilde_assume(context, right, left, vns, vi) + # Track loglikelihood values. + for (vn, logp) in zip(vns, logps) + push!(context, vn, logp) + end + return value, acclogp!!(vi, sum(logps)) end -function acc_logp!(context::PointwiseLogdensityContext, vns::AbstractVector{<:VarName}, logp) - # construct a new VarName from given sequence of VarName - # assume that all items in vns have an IndexLens optic - indices = tuplejoin(map(vn -> getoptic(vn).indices, vns)...) - vn = VarName(first(vns), Accessors.IndexLens(indices)) - push!(context, vn, logp) - return (context) +function _pointwise_tilde_assume(context, right, left, vns, vi) + # We need to drop the `vi` returned. + values_and_logps = broadcast(right, left, vns) do r, l, vn + val, logp, _ = tilde_assume(context, r, vn, vi) + return val, logp + end + return map(first, values_and_logps), map(last, values_and_logps) end -#https://discourse.julialang.org/t/efficient-tuple-concatenation/5398/8 -@inline tuplejoin(x) = x -@inline tuplejoin(x, y) = (x..., y...) -@inline tuplejoin(x, y, z...) = (x..., tuplejoin(y, z...)...) - -() -> begin - # code that generates julia-repl in docstring below - # using DynamicPPL, Turing - # TODO when Turing version that is compatible with DynamicPPL 0.29 becomes available +function _pointwise_tilde_assume( + context, right::MultivariateDistribution, left::AbstractMatrix, vns, vi +) + # We need to drop the `vi` returned. + values_and_logps = map(eachcol(left), vns) do l, vn + val, logp, _ = tilde_assume(context, right, vn, vi) + return val, logp + end + # HACK(torfjelde): Due to the way we handle `.~`, we should use `recombine` to stay consistent. + # But this also means that we need to first flatten the entire `values` component before recombining. + values = recombine(right, mapreduce(vec ∘ first, vcat, values_and_logps), length(vns)) + return values, map(last, values_and_logps) end """ @@ -268,8 +267,9 @@ julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ``` """ -function pointwise_logdensities(model::Model, chain, - context::AbstractContext=DefaultContext(), keytype::Type{T}=String) where {T} +function pointwise_logdensities( + model::Model, chain, context::AbstractContext=DefaultContext(), keytype::Type{T}=String +) where {T} # Get the data by executing the model once vi = VarInfo(model) point_context = PointwiseLogdensityContext(OrderedDict{T,Vector{Float64}}(), context) @@ -292,12 +292,12 @@ function pointwise_logdensities(model::Model, chain, return logdensities end -function pointwise_logdensities(model::Model, - varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext()) +function pointwise_logdensities( + model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext() +) point_context = PointwiseLogdensityContext( - OrderedDict{VarName,Vector{Float64}}(), context) + OrderedDict{VarName,Vector{Float64}}(), context + ) model(varinfo, point_context) return point_context.logdensities end - - From f5032c45859f194b8083d4e41138d32017e1a6b8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 20 Sep 2024 18:49:12 +0100 Subject: [PATCH 08/14] Renamed `test/pointwise_logdensitiesjl` to `test/pointwise_logdensities.jl` --- test/{pointwise_logdensitiesjl => pointwise_logdensities.jl} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/{pointwise_logdensitiesjl => pointwise_logdensities.jl} (100%) diff --git a/test/pointwise_logdensitiesjl b/test/pointwise_logdensities.jl similarity index 100% rename from test/pointwise_logdensitiesjl rename to test/pointwise_logdensities.jl From 3d3a97e7025c945957d25ba336b8781b13cc0aa9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 20 Sep 2024 19:06:16 +0100 Subject: [PATCH 09/14] Removed deprecated --- src/DynamicPPL.jl | 1 - src/deprecated.jl | 9 --------- 2 files changed, 10 deletions(-) delete mode 100644 src/deprecated.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 9d870c9e8..14b66ee36 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -190,7 +190,6 @@ include("logdensityfunction.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") -include("deprecated.jl") include("debug_utils.jl") using .DebugUtils diff --git a/src/deprecated.jl b/src/deprecated.jl deleted file mode 100644 index 09e1ac84d..000000000 --- a/src/deprecated.jl +++ /dev/null @@ -1,9 +0,0 @@ -# https://invenia.github.io/blog/2022/06/17/deprecating-in-julia/ - -Base.@deprecate pointwise_loglikelihoods(model::Model, chain, keytype) pointwise_logdensities( - model::Model, LikelihoodContext(), chain, keytype) - -Base.@deprecate pointwise_loglikelihoods( - model::Model, varinfo::AbstractVarInfo) pointwise_logdensities( - model::Model, varinfo, LikelihoodContext()) - From 49ad8b0a24c9a2983d91a9994ad9e48b2ea27ea5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 20 Sep 2024 19:11:24 +0100 Subject: [PATCH 10/14] Added back `pointwise_loglikelihoods` and a new function `pointwise_prior_logdensities` + a mechanism to determine what we should include in the resulting dictionary based on the leaf context --- src/pointwise_logdensities.jl | 90 ++++++++++++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 2 deletions(-) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index ec1ac011a..37634979c 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -19,6 +19,13 @@ function setchildcontext(context::PointwiseLogdensityContext, child) return PointwiseLogdensityContext(context.logdensities, child) end +function _include_prior(context::PointwiseLogdensityContext) + return leafcontext(context) isa Union{PriorContext,DefaultContext} +end +function _include_likelihood(context::PointwiseLogdensityContext) + return leafcontext(context) isa Union{LikelihoodContext,DefaultContext} +end + function Base.push!( context::PointwiseLogdensityContext{<:AbstractDict{VarName,Vector{Float64}}}, vn::VarName, @@ -78,6 +85,11 @@ function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi) return tilde_observe!!(context.context, right, left, vi) end function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi) + # Completely defer to child context if we are not tracking likelihoods. + if !(_include_likelihood(context)) + return tilde_observe!!(context.context, right, left, vn, vi) + end + # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. # we have to intercept the call to `tilde_observe!`. logp, vi = tilde_observe(context.context, right, left, vi) @@ -93,6 +105,11 @@ function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, v return dot_tilde_observe!!(context.context, right, left, vi) end function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi) + # Completely defer to child context if we are not tracking likelihoods. + if !(_include_likelihood(context)) + return dot_tilde_observe!!(context.context, right, left, vn, vi) + end + # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. # we have to intercept the call to `dot_tilde_observe!`. @@ -130,6 +147,10 @@ function _pointwise_tilde_observe( end function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi) + # Completely defer to child context if we are not tracking prior densities. + _include_prior(context) || return tilde_assume!!(context.context, right, vn, vi) + + # Otherwise, capture the return values. value, logp, vi = tilde_assume(context.context, right, vn, vi) # Track loglikelihood value. push!(context, vn, logp) @@ -138,6 +159,11 @@ function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi) end function dot_tilde_assume!!(context::PointwiseLogdensityContext, right, left, vns, vi) + # Completely defer to child context if we are not tracking prior densities. + if !(_include_prior(context)) + return dot_tilde_assume!!(context.context, right, left, vns, vi) + end + value, logps = _pointwise_tilde_assume(context, right, left, vns, vi) # Track loglikelihood values. for (vn, logp) in zip(vns, logps) @@ -173,7 +199,7 @@ end pointwise_logdensities(model::Model, chain::Chains, keytype = String) Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` -with keys corresponding to symbols of the observations, and values being matrices +with keys corresponding to symbols of the variables, and values being matrices of shape `(num_chains, num_samples)`. `keytype` specifies what the type of the keys used in the returned `OrderedDict` are. @@ -268,7 +294,7 @@ julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], """ function pointwise_logdensities( - model::Model, chain, context::AbstractContext=DefaultContext(), keytype::Type{T}=String + model::Model, chain, context::AbstractContext=DefaultContext() ) where {T} # Get the data by executing the model once vi = VarInfo(model) @@ -301,3 +327,63 @@ function pointwise_logdensities( model(varinfo, point_context) return point_context.logdensities end + +""" + pointwise_loglikelihoods(model, chain[, context]) + +Compute the pointwise log-likelihoods of the model given the chain. + +This is the same as `pointwise_logdensities(model, chain, context)`, but only +including the likelihood terms. + +See also: [`pointwise_logdensities`](@ref). +""" +function pointwise_loglikelihoods( + model::Model, chain, context::AbstractContext=LikelihoodContext() +) where {T} + if !(leafcontext(context) isa LikelihoodContext) + throw(ArgumentError("Leaf context should be a LikelihoodContext")) + end + + return pointwise_logdensities(model, chain, context) +end + +function pointwise_loglikelihoods( + model::Model, varinfo::AbstractVarInfo, context::AbstractContext=LikelihoodContext() +) where {T} + if !(leafcontext(context) isa LikelihoodContext) + throw(ArgumentError("Leaf context should be a LikelihoodContext")) + end + + return pointwise_logdensities(model, chain, context) +end + +""" + pointwise_prior_logdensities(model, chain[, context]) + +Compute the pointwise log-prior-densities of the model given the chain. + +This is the same as `pointwise_logdensities(model, chain, context)`, but only +including the prior terms. + +See also: [`pointwise_logdensities`](@ref). +""" +function pointwise_prior_logdensities( + model::Model, chain, context::AbstractContext=PriorContext() +) where {T} + if !(leafcontext(context) isa PriorContext) + throw(ArgumentError("Leaf context should be a PriorContext")) + end + + return pointwise_logdensities(model, chain, context) +end + +function pointwise_prior_logdensities( + model::Model, varinfo::AbstractVarInfo, context::AbstractContext=PriorContext() +) where {T} + if !(leafcontext(context) isa PriorContext) + throw(ArgumentError("Leaf context should be a PriorContext")) + end + + return pointwise_logdensities(model, chain, context) +end From 85913757466ee88e8b19f7ccffa8c2fbb1902549 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 20 Sep 2024 19:14:11 +0100 Subject: [PATCH 11/14] Accidentally removed the `keytype` argument --- src/pointwise_logdensities.jl | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 37634979c..167299064 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -294,7 +294,7 @@ julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], """ function pointwise_logdensities( - model::Model, chain, context::AbstractContext=DefaultContext() + model::Model, chain, context::AbstractContext=DefaultContext(), keytype::Type{T}=String ) where {T} # Get the data by executing the model once vi = VarInfo(model) @@ -339,23 +339,26 @@ including the likelihood terms. See also: [`pointwise_logdensities`](@ref). """ function pointwise_loglikelihoods( - model::Model, chain, context::AbstractContext=LikelihoodContext() + model::Model, + chain, + context::AbstractContext=LikelihoodContext(), + keytype::Type{T}=String, ) where {T} if !(leafcontext(context) isa LikelihoodContext) throw(ArgumentError("Leaf context should be a LikelihoodContext")) end - return pointwise_logdensities(model, chain, context) + return pointwise_logdensities(model, chain, context, keytype) end function pointwise_loglikelihoods( model::Model, varinfo::AbstractVarInfo, context::AbstractContext=LikelihoodContext() -) where {T} +) if !(leafcontext(context) isa LikelihoodContext) throw(ArgumentError("Leaf context should be a LikelihoodContext")) end - return pointwise_logdensities(model, chain, context) + return pointwise_logdensities(model, varinfo, context) end """ @@ -369,21 +372,21 @@ including the prior terms. See also: [`pointwise_logdensities`](@ref). """ function pointwise_prior_logdensities( - model::Model, chain, context::AbstractContext=PriorContext() + model::Model, chain, context::AbstractContext=PriorContext(), keytype::Type{T}=String ) where {T} if !(leafcontext(context) isa PriorContext) throw(ArgumentError("Leaf context should be a PriorContext")) end - return pointwise_logdensities(model, chain, context) + return pointwise_logdensities(model, chain, context, keytype) end function pointwise_prior_logdensities( model::Model, varinfo::AbstractVarInfo, context::AbstractContext=PriorContext() -) where {T} +) if !(leafcontext(context) isa PriorContext) throw(ArgumentError("Leaf context should be a PriorContext")) end - return pointwise_logdensities(model, chain, context) + return pointwise_logdensities(model, varinfo, context) end From b8d033ab6ab1bd8a3ebd039e3c4f34b4fa07e5f7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 20 Sep 2024 19:27:29 +0100 Subject: [PATCH 12/14] Fixed keytype argument --- src/pointwise_logdensities.jl | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 167299064..3ab1101e3 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -196,14 +196,19 @@ function _pointwise_tilde_assume( end """ - pointwise_logdensities(model::Model, chain::Chains, keytype = String) + pointwise_logdensities(model::Model, chain::Chains[, keytype::Type, context::AbstractContext]) Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` with keys corresponding to symbols of the variables, and values being matrices of shape `(num_chains, num_samples)`. -`keytype` specifies what the type of the keys used in the returned `OrderedDict` are. -Currently, only `String` and `VarName` are supported. +# Arguments +- `model`: the `Model` to run. +- `chain`: the `Chains` to run the model on. +- `keytype`: the type of the keys used in the returned `OrderedDict` are. + Currently, only `String` and `VarName` are supported. +- `context`: the context to use when running the model. Default: `DefaultContext`. + The [`leafcontext`](@ref) is used to decide which variables to include. # Notes Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ` @@ -294,7 +299,7 @@ julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], """ function pointwise_logdensities( - model::Model, chain, context::AbstractContext=DefaultContext(), keytype::Type{T}=String + model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext() ) where {T} # Get the data by executing the model once vi = VarInfo(model) @@ -329,7 +334,7 @@ function pointwise_logdensities( end """ - pointwise_loglikelihoods(model, chain[, context]) + pointwise_loglikelihoods(model, chain[, keytype, context]) Compute the pointwise log-likelihoods of the model given the chain. @@ -341,14 +346,14 @@ See also: [`pointwise_logdensities`](@ref). function pointwise_loglikelihoods( model::Model, chain, - context::AbstractContext=LikelihoodContext(), keytype::Type{T}=String, + context::AbstractContext=LikelihoodContext(), ) where {T} if !(leafcontext(context) isa LikelihoodContext) throw(ArgumentError("Leaf context should be a LikelihoodContext")) end - return pointwise_logdensities(model, chain, context, keytype) + return pointwise_logdensities(model, chain, T, context) end function pointwise_loglikelihoods( @@ -362,7 +367,7 @@ function pointwise_loglikelihoods( end """ - pointwise_prior_logdensities(model, chain[, context]) + pointwise_prior_logdensities(model, chain[, keytype, context]) Compute the pointwise log-prior-densities of the model given the chain. @@ -372,13 +377,13 @@ including the prior terms. See also: [`pointwise_logdensities`](@ref). """ function pointwise_prior_logdensities( - model::Model, chain, context::AbstractContext=PriorContext(), keytype::Type{T}=String + model::Model, chain, keytype::Type{T}=String, context::AbstractContext=PriorContext() ) where {T} if !(leafcontext(context) isa PriorContext) throw(ArgumentError("Leaf context should be a PriorContext")) end - return pointwise_logdensities(model, chain, context, keytype) + return pointwise_logdensities(model, chain, T, context) end function pointwise_prior_logdensities( From 91a69079d5f28aaf9c57d657c9e338d94c10d834 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 20 Sep 2024 19:27:39 +0100 Subject: [PATCH 13/14] Reverted the introdces to the testing of `pointwise_loglikelihoods` (now in `test/deprecated.jl`) --- test/deprecated.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/deprecated.jl b/test/deprecated.jl index 24fb3a55e..f5c400691 100644 --- a/test/deprecated.jl +++ b/test/deprecated.jl @@ -10,15 +10,14 @@ # Compute the pointwise loglikelihoods. lls = pointwise_loglikelihoods(m, vi) - loglikelihood = sum(sum, values(lls)) - #if isempty(lls) - if loglikelihood ≈ 0.0 #isempty(lls) + if isempty(lls) # One of the models with literal observations, so we just skip. # TODO: Think of better way to detect this special case continue end + loglikelihood = sum(sum, values(lls)) loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(m, example_values...) #priors = @@ -26,4 +25,3 @@ @test loglikelihood ≈ loglikelihood_true end end - From 8bd2085098208fc58d1e33bbe48ec56e7efcd691 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 23 Sep 2024 10:49:49 +0100 Subject: [PATCH 14/14] Added comment on a potential issue with `_pointwise_tilde_assume` --- src/pointwise_logdensities.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 3ab1101e3..0b2435c16 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -175,6 +175,11 @@ end function _pointwise_tilde_assume(context, right, left, vns, vi) # We need to drop the `vi` returned. values_and_logps = broadcast(right, left, vns) do r, l, vn + # HACK(torfjelde): This drops the `vi` returned, which means the `vi` is not updated + # in case of immutable varinfos. But a) atm we're only using mutable varinfos for this, + # and b) even if the variables aren't stored in the vi correctly, we're not going to use + # this vi for anything downstream anyways, i.e. I don't see a case where this would matter + # for this particular use case. val, logp, _ = tilde_assume(context, r, vn, vi) return val, logp end