Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pointpriors #663

Merged
merged 39 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
d05124c
implement pointwise_logpriors
bgctw Sep 16, 2024
4f46102
implement varwise_logpriors
bgctw Sep 17, 2024
c6653b9
remove pointwise_logpriors
bgctw Sep 17, 2024
216d50c
revert dot_assume to not explicitly resolve components of sum
bgctw Sep 17, 2024
fd8d3b2
docstring varwise_logpriores
bgctw Sep 18, 2024
5842656
integrate pointwise_loglikelihoods and varwise_logpriors by pointwise…
bgctw Sep 19, 2024
18beb57
record single prior components
bgctw Sep 21, 2024
d9945d7
forward dot_tilde_assume to tilde_assume for Multivariate
bgctw Sep 22, 2024
656a757
avoid recording prior components on leaf-prior-context
bgctw Sep 24, 2024
7aa9ebe
undeprecate pointwise_loglikelihoods and implement pointwise_prior_lo…
bgctw Sep 24, 2024
2f67c5b
drop vi instead of re-compute vi
bgctw Sep 24, 2024
9dfb9ed
include docstrings of pointwise_logdensities
bgctw Sep 24, 2024
c1939e0
Update src/pointwise_logdensities.jl remove commented code
bgctw Sep 25, 2024
790be1d
Update src/pointwise_logdensities.jl remove commented code
bgctw Sep 25, 2024
426df38
Update test/pointwise_logdensities.jl rename m to model
bgctw Sep 25, 2024
c32bf3b
Update test/pointwise_logdensities.jl remove unused code
bgctw Sep 25, 2024
6213249
Update test/pointwise_logdensities.jl rename m to model
bgctw Sep 25, 2024
3551b38
Update test/pointwise_logdensities.jl rename m to model
bgctw Sep 25, 2024
95c892b
Update src/test_utils.jl remove old code
bgctw Sep 25, 2024
a7a7e70
rename m to model
bgctw Sep 25, 2024
1653aba
JuliaFormatter
bgctw Sep 25, 2024
e4f0a1d
Merge branch 'pointpriors' of github.com:bgctw/DynamicPPL.jl into poi…
bgctw Sep 25, 2024
a99eab4
Update test/runtests.jl remove interactive code
bgctw Sep 26, 2024
64ce63a
remove demo_dot_assume_matrix_dot_observe_matrix2 testcase
bgctw Sep 26, 2024
456115c
ignore local interactive development code
bgctw Sep 26, 2024
222529a
ignore temporary directory holding local interactive development code
bgctw Sep 26, 2024
17b251a
Apply suggestions from code review: clean up comments and Imports
bgctw Sep 26, 2024
7e990f0
Apply suggestions from code review: change test of applying to chains…
bgctw Sep 26, 2024
8706f68
fix test on names in likelihood components
bgctw Sep 26, 2024
073a325
try to fix testset pointwise_logdensities chain
bgctw Sep 26, 2024
23e1711
Update test/pointwise_logdensities.jl
torfjelde Sep 26, 2024
34ae4f8
Update .gitignore
torfjelde Sep 26, 2024
1f251d1
Merge branch 'master' into pointpriors
torfjelde Sep 26, 2024
777624a
Formtating
torfjelde Sep 26, 2024
4864e60
Fixed tests
torfjelde Sep 26, 2024
4d3b0c0
Updated docs for `pointwise_logdensities` + made it a doctest not
torfjelde Sep 27, 2024
e54fa4e
Bump patch version
torfjelde Sep 27, 2024
bcd82a9
Remove blank line from `@model` in doctest to see if that fixes the
torfjelde Sep 27, 2024
cff0941
Added doctest filter to handle the `;;]` at the end of lines for matr…
torfjelde Sep 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,17 @@ For a chain of samples, one can compute the pointwise log-likelihoods of each ob
pointwise_loglikelihoods
```

Similarly, one can compute the pointwise log-priors of each sampled random variable
with [`varwise_logpriors`](@ref).
Differently from `pointwise_loglikelihoods` it reports only a
single value for `.~` assignements.
If one needs to access the parts for single indices, one can
reformulate the model to use an explicit loop instead.

```@docs
varwise_logpriors
```

For converting a chain into a format that can more easily be fed into a `Model` again, for example using `condition`, you can use [`value_iterator_from_chain`](@ref).

```@docs
Expand Down
2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ export AbstractVarInfo,
logprior,
logjoint,
pointwise_loglikelihoods,
varwise_logpriors,
condition,
decondition,
fix,
Expand Down Expand Up @@ -182,6 +183,7 @@ include("simple_varinfo.jl")
include("context_implementations.jl")
include("compiler.jl")
include("loglikelihoods.jl")
include("logpriors_var.jl")
include("submodel_macro.jl")
include("test_utils.jl")
include("transforming.jl")
Expand Down
25 changes: 21 additions & 4 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,15 @@ function dot_assume(
var::AbstractMatrix,
vns::AbstractVector{<:VarName},
vi::AbstractVarInfo,
)
r, lp, vi = dot_assume_vec(dist, var, vns, vi)
return r, sum(lp), vi
end
function dot_assume_vec(
bgctw marked this conversation as resolved.
Show resolved Hide resolved
dist::MultivariateDistribution,
var::AbstractMatrix,
vns::AbstractVector{<:VarName},
vi::AbstractVarInfo,
)
@assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))"
# NOTE: We cannot work with `var` here because we might have a model of the form
Expand All @@ -434,7 +443,7 @@ function dot_assume(
#
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
r = vi[vns, dist]
lp = sum(zip(vns, eachcol(r))) do (vn, ri)
lp = map(zip(vns, eachcol(r))) do (vn, ri)
return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn))
end
return r, lp, vi
Expand All @@ -455,21 +464,29 @@ function dot_assume(
end

function dot_assume(
dist::Union{Distribution,AbstractArray{<:Distribution}}, var::AbstractArray, vns::AbstractArray{<:VarName}, vi
)
# possibility to acesss the single logpriors
r, lp, vi = dot_assume_vec(dist, var, vns, vi)
return r, sum(lp), vi
end

function dot_assume_vec(
dist::Distribution, var::AbstractArray, vns::AbstractArray{<:VarName}, vi
)
r = getindex.((vi,), vns, (dist,))
lp = sum(Bijectors.logpdf_with_trans.((dist,), r, istrans.((vi,), vns)))
lp = Bijectors.logpdf_with_trans.((dist,), r, istrans.((vi,), vns))
return r, lp, vi
end

function dot_assume(
function dot_assume_vec(
dists::AbstractArray{<:Distribution},
var::AbstractArray,
vns::AbstractArray{<:VarName},
vi,
)
r = getindex.((vi,), vns, dists)
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns)))
lp = Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))
return r, lp, vi
end

