@@ -42,6 +42,307 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
42
42
return keys (c. info. varname_to_symbol)
43
43
end
44
44
45
+ # this is copied from Turing.jl, `stats` field is omitted as it is never used
46
+ struct Transition{T,F}
47
+ θ:: T
48
+ lp:: F
49
+ end
50
+
51
+ function Transition (model:: DynamicPPL.Model , vi:: DynamicPPL.VarInfo )
52
+ return Transition (getparams (model, vi), DynamicPPL. getlogp (vi))
53
+ end
54
+
55
+ # a copy of Turing.Inference.getparams
56
+ getparams (model, t) = t. θ
57
+ function getparams (model:: DynamicPPL.Model , vi:: DynamicPPL.VarInfo )
58
+ # NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used.
59
+ # Unfortunately, using `invlink` can cause issues in scenarios where the constraints
60
+ # of the parameters change depending on the realizations. Hence we have to use
61
+ # `values_as_in_model`, which re-runs the model and extracts the parameters
62
+ # as they are seen in the model, i.e. in the constrained space. Moreover,
63
+ # this means that the code below will work both of linked and invlinked `vi`.
64
+ # Ref: https://github.com/TuringLang/Turing.jl/issues/2195
65
+ # NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
66
+ vals = DynamicPPL. values_as_in_model (model, deepcopy (vi))
67
+
68
+ # Obtain an iterator over the flattened parameter names and values.
69
+ iters = map (DynamicPPL. varname_and_value_leaves, keys (vals), values (vals))
70
+
71
+ # Materialize the iterators and concatenate.
72
+ return mapreduce (collect, vcat, iters)
73
+ end
74
+
75
+ function _params_to_array (model:: DynamicPPL.Model , ts:: Vector )
76
+ names_set = DynamicPPL. OrderedCollections. OrderedSet {DynamicPPL.VarName} ()
77
+ # Extract the parameter names and values from each transition.
78
+ dicts = map (ts) do t
79
+ nms_and_vs = getparams (model, t)
80
+ nms = map (first, nms_and_vs)
81
+ vs = map (last, nms_and_vs)
82
+ for nm in nms
83
+ push! (names_set, nm)
84
+ end
85
+ # Convert the names and values to a single dictionary.
86
+ return DynamicPPL. OrderedCollections. OrderedDict (zip (nms, vs))
87
+ end
88
+ names = collect (names_set)
89
+ vals = [
90
+ get (dicts[i], key, missing ) for i in eachindex (dicts), (j, key) in enumerate (names)
91
+ ]
92
+
93
+ return names, vals
94
+ end
95
+
96
+ """
97
+
98
+ predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
99
+
100
+ Execute `model` conditioned on each sample in `chain`, and return the resulting `Chains`.
101
+
102
+ If `include_all` is `false`, the returned `Chains` will contain only those variables
103
+ sampled/not present in `chain`.
104
+
105
+ # Details
106
+ Internally calls `Turing.Inference.transitions_from_chain` to obtained the samples
107
+ and then converts these into a `Chains` object using `AbstractMCMC.bundle_samples`.
108
+
109
+ # Example
110
+ ```jldoctest
111
+ julia> using AbstractMCMC, AdvancedHMC, DynamicPPL, ForwardDiff;
112
+ [ Info: [Turing]: progress logging is disabled globally
113
+
114
+ julia> @model function linear_reg(x, y, σ = 0.1)
115
+ β ~ Normal(0, 1)
116
+
117
+ for i ∈ eachindex(y)
118
+ y[i] ~ Normal(β * x[i], σ)
119
+ end
120
+ end;
121
+
122
+ julia> σ = 0.1; f(x) = 2 * x + 0.1 * randn();
123
+
124
+ julia> Δ = 0.1; xs_train = 0:Δ:10; ys_train = f.(xs_train);
125
+
126
+ julia> xs_test = [10 + Δ, 10 + 2 * Δ]; ys_test = f.(xs_test);
127
+
128
+ julia> m_train = linear_reg(xs_train, ys_train, σ);
129
+
130
+ julia> n_train_logdensity_function = DynamicPPL.LogDensityFunction(m_train, DynamicPPL.VarInfo(m_train));
131
+
132
+ julia> chain_lin_reg = AbstractMCMC.sample(n_train_logdensity_function, NUTS(0.65), 200; chain_type=MCMCChains.Chains, param_names=[:β]);
133
+ ┌ Info: Found initial step size
134
+ └ ϵ = 0.003125
135
+
136
+ julia> m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ);
137
+
138
+ julia> predictions = predict(m_test, chain_lin_reg)
139
+ Object of type Chains, with data of type 100×2×1 Array{Float64,3}
140
+
141
+ Iterations = 1:100
142
+ Thinning interval = 1
143
+ Chains = 1
144
+ Samples per chain = 100
145
+ parameters = y[1], y[2]
146
+
147
+ 2-element Array{ChainDataFrame,1}
148
+
149
+ Summary Statistics
150
+ parameters mean std naive_se mcse ess r_hat
151
+ ────────── ─────── ────── ──────── ─────── ──────── ──────
152
+ y[1] 20.1974 0.1007 0.0101 missing 101.0711 0.9922
153
+ y[2] 20.3867 0.1062 0.0106 missing 101.4889 0.9903
154
+
155
+ Quantiles
156
+ parameters 2.5% 25.0% 50.0% 75.0% 97.5%
157
+ ────────── ─────── ─────── ─────── ─────── ───────
158
+ y[1] 20.0342 20.1188 20.2135 20.2588 20.4188
159
+ y[2] 20.1870 20.3178 20.3839 20.4466 20.5895
160
+
161
+
162
+ julia> ys_pred = vec(mean(Array(group(predictions, :y)); dims = 1));
163
+
164
+ julia> sum(abs2, ys_test - ys_pred) ≤ 0.1
165
+ true
166
+ ```
167
+ """
168
+ function DynamicPPL. predict (
169
+ rng:: DynamicPPL.Random.AbstractRNG ,
170
+ model:: DynamicPPL.Model ,
171
+ chain:: MCMCChains.Chains ;
172
+ include_all= false ,
173
+ )
174
+ # Don't need all the diagnostics
175
+ chain_parameters = MCMCChains. get_sections (chain, :parameters )
176
+
177
+ spl = DynamicPPL. SampleFromPrior ()
178
+
179
+ # Sample transitions using `spl` conditioned on values in `chain`
180
+ transitions = transitions_from_chain (rng, model, chain_parameters; sampler= spl)
181
+
182
+ # Let the Turing internals handle everything else for you
183
+ chain_result = reduce (
184
+ MCMCChains. chainscat,
185
+ [
186
+ _bundle_samples (transitions[:, chain_idx], model, spl) for
187
+ chain_idx in 1 : size (transitions, 2 )
188
+ ],
189
+ )
190
+
191
+ parameter_names = if include_all
192
+ MCMCChains. names (chain_result, :parameters )
193
+ else
194
+ filter (
195
+ k -> ! (k in MCMCChains. names (chain_parameters, :parameters )),
196
+ names (chain_result, :parameters ),
197
+ )
198
+ end
199
+
200
+ return chain_result[parameter_names]
201
+ end
202
+
203
+ getlogp (t:: Transition ) = t. lp
204
+
205
+ function get_transition_extras (ts:: AbstractVector{<:Transition} )
206
+ valmat = reshape ([getlogp (t) for t in ts], :, 1 )
207
+ return [:lp ], valmat
208
+ end
209
+
210
+ function names_values (extra_data:: AbstractVector{<:NamedTuple{names}} ) where {names}
211
+ values = [getfield (data, name) for data in extra_data, name in names]
212
+ return collect (names), values
213
+ end
214
+
215
+ function names_values (xs:: AbstractVector{<:NamedTuple} )
216
+ # Obtain all parameter names.
217
+ names_set = Set {Symbol} ()
218
+ for x in xs
219
+ for k in keys (x)
220
+ push! (names_set, k)
221
+ end
222
+ end
223
+ names_unique = collect (names_set)
224
+
225
+ # Extract all values as matrix.
226
+ values = [haskey (x, name) ? x[name] : missing for x in xs, name in names_unique]
227
+
228
+ return names_unique, values
229
+ end
230
+
231
+ getlogevidence (transitions, sampler, state) = missing
232
+
233
+ # this is copied from Turing.jl/src/mcmc/Inference.jl, types are more restrictive (removed types that are defined in Turing)
234
+ # the function is simplified, so that unused arguments are removed
235
+ function _bundle_samples (
236
+ ts:: Vector{<:Transition} , model:: DynamicPPL.Model , spl:: DynamicPPL.SampleFromPrior
237
+ )
238
+ # Convert transitions to array format.
239
+ # Also retrieve the variable names.
240
+ varnames, vals = _params_to_array (model, ts)
241
+ varnames_symbol = map (Symbol, varnames)
242
+
243
+ # Get the values of the extra parameters in each transition.
244
+ extra_params, extra_values = get_transition_extras (ts)
245
+
246
+ # Extract names & construct param array.
247
+ nms = [varnames_symbol; extra_params]
248
+ parray = hcat (vals, extra_values)
249
+
250
+ # Set up the info tuple.
251
+ info = NamedTuple ()
252
+
253
+ info = merge (
254
+ info,
255
+ (
256
+ varname_to_symbol= DynamicPPL. OrderedCollections. OrderedDict (
257
+ zip (varnames, varnames_symbol)
258
+ ),
259
+ ),
260
+ )
261
+
262
+ # Conretize the array before giving it to MCMCChains.
263
+ parray = MCMCChains. concretize (parray)
264
+
265
+ # Chain construction.
266
+ chain = MCMCChains. Chains (parray, nms, (internals= extra_params,))
267
+
268
+ return chain
269
+ end
270
+
271
+ """
272
+ transitions_from_chain(
273
+ [rng::AbstractRNG,]
274
+ model::Model,
275
+ chain::MCMCChains.Chains;
276
+ sampler = DynamicPPL.SampleFromPrior()
277
+ )
278
+
279
+ Execute `model` conditioned on each sample in `chain`, and return resulting transitions.
280
+
281
+ The returned transitions are represented in a `Vector{<:Turing.Inference.Transition}`.
282
+
283
+ # Details
284
+
285
+ In a bit more detail, the process is as follows:
286
+ 1. For every `sample` in `chain`
287
+ 1. For every `variable` in `sample`
288
+ 1. Set `variable` in `model` to its value in `sample`
289
+ 2. Execute `model` with variables fixed as above, sampling variables NOT present
290
+ in `chain` using `SampleFromPrior`
291
+ 3. Return sampled variables and log-joint
292
+
293
+ # Example
294
+ ```julia-repl
295
+ julia> using Turing
296
+
297
+ julia> @model function demo()
298
+ m ~ Normal(0, 1)
299
+ x ~ Normal(m, 1)
300
+ end;
301
+
302
+ julia> m = demo();
303
+
304
+ julia> chain = Chains(randn(2, 1, 1), ["m"]); # 2 samples of `m`
305
+
306
+ julia> transitions = Turing.Inference.transitions_from_chain(m, chain);
307
+
308
+ julia> [Turing.Inference.getlogp(t) for t in transitions] # extract the logjoints
309
+ 2-element Array{Float64,1}:
310
+ -3.6294991938628374
311
+ -2.5697948166987845
312
+
313
+ julia> [first(t.θ.x) for t in transitions] # extract samples for `x`
314
+ 2-element Array{Array{Float64,1},1}:
315
+ [-2.0844148956440796]
316
+ [-1.704630494695469]
317
+ ```
318
+ """
319
+ function transitions_from_chain (
320
+ model:: DynamicPPL.Model , chain:: MCMCChains.Chains ; kwargs...
321
+ )
322
+ return transitions_from_chain (Random. default_rng (), model, chain; kwargs... )
323
+ end
324
+
325
+ function transitions_from_chain (
326
+ rng:: DynamicPPL.Random.AbstractRNG ,
327
+ model:: DynamicPPL.Model ,
328
+ chain:: MCMCChains.Chains ;
329
+ sampler= DynamicPPL. SampleFromPrior (),
330
+ )
331
+ vi = DynamicPPL. VarInfo (model)
332
+
333
+ iters = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
334
+ transitions = map (iters) do (sample_idx, chain_idx)
335
+ # Set variables present in `chain` and mark those NOT present in chain to be resampled.
336
+ DynamicPPL. setval_and_resample! (vi, chain, sample_idx, chain_idx)
337
+ model (rng, vi, sampler)
338
+
339
+ # Convert `VarInfo` into `NamedTuple` and save.
340
+ Transition (model, vi)
341
+ end
342
+
343
+ return transitions
344
+ end
345
+
45
346
"""
46
347
generated_quantities(model::Model, chain::MCMCChains.Chains)
47
348
0 commit comments