Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move predict from Turing, implemented using fix #651

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
module DynamicPPLMCMCChainsExt

if isdefined(Base, :get_extension)
using DynamicPPL: DynamicPPL
using DynamicPPL: DynamicPPL, Random
using MCMCChains: MCMCChains
else
using ..DynamicPPL: DynamicPPL
using ..DynamicPPL: DynamicPPL, Random
using ..MCMCChains: MCMCChains
end

Expand Down Expand Up @@ -190,4 +190,57 @@
return varname_pairs
end

function DynamicPPL.predict(

Check warning on line 193 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L193

Added line #L193 was not covered by tests
model::DynamicPPL.Model, chain::MCMCChains.Chains; include_all=false
)
return predict(Random.default_rng(), model, chain; include_all=include_all)

Check warning on line 196 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L196

Added line #L196 was not covered by tests
end
function DynamicPPL.predict(

Check warning on line 198 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L198

Added line #L198 was not covered by tests
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
chain::MCMCChains.Chains;
include_all=false,
)
params_only_chain = MCMCChains.get_sections(chain, :parameters)

Check warning on line 204 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L204

Added line #L204 was not covered by tests

varname_to_symbol = if :varname_to_symbol in keys(params_only_chain.info)

Check warning on line 206 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L206

Added line #L206 was not covered by tests
# the mapping is introduced in Turing by
# https://github.com/TuringLang/Turing.jl/commit/8d8416ac6c7363c6003ee6ea1fbaac26b4fc8dc3
params_only_chain.info[:varname_to_symbol]

Check warning on line 209 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L209

Added line #L209 was not covered by tests
else
# if not using Turing, then we need to construct the mapping ourselves
Dict{DynamicPPL.VarName,Symbol}([
DynamicPPL.@varname($sym) => sym for

Check warning on line 213 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L212-L213

Added lines #L212 - L213 were not covered by tests
sym in params_only_chain.name_map.parameters
])
end

num_of_chains = size(params_only_chain, 3)

Check warning on line 218 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L218

Added line #L218 was not covered by tests
# num_of_params =
num_of_samples = size(params_only_chain, 1)

Check warning on line 220 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L220

Added line #L220 was not covered by tests

predictions = []
for chain_idx in 1:num_of_chains
predictions_single_chain = []
for sample_idx in 1:num_of_samples
d_to_fix = OrderedDict{DynamicPPL.VarName,Any}()

Check warning on line 226 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L222-L226

Added lines #L222 - L226 were not covered by tests

# construct the dictionary to fix the model
for (vn, sym) in varname_to_symbol
d_to_fix[vn] = params_only_chain[sample_idx, sym, chain_idx]

Check warning on line 230 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L229-L230

Added lines #L229 - L230 were not covered by tests
end

# fix the model and sample from it
fixed_model = DynamicPPL.fix(model, d_to_fix)
predictive_sample = rand(rng, fixed_model)

Check warning on line 235 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L234-L235

Added lines #L234 - L235 were not covered by tests

# TODO: Turing version uses `Transition` and `bundle_samples` to form new chains: is it worth it to move Transition to AbstractMCMC?
push!(predictions, predictive_sample)

Check warning on line 238 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L238

Added line #L238 was not covered by tests
end
push!(predictions, predictions_single_chain)

Check warning on line 240 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L240

Added line #L240 was not covered by tests
end

return predictions

Check warning on line 243 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L243

Added line #L243 was not covered by tests
end

end
2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ end
# Used here and overloaded in Turing
function getspace end

function predict end

"""
AbstractVarInfo

Expand Down
57 changes: 57 additions & 0 deletions test/predict.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
module TestPredict

using Test
using DynamicPPL
using AbstractMCMC
using MCMCChains
using Distributions
using Random
using LogDensityProblemsAD
using AdvancedHMC
using Tapir
using ForwardDiff

@model function linear_reg(x, y, σ=0.1)
β ~ Normal(0, 1)
for i in eachindex(y)
y[i] ~ Normal(β * x[i], σ)
end
end

@model function linear_reg_vec(x, y, σ=0.1)
β ~ Normal(0, 1)
return y ~ MvNormal(β .* x, σ^2 * I)
end

f(x) = 2 * x + 0.1 * randn()

Δ = 0.1
xs_train = 0:Δ:10
ys_train = f.(xs_train)
xs_test = [10 + Δ, 10 + 2 * Δ]
ys_test = f.(xs_test)

model = linear_reg(xs_train, ys_train)

m_lin_reg = linear_reg(xs_train, ys_train)
ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model))
ad_ldf = LogDensityProblemsAD.ADgradient(Val(:Tapir), ldf; safety_on=false)
chain = AbstractMCMC.sample(
ad_ldf, AdvancedHMC.NUTS(0.6), 1000; chain_type=MCMCChains.Chains, param_names=[:β]
)

DynamicPPL.predict(test_model, chain)

# LKJ example
@model demo_lkj() = x ~ LKJCholesky(2, 1.0)

model = demo_lkj()

ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.SimpleVarInfo(model))
ad_ldf = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ldf)

chain = AbstractMCMC.sample(
ad_ldf, AdvancedHMC.NUTS(0.6), 1000; chain_type=MCMCChains.Chains, param_names=[:Σ, :x]
)

end # module
Loading