Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] More flexibility in RHS of ~, e.g. MeasureTheory.jl #292

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
MeasureTheory = "eadaa1a4-d27c-401d-8699-e962e1bbc33b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,5 +129,6 @@ include("prob_macro.jl")
include("compat/ad.jl")
include("loglikelihoods.jl")
include("submodel_macro.jl")
include("measuretheory.jl")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that I know much about DynamicPPL, but shouldn't this be a https://github.com/JuliaPackaging/Requires.jl include because measuretheory is quite a large package? 🙂

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this PR as it is now isn't intended to make it into master because of exactly this, i.e. MeasureTheory.jl is too large of a dependency (btw @cscherrer have you considered reducing the number of deps?). Instead this PR demonstrates how we could allow such extensions, e.g. in a DynamicPPLMeasureTheory.jl bridge package or even just adding these overloads in Turing.jl. I just added it here in case people wanted to try the branch out.

And Requires.jl isn't without it's own costs btw and will increase compilation times, so we're probably not going to be using Requires.jl in DPPL in the near future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And Requires.jl isn't without it's own costs btw and will increase compilation times, so we're probably not going to be using Requires.jl in DPPL in the near future.

Interesting to learn that! Thanks

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cscherrer Would it be possible to define the AbstractMeasure interface functions in another lightweight package, so Turing only need to depend on the lightweight package? One possibility is AbstractPPL..

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yebai that's a great idea. I had started MeasureBase for this a while back, but it's out of date now. I think this can be the core, and MeasureTheory can define the actual parameterized measures, etc.

I like the concept of AbstractPPL, but I still need to understand better what it would look like to recast Soss in a way that uses this. Maybe we should have a call about this some time after JuliaCon?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing some dependencies, or at least making it possible to depend on MeasureTheory.jl without all the extras, would be dope @cscherrer 🎉


end # module
70 changes: 70 additions & 0 deletions src/measuretheory.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
using MeasureTheory: MeasureTheory

# src/compiler.jl
# Allow `AbstractMeasure` on the RHS of `~`.
check_tilde_rhs(x::MeasureTheory.AbstractMeasure) = x

# src/utils.jl
# Linearization.
vectorize(d::MeasureTheory.AbstractMeasure, x::Real) = [x]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a generic flatten. In what cases would you do anything with the first argument?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I agree. But the current impl is specified to Distribution, so I added specification here to.

vectorize(d::MeasureTheory.AbstractMeasure, x::AbstractArray{<:Real}) = copy(vec(x))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it correct that x here is just a rearrangement, with logabsdet = 0.0?

This seems very similar to what TransformVariables gives us, something like

reconstruct(d::AbstractMeasure, x::AbstractVector) = transform(as(d), x)

That's not quite right, since (as I understand) you need this without stretching the space. But it should be possible to transform the transformation, replacing e.g. each as𝕀 with asℝ

Copy link
Member Author

@torfjelde torfjelde Jul 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it correct that x here is just a rearrangement, with logabsdet = 0.0?

Yep 👍

That's not quite right, since (as I understand) you need this without stretching the space. But it should be possible to transform the transformation, replacing e.g. each as𝕀 with asℝ

Probably! This is why I'm asking:) I haven't looked at TransformVariables.jl in ages. We're also going to add a Reshape, etc. to Bijectors.jl once TuringLang/Bijectors.jl#183 has gone through.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this at all. Is there a toy example?

Eh no need, I've made a PR now anyways: TuringLang/AbstractPPL.jl#26

I see, yeah that does complicate things.

Well, it sort of makes things easier:) Just look at the impls I have for MeasureTheory now. All we really need is a way to convert a named tuple in to a vector given a Soss-model. So like ParameterHandling.flatten, but without the closure.

Maybe we just need a generic flatten, then vectorize can call it? NestedTuples has

flatten(x, y...) = (flatten(x)..., flatten(y...)...)
flatten(x::Tuple) = flatten(x...)
flatten(x::NamedTuple) = flatten(values(x)...)
flatten(x) = (x,)

so I guess an array version of this?

Is there any concern for performance here, or here it quick enough not to worry about that in thi

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we just need a generic flatten, then vectorize can call it?

so I guess an array version of this?

Exactly! Though it also likely requires knowledge of the measure that, similar to the current vectorize.

Is there any concern for performance here, or here it quick enough not to worry about that in thi

Let's get to that once we have a working impl. Only note I have is that you probably want to use inferrable vcat, i.e. act on the first element, and then reduce vcat with init set to an array containing the first element rather than splatting (like you do for tuples above). Splatting will be super-slow for larger arrays.

function reconstruct(d::MeasureTheory.AbstractMeasure, x::AbstractVector{<:Real})
return _reconstruct(d, x, sampletype(d))
end

# TODO: Higher dims. What to do? Do we have access to size, e.g. for `LKJ` we should have?
function _reconstruct(
d::MeasureTheory.AbstractMeasure, x::AbstractVector{<:Real}, ::Type{<:Real}
)
return x[1]
end
function _reconstruct(
d::MeasureTheory.AbstractMeasure,
x::AbstractVector{<:Real},
::Type{<:AbstractVector{<:Real}},
)
return x
end

# src/context_implementations.jl
# assume
function assume(dist::MeasureTheory.AbstractMeasure, vn::VarName, vi)
r = vi[vn]
# TODO: Transformed variables.
return r, MeasureTheory.logdensity(dist, r)
end

function assume(
rng::Random.AbstractRNG,
sampler::Union{SampleFromPrior,SampleFromUniform},
dist::MeasureTheory.AbstractMeasure,
vn::VarName,
vi,
)
if haskey(vi, vn)
# Always overwrite the parameters with new ones for `SampleFromUniform`.
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
unset_flag!(vi, vn, "del")
r = init(rng, dist, sampler)
vi[vn] = vectorize(dist, r)
settrans!(vi, false, vn)
setorder!(vi, vn, get_num_produce(vi))
else
r = vi[vn]
end
else
r = init(rng, dist, sampler)
push!(vi, vn, r, dist, sampler)
settrans!(vi, false, vn)
end

# TODO: Transformed variables.
return r, MeasureTheory.logdensity(dist, r)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cscherrer How do I get the "transformed" logdensity here, i.e. with the logabsdetjac factor?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far I've been using TransformVariables for this, so it all follows that interface. Measures have as methods for computing the transform, and TV handles the computation. I'm open to making this more general, TV is just what I started with. The most important thing was the dimensionality as in the Dirichlet case, and of course performance

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

end

# observe
function observe(right::MeasureTheory.AbstractMeasure, left, vi)
increment_num_produce!(vi)
return MeasureTheory.logdensity(right, left)
end
16 changes: 7 additions & 9 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ barrier to make the rest of the sampling type stable.
"""
struct Metadata{
TIdcs<:Dict{<:VarName,Int},
TDists<:AbstractVector{<:Distribution},
TDists<:AbstractVector,
TVN<:AbstractVector{<:VarName},
TVal<:AbstractVector{<:Real},
TGIds<:AbstractVector{Set{Selector}},
Expand Down Expand Up @@ -192,7 +192,7 @@ function Metadata()
Vector{VarName}(),
Vector{UnitRange{Int}}(),
vals,
Vector{Distribution}(),
Vector(),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pinged you @mohamed82008 because I figured you'd might have something useful to say about this change:)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think it's fine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can also change it to a Union{Distribution, AbstractMeasure} just to communicate that not anything goes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact it might be better to have a const DistOrMeasure = Union{Distribution, AbstractMeasure} and sprinkle that everywhere instead of Distribution..

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are no perf considerations, then IMO we should just remove it completely rather than introducing some other type because:

  1. We're going to error in the tilde-check anyways.
  2. It will allow additional extensions by simply implementing the tilde functionality for a particular type, e.g. we could allow iterators of distributions to be used on the RHS of .~ rather than only arrays, etc.

Vector{Set{Selector}}(),
Vector{Int}(),
flags,
Expand Down Expand Up @@ -1096,7 +1096,7 @@ end
Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to
the `VarInfo` `vi`.
"""
function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution)
function push!(vi::AbstractVarInfo, vn::VarName, r, dist)
return push!(vi, vn, r, dist, Set{Selector}([]))
end

Expand All @@ -1108,12 +1108,10 @@ from a distribution `dist` to `VarInfo` `vi`.

The sampler is passed here to invalidate its cache where defined.
"""
function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler)
function push!(vi::AbstractVarInfo, vn::VarName, r, dist, spl::Sampler)
return push!(vi, vn, r, dist, spl.selector)
end
function push!(
vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler
)
function push!(vi::AbstractVarInfo, vn::VarName, r, dist, spl::AbstractSampler)
return push!(vi, vn, r, dist)
end

Expand All @@ -1123,10 +1121,10 @@ end
Push a new random variable `vn` with a sampled value `r` sampled with a sampler of
selector `gid` from a distribution `dist` to `VarInfo` `vi`.
"""
function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector)
function push!(vi::AbstractVarInfo, vn::VarName, r, dist, gid::Selector)
return push!(vi, vn, r, dist, Set([gid]))
end
function push!(vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector})
function push!(vi::VarInfo, vn::VarName, r, dist, gidset::Set{Selector})
if vi isa UntypedVarInfo
@assert ~(vn in keys(vi)) "[push!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset"
elseif vi isa TypedVarInfo
Expand Down