Skip to content

Commit 1c1c907

Browse files
committed
move predict from Turing
1 parent f5890a1 commit 1c1c907

File tree

6 files changed

+477
-1
lines changed

6 files changed

+477
-1
lines changed

ext/DynamicPPLMCMCChainsExt.jl

+301
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,307 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
4242
return keys(c.info.varname_to_symbol)
4343
end
4444

45+
# this is copied from Turing.jl, `stats` field is omitted as it is never used
46+
struct Transition{T,F}
47+
θ::T
48+
lp::F
49+
end
50+
51+
function Transition(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
52+
return Transition(getparams(model, vi), DynamicPPL.getlogp(vi))
53+
end
54+
55+
# a copy of Turing.Inference.getparams
56+
getparams(model, t) = t.θ
57+
function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
58+
# NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used.
59+
# Unfortunately, using `invlink` can cause issues in scenarios where the constraints
60+
# of the parameters change depending on the realizations. Hence we have to use
61+
# `values_as_in_model`, which re-runs the model and extracts the parameters
62+
# as they are seen in the model, i.e. in the constrained space. Moreover,
63+
# this means that the code below will work both of linked and invlinked `vi`.
64+
# Ref: https://github.com/TuringLang/Turing.jl/issues/2195
65+
# NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
66+
vals = DynamicPPL.values_as_in_model(model, deepcopy(vi))
67+
68+
# Obtain an iterator over the flattened parameter names and values.
69+
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
70+
71+
# Materialize the iterators and concatenate.
72+
return mapreduce(collect, vcat, iters)
73+
end
74+
75+
function _params_to_array(model::DynamicPPL.Model, ts::Vector)
76+
names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
77+
# Extract the parameter names and values from each transition.
78+
dicts = map(ts) do t
79+
nms_and_vs = getparams(model, t)
80+
nms = map(first, nms_and_vs)
81+
vs = map(last, nms_and_vs)
82+
for nm in nms
83+
push!(names_set, nm)
84+
end
85+
# Convert the names and values to a single dictionary.
86+
return DynamicPPL.OrderedCollections.OrderedDict(zip(nms, vs))
87+
end
88+
names = collect(names_set)
89+
vals = [
90+
get(dicts[i], key, missing) for i in eachindex(dicts), (j, key) in enumerate(names)
91+
]
92+
93+
return names, vals
94+
end
95+
96+
"""
97+
98+
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
99+
100+
Execute `model` conditioned on each sample in `chain`, and return the resulting `Chains`.
101+
102+
If `include_all` is `false`, the returned `Chains` will contain only those variables
103+
sampled/not present in `chain`.
104+
105+
# Details
106+
Internally calls `Turing.Inference.transitions_from_chain` to obtained the samples
107+
and then converts these into a `Chains` object using `AbstractMCMC.bundle_samples`.
108+
109+
# Example
110+
```jldoctest
111+
julia> using AbstractMCMC, AdvancedHMC, DynamicPPL, ForwardDiff;
112+
[ Info: [Turing]: progress logging is disabled globally
113+
114+
julia> @model function linear_reg(x, y, σ = 0.1)
115+
β ~ Normal(0, 1)
116+
117+
for i ∈ eachindex(y)
118+
y[i] ~ Normal(β * x[i], σ)
119+
end
120+
end;
121+
122+
julia> σ = 0.1; f(x) = 2 * x + 0.1 * randn();
123+
124+
julia> Δ = 0.1; xs_train = 0:Δ:10; ys_train = f.(xs_train);
125+
126+
julia> xs_test = [10 + Δ, 10 + 2 * Δ]; ys_test = f.(xs_test);
127+
128+
julia> m_train = linear_reg(xs_train, ys_train, σ);
129+
130+
julia> n_train_logdensity_function = DynamicPPL.LogDensityFunction(m_train, DynamicPPL.VarInfo(m_train));
131+
132+
julia> chain_lin_reg = AbstractMCMC.sample(n_train_logdensity_function, NUTS(0.65), 200; chain_type=MCMCChains.Chains, param_names=[:β]);
133+
┌ Info: Found initial step size
134+
└ ϵ = 0.003125
135+
136+
julia> m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ);
137+
138+
julia> predictions = predict(m_test, chain_lin_reg)
139+
Object of type Chains, with data of type 100×2×1 Array{Float64,3}
140+
141+
Iterations = 1:100
142+
Thinning interval = 1
143+
Chains = 1
144+
Samples per chain = 100
145+
parameters = y[1], y[2]
146+
147+
2-element Array{ChainDataFrame,1}
148+
149+
Summary Statistics
150+
parameters mean std naive_se mcse ess r_hat
151+
────────── ─────── ────── ──────── ─────── ──────── ──────
152+
y[1] 20.1974 0.1007 0.0101 missing 101.0711 0.9922
153+
y[2] 20.3867 0.1062 0.0106 missing 101.4889 0.9903
154+
155+
Quantiles
156+
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
157+
────────── ─────── ─────── ─────── ─────── ───────
158+
y[1] 20.0342 20.1188 20.2135 20.2588 20.4188
159+
y[2] 20.1870 20.3178 20.3839 20.4466 20.5895
160+
161+
162+
julia> ys_pred = vec(mean(Array(group(predictions, :y)); dims = 1));
163+
164+
julia> sum(abs2, ys_test - ys_pred) ≤ 0.1
165+
true
166+
```
167+
"""
168+
function DynamicPPL.predict(
169+
rng::DynamicPPL.Random.AbstractRNG,
170+
model::DynamicPPL.Model,
171+
chain::MCMCChains.Chains;
172+
include_all=false,
173+
)
174+
# Don't need all the diagnostics
175+
chain_parameters = MCMCChains.get_sections(chain, :parameters)
176+
177+
spl = DynamicPPL.SampleFromPrior()
178+
179+
# Sample transitions using `spl` conditioned on values in `chain`
180+
transitions = transitions_from_chain(rng, model, chain_parameters; sampler=spl)
181+
182+
# Let the Turing internals handle everything else for you
183+
chain_result = reduce(
184+
MCMCChains.chainscat,
185+
[
186+
_bundle_samples(transitions[:, chain_idx], model, spl) for
187+
chain_idx in 1:size(transitions, 2)
188+
],
189+
)
190+
191+
parameter_names = if include_all
192+
MCMCChains.names(chain_result, :parameters)
193+
else
194+
filter(
195+
k -> !(k in MCMCChains.names(chain_parameters, :parameters)),
196+
names(chain_result, :parameters),
197+
)
198+
end
199+
200+
return chain_result[parameter_names]
201+
end
202+
203+
getlogp(t::Transition) = t.lp
204+
205+
function get_transition_extras(ts::AbstractVector{<:Transition})
206+
valmat = reshape([getlogp(t) for t in ts], :, 1)
207+
return [:lp], valmat
208+
end
209+
210+
function names_values(extra_data::AbstractVector{<:NamedTuple{names}}) where {names}
211+
values = [getfield(data, name) for data in extra_data, name in names]
212+
return collect(names), values
213+
end
214+
215+
function names_values(xs::AbstractVector{<:NamedTuple})
216+
# Obtain all parameter names.
217+
names_set = Set{Symbol}()
218+
for x in xs
219+
for k in keys(x)
220+
push!(names_set, k)
221+
end
222+
end
223+
names_unique = collect(names_set)
224+
225+
# Extract all values as matrix.
226+
values = [haskey(x, name) ? x[name] : missing for x in xs, name in names_unique]
227+
228+
return names_unique, values
229+
end
230+
231+
getlogevidence(transitions, sampler, state) = missing
232+
233+
# this is copied from Turing.jl/src/mcmc/Inference.jl, types are more restrictive (removed types that are defined in Turing)
234+
# the function is simplified, so that unused arguments are removed
235+
function _bundle_samples(
236+
ts::Vector{<:Transition}, model::DynamicPPL.Model, spl::DynamicPPL.SampleFromPrior
237+
)
238+
# Convert transitions to array format.
239+
# Also retrieve the variable names.
240+
varnames, vals = _params_to_array(model, ts)
241+
varnames_symbol = map(Symbol, varnames)
242+
243+
# Get the values of the extra parameters in each transition.
244+
extra_params, extra_values = get_transition_extras(ts)
245+
246+
# Extract names & construct param array.
247+
nms = [varnames_symbol; extra_params]
248+
parray = hcat(vals, extra_values)
249+
250+
# Set up the info tuple.
251+
info = NamedTuple()
252+
253+
info = merge(
254+
info,
255+
(
256+
varname_to_symbol=DynamicPPL.OrderedCollections.OrderedDict(
257+
zip(varnames, varnames_symbol)
258+
),
259+
),
260+
)
261+
262+
# Conretize the array before giving it to MCMCChains.
263+
parray = MCMCChains.concretize(parray)
264+
265+
# Chain construction.
266+
chain = MCMCChains.Chains(parray, nms, (internals=extra_params,))
267+
268+
return chain
269+
end
270+
271+
"""
272+
transitions_from_chain(
273+
[rng::AbstractRNG,]
274+
model::Model,
275+
chain::MCMCChains.Chains;
276+
sampler = DynamicPPL.SampleFromPrior()
277+
)
278+
279+
Execute `model` conditioned on each sample in `chain`, and return resulting transitions.
280+
281+
The returned transitions are represented in a `Vector{<:Turing.Inference.Transition}`.
282+
283+
# Details
284+
285+
In a bit more detail, the process is as follows:
286+
1. For every `sample` in `chain`
287+
1. For every `variable` in `sample`
288+
1. Set `variable` in `model` to its value in `sample`
289+
2. Execute `model` with variables fixed as above, sampling variables NOT present
290+
in `chain` using `SampleFromPrior`
291+
3. Return sampled variables and log-joint
292+
293+
# Example
294+
```julia-repl
295+
julia> using Turing
296+
297+
julia> @model function demo()
298+
m ~ Normal(0, 1)
299+
x ~ Normal(m, 1)
300+
end;
301+
302+
julia> m = demo();
303+
304+
julia> chain = Chains(randn(2, 1, 1), ["m"]); # 2 samples of `m`
305+
306+
julia> transitions = Turing.Inference.transitions_from_chain(m, chain);
307+
308+
julia> [Turing.Inference.getlogp(t) for t in transitions] # extract the logjoints
309+
2-element Array{Float64,1}:
310+
-3.6294991938628374
311+
-2.5697948166987845
312+
313+
julia> [first(t.θ.x) for t in transitions] # extract samples for `x`
314+
2-element Array{Array{Float64,1},1}:
315+
[-2.0844148956440796]
316+
[-1.704630494695469]
317+
```
318+
"""
319+
function transitions_from_chain(
320+
model::DynamicPPL.Model, chain::MCMCChains.Chains; kwargs...
321+
)
322+
return transitions_from_chain(Random.default_rng(), model, chain; kwargs...)
323+
end
324+
325+
function transitions_from_chain(
326+
rng::DynamicPPL.Random.AbstractRNG,
327+
model::DynamicPPL.Model,
328+
chain::MCMCChains.Chains;
329+
sampler=DynamicPPL.SampleFromPrior(),
330+
)
331+
vi = DynamicPPL.VarInfo(model)
332+
333+
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
334+
transitions = map(iters) do (sample_idx, chain_idx)
335+
# Set variables present in `chain` and mark those NOT present in chain to be resampled.
336+
DynamicPPL.setval_and_resample!(vi, chain, sample_idx, chain_idx)
337+
model(rng, vi, sampler)
338+
339+
# Convert `VarInfo` into `NamedTuple` and save.
340+
Transition(model, vi)
341+
end
342+
343+
return transitions
344+
end
345+
45346
"""
46347
generated_quantities(model::Model, chain::MCMCChains.Chains)
47348

src/DynamicPPL.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using AbstractPPL
55
using Bijectors
66
using Compat
77
using Distributions
8-
using OrderedCollections: OrderedDict
8+
using OrderedCollections: OrderedCollections, OrderedDict
99

1010
using AbstractMCMC: AbstractMCMC
1111
using ADTypes: ADTypes

src/model.jl

+14
Original file line numberDiff line numberDiff line change
@@ -1203,6 +1203,20 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC
12031203
end
12041204
end
12051205

1206+
"""
1207+
predict([rng::AbstractRNG,] model::Model, chain; include_all=false)
1208+
1209+
Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample
1210+
in `chain`, and return the resulting `Chains`. At the moment, `chain` must be a `MCMCChains.Chains` object.
1211+
1212+
If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by
1213+
the samples in `chain`. This is useful when you want to sample only new variables from the posterior
1214+
predictive distribution.
1215+
"""
1216+
function predict(model::Model, chain; include_all=false)
1217+
return predict(Random.default_rng(), model, chain; include_all)
1218+
end
1219+
12061220
"""
12071221
generated_quantities(model::Model, parameters::NamedTuple)
12081222
generated_quantities(model::Model, values, keys)

test/Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
44
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
5+
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
56
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
67
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
78
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
@@ -32,6 +33,7 @@ AbstractMCMC = "5"
3233
AbstractPPL = "0.8.4, 0.9"
3334
Accessors = "0.1"
3435
Bijectors = "0.13.9, 0.14"
36+
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6"
3537
Combinatorics = "1"
3638
Compat = "4.3.0"
3739
Distributions = "0.25"

0 commit comments

Comments
 (0)