Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] More flexibility in RHS of ~, e.g. MeasureTheory.jl #292

Closed
wants to merge 6 commits into from

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented Jul 27, 2021

I've recently been thinking a bit about how it would be nice to support more than just Distribution from Distributions.jl on RHS of a ~ statement (and other values as samples, i.e. the LHS of ~), and as I was looking through some code today I realized that it shouldn't be too difficult.

Hence this PR which demonstrates what it would take to add this feature, using MeasureTheory.jl as an example use-case. All changes outside of src/measuretheory.jl are general changes that are required to accomodate non-Distribution on RHS of ~.

julia> using DynamicPPL, MeasureTheory

julia> @model function demo(; x=missing, n = x isa AbstractArray ? length(x) : 1)
           m ~ Normal=0.0, σ=1.0)
           x ~ For(1:n) do i
               Normal=m, σ=1.0)
           end
       end

demo (generic function with 1 method)

julia> m() # sampling
3-element Vector{Float64}:
 -1.106403180421966
  0.40711759833021666
 -2.46921196310957

julia> vi = VarInfo(m); m(vi, DefaultContext()) # evaluation
3-element view(::Vector{Float64}, 1:3) with eltype Float64:
 -0.10556975107508737
  0.578507546477508
 -1.4482491679503848

@cscherrer :)

@torfjelde torfjelde marked this pull request as draft July 27, 2021 00:57
@cscherrer
Copy link

THIS WOULD BE SO GREAT!!!

Thanks @torfjelde for taking the initiative on this. BTW, a Soss model is already a measure, so this would make it easy to use Soss from within Turing. I wonder, what would it take to make a Turing model a measure, or even to have a wrapper around one that would make it behave in this way? That could be a good way to get things working together more easily.

@@ -192,7 +192,7 @@ function Metadata()
Vector{VarName}(),
Vector{UnitRange{Int}}(),
vals,
Vector{Distribution}(),
Vector(),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pinged you @mohamed82008 because I figured you'd might have something useful to say about this change:)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think it's fine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can also change it to a Union{Distribution, AbstractMeasure} just to communicate that not anything goes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact it might be better to have a const DistOrMeasure = Union{Distribution, AbstractMeasure} and sprinkle that everywhere instead of Distribution..

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are no perf considerations, then IMO we should just remove it completely rather than introducing some other type because:

  1. We're going to error in the tilde-check anyways.
  2. It will allow additional extensions by simply implementing the tilde functionality for a particular type, e.g. we could allow iterators of distributions to be used on the RHS of .~ rather than only arrays, etc.

end

# TODO: Transformed variables.
return r, MeasureTheory.logdensity(dist, r)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cscherrer How do I get the "transformed" logdensity here, i.e. with the logabsdetjac factor?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far I've been using TransformVariables for this, so it all follows that interface. Measures have as methods for computing the transform, and TV handles the computation. I'm open to making this more general, TV is just what I started with. The most important thing was the dimensionality as in the Dirichlet case, and of course performance

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@torfjelde
Copy link
Member Author

BTW, a Soss model is already a measure, so this would make it easy to use Soss from within Turing. I wonder, what would it take to make a Turing model a measure, or even to have a wrapper around one that would make it behave in this way? That could be a good way to get things working together more easily.

That's a good point! Honestly think the main issue is the linearization that DPPL currently requires. If you have a good way of linearizing the nested tuple sample from Soss, then it shouldn't be much of a leap from this branch:)

@cscherrer
Copy link

There are some tricks in https://github.com/cscherrer/NestedTuples.jl, maybe something from there can help. We can't throw away the structure, but maybe this "leaf setter" thing?
https://github.com/cscherrer/NestedTuples.jl#leaf-setter

Yeah, naming things is hard :)

