Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempt at implementation of VarNamedVector (Metadata alternative) #555

Merged
merged 222 commits into from
Oct 8, 2024

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented Nov 7, 2023

This is an attempt at a implementation similar to what was discussed in #528 (comment).

The idea is to introduce a simpler alternative to Metadata, which contains a fair bit of not-so-useful-anymore functionality and isn't particularly accessible.

With the current state of this PR, it's possible to evaluate a model with a VarNameVector:

julia> using DynamicPPL

julia> model = last(DynamicPPL.TestUtils.DEMO_MODELS);

julia> x = rand(OrderedDict, model)
OrderedDict{Any, Any} with 2 entries:
  s => [0.91846 1.26951]
  m => [0.379781, -2.30755]

julia> vnv = VarNameVector(x)
[s = [0.9184600730345488 1.2695082629237986], m = [0.3797814300783162, -2.307552998945772]

julia> logjoint(model, VarInfo(vnv))
-15.691863927983814

But it's not possible to sample.

What are we trying to achieve here?

The idea behind VarNameVector is to enable both an efficient and convenient representation of a varinfo, i.e. an efficient representation of a realization from a @model with additional information necessary for both a sampler developer and end-user to extract what they need.

In that vein, it seems like we need something that sort of acts like a Vector, and sort of acts like an OrderedDict, but is neither (at least, I don't think it fits in either).

  1. We need a Vector because we need:
    1. A straight-forward way to work with gradient-based samplers (and more generally, most sampler implementations assumes a simple Vector representation).
    2. Efficient representation of the realizations, e.g. having variables be in contiguous blocks of memory is good (e.g. in the case of for loop over variables x[i] in a model, we want the values for x to be stored in a contiguous chunk of memory).
  2. We need an OrderedDict because we need:
    1. A simple way for the end-user to get information related to a particular VarName, e.g. trace[vn] should result in the realization in it's "constrained" space (i.e. in the space that the distribution of which it came from commonly works in).

Moreover, the above should be achieved while ensuring that we allow both mutable and immutable implementations, in addition to maintaining type-stability whenever possible and falling back to type unstable when not.

Current implementation: VarInfo with Metadata

The current implementation of VarInfo with Metadata as it's underlying storage achieves some of these properties through a few different means:

  1. Type stability for varname-specific operations, e.g. getindex(varinfo, varname), is achieved by effectively grouping the varnames contained in a VarInfo by their symbol (which is part of the type), and putting each group in a separate Metadata in a NamedTuple. Basically, instead of representing pairs like (@varname(x) => 1.0, @varname(y) => 2.0) in a flattened way, we represent them as (x = (@varname(x) => 1.0,), y = (@varname(y) => 2.0,)), such that in the scenario where they different in types, e.g. y isa Int, we can dispatch on the sym in VarName{sym} to determine which entry of the NamedTuple to extract.
    1. This, IMO, is a good approach, and one we should continue as this grouping by symbol is generally what the end-user does anyways, e.g. x is continuous while y is discrete. It also provides very straight-forward guidelines on how to speed up models, i.e. "group variables into a single higher-variate random variable, e.g. x[1] ~ Normal(); x[2] ~ Normal() into x ~ fill(Normal(), 2)".
  2. Run once with type-unstable varinfo to get type-information, and then use this for the subsequent runs to obtain type-stability (and thus much improved performance).
    1. IMO, we should also keep this idea, as it works very well in practice.
    2. But the current impl of type-stable VarInfo is somewhat lacking, e.g. it does not support changing support even if the types stay the same,
  3. Interaction for the end-user is mainly done through varinfo[vn], but is currently a) limited, and b) very confusing.
    1. Currently, VarInfo implements a somewhat confusing subset of both AbstractDict and AbstractVector interfaces, but neither of which are "complete" in any sense.
    2. There is little coherency between the operations one can perform on VarInfo and Metadata, even though VarInfo is effectively just a wrapper around Metadata / NamedTuple containing Metadata, which adds further confusion.
    3. Metadata comes with a lot of additional fields and information that we have slowly been moving away from using, as the resulting codebase ends up being overly complicated and difficult to work with. This also significantly reduces the utility of VarInfo for the end-user.

Replacing Metadata with VarNameVector

VarNameVector is an attempt at replacing Metadata while preserving some of the good ideas from VarInfo with Metadata (and in the process, establish a bare-minimum of functionality required to implement a underlying storage for varinfo):

  1. VarNameVector (should) implement (almost) all operations for AbstractDict + all non-linear-algebra operations for a AbstractVector{<:Real}, to the point where a user will find it easy to use VarNameVector (and thus a VarInfo wrapping a VarNameVector).
    1. AbstractDict interface is implemented as if keys are of type VarName and values are of type corresponding to varinfo[varname].
    2. AbstractVector interface is implemented wrt. underlying "raw" storage, e.g. if we're working with a unconstrained representation, then the vector will be in unconstrained space (unlike varinfo[varname] which will be in constrained space).
  2. VarNameVector uses a contiguous chunk of memory to store the values in a flattened manner, both in the type stable and unstable scenarios, which overall should lead to better performance in both scenarios.
  3. VarNameVector reduces the number of fields + replaces some now-less-useful fields such as dist with fields requiring less information such as transformss (holding the transformations used to map from "unconstrained" to "constrained" representation).

Updates

2023-11-13

Okay, so I've made some new additions:

  1. Much improved testing of the core functionalities for VarNameVector.
  2. It's now possible to push! and update! a VarNameVector where:
    • push! means what it usually means, but the varname / key must be unique.
    • update! works with either existing keys or new keys.

push! & update!

The idea behind the implementation is as follows:

  1. When we see a new varname, push! (and update!, which is equivalent in this scenario), is straight-forward: we just call push! and setindex! on all the necessary underlying containers, e.g. ranges.
  2. When wee see an already existing varname, things become a bit more complicated for a few fields, namely ranges and values (everthing else is just call to setindex!):
    • If the new value requires less memory than the previous value, we can make use of the part of values already allocated to varname.
    • If the new value requires more memory than the previous value, things become a bit more difficult (once again). One approach would be to simply allocate a larger chunk of values to varname and then shift all the values occuring after this part; doing this on every update! will be expensive! Instead, we just move the new value to the end of values and mark the old location as "inactive". This leads to much more efficient update!, and then we can just perform a "sweep" to re-contiguify the underlying storage every now and then.

To make things a bit more concrete, consider the following example:

julia> using DynamicPPL

julia> vnv = VarNameVector(@varname(x) => 1.0, @varname(y) => [2.0]);

julia> vnv.varname_to_index
OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Int64} with 2 entries:
  x => 1
  y => 2

julia> vnv.ranges
2-element Vector{UnitRange{Int64}}:
 1:1
 2:2

julia> OrderedDict(vnv)
OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Any} with 2 entries:
  x => 1.0
  y => [2.0]

julia> vnv[:]
2-element Vector{Float64}:
 1.0
 2.0

Then we update the entry for @varname(x) to a differently sized value:

julia> DynamicPPL.update!(vnv, @varname(x), [3.0, 4.0, 5.0]);


julia> vnv.varname_to_index
OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Int64} with 2 entries:
  x => 1
  y => 2

julia> vnv.ranges
2-element Vector{UnitRange{Int64}}:
 3:5
 2:2

julia> OrderedDict(vnv)
OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Vector{Float64}} with 2 entries:
  x => [3.0, 4.0, 5.0]
  y => [2.0]

julia> vnv[:]
4-element Vector{Float64}:
 3.0
 4.0
 5.0
 2.0

Notice that the order is still preserved, even though the underlying ranges is no longer ordered.

But, if we inspect the underlying values, this contains now-inactive entries:

julia> vnv.vals
5-element Vector{Float64}:
 1.0
 2.0
 3.0
 4.0
 5.0

But in the scenario where we care about performance, we can easily fix this:

julia> DynamicPPL.inactive_ranges_sweep!(vnv);

julia> vnv.varname_to_index
OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Int64} with 2 entries:
  x => 1
  y => 2

julia> vnv.ranges
2-element Vector{UnitRange{Int64}}:
 1:3
 4:4

julia> OrderedDict(vnv)
OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Vector{Float64}} with 2 entries:
  x => [3.0, 4.0, 5.0]
  y => [2.0]

julia> vnv[:]
4-element Vector{Float64}:
 3.0
 4.0
 5.0
 2.0

Type-stable sampling for dynamic suppport

The result of this is that we can even perform type-stable sampling for models with changing support:

julia> using Distributions

julia> @model function demo_random_num_variables(::Type{TV}=Vector{Float64}) where {TV}
           α ~ Dirichlet(ones(10))
           d ~ Categorical(α)
           x = TV(undef, d)
           for i = 1:d
               x[i] ~ Normal()
           end

           return (; α, d, x)
       end
demo_random_num_variables (generic function with 4 methods)

julia> model = demo_random_num_variables();

julia> x = rand(OrderedDict, model);

julia> vi = VarInfo(VarNameVector(x));

julia> first(DynamicPPL.evaluate!!(model, empty!!(vi), SamplingContext()))
(α = [0.38811824437465864, 0.04808086905976418, 0.2038431779706373, 0.07337088670210494, 0.039133882408812014, 0.012910572919725475, 0.014187317531812365, 0.05987651255436473, 0.12002476163571087, 0.040453774842409536], d = 1, x = [-0.26582545522147744])

julia> OrderedDict(last(DynamicPPL.evaluate!!(model, empty!!(vi), SamplingContext())).metadata)
OrderedDict{VarName, Any} with 8 entries:
  α    => [0.0206437, 0.048932, 0.106592, 0.016048, 0.00882465, 0.383303, 0.173889, 0.06695  d    => 6.0
  x[1] => 0.755316
  x[2] => -1.22984
  x[3] => 1.0898
  x[4] => 0.678128
  x[5] => -1.17279
  x[6] => 0.654284

julia> OrderedDict(last(DynamicPPL.evaluate!!(model, empty!!(vi), SamplingContext())).metadata)
OrderedDict{VarName, Any} with 9 entries:
  α    => [0.115859, 0.120242, 0.213961, 0.069815, 0.0541034, 0.207712, 0.0977965, 0.062019  d    => 7.0
  x[1] => -1.78224
  x[2] => 0.357474
  x[3] => -0.39225
  x[4] => -0.154527
  x[5] => -1.00411
  x[6] => 1.01872
  x[7] => 0.262885

The really nice thing here is that, unlike with TypedVarInfo, we don't need to mess around with boolean flags to indicate whether something should be resampled, etc. Instead we just call similar on the VarNameVector and push! onto this.

We can also make this work nicely with TypedVarInfo:)

2024-01-26T15:00:38

Okay, so now all tests should be passing and we should have, effectively, feature parity with VarInfo using Metadata.

But this required quite a lot of code, and there are a few annoyances that are worth pointing out / discussing:

Transformations are (still) a pain

My original idea was that we really did not want to attach the entire Distribution from which the random variable came from to the metadata for a couple of reasons:

  1. This can technically change between evaluations of a model (or, maybe a more realistic scenario, we can't use a VarInfo from ModelA in ModelB, even though they only differ by the RHS of a single ~ statement).
  2. We're really only using the Distribution to determine the transformation from the vectorized / flattened representation to the original one we want and linking.

In fact, the above is only partially true: getindex inside a model evaluation actually uses getindex(vi, vn, dist) and uses the "tilde-local" dist for the reshaping and linking, not the dist present in vi.

Hence it seemed to me that one immediate improvement of Metadata is to remove the dist completely, and instead just store the transformation from the vectorized / flattened representation to whatever desired form we want (be that linked or invlinked).

And this is what we do currently with VarNameVector.

This then "resolves" (1) since now all we need is a map f that takes a vector and outputs something we can work with.

However, (2) is still an issue: we still need the "tilde-local" dist to determine the necessary transformation for the particular realization we're interested in, while simultaenously wanting getindex(vi, vn) to also function as intended outside of a model, i.e. I should still be able to do

@model demo() = x ~ LKJCholesky(2, 1.0)
vi = DynamicPPL.link(VarInfo(model), model)
vi[@varname(x)] == vi[@varname(x), LKJCholesky(2, 1.0)] # => true

Sooo I don't think we can get around having to keep transformations in the metadata object that might not actually get used within a model evaluation if we want allow the indexing behavior above 😕

2024-01-31: Update on transformations

See #575

@torfjelde torfjelde marked this pull request as draft November 7, 2023 11:09
src/varinfo.jl Outdated Show resolved Hide resolved
src/varnamevector.jl Outdated Show resolved Hide resolved
src/varnamevector.jl Outdated Show resolved Hide resolved
src/varnamevector.jl Outdated Show resolved Hide resolved
src/varnamevector.jl Outdated Show resolved Hide resolved
test/varnamevector.jl Outdated Show resolved Hide resolved
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@coveralls
Copy link

coveralls commented Nov 7, 2023

Pull Request Test Coverage Report for Build 11218828080

Details

  • 570 of 672 (84.82%) changed or added relevant lines in 10 files are covered.
  • 121 unchanged lines in 3 files lost coverage.
  • Overall coverage increased (+1.3%) to 79.445%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/abstract_varinfo.jl 0 1 0.0%
src/utils.jl 36 41 87.8%
src/varnamedvector.jl 373 416 89.66%
src/varinfo.jl 113 166 68.07%
Files with Coverage Reduction New Missed Lines %
src/utils.jl 2 82.76%
src/abstract_varinfo.jl 6 78.62%
src/varinfo.jl 113 79.82%
Totals Coverage Status
Change from base Build 11160832460: 1.3%
Covered Lines: 3320
Relevant Lines: 4179

💛 - Coveralls

Copy link

codecov bot commented Nov 7, 2023

Codecov Report

Attention: Patch coverage is 84.85714% with 106 lines in your changes missing coverage. Please review.

Project coverage is 79.02%. Comparing base (7f91c07) to head (b23d4e2).
Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
src/varinfo.jl 69.07% 60 Missing ⚠️
src/varnamedvector.jl 90.14% 41 Missing ⚠️
src/utils.jl 90.24% 4 Missing ⚠️
src/abstract_varinfo.jl 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #555      +/-   ##
==========================================
+ Coverage   77.69%   79.02%   +1.33%     
==========================================
  Files          29       30       +1     
  Lines        3591     4201     +610     
==========================================
+ Hits         2790     3320     +530     
- Misses        801      881      +80     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@torfjelde
Copy link
Member Author

Replied to the comments @mhauru , only very minor changes suggested, aaaand then I think we're good to maybe hit that big green button 👀

@mhauru
Copy link
Member

mhauru commented Oct 4, 2024

@torfjelde, I think I addressed all of them.

Note that the procedure for merging should be

  1. Agree that the code is ready.
  2. Change the default metadata type back to Metadata
  3. Check that tests still pass.
  4. Merge.

@torfjelde
Copy link
Member Author

I unfortuantely can't approve this because I'm still technically the "owner" of the PR, but this is looking good to me 🎉

Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the default metadata backend to Metadata again. VarNamedVector will be in master, but not, by default, used by anything.

I think this is all ready to go. If anyone wants to put in any last minute comments, now is the time.

Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm glad to see this finally gone through!

@mhauru mhauru enabled auto-merge October 8, 2024 13:44
@mhauru mhauru added this pull request to the merge queue Oct 8, 2024
@mhauru
Copy link
Member

mhauru commented Oct 8, 2024

Thanks @torfjelde for doing the bulk of the work for this, and everyone, but especially @willtebbutt and @yebai, for great feedback!

Merged via the queue into master with commit c38e65f Oct 8, 2024
14 checks passed
@mhauru mhauru deleted the torfjelde/varnamevector branch October 8, 2024 14:16
@penelopeysm penelopeysm restored the torfjelde/varnamevector branch October 29, 2024 20:05
@penelopeysm penelopeysm deleted the torfjelde/varnamevector branch November 24, 2024 03:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants