Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add InferenceObjects integration #465

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.22.1"
version = "0.22.2"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -9,29 +9,43 @@ BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
InferenceObjects = "b5cf5a8d-e756-4ee3-b014-01d49d192c00"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
InferenceObjects = "b5cf5a8d-e756-4ee3-b014-01d49d192c00"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[extensions]
DynamicPPLInferenceObjectsExt = ["DimensionalData", "InferenceObjects", "StatsBase"]

[compat]
AbstractMCMC = "2, 3.0, 4"
AbstractPPL = "0.5.3"
BangBang = "0.3"
Bijectors = "0.11, 0.12"
ChainRulesCore = "0.9.7, 0.10, 1"
ConstructionBase = "1"
DimensionalData = "0.23.1, 0.24"
Distributions = "0.23.8, 0.24, 0.25"
DocStringExtensions = "0.8, 0.9"
InferenceObjects = "0.3"
LogDensityProblems = "2"
MacroTools = "0.5.6"
OrderedCollections = "1"
Setfield = "0.7.1, 0.8, 1"
StatsBase = "0.32, 0.33"
ZygoteRules = "0.2"
julia = "1.6"
17 changes: 17 additions & 0 deletions ext/DynamicPPLInferenceObjectsExt/DynamicPPLInferenceObjectsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module DynamicPPLInferenceObjectsExt
yebai marked this conversation as resolved.
Show resolved Hide resolved

using AbstractPPL: AbstractPPL
using DimensionalData: DimensionalData, Dimensions, LookupArrays
using DynamicPPL: DynamicPPL
using InferenceObjects: InferenceObjects
using Random: Random
using StatsBase: StatsBase

include("utils.jl")
include("varinfo.jl")
include("condition.jl")
include("generated_quantities.jl")
include("predict.jl")
include("pointwise_loglikelihoods.jl")