@@ -129,5 +129,6 @@ include("prob_macro.jl")
include("compat/ad.jl")
include("loglikelihoods.jl")
include("submodel_macro.jl")
include("measuretheory.jl")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that I know much about DynamicPPL, but shouldn't this be a https://github.com/JuliaPackaging/Requires.jl include because measuretheory is quite a large package? 🙂

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this PR as it is now isn't intended to make it into master because of exactly this, i.e. MeasureTheory.jl is too large of a dependency (btw @cscherrer have you considered reducing the number of deps?). Instead this PR demonstrates how we could allow such extensions, e.g. in a DynamicPPLMeasureTheory.jl bridge package or even just adding these overloads in Turing.jl. I just added it here in case people wanted to try the branch out.

And Requires.jl isn't without it's own costs btw and will increase compilation times, so we're probably not going to be using Requires.jl in DPPL in the near future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And Requires.jl isn't without it's own costs btw and will increase compilation times, so we're probably not going to be using Requires.jl in DPPL in the near future.

Interesting to learn that! Thanks

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cscherrer Would it be possible to define the AbstractMeasure interface functions in another lightweight package, so Turing only need to depend on the lightweight package? One possibility is AbstractPPL..

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yebai that's a great idea. I had started MeasureBase for this a while back, but it's out of date now. I think this can be the core, and MeasureTheory can define the actual parameterized measures, etc.

I like the concept of AbstractPPL, but I still need to understand better what it would look like to recast Soss in a way that uses this. Maybe we should have a call about this some time after JuliaCon?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing some dependencies, or at least making it possible to depend on MeasureTheory.jl without all the extras, would be dope @cscherrer 🎉

@torfjelde
Copy link
Member Author

There are some tricks in https://github.com/cscherrer/NestedTuples.jl, maybe something from there can help. We can't throw away the structure, but maybe this "leaf setter" thing?
https://github.com/cscherrer/NestedTuples.jl#leaf-setter

Yeah, naming things is hard :)

But how do you use DynamicHMC then? Surely HMC requires some form of linearization of the parameters?

And we don't need to throw away the structure, we only need to temporarily hide it and revert the change once we reach the Soss model:)

@cscherrer
Copy link

But how do you use DynamicHMC then? Surely HMC requires some form of linearization of the parameters?

You define a transform, like say

t = as((a = asℝ₊, b = as((b1 = asℝ, b2 = as(Vector, as𝕀, 3), b3=CorrCholesky(4)))))

Soss automates composing one of these for a given model. Yes, these are linearized, but there are no names, and everything is stretched in order to have a density over ℝⁿ. I had thought this is also how Turing does things using Bijectors, but maybe that's wrong?

And we don't need to throw away the structure, we only need to temporarily hide it and revert the change once we reach the Soss model:)

Thinking some more about this, it seems like ParameterHandling.flatten could work well:
https://invenia.github.io/ParameterHandling.jl/dev/#ParameterHandling.flatten

@torfjelde
Copy link
Member Author

I'll respond to you here to keep things a bit organized:)

But how do you use DynamicHMC then? Surely HMC requires some form of linearization of the parameters?

You define a transform, like say

t = as((a = asℝ₊, b = as((b1 = asℝ, b2 = as(Vector, as𝕀, 3), b3=CorrCholesky(4)))))

As things are currently, when storing a variable in the trace (VarInfo), we essentially flatten it into a vector and then store the variable names corresponding to ranges in the vector in a separate vector. Hence we'd need to do the same with a Soss-model's output, i.e. extract the linear shape + the variable-names used. I want whatever the above as does internally to make it into a vector.

Soss automates composing one of these for a given model. Yes, these are linearized, but there are no names, and everything is stretched in order to have a density over ℝⁿ.

No that's right, and that's my point:) But

I had thought this is also how Turing does things using Bijectors, but maybe that's wrong?

This is completely independent of the usage of Bijectors.jl though; we'll reshape back into the original shape of the variable before getting the transform from Bijectors.jl. We could also define bijector(measure) for the different measures to return the transformation taking as from the domain of the measure to , but that will only allow us to share some code for the transformation (e.g. the Bijectors.logpdf_with_trans), it won't do anything to address the issue that we need a flatten representation in VarInfo.

Thinking some more about this, it seems like ParameterHandling.flatten could work well:
https://invenia.github.io/ParameterHandling.jl/dev/#ParameterHandling.flatten

Exactly, but we want the symbols, e.g. say we have this

x = (a = 1.0, b = (c = 2.0, ))
x ~ SossModel()

then internally we'd want something like

vals = [1.0, 2.0]
varnames = [VarName{Symbol("x.a")}(), VarName{Symbol("x.a.b.c")}()]

We also want a transformation for which we can compute the logdensity, but this should take x in it's original shape, not the vector vals, i.e. separate issue.

@cscherrer
Copy link

Ok, I think I see. NestedTuples has a lenses function that... well maybe an example is best:

julia> x
(a = 2.0509701709447876, b = (b1 = -0.31411507894223795, b2 = [0.19141842948352514, 0.3248182896463582, 0.6726129111118845], b3 = LinearAlgebra.Cholesky{Float64, LinearAlgebra.UpperTriangular{Float64, Matrix{Float64}}}([1.0 -0.24542409737564375 -0.09619507309595421 -0.3347648223848326; 0.0 0.969415809870744 -0.6940440834677906 -0.25974375198705785; 0.0 0.0 0.7134769219220889 -0.14376094280898585; 0.0 0.0 0.0 0.8943145354515987], 'U', 0)))

julia> NestedTuples.lenses(x)
((@optic _.a), (@optic _.b1)  (@optic _.b), (@optic _.b2)  (@optic _.b), (@optic _.b3)  (@optic _.b))

julia> typeof(NestedTuples.lenses(x)[3])
ComposedFunction{Accessors.PropertyLens{:b2}, Accessors.PropertyLens{:b}}

Currently I'm stopping when I hit an array, but Accessors can also handle these, for example

julia> @optic _.b.b2[3]
(@optic _[3])  (@optic _.b2)  (@optic _.b)

@cscherrer
Copy link

@torfjelde Is it correct that you need things entirely unrolled, so each name is for a scalar? Also, do you need to be able to reconstruct everything from the names alone, or can there be something carried along with it to make this easier?

@torfjelde
Copy link
Member Author

Currently I'm stopping when I hit an array, but Accessors can also handle these, for example

I've actually played around with replacing all this indexing behavior, etc. in Turing by the lenses from Setfield.jl (Accessors.jl seems more unstable from the README, and so it's somewhat unlikely we'll use that atm?).

The annoying case, and the case that stopped me from replacing VarName indexing with Setfield.jl's lenses is the handling of begin and end. It's difficult TuringLang/AbstractPPL.jl#25 :)

But I think I have a way of addressing this actually.

@torfjelde Is it correct that you need things entirely unrolled, so each name is for a scalar? Also, do you need to be able to reconstruct everything from the names alone, or can there be something carried along with it to make this easier?

Not quite:) We also have lists of ranges and dists. So we can for example have

vals = randn(4)
ranges = [1:1, 2:4]
dists = [MvNormal(1, 1.0), MvNormal(3, 1.0)]
varnames = [@varname(x[1:1]), @varname(x[2:4])]

And whenever we encounter, say, x[2:4] in the the model, we can extract correctly sized value from VarInfo by using the size of the dists. See the reconstruct and vectorize functions that I've overloaded in this PR.

So essentially what I'm asking for is a reconstruct and vectorize for Soss-models:)

@cscherrer
Copy link

I've actually played around with replacing all this indexing behavior, etc. in Turing by the lenses from Setfield.jl (Accessors.jl seems more unstable from the README, and so it's somewhat unlikely we'll use that atm?).

Yeah, I'm not too worried about stability. We have version dependencies anyway, plus it's just not that much code. Seems worth it IMO to have easy inroads to ongoing improvements. But Setfield it fine too, it shouldn't matter that much.

The annoying case, and the case that stopped me from replacing VarName indexing with Setfield.jl's lenses is the handling of begin and end. It's difficult TuringLang/AbstractPPL.jl#25 :)

I don't understand this at all. Is there a toy example?

Not quite:) We also have lists of ranges and dists.

I see, yeah that does complicate things.

So essentially what I'm asking for is a reconstruct and vectorize for Soss-models:)

Ok I'll have a look :)

@torfjelde
Copy link
Member Author

I don't understand this at all. Is there a toy example?

Eh no need, I've made a PR now anyways: TuringLang/AbstractPPL.jl#26

I see, yeah that does complicate things.

Well, it sort of makes things easier:) Just look at the impls I have for MeasureTheory now. All we really need is a way to convert a named tuple in to a vector given a Soss-model. So like ParameterHandling.flatten, but without the closure.

# Linearization.
vectorize(d::MeasureTheory.AbstractMeasure, x::Real) = [x]
vectorize(d::MeasureTheory.AbstractMeasure, x::AbstractArray{<:Real}) = copy(vec(x))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it correct that x here is just a rearrangement, with logabsdet = 0.0?

This seems very similar to what TransformVariables gives us, something like

reconstruct(d::AbstractMeasure, x::AbstractVector) = transform(as(d), x)

That's not quite right, since (as I understand) you need this without stretching the space. But it should be possible to transform the transformation, replacing e.g. each as𝕀 with asℝ

Copy link
Member Author

@torfjelde torfjelde Jul 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it correct that x here is just a rearrangement, with logabsdet = 0.0?

Yep 👍

That's not quite right, since (as I understand) you need this without stretching the space. But it should be possible to transform the transformation, replacing e.g. each as𝕀 with asℝ

Probably! This is why I'm asking:) I haven't looked at TransformVariables.jl in ages. We're also going to add a Reshape, etc. to Bijectors.jl once TuringLang/Bijectors.jl#183 has gone through.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this at all. Is there a toy example?

Eh no need, I've made a PR now anyways: TuringLang/AbstractPPL.jl#26

I see, yeah that does complicate things.

Well, it sort of makes things easier:) Just look at the impls I have for MeasureTheory now. All we really need is a way to convert a named tuple in to a vector given a Soss-model. So like ParameterHandling.flatten, but without the closure.

Maybe we just need a generic flatten, then vectorize can call it? NestedTuples has

flatten(x, y...) = (flatten(x)..., flatten(y...)...)
flatten(x::Tuple) = flatten(x...)
flatten(x::NamedTuple) = flatten(values(x)...)
flatten(x) = (x,)

so I guess an array version of this?

Is there any concern for performance here, or here it quick enough not to worry about that in thi

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we just need a generic flatten, then vectorize can call it?

so I guess an array version of this?

Exactly! Though it also likely requires knowledge of the measure that, similar to the current vectorize.

Is there any concern for performance here, or here it quick enough not to worry about that in thi

Let's get to that once we have a working impl. Only note I have is that you probably want to use inferrable vcat, i.e. act on the first element, and then reduce vcat with init set to an array containing the first element rather than splatting (like you do for tuples above). Splatting will be super-slow for larger arrays.

@torfjelde
Copy link
Member Author

Yeah, I'm not too worried about stability. We have version dependencies anyway, plus it's just not that much code. Seems worth it IMO to have easy inroads to ongoing improvements. But Setfield it fine too, it shouldn't matter that much.

Not to press this point because I'm with you on what you just said, but I just realized that BangBang.jl uses Setfield.jl and we're likely to be making use of BangBang.jl in DPPL very soon. In particular I love that there's a @setfield!!! https://github.com/JuliaFolds/BangBang.jl/blob/master/src/setfield.jl


