-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Suggestions for pointwise_logdensities
and siblings
#669
base: master
Are you sure you want to change the base?
Conversation
use loop for prior in example Unfortunately cannot make it a jldoctest, because relies on Turing for sampling
`_pointwise_tilde_observe` method
`pointwise_prior_logdensities` + a mechanism to determine what we should include in the resulting dictionary based on the leaf context
`pointwise_loglikelihoods` (now in `test/deprecated.jl`)
Similarly, one can compute the pointwise log-priors of each sampled random variable | ||
with [`varwise_logpriors`](@ref). | ||
Differently from `pointwise_loglikelihoods` it reports only a | ||
single value for `.~` assignements. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
single value for `.~` assignements. | |
single value for `.~` assignements. |
function TestLogModifyingChildContext( | ||
mod=1.2, | ||
context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext(), | ||
#OrderedDict{VarName,Vector{Float64}}(),PriorContext()), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
#OrderedDict{VarName,Vector{Float64}}(),PriorContext()), | |
#OrderedDict{VarName,Vector{Float64}}(),PriorContext()), |
return TestLogModifyingChildContext{typeof(mod),typeof(context)}( | ||
mod, context | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
return TestLogModifyingChildContext{typeof(mod),typeof(context)}( | |
mod, context | |
) | |
return TestLogModifyingChildContext{typeof(mod),typeof(context)}(mod, context) |
function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, vn, vi) | ||
#@info "TestLogModifyingChildContext tilde_assume!! called for $vn" | ||
value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) | ||
return value, logp*context.mod, vi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
return value, logp*context.mod, vi | |
return value, logp * context.mod, vi |
value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) | ||
return value, logp*context.mod, vi | ||
end | ||
function DynamicPPL.dot_tilde_assume(context::TestLogModifyingChildContext, right, left, vn, vi) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
function DynamicPPL.dot_tilde_assume(context::TestLogModifyingChildContext, right, left, vn, vi) | |
function DynamicPPL.dot_tilde_assume( | |
context::TestLogModifyingChildContext, right, left, vn, vi | |
) |
arr0 = [5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 0.9199555480151707 -0.1304320097505629 1.0669120062696917 -0.05253734412139093; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183;;; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 2.5409470583244933 1.7838744695696407 0.7013562890105632 -3.0843947804314658; 0.8296370582311665 1.5360702767879642 -1.5964695255693102 0.16928084806166913; 2.6246697053824954 0.8096845024785173 -1.2621822861663752 1.1414885535466166; 1.1304261861894538 0.7325784741344005 -1.1866016911837542 -0.1639319562090826; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 0.9838526141898173 -0.20198797220982412 2.0569535882007006 -1.1560724118010939] | ||
chain = Chains(arr0, [:s, Symbol("m[1]"), Symbol("m[2]"), Symbol("m[3]")]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
arr0 = [5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 0.9199555480151707 -0.1304320097505629 1.0669120062696917 -0.05253734412139093; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183;;; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 2.5409470583244933 1.7838744695696407 0.7013562890105632 -3.0843947804314658; 0.8296370582311665 1.5360702767879642 -1.5964695255693102 0.16928084806166913; 2.6246697053824954 0.8096845024785173 -1.2621822861663752 1.1414885535466166; 1.1304261861894538 0.7325784741344005 -1.1866016911837542 -0.1639319562090826; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 0.9838526141898173 -0.20198797220982412 2.0569535882007006 -1.1560724118010939] | |
chain = Chains(arr0, [:s, Symbol("m[1]"), Symbol("m[2]"), Symbol("m[3]")]); | |
arr0 = [ | |
5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 0.9199555480151707 -0.1304320097505629 1.0669120062696917 -0.05253734412139093; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183;;; | |
3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 2.5409470583244933 1.7838744695696407 0.7013562890105632 -3.0843947804314658; 0.8296370582311665 1.5360702767879642 -1.5964695255693102 0.16928084806166913; 2.6246697053824954 0.8096845024785173 -1.2621822861663752 1.1414885535466166; 1.1304261861894538 0.7325784741344005 -1.1866016911837542 -0.1639319562090826; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 0.9838526141898173 -0.20198797220982412 2.0569535882007006 -1.1560724118010939 | |
] | |
chain = Chains(arr0, [:s, Symbol("m[1]"), Symbol("m[2]"), Symbol("m[3]")]) |
chain = Chains(arr0, [:s, Symbol("m[1]"), Symbol("m[2]"), Symbol("m[3]")]); | ||
tmp1 = pointwise_logdensities(model, chain) | ||
vi = VarInfo(model) | ||
i_sample, i_chain = (1,2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
i_sample, i_chain = (1,2) | |
i_sample, i_chain = (1, 2) |
vi = VarInfo(model) | ||
i_sample, i_chain = (1,2) | ||
DynamicPPL.setval!(vi, chain, i_sample, i_chain) | ||
lp1 = DynamicPPL.pointwise_logdensities(model, vi) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
lp1 = DynamicPPL.pointwise_logdensities(model, vi) | |
lp1 = DynamicPPL.pointwise_logdensities(model, vi) |
lp1 = DynamicPPL.pointwise_logdensities(model, vi) | ||
# k = first(keys(lp1)) | ||
for k in keys(lp1) | ||
@test tmp1[string(k)][i_sample,i_chain] .≈ lp1[k][1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
@test tmp1[string(k)][i_sample,i_chain] .≈ lp1[k][1] | |
@test tmp1[string(k)][i_sample, i_chain] .≈ lp1[k][1] |
for k in keys(lp1) | ||
@test tmp1[string(k)][i_sample,i_chain] .≈ lp1[k][1] | ||
end | ||
end; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end; | |
end; |
…gdensities mostly taken from TuringLang#669
* implement pointwise_logpriors * implement varwise_logpriors * remove pointwise_logpriors * revert dot_assume to not explicitly resolve components of sum * docstring varwise_logpriores use loop for prior in example Unfortunately cannot make it a jldoctest, because relies on Turing for sampling * integrate pointwise_loglikelihoods and varwise_logpriors by pointwise_densities * record single prior components by forwarding dot_tilde_assume to tilde_assume * forward dot_tilde_assume to tilde_assume for Multivariate * avoid recording prior components on leaf-prior-context and avoid recording likelihoods when invoked with leaf-Likelihood context * undeprecate pointwise_loglikelihoods and implement pointwise_prior_logdensities mostly taken from #669 * drop vi instead of re-compute vi bgctw first forwared dot_tilde_assume to get a correct vi and then recomputed it for recording component prior densities. Replaced this by the Hack of torfjelde that completely drops vi and recombines the value, so that assume is called only once for each varName, * include docstrings of pointwise_logdensities pointwise_prior_logdensities int api.md docu * Update src/pointwise_logdensities.jl remove commented code Co-authored-by: Tor Erlend Fjelde <[email protected]> * Update src/pointwise_logdensities.jl remove commented code Co-authored-by: Tor Erlend Fjelde <[email protected]> * Update test/pointwise_logdensities.jl rename m to model Co-authored-by: Tor Erlend Fjelde <[email protected]> * Update test/pointwise_logdensities.jl remove unused code Co-authored-by: Tor Erlend Fjelde <[email protected]> * Update test/pointwise_logdensities.jl rename m to model Co-authored-by: Tor Erlend Fjelde <[email protected]> * Update test/pointwise_logdensities.jl rename m to model Co-authored-by: Tor Erlend Fjelde <[email protected]> * Update src/test_utils.jl remove old code Co-authored-by: Tor Erlend Fjelde <[email protected]> * rename m to model * JuliaFormatter * Update test/runtests.jl remove interactive code Co-authored-by: Tor Erlend Fjelde <[email protected]> * remove demo_dot_assume_matrix_dot_observe_matrix2 testcase testing higher dimensions better left for other PR * ignore local interactive development code * ignore temporary directory holding local interactive development code * Apply suggestions from code review: clean up comments and Imports Co-authored-by: Tor Erlend Fjelde <[email protected]> * Apply suggestions from code review: change test of applying to chains on already used model Co-authored-by: Tor Erlend Fjelde <[email protected]> * fix test on names in likelihood components to work with literal models * try to fix testset pointwise_logdensities chain * Update test/pointwise_logdensities.jl * Update .gitignore * Formtating * Fixed tests * Updated docs for `pointwise_logdensities` + made it a doctest not dependent on Turing.jl * Bump patch version * Remove blank line from `@model` in doctest to see if that fixes the parsing issues * Added doctest filter to handle the `;;]` at the end of lines for matrices --------- Co-authored-by: Tor Erlend Fjelde <[email protected]> Co-authored-by: Tor Erlend Fjelde <[email protected]>
This is based on #663 but with the following changes:
pointwise_loglikelihoods
(but it just usespointwise_logdensities
under the hood)pointwise_prior_logdensities
.dot_tilde_assume
so we can handle.~
correctly.@bgctw I wanted to make a PR to your PR but doesn't seem possible due to being on your fork. I recommend looking at the diffs I've added. We can discuss the changes here and then you can incorporate those we want in your PR so that you get proper credit 👍
EDIT: Ignore all the formatting stuff. I'm deliberately not doing that to make the diff between this branch and @bgctw 's branch simpler to read.