end
10 changes: 10 additions & 0 deletions ext/DynamicPPLInferenceObjectsExt/condition.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
function AbstractPPL.condition(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole file is type-piracy

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it rather be an extension to AbstractPPL? Then it would not be type piracy (or rather, only the one that extensions were designed for).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would then we also make InferenceObjects a full dependency of AbstractPPL for v1.8 and earlier?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AbstractPPL is supposed to be extremely lightweight (https://github.com/TuringLang/AbstractPPL.jl/blob/main/Project.toml), so I don't think that's an attractive option. Maybe an optional dependency with Requires or a full-blown subpackage would be better (one can avoid loading it in newer Julia versions).

context::AbstractPPL.AbstractContext, data::InferenceObjects.Dataset
)
return AbstractPPL.condition(context, NamedTuple(data))
end
function AbstractPPL.condition(
context::AbstractPPL.AbstractContext, data::InferenceObjects.InferenceData
)
return AbstractPPL.condition(context, data.posterior)
end
28 changes: 28 additions & 0 deletions ext/DynamicPPLInferenceObjectsExt/generated_quantities.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
function DynamicPPL.generated_quantities(
Copy link
Member Author

@sethaxen sethaxen Feb 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since DynamicPPL places no restrictions on what types these can be, and users might have intermediate types that don't fit the InferenceData format, it would be nice to support users specifying an output type. Either that, or we should document the constraints upon the returned objects in a model.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing in particular that should be possible for the user to specify somehow is the variables to include in the chain. Sometimes you might want to return something that you really don't want to end up in the chain, e.g. the full solution of a ODE solve (this can be useful for checking convergence as a post-processing step, but you usually don't want the full solution in your chain).

mod::DynamicPPL.Model, data::InferenceObjects.Dataset; coords=(;), kwargs...
)
sample_dims = Dimensions.dims(data, (:draw, :chain))
diminds = DimensionalData.DimIndices(sample_dims)
values = map(diminds) do dims
DynamicPPL.generated_quantities(mod, data[dims...])
end
coords = merge(coords, dims2coords(sample_dims))
return InferenceObjects.convert_to_dataset(
collect(eachcol(values)); coords=coords, kwargs...
)
end

function DynamicPPL.generated_quantities(
mod::DynamicPPL.Model, idata::InferenceObjects.InferenceData; kwargs...
)
new_groups = Dict{Symbol,InferenceObjects.Dataset}()
for k in (:posterior, :prior)
if haskey(idata, k)
data = idata[k]
new_groups[k] = merge(
DynamicPPL.generated_quantities(mod, data; kwargs...), data
)
end
end
return merge(idata, InferenceObjects.InferenceData(; new_groups...))
end
35 changes: 35 additions & 0 deletions ext/DynamicPPLInferenceObjectsExt/pointwise_loglikelihoods.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
function DynamicPPL.pointwise_loglikelihoods(
model::DynamicPPL.Model, data::InferenceObjects.Dataset; coords=(;), kwargs...
)
# Get the data by executing the model once
vi = DynamicPPL.VarInfo(model)
context = DynamicPPL.PointwiseLikelihoodContext(Dict{String,Vector{Float64}}())

iters = Iterators.product(axes(data, :draw), axes(data, :chain))
for (draw, chain) in iters
# Update the values
DynamicPPL.setval!(vi, data, draw, chain)

# Execute model
model(vi, context)
end

ndraws = size(data, :draw)
nchains = size(data, :chain)
# TODO: optionally post-process idata to convert index variables like Symbol("y[1]") to Symbol("y")
Copy link
Member Author

@sethaxen sethaxen Feb 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty important for the results to be useful with ArviZ but is seemingly non-trivial so will wait for a future PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do actually have some functionality do perform this now:)

There is https://turinglang.org/DynamicPPL.jl/dev/api/#DynamicPPL.value_iterator_from_chain which makes use of this under the hood; in particular, to get the "innermost" VarName, you can use varname_leaves

DynamicPPL.jl/src/utils.jl

Lines 820 to 844 in 1ebe8bc

"""
varname_leaves(vn::VarName, val)
Return an iterator over all varnames that are represented by `vn` on `val`.
# Examples
```jldoctest
julia> using DynamicPPL: varname_leaves
julia> foreach(println, varname_leaves(@varname(x), rand(2)))
x[1]
x[2]
julia> foreach(println, varname_leaves(@varname(x[1:2]), rand(2)))
x[1:2][1]
x[1:2][2]
julia> x = (y = 1, z = [[2.0], [3.0]]);
julia> foreach(println, varname_leaves(@varname(x), x))
x.y
x.z[1][1]
x.z[2][1]
```
"""

Using this you can take the varname from the varinfo + a value, and then determine the varname-leaves.

loglikelihoods = Dict(
varname => reshape(logliks, ndraws, nchains) for
(varname, logliks) in context.loglikelihoods
)
isempty(loglikelihoods) && return nothing
coords = merge(coords, dims2coords(Dimensions.dims(data, (:draw, :chain))))
return InferenceObjects.convert_to_dataset(
loglikelihoods; group=:log_likelihood, coords=coords, kwargs...
)
end
function DynamicPPL.pointwise_loglikelihoods(
model::DynamicPPL.Model, data::InferenceObjects.InferenceData; kwargs...
)
log_likelihood = DynamicPPL.pointwise_loglikelihoods(model, data.posterior; kwargs...)
return merge(data, InferenceObjects.InferenceData(; log_likelihood=log_likelihood))
end
72 changes: 72 additions & 0 deletions ext/DynamicPPLInferenceObjectsExt/predict.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
function StatsBase.predict(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

predict is not part of the DynamicPPL API, isn't it? At least I don't remember it, I think we just use rand for obtaining samples from a model (and you can of course condition the model on some data before sampling). I think extensions should not (must not?) add new API to a package, so if it's doing the same as rand on an conditioned model maybe just implement rand instead? And open an issue about adding predict to the API (maybe it could be defined just as rand on a conditioned model)?

rng::Random.AbstractRNG,
model::DynamicPPL.Model,
data::InferenceObjects.Dataset;
coords=(;),
kwargs...,
)
spl = DynamicPPL.SampleFromPrior()
vi = DynamicPPL.VarInfo(model)
iters = Iterators.product(axes(data, :draw), axes(data, :chain))
values = map(iters) do (draw_id, chain_id)
# Set variables present in `data` and mark those NOT present in data to be resampled.
DynamicPPL.setval_and_resample!(vi, data, draw_id, chain_id)
model(rng, vi, spl)
return map(concretize, DynamicPPL.values_as(vi, NamedTuple))
end
coords = merge(coords, dims2coords(Dimensions.dims(data, (:draw, :chain))))
predictions = InferenceObjects.convert_to_dataset(
collect(eachcol(values)); group=:posterior_predictive, coords=coords, kwargs...
)
pred_keys = filter(∉(keys(data)), keys(predictions))
isempty(pred_keys) && return nothing
return predictions[pred_keys]
end
function StatsBase.predict(
model::DynamicPPL.Model, data::InferenceObjects.Dataset; kwargs...
)
return StatsBase.predict(Random.default_rng(), model, data; kwargs...)
end

function StatsBase.predict(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
data::InferenceObjects.InferenceData;
coords=(;),
kwargs...,
)
if haskey(data, :observed_data)
coords = merge(coords, dims2coords(Dimensions.dims(data.observed_data)))
end
new_groups = Dict{Symbol,InferenceObjects.Dataset}()
if haskey(data, :posterior)
posterior_predictive = StatsBase.predict(
rng, model, data.posterior; coords=coords, kwargs...
)
if posterior_predictive === nothing
@warn "No predictions were made based on posterior. Has the model been deconditioned?"
else
new_groups[:posterior_predictive] = posterior_predictive
end
end
if haskey(data, :prior)
prior_predictive = StatsBase.predict(
rng, model, data.prior; coords=coords, kwargs...
)
if prior_predictive === nothing
@warn "No predictions were made based on prior. Has the model been deconditioned?"
else
new_groups[:prior_predictive] = prior_predictive
end
end
if !(haskey(data, :posterior) || haskey(data, :prior))
@warn "No posterior or prior found in InferenceData. Returning unmodified input."
return data
end
return merge(data, InferenceObjects.InferenceData(; new_groups...))
end
function StatsBase.predict(
model::DynamicPPL.Model, data::InferenceObjects.InferenceData; kwargs...
)
return StatsBase.predict(Random.default_rng(), model, data; kwargs...)
end
21 changes: 21 additions & 0 deletions ext/DynamicPPLInferenceObjectsExt/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# adapted from MCMCChains
function isconcretetype_recursive(T)
return isconcretetype(T) && (eltype(T) === T || isconcretetype_recursive(eltype(T)))
end

concretize(x) = x
function concretize(x::AbstractArray)
if isconcretetype_recursive(typeof(x))
return x
else
xnew = map(concretize, x)
T = mapreduce(typeof, promote_type, xnew; init=Union{})
if T <: eltype(xnew) && T !== Union{}
return convert(AbstractArray{T}, xnew)
else
return xnew
end
end
end

dims2coords(dims) = NamedTuple{Dimensions.dim2key(dims)}(Dimensions.lookup(dims))
14 changes: 14 additions & 0 deletions ext/DynamicPPLInferenceObjectsExt/varinfo.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
function DynamicPPL.setval!(
vi::DynamicPPL.VarInfo, data::InferenceObjects.Dataset, draw_id::Int, chain_id::Int
)
return DynamicPPL.setval!(vi, data[draw=draw_id, chain=chain_id])
end

function DynamicPPL.setval_and_resample!(
vi::DynamicPPL.VarInfoOrThreadSafeVarInfo,
data::InferenceObjects.Dataset,
draw_id::Int,
chain_id::Int,
)
return DynamicPPL.setval_and_resample!(vi, data[draw=draw_id, chain=chain_id])
end
6 changes: 6 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import Base:
keys,
haskey

const EXTENSIONS_SUPPORTED = isdefined(Base, :get_extension)

# VarInfo
export AbstractVarInfo,
VarInfo,
Expand Down Expand Up @@ -166,4 +168,8 @@ include("test_utils.jl")
include("transforming.jl")
include("logdensityfunction.jl")

if !EXTENSIONS_SUPPORTED
include("../ext/DynamicPPLInferenceObjectsExt/DynamicPPLInferenceObjectsExt.jl")
end

end # module