Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

error with invlink!! and Dirichlet #504

Closed
miguelbiron opened this issue Jul 24, 2023 · 12 comments
Closed

error with invlink!! and Dirichlet #504

miguelbiron opened this issue Jul 24, 2023 · 12 comments

Comments

@miguelbiron
Copy link

miguelbiron commented Jul 24, 2023

The following code works well with DynamicPPL v0.23.0 but fails with anything newer. In turn, this seems to be caused by the new parametrization of the Simplex introduced in Bijectors v0.13.0

using DynamicPPL, Distributions

@model function demo()
    x ~ Dirichlet(2, 1.0)
end
model = demo()
vi    = VarInfo(model)                                   # make VarInfo -> sample from prior and compute logdensity
getlogp(vi)  0.0                                        # zero because Dirichlet(1) == Uniform over Simplex
spl   = SampleFromPrior()                                # create dummy sampler for linking
DynamicPPL.link!!(vi, spl, model)                        # transform to unconstrained space
!(0.0  getlogp(last(DynamicPPL.evaluate!!(model, vi)))) # non-zero now due to log(abs(determinant(jacobian)))
x = vi[spl]                                              # extract unconstrained values
newx  = deepcopy(x)                                      # simulate making a change to x
vinew = VarInfo(vi, spl, newx)                           # make a new vi with the new unconstrained values
invlink!!(vinew,spl,model)                               # fails with Bijectors >= v0.13

The error is

julia> invlink!!(vinew,spl,model)
ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [1:2]
Stacktrace:
  [1] throw_boundserror(A::Vector{Float64}, I::Tuple{UnitRange{Int64}})
    @ Base ./abstractarray.jl:744
  [2] checkbounds
    @ ./abstractarray.jl:709 [inlined]
  [3] setindex!
    @ ./array.jl:992 [inlined]
  [4] setval!(md::DynamicPPL.Metadata{Dict{VarName{:x, Setfield.IdentityLens}, Int64}, Vector{Dirichlet{Float64, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Float64}}, Vector{VarName{:x, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, val::Vector{Float64}, vn::VarName{:x, Setfield.IdentityLens})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/qAXlC/src/varinfo.jl:328
  [5] setval!
    @ ~/.julia/packages/DynamicPPL/qAXlC/src/varinfo.jl:327 [inlined]
  [6] _inner_transform!(vi::TypedVarInfo{NamedTuple{(:x,), Tuple{DynamicPPL.Metadata{Dict{VarName{:x, Setfield.IdentityLens}, Int64}, Vector{Dirichlet{Float64, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Float64}}, Vector{VarName{:x, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, vn::VarName{:x, Setfield.IdentityLens}, dist::Dirichlet{Float64, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Float64}, f::Bijectors.Inverse{Bijectors.SimplexBijector})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/qAXlC/src/varinfo.jl:892
  [7] macro expansion
    @ ~/.julia/packages/DynamicPPL/qAXlC/src/varinfo.jl:869 [inlined]
  [8] _invlink!(metadata::NamedTuple{(:x,), Tuple{DynamicPPL.Metadata{Dict{VarName{:x, Setfield.IdentityLens}, Int64}, Vector{Dirichlet{Float64, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Float64}}, Vector{VarName{:x, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, vi::TypedVarInfo{NamedTuple{(:x,), Tuple{DynamicPPL.Metadata{Dict{VarName{:x, Setfield.IdentityLens}, Int64}, Vector{Dirichlet{Float64, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Float64}}, Vector{VarName{:x, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, vns::NamedTuple{(:x,), Tuple{Vector{VarName{:x, Setfield.IdentityLens}}}}, #unused#::Val{()})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/qAXlC/src/varinfo.jl:855
  [9] _invlink!
    @ ~/.julia/packages/DynamicPPL/qAXlC/src/varinfo.jl:853 [inlined]
 [10] _invlink!
    @ ~/.julia/packages/DynamicPPL/qAXlC/src/varinfo.jl:849 [inlined]
 [11] invlink!!
    @ ~/.julia/packages/DynamicPPL/qAXlC/src/varinfo.jl:801 [inlined]
 [12] invlink!!(vi::TypedVarInfo{NamedTuple{(:x,), Tuple{DynamicPPL.Metadata{Dict{VarName{:x, Setfield.IdentityLens}, Int64}, Vector{Dirichlet{Float64, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Float64}}, Vector{VarName{:x, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, spl::SampleFromPrior, model::Model{typeof(demo), (), (), (), Tuple{}, Tuple{}, DefaultContext})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/qAXlC/src/abstract_varinfo.jl:403
 [13] top-level scope
@yebai
Copy link
Member

yebai commented Aug 2, 2023

cc @torfjelde @harisorgn

@torfjelde
Copy link
Member

Yeah, this is a bug! Working on a fix now. Thanks for bringing this to our attention @miguelbiron !

@torfjelde
Copy link
Member

Btw, what was the use-case where you ran into this errr @miguelbiron ?

@miguelbiron
Copy link
Author

Hi @torfjelde ! Thanks for working on this. I was implementing a model with a Dirichlet prior on a custom sampler package that leverages DynamicPPL. The code is >1y old so it does not use the newer LogDensityProblem interface.

@torfjelde
Copy link
Member

Gotcha 👍 We're currently in the process of moving Turing to the more recent version of DynamicPPL (and thus Bijectors), and so, even though I haven't seen this issue show up yet (we've run into some other issues), this most certainly seem like an issue that could show up in our codebase too. Hence it's high on my priority list

@torfjelde
Copy link
Member

torfjelde commented Aug 5, 2023

Hmm, this seems to be somewhat of annoying issue to address.

Basically, VarInfo was not originally designed to support scenarios with changing support, and so we've had to do a somewhat hacky solution to get it working. This works well internally in Turing.jl because we always just re-use a single VarInfo until the very last step where we decide to convert it into a Turing.Inference.Transition (which we can then convert into a MCMCChains.Chains). This means that we use the same underlying buffer both when calling link!! and invlink!!.

That is, we effectively do the following instead of what you wrote:

julia> using DynamicPPL, Distributions


julia> @model function demo()
           x ~ Dirichlet(2, 1.0)
       end
demo (generic function with 2 methods)

julia> model = demo();

julia> vi = VarInfo(model);

julia> getlogp(vi)  0.0
true

julia> spl = SampleFromPrior();

julia> DynamicPPL.link!!(vi, spl, model);

julia> !(0.0  getlogp(last(DynamicPPL.evaluate!!(model, vi))))
true

julia> x = vi[spl];

julia> newx = deepcopy(x);

julia> vinew = deepcopy(vi);

julia> vinew[spl] = newx;

julia> invlink!!(vinew, spl, model);

julia> vinew[spl]
2-element Vector{Float64}:
 0.5757209465831549
 0.42427905341684513

(it turns out that this actually requires a bugfix too #513 ; we never really link using SampleFromPrior, which hits a particular implementation of setting the variables that also hadn't bee updated properly)

I think to address this particular issue properly, we actually need to keep track of both the shape-information of the constrained and the unconstrained representation, which is somewhat annoying 😕

@torfjelde
Copy link
Member

torfjelde commented Aug 5, 2023

But will the above snippet work for you @miguelbiron (once the PR has gone through) ?

@miguelbiron
Copy link
Author

Yes, that looks good to me. Thank you @torfjelde for quickly addressing the issue!

@miguelbiron
Copy link
Author

miguelbiron commented Aug 8, 2023

Oops, just realized that the fix immediately fails if I increase the dimension of the Dirichlet from 2 to 3. I.e., the following fails

using DynamicPPL, Distributions

@model function demo()
    x ~ Dirichlet(3, 1.0)                                # increase K to 3
end
model = demo()
vi    = VarInfo(model)                                   # make VarInfo -> sample from prior and compute logdensity
getlogp(vi)  0.0                                        # zero because Dirichlet(1) == Uniform over Simplex
spl   = SampleFromPrior()                                # create dummy sampler for linking
DynamicPPL.link!!(vi, spl, model)                        # transform to unconstrained space
!(0.0  getlogp(last(DynamicPPL.evaluate!!(model, vi)))) # non-zero now due to log(abs(determinant(jacobian)))
x = vi[spl]                                              # extract unconstrained values
newx  = deepcopy(x)                                      # simulate making a change to x
vinew = deepcopy(vi)
vinew[spl] = newx;

with error

ERROR: DimensionMismatch: array could not be broadcast to match destination
Stacktrace:
  [1] check_broadcast_shape
    @ ./broadcast.jl:553 [inlined]
  [2] check_broadcast_axes
    @ ./broadcast.jl:556 [inlined]
  [3] instantiate
    @ ./broadcast.jl:297 [inlined]
  [4] materialize!
    @ ./broadcast.jl:884 [inlined]
  [5] materialize!
    @ ./broadcast.jl:881 [inlined]
  [6] macro expansion
    @ ~/.julia/packages/DynamicPPL/JvJLF/src/varinfo.jl:0 [inlined]
  [7] _setall!(metadata::NamedTuple{(:x,), Tuple{DynamicPPL.Metadata{Dict{VarName{:x, Setfield.IdentityLens}, Int64}, Vector{Dirichlet{Float64, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Float64}}, Vector{VarName{:x, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, val::Vector{Float64})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/JvJLF/src/varinfo.jl:363
  [8] setall!
    @ ~/.julia/packages/DynamicPPL/JvJLF/src/varinfo.jl:362 [inlined]
  [9] setindex!(vi::TypedVarInfo{NamedTuple{(:x,), Tuple{DynamicPPL.Metadata{Dict{VarName{:x, Setfield.IdentityLens}, Int64}, Vector{Dirichlet{Float64, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Float64}}, Vector{VarName{:x, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, val::Vector{Float64}, spl::SampleFromPrior)
    @ DynamicPPL ~/.julia/packages/DynamicPPL/JvJLF/src/varinfo.jl:1000

Manifest shows DynamicPPL is at version = "0.23.12" so this should already incorporate the PR. Any thoughts?

@yebai
Copy link
Member

yebai commented Aug 8, 2023

An inconsistent Array shape causes it. Should be working after #516

@yebai
Copy link
Member

yebai commented Aug 9, 2023

Fixed by #516

@yebai yebai closed this as completed Aug 9, 2023
@miguelbiron
Copy link
Author

Fix is working great --- thank you guys!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants