-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] More flexibility in RHS of ~
, e.g. MeasureTheory.jl
#292
Changes from 4 commits
19335dd
a68aeed
267e9fd
4443ec8
4714178
4c06543
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
using MeasureTheory: MeasureTheory | ||
|
||
# src/compiler.jl | ||
# Allow `AbstractMeasure` on the RHS of `~`. | ||
check_tilde_rhs(x::MeasureTheory.AbstractMeasure) = x | ||
|
||
# src/utils.jl | ||
# Linearization. | ||
vectorize(d::MeasureTheory.AbstractMeasure, x::Real) = [x] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like a generic There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, I agree. But the current impl is specified to |
||
vectorize(d::MeasureTheory.AbstractMeasure, x::AbstractArray{<:Real}) = copy(vec(x)) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it correct that This seems very similar to what TransformVariables gives us, something like reconstruct(d::AbstractMeasure, x::AbstractVector) = transform(as(d), x) That's not quite right, since (as I understand) you need this without stretching the space. But it should be possible to transform the transformation, replacing e.g. each There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yep 👍
Probably! This is why I'm asking:) I haven't looked at TransformVariables.jl in ages. We're also going to add a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Maybe we just need a generic flatten(x, y...) = (flatten(x)..., flatten(y...)...)
flatten(x::Tuple) = flatten(x...)
flatten(x::NamedTuple) = flatten(values(x)...)
flatten(x) = (x,) so I guess an array version of this? Is there any concern for performance here, or here it quick enough not to worry about that in thi There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Exactly! Though it also likely requires knowledge of the measure that, similar to the current
Let's get to that once we have a working impl. Only note I have is that you probably want to use inferrable |
||
function reconstruct(d::MeasureTheory.AbstractMeasure, x::AbstractVector{<:Real}) | ||
return _reconstruct(d, x, sampletype(d)) | ||
end | ||
|
||
# TODO: Higher dims. What to do? Do we have access to size, e.g. for `LKJ` we should have? | ||
function _reconstruct( | ||
d::MeasureTheory.AbstractMeasure, x::AbstractVector{<:Real}, ::Type{<:Real} | ||
) | ||
return x[1] | ||
end | ||
function _reconstruct( | ||
d::MeasureTheory.AbstractMeasure, | ||
x::AbstractVector{<:Real}, | ||
::Type{<:AbstractVector{<:Real}}, | ||
) | ||
return x | ||
end | ||
|
||
# src/context_implementations.jl | ||
# assume | ||
function assume(dist::MeasureTheory.AbstractMeasure, vn::VarName, vi) | ||
r = vi[vn] | ||
# TODO: Transformed variables. | ||
return r, MeasureTheory.logdensity(dist, r) | ||
end | ||
|
||
function assume( | ||
rng::Random.AbstractRNG, | ||
sampler::Union{SampleFromPrior,SampleFromUniform}, | ||
dist::MeasureTheory.AbstractMeasure, | ||
vn::VarName, | ||
vi, | ||
) | ||
if haskey(vi, vn) | ||
# Always overwrite the parameters with new ones for `SampleFromUniform`. | ||
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") | ||
unset_flag!(vi, vn, "del") | ||
r = init(rng, dist, sampler) | ||
vi[vn] = vectorize(dist, r) | ||
settrans!(vi, false, vn) | ||
setorder!(vi, vn, get_num_produce(vi)) | ||
else | ||
r = vi[vn] | ||
end | ||
else | ||
r = init(rng, dist, sampler) | ||
push!(vi, vn, r, dist, sampler) | ||
settrans!(vi, false, vn) | ||
end | ||
|
||
# TODO: Transformed variables. | ||
return r, MeasureTheory.logdensity(dist, r) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cscherrer How do I get the "transformed" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So far I've been using TransformVariables for this, so it all follows that interface. Measures have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
end | ||
|
||
# observe | ||
function observe(right::MeasureTheory.AbstractMeasure, left, vi) | ||
increment_num_produce!(vi) | ||
return MeasureTheory.logdensity(right, left) | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,7 +43,7 @@ barrier to make the rest of the sampling type stable. | |
""" | ||
struct Metadata{ | ||
TIdcs<:Dict{<:VarName,Int}, | ||
TDists<:AbstractVector{<:Distribution}, | ||
TDists<:AbstractVector, | ||
TVN<:AbstractVector{<:VarName}, | ||
TVal<:AbstractVector{<:Real}, | ||
TGIds<:AbstractVector{Set{Selector}}, | ||
|
@@ -192,7 +192,7 @@ function Metadata() | |
Vector{VarName}(), | ||
Vector{UnitRange{Int}}(), | ||
vals, | ||
Vector{Distribution}(), | ||
Vector(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I pinged you @mohamed82008 because I figured you'd might have something useful to say about this change:) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm I think it's fine. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can also change it to a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In fact it might be better to have a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If there are no perf considerations, then IMO we should just remove it completely rather than introducing some other type because:
|
||
Vector{Set{Selector}}(), | ||
Vector{Int}(), | ||
flags, | ||
|
@@ -1096,7 +1096,7 @@ end | |
Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to | ||
the `VarInfo` `vi`. | ||
""" | ||
function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) | ||
function push!(vi::AbstractVarInfo, vn::VarName, r, dist) | ||
return push!(vi, vn, r, dist, Set{Selector}([])) | ||
end | ||
|
||
|
@@ -1108,12 +1108,10 @@ from a distribution `dist` to `VarInfo` `vi`. | |
|
||
The sampler is passed here to invalidate its cache where defined. | ||
""" | ||
function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler) | ||
function push!(vi::AbstractVarInfo, vn::VarName, r, dist, spl::Sampler) | ||
return push!(vi, vn, r, dist, spl.selector) | ||
end | ||
function push!( | ||
vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler | ||
) | ||
function push!(vi::AbstractVarInfo, vn::VarName, r, dist, spl::AbstractSampler) | ||
return push!(vi, vn, r, dist) | ||
end | ||
|
||
|
@@ -1123,10 +1121,10 @@ end | |
Push a new random variable `vn` with a sampled value `r` sampled with a sampler of | ||
selector `gid` from a distribution `dist` to `VarInfo` `vi`. | ||
""" | ||
function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) | ||
function push!(vi::AbstractVarInfo, vn::VarName, r, dist, gid::Selector) | ||
return push!(vi, vn, r, dist, Set([gid])) | ||
end | ||
function push!(vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}) | ||
function push!(vi::VarInfo, vn::VarName, r, dist, gidset::Set{Selector}) | ||
if vi isa UntypedVarInfo | ||
@assert ~(vn in keys(vi)) "[push!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" | ||
elseif vi isa TypedVarInfo | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not that I know much about DynamicPPL, but shouldn't this be a https://github.com/JuliaPackaging/Requires.jl include because measuretheory is quite a large package? 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this PR as it is now isn't intended to make it into master because of exactly this, i.e. MeasureTheory.jl is too large of a dependency (btw @cscherrer have you considered reducing the number of deps?). Instead this PR demonstrates how we could allow such extensions, e.g. in a DynamicPPLMeasureTheory.jl bridge package or even just adding these overloads in Turing.jl. I just added it here in case people wanted to try the branch out.
And Requires.jl isn't without it's own costs btw and will increase compilation times, so we're probably not going to be using Requires.jl in DPPL in the near future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting to learn that! Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cscherrer Would it be possible to define the
AbstractMeasure
interface functions in another lightweight package, so Turing only need to depend on the lightweight package? One possibility isAbstractPPL
..There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yebai that's a great idea. I had started
MeasureBase
for this a while back, but it's out of date now. I think this can be the core, and MeasureTheory can define the actual parameterized measures, etc.I like the concept of AbstractPPL, but I still need to understand better what it would look like to recast Soss in a way that uses this. Maybe we should have a call about this some time after JuliaCon?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing some dependencies, or at least making it possible to depend on MeasureTheory.jl without all the extras, would be dope @cscherrer 🎉