Skip to content

Commit 2fad97b

Browse files
committed
Implement bundle_samples for ParamsWithStats -> MCMCChains
1 parent 2bf5b18 commit 2fad97b

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
4141
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
4242
DynamicPPLForwardDiffExt = ["ForwardDiff"]
4343
DynamicPPLJETExt = ["JET"]
44-
DynamicPPLMCMCChainsExt = ["MCMCChains"]
44+
DynamicPPLMCMCChainsExt = ["MCMCChains", "Statistics"]
4545
DynamicPPLMarginalLogDensitiesExt = ["MarginalLogDensities"]
4646
DynamicPPLMooncakeExt = ["Mooncake"]
4747

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module DynamicPPLMCMCChainsExt
22

33
using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC
44
using MCMCChains: MCMCChains
5+
using Statistics: mean
56

67
_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names
78

@@ -140,6 +141,44 @@ function AbstractMCMC.to_samples(
140141
end
141142
end
142143

144+
function AbstractMCMC.bundle_samples(
145+
ts::Vector{<:DynamicPPL.ParamsWithStats},
146+
model::DynamicPPL.Model,
147+
spl::AbstractMCMC.AbstractSampler,
148+
state,
149+
chain_type::Type{MCMCChains.Chains};
150+
save_state=false,
151+
stats=missing,
152+
sort_chain=false,
153+
discard_initial=0,
154+
thinning=1,
155+
kwargs...,
156+
)
157+
# Construct the 'bare' chain first
158+
bare_chain = AbstractMCMC.from_samples(MCMCChains.Chains, reshape(ts, :, 1))
159+
160+
# Add additional MCMC-specific info
161+
info = bare_chain.info
162+
if save_state
163+
info = merge(info, (model=model, sampler=spl, samplerstate=state))
164+
end
165+
if !ismissing(stats)
166+
info = merge(info, (start_time=stats.start, stop_time=stats.stop))
167+
end
168+
169+
# Reconstruct the chain with the extra information
170+
# Yeah, this is quite ugly. Blame MCMCChains.
171+
chain = MCMCChains.Chains(
172+
bare_chain.value.data,
173+
names(bare_chain),
174+
bare_chain.name_map;
175+
info=info,
176+
start=discard_initial + 1,
177+
thin=thinning,
178+
)
179+
return sort_chain ? sort(chain) : chain
180+
end
181+
143182
"""
144183
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
145184

0 commit comments

Comments
 (0)