Skip to content

Commit

Permalink
stop using PredictiveSample type
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Nov 29, 2024
1 parent 53b6749 commit fcd7c3d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
18 changes: 11 additions & 7 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ function DynamicPPL.predict(
chain_result = reduce(
MCMCChains.chainscat,
[
_bundle_samples(predictive_samples[:, chain_idx]) for
_bundle_predictive_samples(predictive_samples[:, chain_idx]) for
chain_idx in 1:size(predictive_samples, 2)
],
)
Expand All @@ -143,11 +143,11 @@ function DynamicPPL.predict(
return chain_result[parameter_names]
end

function _params_to_array(ts::Vector)
function _params_to_array(predictive_samples)
names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()

dicts = map(ts) do t
nms_and_vs = t.values
dicts = map(predictive_samples) do t
nms_and_vs = t[:values]
nms = map(first, nms_and_vs)
vs = map(last, nms_and_vs)
for nm in nms
Expand All @@ -164,11 +164,15 @@ function _params_to_array(ts::Vector)
return names, vals
end

function _bundle_samples(ts::Vector{<:DynamicPPL.PredictiveSample})
varnames, vals = _params_to_array(ts)
function _bundle_predictive_samples(
predictive_samples::AbstractArray{
<:DynamicPPL.OrderedCollections.OrderedDict{Symbol,Any}
},
)
varnames, vals = _params_to_array(predictive_samples)
varnames_symbol = map(Symbol, varnames)
extra_params = [:lp]
extra_values = reshape([t.logp for t in ts], :, 1)
extra_values = reshape([t[:logp] for t in predictive_samples], :, 1)
nms = [varnames_symbol; extra_params]
parray = hcat(vals, extra_values)
parray = MCMCChains.concretize(parray)
Expand Down
11 changes: 4 additions & 7 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1203,11 +1203,6 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC
end
end

struct PredictiveSample{T,F}
values::T
logp::F
end

"""
predict([rng::AbstractRNG,] model::Model, chain; include_all=false)
Expand All @@ -1228,13 +1223,15 @@ function predict(
varinfos::AbstractArray{<:AbstractVarInfo};
include_all=false,
)
predictive_samples = Array{PredictiveSample}(undef, size(varinfos))
predictive_samples = similar(varinfos, OrderedDict{Symbol,Any})
for i in eachindex(varinfos)
model(rng, varinfos[i], SampleFromPrior())
vals = values_as_in_model(model, varinfos[i])
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
params = mapreduce(collect, vcat, iters)
predictive_samples[i] = PredictiveSample(params, getlogp(varinfos[i]))
predictive_samples[i] = OrderedDict(
:values => params, :logp => getlogp(varinfos[i])
)
end
return predictive_samples
end
Expand Down

0 comments on commit fcd7c3d

Please sign in to comment.