Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Immutable versions of link and invlink #525

Merged
merged 27 commits into from
Sep 1, 2023

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented Aug 31, 2023

For a while now I've been thinking that it would be nice with immutable and mutable versions of the different !! functions we have (the !! convention is supposed to mean "prefer mutation").

When we introduced the usage of !!, we simply said that the types themselves would indicate whether or not we were mutating, but often it's convenient for a user to be able to specify that they want immutability, even if the underlying type supports mutability (which would then be used in the corresponding !! method).

One particular aspect where immutable versions are convenient is that of link and invlink. For example, before returning the a transition to the user (and potentially a Chains), in Turing.jl we always want to make sure that the underlying varinfo has been both invlink-ed and deepcopy-ed so that a) the user sees the samples in constrained space rather than unconstrained, and b) to avoid accidentally passing back buffers which are then mutated in the evaluation of the samples (we've had several bugs where we forget a deepcopy + use mutable invlink!, resulting in chains being all the same value). With an immutable invlink, makes such mistakes becomes more difficult.

Moreover, currently it's not possible with the mutable invlinking:

julia> @model demo_dirichlet() = x ~ Dirichlet(ones(2))
demo_dirichlet (generic function with 2 methods)

julia> model = demo_dirichlet()
Model{typeof(demo_dirichlet), (), (), (), Tuple{}, Tuple{}, DefaultContext}(demo_dirichlet, NamedTuple(), NamedTuple(), DefaultContext())

julia> vi = VarInfo(model);

julia> vi[:]
2-element 
Vector{Float64}:
 0.12307494142994369
 0.8769250585700563

julia> vi.metadata.x.vals
2-element Vector{Float64}:
 0.12307494142994369
 0.8769250585700563

julia> vi_linked = DynamicPPL.link!!(vi, model);

julia> vi_linked[:]
1-element Vector{Float64}:
 -1.9636280869363587

julia> vi_linked.metadata.x.vals
2-element Vector{Float64}:
 -1.9636280869363587
  0.8769250585700563

julia> vi_linked_unflattened = DynamicPPL.unflatten(vi_linked, vi_linked[:]);

julia> vi_linked_unflattened[:]
1-element Vector{Float64}:
 -1.9636280869363587

julia> vi_linked_unflattened.metadata.x.vals
1-element Vector{Float64}:
 -1.9636280869363587

julia> vi_linked_invlinked = DynamicPPL.invlink!!(
           vi_linked_unflattened,
           model
       )
ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [1:2]
Stacktrace:
  [1] throw_boundserror(A::Vector{Float64}, I::Tuple{UnitRange{Int64}})
    @ Base ./abstractarray.jl:744
  [2] checkbounds
    @ ./abstractarray.jl:709 [inlined]
  [3] setindex!
    @ ./array.jl:992 [inlined]
  [4] setval!(md::DynamicPPL.Metadata{Dict{VarName{:x, Setfield.IdentityLens}, Int64}, Vector{Dirichlet{Float64, Vector{Float64}, Float64}}, Vector{VarName{:x, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, val::Vector{Float64}, vn::VarName{:x, Setfield.IdentityLens})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/slbWl/src/varinfo.jl:329
  [5] setval!
    @ ~/.julia/packages/DynamicPPL/slbWl/src/varinfo.jl:327 [inlined]
  [6] _inner_transform!(vi::TypedVarInfo{NamedTuple{(:x,), Tuple{DynamicPPL.Metadata{Dict{VarName{:x, Setfield.IdentityLens}, Int64}, Vector{Dirichlet{Float64, Vector{Float64}, Float64}}, Vector{VarName{:x, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, vn::VarName{:x, Setfield.IdentityLens}, dist::Dirichlet{Float64, Vector{Float64}, Float64}, f::Bijectors.Inverse{Bijectors.SimplexBijector})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/slbWl/src/varinfo.jl:897
  [7] macro expansion
    @ ~/.julia/packages/DynamicPPL/slbWl/src/varinfo.jl:874 [inlined]
  [8] _invlink!(metadata::NamedTuple{(:x,), Tuple{DynamicPPL.Metadata{Dict{VarName{:x, Setfield.IdentityLens}, Int64}, Vector{Dirichlet{Float64, Vector{Float64}, Float64}}, Vector{VarName{:x, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, vi::TypedVarInfo{NamedTuple{(:x,), Tuple{DynamicPPL.Metadata{Dict{VarName{:x, Setfield.IdentityLens}, Int64}, Vector{Dirichlet{Float64, Vector{Float64}, Float64}}, Vector{VarName{:x, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, vns::NamedTuple{(:x,), Tuple{Vector{VarName{:x, Setfield.IdentityLens}}}}, #unused#::Val{()})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/slbWl/src/varinfo.jl:860
  [9] _invlink!
    @ ~/.julia/packages/DynamicPPL/slbWl/src/varinfo.jl:858 [inlined]
 [10] _invlink!
    @ ~/.julia/packages/DynamicPPL/slbWl/src/varinfo.jl:854 [inlined]
 [11] invlink!!
    @ ~/.julia/packages/DynamicPPL/slbWl/src/varinfo.jl:806 [inlined]
 [12] invlink!!
    @ ~/.julia/packages/DynamicPPL/slbWl/src/abstract_varinfo.jl:403 [inlined]
 [13] invlink!!(vi::TypedVarInfo{NamedTuple{(:x,), Tuple{DynamicPPL.Metadata{Dict{VarName{:x, Setfield.IdentityLens}, Int64}, Vector{Dirichlet{Float64, Vector{Float64}, Float64}}, Vector{VarName{:x, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, model::Model{typeof(demo_dirichlet), (), (), (), Tuple{}, Tuple{}, DefaultContext})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/slbWl/src/abstract_varinfo.jl:397
 [14] top-level scope
    @ REPL[35]:1

This is because model contains distributions that are projected onto a smaller space upon linking, e.g. Dirichlet or LKJCholesky. Here unflatten will only extract the sub-part of varinfo.metadata.$var.vals that is "active" after link!!, and then our subsequent call to invlink!! tries to push a larger-dimensional vector into the too small container.

Note that this pattern is quite common in Turing.jl, e.g. https://github.com/TuringLang/Turing.jl/blob/0ad7368384e0885c5a76f5a43d1e6b630ef09926/src/inference/hmc.jl#L212-L222.

Fixing this in the mutable invlink!! for VarInfo is quite awkward + worsens performance, as we will have to resize the underlying varinfo.metadata.$var.vals as needed, and then insert the values in the buffer. This is also partially because we've had to shoehorn support for these things into VarInfo.

In an immutable invlink, implementing this is trivial (as is done in this PR).

Related: #504

src/varinfo.jl Outdated Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Aug 31, 2023

Codecov Report

Patch coverage: 91.40% and project coverage change: +0.41% 🎉

Comparison is base (549d9b1) 80.40% compared to head (e05bb09) 80.81%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #525      +/-   ##
==========================================
+ Coverage   80.40%   80.81%   +0.41%     
==========================================
  Files          24       24              
  Lines        2776     2904     +128     
==========================================
+ Hits         2232     2347     +115     
- Misses        544      557      +13     
Files Changed Coverage Δ
src/DynamicPPL.jl 100.00% <ø> (ø)
src/threadsafe.jl 20.40% <0.00%> (-0.87%) ⬇️
src/abstract_varinfo.jl 91.17% <60.00%> (-3.39%) ⬇️
src/varinfo.jl 92.37% <95.38%> (-0.01%) ⬇️
src/test_utils.jl 86.36% <100.00%> (+1.66%) ⬆️
src/transforming.jl 100.00% <100.00%> (ø)
src/utils.jl 78.92% <100.00%> (+0.19%) ⬆️

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@coveralls
Copy link

coveralls commented Aug 31, 2023

Pull Request Test Coverage Report for Build 6041688975

  • 117 of 128 (91.41%) changed or added relevant lines in 6 files are covered.
  • 2 unchanged lines in 1 file lost coverage.
  • Overall coverage increased (+0.4%) to 80.82%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/varinfo.jl 62 65 95.38%
src/abstract_varinfo.jl 6 10 60.0%
src/threadsafe.jl 0 4 0.0%
Files with Coverage Reduction New Missed Lines %
src/varinfo.jl 2 92.38%
Totals Coverage Status
Change from base Build 6018544130: 0.4%
Covered Lines: 2347
Relevant Lines: 2904

💛 - Coveralls

@yebai yebai requested a review from sunxd3 August 31, 2023 16:32
src/test_utils.jl Outdated Show resolved Hide resolved
src/test_utils.jl Outdated Show resolved Hide resolved
test/varinfo.jl Outdated Show resolved Hide resolved
@torfjelde torfjelde mentioned this pull request Aug 31, 2023
src/abstract_varinfo.jl Outdated Show resolved Hide resolved
Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good work -- I left a few minor comments below.

src/abstract_varinfo.jl Outdated Show resolved Hide resolved

See also: [`default_transformation`](@ref), [`link`](@ref).
"""
invlink(vi::AbstractVarInfo, model::Model) = invlink(vi, SampleFromPrior(), model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SampleFromPrior was introduced before we had the mechanism for passing the t::Transformation. It is now no longer necessary. In the longer run, we can consider replacing SampleFromPrior with a suitable context for clarity and consistency.

Suggested change
invlink(vi::AbstractVarInfo, model::Model) = invlink(vi, SampleFromPrior(), model)
invlink(vi::AbstractVarInfo, model::Model) = invlink(vi, SampleFromPrior(), model)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! This has always been the idea:) But for now, we're not quite read to do that.

src/utils.jl Outdated
@@ -514,6 +514,13 @@ function BangBang.possible(
return BangBang.implements(setindex!, C) &&
promote_type(eltype(C), eltype(T)) <: eltype(C)
end
# NOTE: Makes it possible to use ranges, etc. for setting a vector.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is incomprehensible to me. Maybe @sunxd3 can check its correctness.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a more descriptive comment here; hopefully that clarifies things a bit.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, I think the issue is still there

julia> using DynamicPPL, BangBang

julia> svi = SimpleVarInfo(Dict((@varname(a))=>[1 2; 3 4]))
SimpleVarInfo(Dict(a => [1 2; 3 4]), 0.0)

julia> setindex!!(svi, [2 2; 2 2], @varname(a[1:2, 1:2]))
SimpleVarInfo(Dict{VarName{:a, Setfield.IdentityLens}, Matrix{Any}}(a => [2 2; 2 2]), 0.0)

@torfjelde if the memory is still fresh, maybe this is an easy fix? I can help also.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@torfjelde, for clarity, was this issue fixed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, I think the issue is still there

This is a separate issue to the one I fixed though; I specifically only fixed the one for Vector and ranges. We need a more general fix to address all types of arrays. Could you open an issue @sunxd3?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also worth pointing out that this has nothing to do with the intrduced functionality in this PR; these bugs are from before.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's why I didn't insist. I'll open an issue.

src/varinfo.jl Show resolved Hide resolved
@sunxd3
Copy link
Member

sunxd3 commented Sep 1, 2023

Question about VarInfo: what's the purpose of del flags?

@torfjelde torfjelde added this pull request to the merge queue Sep 1, 2023
@torfjelde
Copy link
Member Author

Question about VarInfo: what's the purpose of del flags?

Basically just used in places like this

function assume(
rng::Random.AbstractRNG,
sampler::Union{SampleFromPrior,SampleFromUniform},
dist::Distribution,
vn::VarName,
vi::VarInfoOrThreadSafeVarInfo,
)
if haskey(vi, vn)
# Always overwrite the parameters with new ones for `SampleFromUniform`.
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
unset_flag!(vi, vn, "del")
r = init(rng, dist, sampler)
BangBang.setindex!!(
vi, vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r)), vn
)
setorder!(vi, vn, get_num_produce(vi))
else
# Otherwise we just extract it.
r = vi[vn, dist]
end
else
r = init(rng, dist, sampler)
if istrans(vi)
push!!(vi, vn, reconstruct_and_link(dist, r), dist, sampler)
# By default `push!!` sets the transformed flag to `false`.
settrans!!(vi, true, vn)
else
push!!(vi, vn, r, dist, sampler)
end
end
# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
return r, logpdf(dist, r) - logjac, vi
end

But now we have SamplingContext which should specify whether we are sampling or evaluating; del flag is from before its time (it's still being used by VarInfo code though).

Merged via the queue into master with commit ba16e3b Sep 1, 2023
12 of 13 checks passed
@torfjelde torfjelde deleted the torfjelde/immutable-versions-of-methods branch September 1, 2023 16:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants