Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible Improvements to FixedContext #710

Open
wants to merge 5 commits into
base: torfjelde/context-cleanup
Choose a base branch
from

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented Nov 1, 2024

As pointed out by @penelopeysm and @mhauru in #702 , FixedContext and ConditionContext doesn't quite do what we want for .~ statements. One of the reasons we first introduced fix was to avoid hitting the tilde_*_assume pipeline to improve performance.

This PR implements a possible way of fixing #702 which involves overloading the tilde_dot_assume for FixedContext to handle cases where only parts of the LHS is fixed.

With this branch we can do stuff like fixing only a subset of a .~ statement:

julia> @model function demo(n)
           m = Vector{Float64}(undef, n)
           m .~ Normal()
           s ~ Normal()

           return (; m)
       end
demo (generic function with 2 methods)

julia> model = fix(demo(2), @varname(m[1]) => 0.0);

julia> model()
(m = [0.0, -0.019050747797961672],)

julia> rand(model)
(var"m[2]" = 0.41118414073410653, s = -0.2798438193201507)

However, it requires overloading tilde_dot_assume for FixedContext, which does go slightly against what tilde_*_assume is meant to do (it's meant to be used for random variables, but clearly fixed variables are not random).

Performance implications

IMO the interesting "case" is when we use fix(::Model, ::NamedTuple) since this is consistently what we consider as the "fast mode" in Turing.jl / DynamicPPL.jl, and we can always ask the user to provide the values as a NamedTuple if they really want performance.

There are a few different "approaches" we can take with fixed (and equally condition):

  1. Use current approach. This calls conditional_isfixed + getfixed_nested(__context__, vn) in the main-body of a @model. When it works, this is very performant, as it's just compile-time generated check of sym in names for VarName{sym} and NamedTuple{names}.
  2. Remove current approach in favour of overloading tilde_*_assume pipeline to extract the fixed values.
  3. A "hybrid" approach. Use (1) whenever possible, but if we hit the tilde_*_assume, we also check there for tilde_dot_assume (so that we cover the cases listed in FixedContext and ConditionedContext don't use the same varnames as tilde-pipeline #702) by iterating over all the variables and defering to tilde_assume (i.e. without the dot).

I ran the following snippet for the different approaches:

using Revise, DynamicPPL, Distributions, BenchmarkTools

@model function demo(n)
    m = Vector{Float64}(undef, n)
    m .~ Normal()
    s ~ Normal()
end

iter = 1:14
suite = BenchmarkGroup()
for nlog2 in iter
    n = 2^nlog2
    model = fix(demo(n), m=zeros(n))
    vi = VarInfo(model)
    suite[n] = @benchmarkable $model($vi)
end
results = run(suite, seconds=60)
OrderedDict(2^nlog2 => results[2^nlog2] for nlog2 in iter)

On #master (Approach 1)

julia> OrderedDict(2^nlog2 => results[2^nlog2] for nlog2 in iter)
OrderedDict{Int64, BenchmarkTools.Trial} with 14 entries:
  2     => 125.000 ns
  4     => 125.000 ns
  8     => 125.000 ns
  16    => 125.000 ns
  32    => 84.000 ns
  64    => 125.000 ns
  128   => 83.000 ns
  256   => 166.000 ns
  512   => 208.000 ns
  1024  => 291.000 ns
  2048  => 333.000 ns
  4096  => 500.000 ns
  8192  => 1.000 μs
  16384 => 2.083 μs

On this branch (Approach 3)

julia> OrderedDict(2^nlog2 => results[2^nlog2] for nlog2 in iter)
OrderedDict{Int64, BenchmarkTools.Trial} with 14 entries:
  2     => 166.000 ns
  4     => 166.000 ns
  8     => 208.000 ns
  16    => 208.000 ns
  32    => 208.000 ns
  64    => 166.000 ns
  128   => 125.000 ns
  256   => 250.000 ns
  512   => 291.000 ns
  1024  => 375.000 ns
  2048  => 375.000 ns
  4096  => 583.000 ns
  8192  => 1.000 μs
  16384 => 2.125 μs

As we can see, the performance difference is very, very minor. However, note that this PR still includes the contetxual_isfixed checki n the main body of the model.

Replace current approach fully be overloading tilde (Approach 2)

If we remove this, i.e. only rely on overloading tilde-pipeline, we get the following result:

julia> OrderedDict(2^nlog2 => results[2^nlog2] for nlog2 in iter)
OrderedDict{Int64, BenchmarkTools.Trial} with 14 entries:
  2     => 167.000 ns
  4     => 166.000 ns
  8     => 208.000 ns
  16    => 250.000 ns
  32    => 208.000 ns
  64    => 209.000 ns
  128   => 291.000 ns
  256   => 416.000 ns
  512   => 583.000 ns
  1024  => 875.000 ns
  2048  => 1.500 μs
  4096  => 2.416 μs
  8192  => 4.875 μs
  16384 => 9.125 μs

As we see here, once we have to rely on a for-loop over the variables to check, we do incur a "signfiicant" runtime overhead.

Conclusion

Performing the check in dot_tilde_assume only when explicitly needed doesn't really hurt performance much for fix(::Model, ::NamedTuple) (i.e. Approach 3 vs. Approach 1).

However, purely relying on Approach 2 (i.e. replacing current approach completely with overloading tilde assume) does have quite a significant overhead for just evaluation (assuming this will be even worse when computing gradients).

Soooo I'm leaning towards Approach 3 (as is implemented in this branch), even though it does make things a bit uglier.

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@torfjelde torfjelde changed the base branch from master to torfjelde/context-cleanup November 1, 2024 16:04
Copy link
Member

@penelopeysm penelopeysm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I certainly wish we didn't need to have added complexity, but I'm thoroughly convinced by the profiling 😄 Thank you for looking into it.

To mitigate that, I think we should at least elaborate in the docstring of contextual_isfixed about this case (where the FixedContext may contain variables that don't match those in the model)? I think that will at least help us (or me) the next time we revisit haha.

CI

Separately, there are a few tests that are failing. Most of them should be fixed by the suggested changes (a small typo). Some others will be fixed by #704, so we just need to merge master into this branch to fix those.

However, there's one more newly failing test @ test/turing/model.jl:6 specifically for demo_dot_assume_matrix_dot_observe_matrix. Here's a DPPL-only MWE:

using DynamicPPL, Distributions
@model function f()
    s = Array{Float64}(undef, 1, 2)
    s .~ product_distribution([InverseGamma(2, 3)])
    # also fails
    # s .~ MvNormal([0.0], [1.0])
end
model = f()

# this doesn't fix the variables, because the varnames are not
# concretised -- although this probably isn't a particularly huge deal
fix(model, @varname(s[:, 1], false) => [1.0], @varname(s[:, 2], false) => [2.0])()

# however, this version with concretised varnames errors, and
# `generated_quantities` calls this and in turn errors
s = Array{Float64}(undef, 1, 2)
fix(model, @varname(s[:, 1], true) => [1.0], @varname(s[:, 2], true) => [2.0])()

# e.g. like this (which is a simplified version of test/turing/model.jl:6)
using Turing
chain = sample(model, Prior(), 10)
generated_quantities(model, MCMCChains.get_sections(chain, :parameters))

src/context_implementations.jl Outdated Show resolved Hide resolved
src/context_implementations.jl Outdated Show resolved Hide resolved
@torfjelde
Copy link
Member Author

Thanks for catching those typos @penelopeysm !

Regarding the failiing test, this feels like someting we should be able to fix 👍

@torfjelde
Copy link
Member Author

Thoughts on the incosistency of overriding tilde_dot_assume though? @penelopeysm @mhauru @yebai

@torfjelde
Copy link
Member Author

torfjelde commented Nov 4, 2024

however, this version with concretised varnames errors, and generated_quantities calls this and in turn errors

Btw, coonretization doesn't handle Colon (AFAIK), but handles stuff like begin and end 😬 The reason why we're getting s[1, 1] and s[1, 2] in the chain is becaus we at some point "flatten" the structure completely (because MCMCChains.Chains require this) using

DynamicPPL.jl/src/utils.jl

Lines 1060 to 1085 in da6f9a0

"""
varname_and_value_leaves(vn::VarName, val)
Return an iterator over all varname-value pairs that are represented by `vn` on `val`.
# Examples
```jldoctest varname-and-value-leaves
julia> using DynamicPPL: varname_and_value_leaves
julia> foreach(println, varname_and_value_leaves(@varname(x), 1:2))
(x[1], 1)
(x[2], 2)
julia> foreach(println, varname_and_value_leaves(@varname(x[1:2]), 1:2))
(x[1:2][1], 1)
(x[1:2][2], 2)
julia> x = (y = 1, z = [[2.0], [3.0]]);
julia> foreach(println, varname_and_value_leaves(@varname(x), x))
(x.y, 1)
(x.z[1][1], 2.0)
(x.z[2][1], 3.0)
```

So the error is caused (after fixing some broadcasting bug with Distributions.Product) by the fact that we get vn = @varname(s[:, 1]) in dot_tilde_assume but @varname(s[1,1]) in context.

@penelopeysm
Copy link
Member

Btw, coonretization doesn't handle Colon (AFAIK)

That's true @varname doesn't concretise for colons unless forced. I don't know why but the varname does come out concretised in the chain though (and generated_quantities will use it)

using Turing

@model function f()
    s = Array{Float64}(undef, 1, 2)
    s .~ product_distribution([InverseGamma(2, 3)])
end

chain = sample(f(), Prior(), 10)
dump(collect(keys(chain.info.varname_to_symbol))[1])

#=
AbstractPPL.VarName{:s, ComposedFunction{Accessors.IndexLens{Tuple{Int64}}, Accessors.IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}
  optic: (@o _[:, 1][1]) (function of type ComposedFunction{Accessors.IndexLens{Tuple{Int64}}, Accessors.IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}})
    outer: Accessors.IndexLens{Tuple{Int64}}
      indices: Tuple{Int64}
        1: Int64 1
    inner: Accessors.IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}
      indices: Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}
        1: AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}
          range: Base.OneTo{Int64}
            stop: Int64 1
        2: Int64 1
=#

@torfjelde
Copy link
Member Author

I don't know why but the varname does come out concretised in the chain though (and generated_quantities will use it)

Because we use DynamicPPL.varname_and_value_leaves in Turing.jl to convert it into s[1, 1] and s[1, 2], etc.

@torfjelde
Copy link
Member Author

torfjelde commented Nov 4, 2024

Hmm, this is actually quite an annoying issue 😕

It raises the question of whether DynamicPPL.hasvalue should support stuff like

DynamicPPL.hasvalue(OrderedDict(@varname(s[1,1]) => 0.0), @varname(s[:, 1]))

which I'm somewhat uncertain we want 😕

The current implementation assumes vn is a strict "sub-key" of any of the keys present in vals

DynamicPPL.jl/src/utils.jl

Lines 872 to 895 in da6f9a0

# For `dictlike` we need to check wether `vn` is "immediately" present, or
# if some ancestor of `vn` is present in `dictlike`.
function hasvalue(vals::AbstractDict, vn::VarName)
# First we check if `vn` is present as is.
haskey(vals, vn) && return true
# If `vn` is not present, we check any parent-varnames by attempting
# to split the optic into the key / `parent` and the extraction optic / `child`.
# If `issuccess` is `true`, we found such a split, and hence `vn` is present.
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
o = optic === nothing ? identity : optic
haskey(vals, VarName(vn, o))
end
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
keyoptic = parent === nothing ? identity : parent
# Return early if no such split could be found.
issuccess || return false
# At this point we just need to check that we `canview` the value.
value = vals[VarName(vn, keyoptic)]
return canview(child, value)
end

In an ideal world, this would also handle stuff like

DynamicPPL.hasvalue(OrderedDict(@varname(s[1,1]) => 0.0), @varname(s[:, 1]))
DynamicPPL.hasvalue(OrderedDict(@varname(s[:,1]) => [0.0]), @varname(s[1, 1]))
DynamicPPL.hasvalue(OrderedDict(@varname(s[:,1]) => [0.0]), @varname(s))

but this will complicate the implementation of both hasvalue and getvalue quite a bit 😕

EDIT: Similarly we also need to add support for these in getvalue, which will be more pain.

@mhauru
Copy link
Member

mhauru commented Nov 4, 2024

This is a bit of a drive-by comment, but I've so far failed to wrap my head around how we use VarNames as keys in containers. I feel like sometimes we use subsumes, sometimes we don't, and maybe sometimes we also do some ad hoc half way thing. This is complicated further by having e.g. multiple backend storage structures for (Simple)VarInfo, and having getters and setters for them defined in a few different places (#654). Having to do string representations with Chains is also a big contributor to confusion.

I don't really have a proposal for how to change this, but for many cases like your above hasvalue examples it is not clear to me what I should expect to happen. A more systematic approach, with [mumblemumble interfacesomething design doc mumblemumble] might help solidify this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants