diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 8c598a6a8..7c7fb216d 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -8,12 +8,6 @@ else using ..MCMCChains: MCMCChains end -_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names -function _check_varname_indexing(c::MCMCChains.Chains) - return DynamicPPL.supports_varname_indexing(c) || - error("Chains do not support indexing using $vn.") -end - # Load state from a `Chains`: By convention, it is stored in `:samplerstate` metadata function DynamicPPL.loadstate(chain::MCMCChains.Chains) if !haskey(chain.info, :samplerstate) @@ -26,10 +20,17 @@ function DynamicPPL.loadstate(chain::MCMCChains.Chains) return chain.info[:samplerstate] end -# A few methods needed. +_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names + function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains) return _has_varname_to_symbol(chain.info) end + +function _check_varname_indexing(c::MCMCChains.Chains) + return DynamicPPL.supports_varname_indexing(c) || + error("Chains do not support indexing using `VarName`s.") +end + function DynamicPPL.getindex_varname( c::MCMCChains.Chains, sample_idx, vn::DynamicPPL.VarName, chain_idx )