Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add converter from Turing using both Chains and Model #133

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
dd7abf8
Add Turing to extras
sethaxen May 19, 2021
dff0ed1
Add initial implementation of from_turing
sethaxen May 19, 2021
55cbc78
Handle non-array eltype constraints
sethaxen May 19, 2021
d25e592
Apply suggestions from code review
sethaxen May 20, 2021
31f902b
Repair predictive model code
sethaxen May 20, 2021
e1da410
Run formatter
sethaxen May 20, 2021
454c23e
Constrain type of model
sethaxen May 20, 2021
98511f4
Add model name to attributes
sethaxen May 20, 2021
191cfd7
Support specifying groups to not be generated
sethaxen May 20, 2021
2c6cc32
Constrain type of rng
sethaxen May 20, 2021
134a6a1
Add docstring
sethaxen May 20, 2021
9e3edb7
Document from_turing
sethaxen May 20, 2021
3408476
Also generate constant_data
sethaxen May 20, 2021
180e560
Make code more modular
sethaxen May 20, 2021
85ddd69
Add Turing tests
sethaxen May 20, 2021
4f50d7a
Force library to be Turing
sethaxen May 20, 2021
1b58b6b
Overload setattribute! for InferenceData
sethaxen May 20, 2021
12ca874
Add function to add inference library info
sethaxen May 20, 2021
a7bb79f
Globally use library utility
sethaxen May 20, 2021
42c8823
Test library utility for Turing
sethaxen May 20, 2021
842447f
Increment version number
sethaxen May 20, 2021
c9b4562
Repair Turing example
sethaxen May 20, 2021
e0d9ae3
Don't import Turing's exports
sethaxen May 20, 2021
a90df63
Return correct variable name
sethaxen May 20, 2021
ea273ef
Indent wrapped lines
sethaxen May 20, 2021
6e93fa7
Update quickstart.md
sethaxen May 20, 2021
92e8a25
Run formatter
sethaxen May 20, 2021
1581eb0
Deep copy arguments
sethaxen May 20, 2021
b863095
Capture status in string
sethaxen May 20, 2021
13ad03a
Better handle adding library info
sethaxen May 21, 2021
a188e27
Run formatter
sethaxen May 21, 2021
ee474ae
Add attribute and library tests
sethaxen May 21, 2021
5ee652a
Extract observed_data from model
sethaxen May 21, 2021
09c37da
Update example
sethaxen May 21, 2021
a05bfd2
Update quickstart
sethaxen May 21, 2021
bf474a1
Fix test
sethaxen May 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

Expand All @@ -25,6 +26,7 @@ PyCall = "1.91.2"
PyPlot = "2.8.2"
Requires = "0.5.2, 1.0"
StatsBase = "0.32, 0.33"
Turing = "0.15"
julia = "^1"

[extras]
Expand All @@ -33,6 +35,7 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[targets]
test = ["CmdStan", "MCMCChains", "MonteCarloMeasurements", "Random", "Test"]
test = ["CmdStan", "MCMCChains", "MonteCarloMeasurements", "Random", "Test", "Turing"]
6 changes: 6 additions & 0 deletions src/ArviZ.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ __precompile__()
module ArviZ

using Base: @__doc__
using Random
using Requires
using REPL
using NamedTupleTools
Expand Down Expand Up @@ -76,6 +77,7 @@ export InferenceData,
from_dict,
from_cmdstan,
from_mcmcchains,
from_turing,
concat,
concat!

Expand Down Expand Up @@ -109,6 +111,10 @@ function __init__()
import .MCMCChains: Chains, sections
include("mcmcchains.jl")
end
@require Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" begin
import .Turing: Turing
include("turing.jl")
end
return nothing
end

