@@ -2,6 +2,7 @@ module DynamicPPLMCMCChainsExt
22
33using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC
44using MCMCChains: MCMCChains
5+ using Statistics: mean
56
67_has_varname_to_symbol (info:: NamedTuple{names} ) where {names} = :varname_to_symbol in names
78
@@ -140,6 +141,44 @@ function AbstractMCMC.to_samples(
140141 end
141142end
142143
144+ function AbstractMCMC. bundle_samples (
145+ ts:: Vector{<:DynamicPPL.ParamsWithStats} ,
146+ model:: DynamicPPL.Model ,
147+ spl:: AbstractMCMC.AbstractSampler ,
148+ state,
149+ chain_type:: Type{MCMCChains.Chains} ;
150+ save_state= false ,
151+ stats= missing ,
152+ sort_chain= false ,
153+ discard_initial= 0 ,
154+ thinning= 1 ,
155+ kwargs... ,
156+ )
157+ # Construct the 'bare' chain first
158+ bare_chain = AbstractMCMC. from_samples (MCMCChains. Chains, reshape (ts, :, 1 ))
159+
160+ # Add additional MCMC-specific info
161+ info = bare_chain. info
162+ if save_state
163+ info = merge (info, (model= model, sampler= spl, samplerstate= state))
164+ end
165+ if ! ismissing (stats)
166+ info = merge (info, (start_time= stats. start, stop_time= stats. stop))
167+ end
168+
169+ # Reconstruct the chain with the extra information
170+ # Yeah, this is quite ugly. Blame MCMCChains.
171+ chain = MCMCChains. Chains (
172+ bare_chain. value. data,
173+ names (bare_chain),
174+ bare_chain. name_map;
175+ info= info,
176+ start= discard_initial + 1 ,
177+ thin= thinning,
178+ )
179+ return sort_chain ? sort (chain) : chain
180+ end
181+
143182"""
144183 predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
145184
0 commit comments