Expand Down
201 changes: 201 additions & 0 deletions src/logpriors_var.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
"""
Context that records logp after tilde_assume!!
for each VarName used by [`varwise_logpriors`](@ref).
"""
struct VarwisePriorContext{A,Ctx} <: AbstractContext
logpriors::A
context::Ctx
end

function VarwisePriorContext(
logpriors=OrderedDict{Symbol,Float64}(),
context::AbstractContext=DynamicPPL.PriorContext(),
#OrderedDict{Symbol,Vector{Float64}}(),PriorContext()),
)
return VarwisePriorContext{typeof(logpriors),typeof(context)}(
logpriors, context
)
end

NodeTrait(::VarwisePriorContext) = IsParent()
childcontext(context::VarwisePriorContext) = context.context
function setchildcontext(context::VarwisePriorContext, child)
return VarwisePriorContext(context.logpriors, child)
end

function tilde_assume(context::VarwisePriorContext, right, vn, vi)
#@info "VarwisePriorContext tilde_assume!! called for $vn"
value, logp, vi = tilde_assume(context.context, right, vn, vi)
#sym = DynamicPPL.getsym(vn)
new_context = acc_logp!(context, vn, logp)
return value, logp, vi
end

function dot_tilde_assume(context::VarwisePriorContext, right, left, vn, vi)
#@info "VarwisePriorContext dot_tilde_assume!! called for $vn"
# @show vn, left, right, typeof(context).name
value, logp, vi = dot_tilde_assume(context.context, right, left, vn, vi)
new_context = acc_logp!(context, vn, logp)
return value, logp, vi
end


function tilde_observe(context::VarwisePriorContext, right, left, vi)
# Since we are evaluating the prior, the log probability of all the observations
# is set to 0. This has the effect of ignoring the likelihood.
return 0.0, vi
#tmp = tilde_observe(context.context, SampleFromPrior(), right, left, vi)
#return tmp
end

function acc_logp!(context::VarwisePriorContext, vn::Union{VarName,AbstractVector{<:VarName}}, logp)
#sym = DynamicPPL.getsym(vn) # leads to duplicates
# if vn is a Vector leads to Symbol("VarName{:s, IndexLens{Tuple{Int64}}}[s[1], s[2]]")
sym = Symbol(vn)
context.logpriors[sym] = logp
return (context)
end


# """
# pointwise_logpriors(model::Model, chain::Chains, keytype = String)

# Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}`
# with keys corresponding to symbols of the observations, and values being matrices
# of shape `(num_chains, num_samples)`.

# `keytype` specifies what the type of the keys used in the returned `OrderedDict` are.
# Currently, only `String` and `VarName` are supported.

# # Notes
# Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ`
# both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an
# *observation*) statements can be implemented in three ways:
# 1. using a `for` loop:
# ```julia
# for i in eachindex(y)
# y[i] ~ Normal(μ, σ)
# end
# ```
# 2. using `.~`:
# ```julia
# y .~ Normal(μ, σ)
# ```
# 3. using `MvNormal`:
# ```julia
# y ~ MvNormal(fill(μ, n), σ^2 * I)
# ```

# In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables,
# while in (3) `y` will be treated as a _single_ n-dimensional observation.

# This is important to keep in mind, in particular if the computation is used
# for downstream computations.

# # Examples
# ## From chain
# ```julia-repl
# julia> using DynamicPPL, Turing

# julia> @model function demo(xs, y)
# s ~ InverseGamma(2, 3)
# m ~ Normal(0, √s)
# for i in eachindex(xs)
# xs[i] ~ Normal(m, √s)
# end

# y ~ Normal(m, √s)
# end
# demo (generic function with 1 method)

# julia> model = demo(randn(3), randn());

# julia> chain = sample(model, MH(), 10);

# julia> pointwise_logpriors(model, chain)
# OrderedDict{String,Array{Float64,2}} with 4 entries:
# "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
# "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
# "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
# "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]

# julia> pointwise_logpriors(model, chain, String)
# OrderedDict{String,Array{Float64,2}} with 4 entries:
# "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
# "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
# "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
# "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]

# julia> pointwise_logpriors(model, chain, VarName)
# OrderedDict{VarName,Array{Float64,2}} with 4 entries:
# xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
# xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
# xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
# y => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
# ```

# ## Broadcasting
# Note that `x .~ Dist()` will treat `x` as a collection of
# _independent_ observations rather than as a single observation.

# ```jldoctest; setup = :(using Distributions)
# julia> @model function demo(x)
# x .~ Normal()
# end;

# julia> m = demo([1.0, ]);

# julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first(ℓ[@varname(x[1])])
# -1.4189385332046727

# julia> m = demo([1.0; 1.0]);

# julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])]))
# (-1.4189385332046727, -1.4189385332046727)
# ```

# """
function varwise_logpriors(
model::Model, varinfo::AbstractVarInfo,
context::AbstractContext=PriorContext()
)
# top_context = VarwisePriorContext(OrderedDict{Symbol,Float64}(), context)
top_context = VarwisePriorContext(OrderedDict{Symbol,Float64}(), context)
model(varinfo, top_context)
return top_context.logpriors
end

function varwise_logpriors(model::Model, chain::AbstractChains,
context::AbstractContext=PriorContext();
top_context::VarwisePriorContext{T} = VarwisePriorContext(OrderedDict{Symbol,Float64}(), context)
) where T
# pass top-context as keyword to allow adapt Number type of log-prior
get_values = (vi) -> begin
model(vi, top_context)
values(top_context.logpriors)
end
arr = map_model(get_values, model, chain)
par_names = collect(keys(top_context.logpriors))
return(arr, par_names)
end

function map_model(get_values, model::Model, chain::AbstractChains)
niters = size(chain, 1)
nchains = size(chain, 3)
vi = VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
# initialize the array by the first result
(sample_idx, chain_idx), iters2 = Iterators.peel(iters)
setval!(vi, chain, sample_idx, chain_idx)
values1 = get_values(vi)
arr = Array{eltype(values1)}(undef, niters, length(values1), nchains)
arr[sample_idx, :, chain_idx] .= values1
#(sample_idx, chain_idx), iters3 = Iterators.peel(iters2)
for (sample_idx, chain_idx) in iters2
# Update the values
setval!(vi, chain, sample_idx, chain_idx)
values_i = get_values(vi)
arr[sample_idx, :, chain_idx] .= values_i
end
return(arr)
end
33 changes: 33 additions & 0 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1042,4 +1042,37 @@ function test_context_interface(context)
end
end

"""
Context that multiplies each log-prior by mod
used to test whether varwise_logpriors respects child-context.
"""
struct TestLogModifyingChildContext{T,Ctx} <: DynamicPPL.AbstractContext
mod::T
context::Ctx
end
function TestLogModifyingChildContext(
mod=1.2,
context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext(),
#OrderedDict{VarName,Vector{Float64}}(),PriorContext()),
bgctw marked this conversation as resolved.
Show resolved Hide resolved
)
return TestLogModifyingChildContext{typeof(mod),typeof(context)}(
mod, context
)
end
DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent()
DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context
function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child)
return TestLogModifyingChildContext(context.mod, child)
end
function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, vn, vi)
#@info "TestLogModifyingChildContext tilde_assume!! called for $vn"
bgctw marked this conversation as resolved.
Show resolved Hide resolved
value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi)
return value, logp*context.mod, vi
end
function DynamicPPL.dot_tilde_assume(context::TestLogModifyingChildContext, right, left, vn, vi)
#@info "TestLogModifyingChildContext dot_tilde_assume!! called for $vn"
bgctw marked this conversation as resolved.
Show resolved Hide resolved
value, logp, vi = DynamicPPL.dot_tilde_assume(context.context, right, left, vn, vi)
return value, logp*context.mod, vi
end

end
3 changes: 2 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
bgctw marked this conversation as resolved.
Show resolved Hide resolved
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand All @@ -26,10 +27,10 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Accessors = "0.1"
ADTypes = "0.2, 1"
AbstractMCMC = "5"
AbstractPPL = "0.8.2"
Accessors = "0.1"
Bijectors = "0.13"
Compat = "4.3.0"
Distributions = "0.25"
Expand Down
Loading