Expand Down
83 changes: 83 additions & 0 deletions src/turing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
function from_turing(
chns=nothing;
model=nothing,
rng=Random.default_rng(),
nchains=ndraws = chns isa Turing.MCMCChains.Chains ? last(size(chns)) : 1,
ndraws=chns isa Turing.MCMCChains.Chains ? first(size(chns)) : 1_000,
library=Turing,
observed_data=nothing,
constant_data=nothing,
posterior_predictive=nothing,
prior=nothing,
prior_predictive=nothing,
log_likelihood=nothing,
kwargs...,
)
groups = Dict{Symbol,Any}(
:observed_data => observed_data,
:constant_data => constant_data,
:posterior_predictive => posterior_predictive,
:prior => prior,
:prior_predictive => prior_predictive,
:log_likelihood => log_likelihood,
)
model === nothing && return from_mcmcchains(chns; library=library, groups..., kwargs...)
if groups[:prior] === nothing
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
groups[:prior] = reduce(
Turing.chainscat,
map(
_ -> Turing.sample(rng, model, Turing.Prior(), ndraws; progress=false),
1:nchains,
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
),
)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
end

groups[:observed_data] === nothing &&
return from_mcmcchains(chns; library=library, groups..., kwargs...)

observed_data = groups[:observed_data]
data_var_names = Set(
observed_data isa Dict ? Symbol.(keys(observed_data)) : propertynames(observed_data)
)

if groups[:constant_data] === nothing
groups[:constant_data] = NamedTuple(
filter(p -> first(p) ∉ data_var_names, pairs(model.args))
)
end

# Instantiate the predictive model
args_pred = NamedTuple(
k => k in data_var_names ? similar(v, Missing) : v for (k, v) in pairs(model.args)
)
model_predict = Turing.DynamicPPL.Model(model.name, model.f, args_pred, model.defaults)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved

# and then sample!
if groups[:prior_predictive] === nothing && groups[:prior] isa Turing.MCMCChains.Chains
groups[:prior_predictive] = Turing.predict(rng, model_predict, groups[:prior])
end

if chns isa Turing.MCMCChains.Chains
if groups[:posterior_predictive] === nothing && chns isa Turing.MCMCChains.Chains
groups[:posterior_predictive] = Turing.predict(rng, model_predict, chns)
end

if groups[:log_likelihood] === nothing &&
groups[:posterior_predictive] isa MCMCChains.Chains
loglikelihoods = Turing.pointwise_loglikelihoods(
model, Turing.MCMCChains.get_sections(chns, :parameters)
)

# Bundle loglikelihoods into a `Chains` object so we can reuse our own variable
# name parsing
pred_names = string.(keys(groups[:posterior_predictive]))
loglikelihoods_vals = getindex.(Ref(loglikelihoods), pred_names)
loglikelihoods_arr = permutedims(cat(loglikelihoods_vals...; dims=3), (1, 3, 2))
groups[:log_likelihood] = Turing.MCMCChains.Chains(
loglikelihoods_arr, pred_names
)
end
end

return from_mcmcchains(chns; library=Turing, groups..., kwargs...)
end
9 changes: 7 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,17 @@ _asstringkeydict(x::Dict{String}) = x

enforce_stat_eltypes(stats) = convert_to_eltypes(stats, sample_stats_eltypes)

_convert_to_eltype(v::AbstractArray, T) = convert(Array{T}, v)
_convert_to_eltype(v, T) = convert(T, v)
function convert_to_eltypes(data::Dict, data_eltypes)
return Dict(k => convert(Array{get(data_eltypes, k, eltype(v))}, v) for (k, v) in data)
return Dict(
k => _convert_to_eltype(v, get(data_eltypes, k, eltype(v))) for (k, v) in data
)
end
function convert_to_eltypes(data::NamedTuple, data_eltypes)
return NamedTuple(
k => convert(Array{get(data_eltypes, k, eltype(v))}, v) for (k, v) in pairs(data)
k => _convert_to_eltype(v, get(data_eltypes, k, eltype(v))) for
(k, v) in pairs(data)
)
end

Expand Down