From cb557e2bb7281a04409bc690833ce0ac8bdcc6ef Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 5 Sep 2024 07:05:08 +0100 Subject: [PATCH] some draft, not working yet --- ext/DynamicPPLMCMCChainsExt.jl | 57 ++++++++++++++++++++++++++++++++-- src/DynamicPPL.jl | 2 ++ test/predict.jl | 57 ++++++++++++++++++++++++++++++++++ 3 files changed, 114 insertions(+), 2 deletions(-) create mode 100644 test/predict.jl diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 7c7fb216d..c0e78ffb9 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -1,10 +1,10 @@ module DynamicPPLMCMCChainsExt if isdefined(Base, :get_extension) - using DynamicPPL: DynamicPPL + using DynamicPPL: DynamicPPL, Random using MCMCChains: MCMCChains else - using ..DynamicPPL: DynamicPPL + using ..DynamicPPL: DynamicPPL, Random using ..MCMCChains: MCMCChains end @@ -59,4 +59,57 @@ function DynamicPPL.generated_quantities( end end +function DynamicPPL.predict( + model::DynamicPPL.Model, chain::MCMCChains.Chains; include_all=false +) + return predict(Random.default_rng(), model, chain; include_all=include_all) +end +function DynamicPPL.predict( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + chain::MCMCChains.Chains; + include_all=false, +) + params_only_chain = MCMCChains.get_sections(chain, :parameters) + + varname_to_symbol = if :varname_to_symbol in keys(params_only_chain.info) + # the mapping is introduced in Turing by + # https://github.com/TuringLang/Turing.jl/commit/8d8416ac6c7363c6003ee6ea1fbaac26b4fc8dc3 + params_only_chain.info[:varname_to_symbol] + else + # if not using Turing, then we need to construct the mapping ourselves + Dict{DynamicPPL.VarName,Symbol}([ + DynamicPPL.@varname($sym) => sym for + sym in params_only_chain.name_map.parameters + ]) + end + + num_of_chains = size(params_only_chain, 3) + # num_of_params = + num_of_samples = size(params_only_chain, 1) + + predictions = [] + for chain_idx in 1:num_of_chains + predictions_single_chain = [] + for sample_idx in 1:num_of_samples + d_to_fix = OrderedDict{DynamicPPL.VarName,Any}() + + # construct the dictionary to fix the model + for (vn, sym) in varname_to_symbol + d_to_fix[vn] = params_only_chain[sample_idx, sym, chain_idx] + end + + # fix the model and sample from it + fixed_model = DynamicPPL.fix(model, d_to_fix) + predictive_sample = rand(rng, fixed_model) + + # TODO: Turing version uses `Transition` and `bundle_samples` to form new chains: is it worth it to move Transition to AbstractMCMC? + push!(predictions, predictive_sample) + end + push!(predictions, predictions_single_chain) + end + + return predictions +end + end diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index eb027b45b..b9e158e17 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -150,6 +150,8 @@ end # Used here and overloaded in Turing function getspace end +function predict end + """ AbstractVarInfo diff --git a/test/predict.jl b/test/predict.jl new file mode 100644 index 000000000..534414646 --- /dev/null +++ b/test/predict.jl @@ -0,0 +1,57 @@ +module TestPredict + +using Test +using DynamicPPL +using AbstractMCMC +using MCMCChains +using Distributions +using Random +using LogDensityProblemsAD +using AdvancedHMC +using Tapir +using ForwardDiff + +@model function linear_reg(x, y, σ=0.1) + β ~ Normal(0, 1) + for i in eachindex(y) + y[i] ~ Normal(β * x[i], σ) + end +end + +@model function linear_reg_vec(x, y, σ=0.1) + β ~ Normal(0, 1) + return y ~ MvNormal(β .* x, σ^2 * I) +end + +f(x) = 2 * x + 0.1 * randn() + +Δ = 0.1 +xs_train = 0:Δ:10 +ys_train = f.(xs_train) +xs_test = [10 + Δ, 10 + 2 * Δ] +ys_test = f.(xs_test) + +model = linear_reg(xs_train, ys_train) + +m_lin_reg = linear_reg(xs_train, ys_train) +ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model)) +ad_ldf = LogDensityProblemsAD.ADgradient(Val(:Tapir), ldf; safety_on=false) +chain = AbstractMCMC.sample( + ad_ldf, AdvancedHMC.NUTS(0.6), 1000; chain_type=MCMCChains.Chains, param_names=[:β] +) + +DynamicPPL.predict(test_model, chain) + +# LKJ example +@model demo_lkj() = x ~ LKJCholesky(2, 1.0) + +model = demo_lkj() + +ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.SimpleVarInfo(model)) +ad_ldf = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ldf) + +chain = AbstractMCMC.sample( + ad_ldf, AdvancedHMC.NUTS(0.6), 1000; chain_type=MCMCChains.Chains, param_names=[:Σ, :x] +) + +end # module