# src/utils.jl
# Linearization.
vectorize(d::MeasureTheory.AbstractMeasure, x::Real) = [x]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a generic flatten. In what cases would you do anything with the first argument?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I agree. But the current impl is specified to Distribution, so I added specification here to.

@cscherrer
Copy link

Not to press this point because I'm with you on what you just said, but I just realized that BangBang.jl uses Setfield.jl and we're likely to be making use of BangBang.jl in DPPL very soon. In particular I love that there's a @setfield!!! https://github.com/JuliaFolds/BangBang.jl/blob/master/src/setfield.jl

Yeah, BangBang is pretty great. Have you seen Kaleido? Definitely worth a look as well

@torfjelde
Copy link
Member Author

Yeah, BangBang is pretty great. Have you seen Kaleido? Definitely worth a look as well

I think the aim of the two is a bit different, no? Both use Setfield.jl under the hood, but BangBang is all about using mutation when it makes sense and not when it doesn't, e.g. arrays are mutated inplace rather than copied. Kaleido looks cool as an extended Setfield.jl though! But seems like overkill for what we'll need in DPPL to accomplish what we want, e.g. TuringLang/AbstractPPL.jl#26 .

bors bot pushed a commit that referenced this pull request Sep 8, 2021
This PR adds a `DynamicPPL.TestUtils` submodule which is meant to include functionality to make it easy to test new samplers, new implementations of `AbstractVarInfo`, etc.

As of right now, this is mainly just a collection of models with equivalent marginal posteriors using the different features of DPPL, e.g. some are using `.~`, some are using `@submodel`, etc.

Eventually this should be expanded to be of more use, but more immediately this will be useful to test functionality in open PRs, e.g. #269, #309, #295, #292.

These models are also already used in Turing.jl's test-suite (https://github.com/TuringLang/Turing.jl/blob/9f52d75c25390b68115624b2e6cf464275a88137/test/test_utils/models.jl#L55-L56), so this PR would avoid the code-duplication + make it easier to keep things up-to-date.
bors bot pushed a commit that referenced this pull request Sep 9, 2021
This PR adds a `DynamicPPL.TestUtils` submodule which is meant to include functionality to make it easy to test new samplers, new implementations of `AbstractVarInfo`, etc.

As of right now, this is mainly just a collection of models with equivalent marginal posteriors using the different features of DPPL, e.g. some are using `.~`, some are using `@submodel`, etc.

Eventually this should be expanded to be of more use, but more immediately this will be useful to test functionality in open PRs, e.g. #269, #309, #295, #292.

These models are also already used in Turing.jl's test-suite (https://github.com/TuringLang/Turing.jl/blob/9f52d75c25390b68115624b2e6cf464275a88137/test/test_utils/models.jl#L55-L56), so this PR would avoid the code-duplication + make it easier to keep things up-to-date.
bors bot pushed a commit that referenced this pull request Sep 9, 2021
This PR adds a `DynamicPPL.TestUtils` submodule which is meant to include functionality to make it easy to test new samplers, new implementations of `AbstractVarInfo`, etc.

As of right now, this is mainly just a collection of models with equivalent marginal posteriors using the different features of DPPL, e.g. some are using `.~`, some are using `@submodel`, etc.

Eventually this should be expanded to be of more use, but more immediately this will be useful to test functionality in open PRs, e.g. #269, #309, #295, #292.

These models are also already used in Turing.jl's test-suite (https://github.com/TuringLang/Turing.jl/blob/9f52d75c25390b68115624b2e6cf464275a88137/test/test_utils/models.jl#L55-L56), so this PR would avoid the code-duplication + make it easier to keep things up-to-date.
@PavanChaggar PavanChaggar requested review from mohamed82008 and removed request for mohamed82008 January 10, 2022 20:21
@yebai
Copy link
Member

yebai commented Nov 2, 2022

I think this is superseded by #342

@yebai yebai closed this Nov 2, 2022
@yebai yebai deleted the tor/measuretheory branch November 3, 2022 18:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants