From 4071615fdd3121c56592299d64ff91629b7511db Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 24 Nov 2022 11:24:00 +0100 Subject: [PATCH 1/2] Add InferenceObjects as dependency --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index ebe9555a6..384cb32b8 100644 --- a/Project.toml +++ b/Project.toml @@ -17,6 +17,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +InferenceObjects = "b5cf5a8d-e756-4ee3-b014-01d49d192c00" Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" @@ -49,6 +50,7 @@ DocStringExtensions = "0.8, 0.9" DynamicPPL = "0.21" EllipticalSliceSampling = "0.5, 1" ForwardDiff = "0.10.3" +InferenceObjects = "0.2" Libtask = "0.6.7, 0.7" LogDensityProblems = "0.12, 1" MCMCChains = "5" From e92be0dac58d3d5dc6a4b351dd50286b2b76a482 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 24 Nov 2022 11:24:33 +0100 Subject: [PATCH 2/2] Add AbstractMCMC / InferenceObjects code --- src/inference/Inference.jl | 85 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 2024cff2c..9cd15b392 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -33,6 +33,7 @@ import EllipticalSliceSampling import LogDensityProblems import Random import MCMCChains +using InferenceObjects: InferenceObjects import StatsBase: predict export InferenceAlgorithm, @@ -67,6 +68,15 @@ export InferenceAlgorithm, predict, isgibbscomponent +const turing_inferencedata_key_map = ( + hamiltonian_energy = :energy, + hamiltonian_energy_error = :energy_error, + is_adapt = :tune, + max_hamiltonian_energy_error = :max_energy_error, + nom_step_size = :step_size_nom, + numerical_error = :diverging, +) + ####################### # Sampler abstraction # ####################### @@ -418,6 +428,81 @@ end DynamicPPL.loadstate(chain::MCMCChains.Chains) = chain.info[:samplerstate] +# Default InferenceObjects constructor +# This is type piracy! +function AbstractMCMC.bundle_samples( + ts::Vector, + model::AbstractModel, + spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior}, + state, + chain_type::Type{InferenceObjects.InferenceData}; + group = spl isa SampleFromPrior ? :prior : :posterior, + save_state = false, + stats = missing, + dims=(;), + coords=(;), + kwargs... +) + sample = map(t -> map(v -> length(v[1]) == 1 ? v[1][1] : v[1], getparams(t)), ts) + sample_stats = map(_rename_sample_stats ∘ metadata, ts) + + # Set up the info tuple. + attrs = OrderedDict{String,Any}() + if save_state + attrs["model"] = model + attrs["sampler"] = spl + attrs["samplerstate"] = state + end + + # Merge in the timing info, if available + if !ismissing(stats) + attrs["start_time"] = stats.start + attrs["stop_time"] = stats.stop + end + + # Get the average or final log evidence, if it exists. + le = getlogevidence(ts, spl, state) + if !ismissing(le) + attrs["log_evidence"] = le + end + + # identify if this is posterior or prior + sample_stats_group = group === :prior ? :sample_stats_prior : :sample_stats + + # InferenceData construction. + idata = InferenceObjects.convert_to_inference_data( + [sample]; + group=group, + sample_stats_group => [sample_stats], + attrs=attrs, + dims=dims, + coords=coords, + ) + return idata +end + +function AbstractMCMC.chainsstack(c::AbstractVector{<:InferenceObjects.InferenceData}) + nchains = length(c) + nchains == 1 && return c[1] + groups = map(keys(first(c))) do k + k => AbstractMCMC.chainsstack(map(idata -> idata[k], c)) + end + return InferenceObjects.InferenceData(; groups...) +end +function AbstractMCMC.chainsstack(c::AbstractVector{<:InferenceObjects.Dataset}) + nchains = length(c) + nchains == 1 && return c[1] + # TODO: gather our metadata into vectors instead of replacing + group = cat(c...; dims=:chain) + # give each chain a different index + return InferenceObjects.DimensionalData.set(group, :chain => Base.OneTo(nchains)) +end + +function _rename_sample_stats(stats::NamedTuple) + new_keys = map(k -> get(turing_inferencedata_key_map, k, k), keys(stats)) + return NamedTuple{new_keys}(values(stats)) +end + ####################################### # Concrete algorithm implementations. # #######################################