Skip to content

Commit

Permalink
docstring varwise_logpriores
Browse files Browse the repository at this point in the history
use loop for prior in example

Unfortunately cannot make it a jldoctest, because relies on Turing for sampling
  • Loading branch information
bgctw committed Sep 18, 2024
1 parent 216d50c commit fd8d3b2
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 122 deletions.
154 changes: 48 additions & 106 deletions src/logpriors_var.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,8 @@ function dot_tilde_assume(context::VarwisePriorContext, right, left, vn, vi)
end


function tilde_observe(context::VarwisePriorContext, right, left, vi)
# Since we are evaluating the prior, the log probability of all the observations
# is set to 0. This has the effect of ignoring the likelihood.
return 0.0, vi
#tmp = tilde_observe(context.context, SampleFromPrior(), right, left, vi)
#return tmp
end
tilde_observe(context::VarwisePriorContext, right, left, vi) = 0, vi
dot_tilde_observe(::VarwisePriorContext, right, left, vi) = 0, vi

function acc_logp!(context::VarwisePriorContext, vn::Union{VarName,AbstractVector{<:VarName}}, logp)
#sym = DynamicPPL.getsym(vn) # leads to duplicates
Expand All @@ -56,105 +51,52 @@ function acc_logp!(context::VarwisePriorContext, vn::Union{VarName,AbstractVecto
return (context)
end


# """
# pointwise_logpriors(model::Model, chain::Chains, keytype = String)

# Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}`
# with keys corresponding to symbols of the observations, and values being matrices
# of shape `(num_chains, num_samples)`.

# `keytype` specifies what the type of the keys used in the returned `OrderedDict` are.
# Currently, only `String` and `VarName` are supported.

# # Notes
# Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ`
# both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an
# *observation*) statements can be implemented in three ways:
# 1. using a `for` loop:
# ```julia
# for i in eachindex(y)
# y[i] ~ Normal(μ, σ)
# end
# ```
# 2. using `.~`:
# ```julia
# y .~ Normal(μ, σ)
# ```
# 3. using `MvNormal`:
# ```julia
# y ~ MvNormal(fill(μ, n), σ^2 * I)
# ```

# In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables,
# while in (3) `y` will be treated as a _single_ n-dimensional observation.

# This is important to keep in mind, in particular if the computation is used
# for downstream computations.

# # Examples
# ## From chain
# ```julia-repl
# julia> using DynamicPPL, Turing

# julia> @model function demo(xs, y)
# s ~ InverseGamma(2, 3)
# m ~ Normal(0, √s)
# for i in eachindex(xs)
# xs[i] ~ Normal(m, √s)
# end

# y ~ Normal(m, √s)
# end
# demo (generic function with 1 method)

# julia> model = demo(randn(3), randn());

# julia> chain = sample(model, MH(), 10);

# julia> pointwise_logpriors(model, chain)
# OrderedDict{String,Array{Float64,2}} with 4 entries:
# "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
# "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
# "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
# "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]

# julia> pointwise_logpriors(model, chain, String)
# OrderedDict{String,Array{Float64,2}} with 4 entries:
# "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
# "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
# "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
# "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]

# julia> pointwise_logpriors(model, chain, VarName)
# OrderedDict{VarName,Array{Float64,2}} with 4 entries:
# xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
# xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
# xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
# y => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
# ```

# ## Broadcasting
# Note that `x .~ Dist()` will treat `x` as a collection of
# _independent_ observations rather than as a single observation.

# ```jldoctest; setup = :(using Distributions)
# julia> @model function demo(x)
# x .~ Normal()
# end;

# julia> m = demo([1.0, ]);

# julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first(ℓ[@varname(x[1])])
# -1.4189385332046727

# julia> m = demo([1.0; 1.0]);

# julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])]))
# (-1.4189385332046727, -1.4189385332046727)
# ```

# """
"""
varwise_logpriors(model::Model, chain::Chains; context)
Runs `model` on each sample in `chain` returning a tuple `(values, var_names)`
with var_names corresponding to symbols of the prior components, and values being
array of shape `(num_samples, num_components, num_chains)`.
`context` specifies child context that handles computation of log-priors.
# Example
```julia; setup = :(using Distributions)
using DynamicPPL, Turing
@model function demo(x, ::Type{TV}=Vector{Float64}) where {TV}
s ~ InverseGamma(2, 3)
m = TV(undef, length(x))
for i in eachindex(x)
m[i] ~ Normal(0, √s)
end
x ~ MvNormal(m, √s)
end
model = demo(randn(3), randn());
chain = sample(model, MH(), 10);
lp = varwise_logpriors(model, chain)
# Can be used to construct a new Chains object
#lpc = MCMCChains(varwise_logpriors(model, chain)...)
# got a logdensity for each parameter prior
(but fewer if used `.~` assignments, see below)
lp[2] == names(chain, :parameters)
# for each sample in the Chains object
size(lp[1])[[1,3]] == size(chain)[[1,3]]
```
# Broadcasting
Note that `m .~ Dist()` will treat `m` as a collection of
_independent_ prior rather than as a single prior,
but `varwise_logpriors` returns the single
sum of log-likelihood of components of `m` only.
If one needs the log-density of the components, one needs to rewrite
the model with an explicit loop.
"""
function varwise_logpriors(
model::Model, varinfo::AbstractVarInfo,
context::AbstractContext=PriorContext()
Expand Down
14 changes: 13 additions & 1 deletion src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1059,7 +1059,10 @@ function TestLogModifyingChildContext(
mod, context
)
end
DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent()
# Samplers call leafcontext(model.context) when evaluating log-densities
# Hence, in order to be used need to say that its a leaf-context
#DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent()
DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsLeaf()
DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context
function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child)
return TestLogModifyingChildContext(context.mod, child)
Expand All @@ -1074,5 +1077,14 @@ function DynamicPPL.dot_tilde_assume(context::TestLogModifyingChildContext, righ
value, logp, vi = DynamicPPL.dot_tilde_assume(context.context, right, left, vn, vi)
return value, logp*context.mod, vi
end
function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi)
value, logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi)
return value, logp*context.mod, vi
end
function DynamicPPL.dot_tilde_observe(context::TestLogModifyingChildContext, right, left, vi)
return DynamicPPL.dot_tilde_observe(context.context, right, left, vi)
return value, logp*context.mod, vi
end


end
31 changes: 16 additions & 15 deletions test/loglikelihoods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ end
mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx)
#m = DynamicPPL.TestUtils.DEMO_MODELS[1]
# m = DynamicPPL.TestUtils.demo_assume_index_observe() # logp at i-level?
# m = DynamicPPL.TestUtils.demo_assume_dot_observe() # failing test?
@testset "$(m.f)" for (i, m) in enumerate(DynamicPPL.TestUtils.DEMO_MODELS)
#@show i
example_values = DynamicPPL.TestUtils.rand_prior_true(m)
Expand All @@ -53,7 +54,8 @@ end
#
# test on modifying child-context
logpriors_mod = DynamicPPL.varwise_logpriors(m, vi, mod_ctx2)
logp1 = getlogp(vi)
logp1 = getlogp(vi)
#logp_mod = logprior(m, vi) # uses prior context but not mod_ctx2
# Following line assumes no Likelihood contributions
# requires lowest Context to be PriorContext
@test !isfinite(logp1) || sum(x -> sum(x), values(logpriors_mod)) logp1 #
Expand All @@ -62,25 +64,24 @@ end
end;

@testset "logpriors_var chain" begin
@model function demo(xs, y)
@model function demo(x, ::Type{TV}=Vector{Float64}) where {TV}
s ~ InverseGamma(2, 3)
m ~ Normal(0, s)
for i in eachindex(xs)
xs[i] ~ Normal(m, s)
m = TV(undef, length(x))
for i in eachindex(x)
m[i] ~ Normal(0, s)
end
y ~ Normal(m, s)
end
xs_true, y_true = ([0.3290767977680923, 0.038972110187911684, -0.5797496780649221], -0.7321425592768186)#randn(3), randn()
model = demo(xs_true, y_true)
x ~ MvNormal(m, s)
end
x_true = [0.3290767977680923, 0.038972110187911684, -0.5797496780649221]
model = demo(x_true)
() -> begin
# generate the sample used below
chain = sample(model, MH(), 10)
arr0 = Array(chain)
chain = sample(model, MH(), MCMCThreads(), 10, 2)
arr0 = stack(Array(chain, append_chains=false))
@show(arr0);
end
arr0 = [1.8585322626573435 -0.05900855284939967; 1.7304068220366808 -0.6386249100228161; 1.7304068220366808 -0.6386249100228161; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039]; # generated in function above
# split into two chains for testing
arr1 = permutedims(reshape(arr0, 5,2,:),(1,3,2))
chain = Chains(arr1, [:s, :m]);
arr0 = [5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 0.9199555480151707 -0.1304320097505629 1.0669120062696917 -0.05253734412139093; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183;;; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 2.5409470583244933 1.7838744695696407 0.7013562890105632 -3.0843947804314658; 0.8296370582311665 1.5360702767879642 -1.5964695255693102 0.16928084806166913; 2.6246697053824954 0.8096845024785173 -1.2621822861663752 1.1414885535466166; 1.1304261861894538 0.7325784741344005 -1.1866016911837542 -0.1639319562090826; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 0.9838526141898173 -0.20198797220982412 2.0569535882007006 -1.1560724118010939]
chain = Chains(arr0, [:s, Symbol("m[1]"), Symbol("m[2]"), Symbol("m[3]")]);
tmp1 = varwise_logpriors(model, chain)
tmp = Chains(tmp1...); # can be used to create a Chains object
vi = VarInfo(model)
Expand Down

0 comments on commit fd8d3b2

Please sign in to comment.