diff --git a/Project.toml b/Project.toml index 0b8dc927b..d765ec2f5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.22.1" +version = "0.22.2" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -9,17 +9,28 @@ BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +InferenceObjects = "b5cf5a8d-e756-4ee3-b014-01d49d192c00" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" +[weakdeps] +DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" +InferenceObjects = "b5cf5a8d-e756-4ee3-b014-01d49d192c00" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" + +[extensions] +DynamicPPLInferenceObjectsExt = ["DimensionalData", "InferenceObjects", "StatsBase"] + [compat] AbstractMCMC = "2, 3.0, 4" AbstractPPL = "0.5.3" @@ -27,11 +38,14 @@ BangBang = "0.3" Bijectors = "0.11, 0.12" ChainRulesCore = "0.9.7, 0.10, 1" ConstructionBase = "1" +DimensionalData = "0.23.1, 0.24" Distributions = "0.23.8, 0.24, 0.25" DocStringExtensions = "0.8, 0.9" +InferenceObjects = "0.3" LogDensityProblems = "2" MacroTools = "0.5.6" OrderedCollections = "1" Setfield = "0.7.1, 0.8, 1" +StatsBase = "0.32, 0.33" ZygoteRules = "0.2" julia = "1.6" diff --git a/ext/DynamicPPLInferenceObjectsExt/DynamicPPLInferenceObjectsExt.jl b/ext/DynamicPPLInferenceObjectsExt/DynamicPPLInferenceObjectsExt.jl new file mode 100644 index 000000000..e6ffbd4b3 --- /dev/null +++ b/ext/DynamicPPLInferenceObjectsExt/DynamicPPLInferenceObjectsExt.jl @@ -0,0 +1,17 @@ +module DynamicPPLInferenceObjectsExt + +using AbstractPPL: AbstractPPL +using DimensionalData: DimensionalData, Dimensions, LookupArrays +using DynamicPPL: DynamicPPL +using InferenceObjects: InferenceObjects +using Random: Random +using StatsBase: StatsBase + +include("utils.jl") +include("varinfo.jl") +include("condition.jl") +include("generated_quantities.jl") +include("predict.jl") +include("pointwise_loglikelihoods.jl") + +end diff --git a/ext/DynamicPPLInferenceObjectsExt/condition.jl b/ext/DynamicPPLInferenceObjectsExt/condition.jl new file mode 100644 index 000000000..ac40d84b2 --- /dev/null +++ b/ext/DynamicPPLInferenceObjectsExt/condition.jl @@ -0,0 +1,10 @@ +function AbstractPPL.condition( + context::AbstractPPL.AbstractContext, data::InferenceObjects.Dataset +) + return AbstractPPL.condition(context, NamedTuple(data)) +end +function AbstractPPL.condition( + context::AbstractPPL.AbstractContext, data::InferenceObjects.InferenceData +) + return AbstractPPL.condition(context, data.posterior) +end diff --git a/ext/DynamicPPLInferenceObjectsExt/generated_quantities.jl b/ext/DynamicPPLInferenceObjectsExt/generated_quantities.jl new file mode 100644 index 000000000..1da2842e5 --- /dev/null +++ b/ext/DynamicPPLInferenceObjectsExt/generated_quantities.jl @@ -0,0 +1,28 @@ +function DynamicPPL.generated_quantities( + mod::DynamicPPL.Model, data::InferenceObjects.Dataset; coords=(;), kwargs... +) + sample_dims = Dimensions.dims(data, (:draw, :chain)) + diminds = DimensionalData.DimIndices(sample_dims) + values = map(diminds) do dims + DynamicPPL.generated_quantities(mod, data[dims...]) + end + coords = merge(coords, dims2coords(sample_dims)) + return InferenceObjects.convert_to_dataset( + collect(eachcol(values)); coords=coords, kwargs... + ) +end + +function DynamicPPL.generated_quantities( + mod::DynamicPPL.Model, idata::InferenceObjects.InferenceData; kwargs... +) + new_groups = Dict{Symbol,InferenceObjects.Dataset}() + for k in (:posterior, :prior) + if haskey(idata, k) + data = idata[k] + new_groups[k] = merge( + DynamicPPL.generated_quantities(mod, data; kwargs...), data + ) + end + end + return merge(idata, InferenceObjects.InferenceData(; new_groups...)) +end diff --git a/ext/DynamicPPLInferenceObjectsExt/pointwise_loglikelihoods.jl b/ext/DynamicPPLInferenceObjectsExt/pointwise_loglikelihoods.jl new file mode 100644 index 000000000..79c14a2d8 --- /dev/null +++ b/ext/DynamicPPLInferenceObjectsExt/pointwise_loglikelihoods.jl @@ -0,0 +1,35 @@ +function DynamicPPL.pointwise_loglikelihoods( + model::DynamicPPL.Model, data::InferenceObjects.Dataset; coords=(;), kwargs... +) + # Get the data by executing the model once + vi = DynamicPPL.VarInfo(model) + context = DynamicPPL.PointwiseLikelihoodContext(Dict{String,Vector{Float64}}()) + + iters = Iterators.product(axes(data, :draw), axes(data, :chain)) + for (draw, chain) in iters + # Update the values + DynamicPPL.setval!(vi, data, draw, chain) + + # Execute model + model(vi, context) + end + + ndraws = size(data, :draw) + nchains = size(data, :chain) + # TODO: optionally post-process idata to convert index variables like Symbol("y[1]") to Symbol("y") + loglikelihoods = Dict( + varname => reshape(logliks, ndraws, nchains) for + (varname, logliks) in context.loglikelihoods + ) + isempty(loglikelihoods) && return nothing + coords = merge(coords, dims2coords(Dimensions.dims(data, (:draw, :chain)))) + return InferenceObjects.convert_to_dataset( + loglikelihoods; group=:log_likelihood, coords=coords, kwargs... + ) +end +function DynamicPPL.pointwise_loglikelihoods( + model::DynamicPPL.Model, data::InferenceObjects.InferenceData; kwargs... +) + log_likelihood = DynamicPPL.pointwise_loglikelihoods(model, data.posterior; kwargs...) + return merge(data, InferenceObjects.InferenceData(; log_likelihood=log_likelihood)) +end diff --git a/ext/DynamicPPLInferenceObjectsExt/predict.jl b/ext/DynamicPPLInferenceObjectsExt/predict.jl new file mode 100644 index 000000000..586c611de --- /dev/null +++ b/ext/DynamicPPLInferenceObjectsExt/predict.jl @@ -0,0 +1,72 @@ +function StatsBase.predict( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + data::InferenceObjects.Dataset; + coords=(;), + kwargs..., +) + spl = DynamicPPL.SampleFromPrior() + vi = DynamicPPL.VarInfo(model) + iters = Iterators.product(axes(data, :draw), axes(data, :chain)) + values = map(iters) do (draw_id, chain_id) + # Set variables present in `data` and mark those NOT present in data to be resampled. + DynamicPPL.setval_and_resample!(vi, data, draw_id, chain_id) + model(rng, vi, spl) + return map(concretize, DynamicPPL.values_as(vi, NamedTuple)) + end + coords = merge(coords, dims2coords(Dimensions.dims(data, (:draw, :chain)))) + predictions = InferenceObjects.convert_to_dataset( + collect(eachcol(values)); group=:posterior_predictive, coords=coords, kwargs... + ) + pred_keys = filter(∉(keys(data)), keys(predictions)) + isempty(pred_keys) && return nothing + return predictions[pred_keys] +end +function StatsBase.predict( + model::DynamicPPL.Model, data::InferenceObjects.Dataset; kwargs... +) + return StatsBase.predict(Random.default_rng(), model, data; kwargs...) +end + +function StatsBase.predict( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + data::InferenceObjects.InferenceData; + coords=(;), + kwargs..., +) + if haskey(data, :observed_data) + coords = merge(coords, dims2coords(Dimensions.dims(data.observed_data))) + end + new_groups = Dict{Symbol,InferenceObjects.Dataset}() + if haskey(data, :posterior) + posterior_predictive = StatsBase.predict( + rng, model, data.posterior; coords=coords, kwargs... + ) + if posterior_predictive === nothing + @warn "No predictions were made based on posterior. Has the model been deconditioned?" + else + new_groups[:posterior_predictive] = posterior_predictive + end + end + if haskey(data, :prior) + prior_predictive = StatsBase.predict( + rng, model, data.prior; coords=coords, kwargs... + ) + if prior_predictive === nothing + @warn "No predictions were made based on prior. Has the model been deconditioned?" + else + new_groups[:prior_predictive] = prior_predictive + end + end + if !(haskey(data, :posterior) || haskey(data, :prior)) + @warn "No posterior or prior found in InferenceData. Returning unmodified input." + return data + end + return merge(data, InferenceObjects.InferenceData(; new_groups...)) +end +function StatsBase.predict( + model::DynamicPPL.Model, data::InferenceObjects.InferenceData; kwargs... +) + return StatsBase.predict(Random.default_rng(), model, data; kwargs...) +end diff --git a/ext/DynamicPPLInferenceObjectsExt/utils.jl b/ext/DynamicPPLInferenceObjectsExt/utils.jl new file mode 100644 index 000000000..13f9b3db2 --- /dev/null +++ b/ext/DynamicPPLInferenceObjectsExt/utils.jl @@ -0,0 +1,21 @@ +# adapted from MCMCChains +function isconcretetype_recursive(T) + return isconcretetype(T) && (eltype(T) === T || isconcretetype_recursive(eltype(T))) +end + +concretize(x) = x +function concretize(x::AbstractArray) + if isconcretetype_recursive(typeof(x)) + return x + else + xnew = map(concretize, x) + T = mapreduce(typeof, promote_type, xnew; init=Union{}) + if T <: eltype(xnew) && T !== Union{} + return convert(AbstractArray{T}, xnew) + else + return xnew + end + end +end + +dims2coords(dims) = NamedTuple{Dimensions.dim2key(dims)}(Dimensions.lookup(dims)) diff --git a/ext/DynamicPPLInferenceObjectsExt/varinfo.jl b/ext/DynamicPPLInferenceObjectsExt/varinfo.jl new file mode 100644 index 000000000..6d2331ce7 --- /dev/null +++ b/ext/DynamicPPLInferenceObjectsExt/varinfo.jl @@ -0,0 +1,14 @@ +function DynamicPPL.setval!( + vi::DynamicPPL.VarInfo, data::InferenceObjects.Dataset, draw_id::Int, chain_id::Int +) + return DynamicPPL.setval!(vi, data[draw=draw_id, chain=chain_id]) +end + +function DynamicPPL.setval_and_resample!( + vi::DynamicPPL.VarInfoOrThreadSafeVarInfo, + data::InferenceObjects.Dataset, + draw_id::Int, + chain_id::Int, +) + return DynamicPPL.setval_and_resample!(vi, data[draw=draw_id, chain=chain_id]) +end diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 594084d66..74c9c06df 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -34,6 +34,8 @@ import Base: keys, haskey +const EXTENSIONS_SUPPORTED = isdefined(Base, :get_extension) + # VarInfo export AbstractVarInfo, VarInfo, @@ -166,4 +168,8 @@ include("test_utils.jl") include("transforming.jl") include("logdensityfunction.jl") +if !EXTENSIONS_SUPPORTED + include("../ext/DynamicPPLInferenceObjectsExt/DynamicPPLInferenceObjectsExt.jl") +end + end # module