Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CompatHelper: bump compat for "Turing" to "0.16" #216

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ authors = ["Vaibhavdixit02 <[email protected]>"]
version = "2.24.0"

[deps]
ApproxBayes = "f5f396d3-230c-5e07-80e6-9fadf06146cc"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Expand All @@ -26,12 +25,12 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StanSample = "c1514b29-d3a0-5178-b312-660c88baa699"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
ApproxBayes = "0.3"
DiffEqBase = "6.36"
DiffResults = "0.0.4, 1.0"
Distances = "0.8, 0.9, 0.10"
Expand All @@ -46,14 +45,13 @@ Missings = "0.4, 1.0"
ModelingToolkit = "5.6"
Optim = "0.19, 0.20, 0.21, 0.22, 1.0"
PDMats = "0.9, 0.10, 0.11"
ParameterizedFunctions = "5"
Parameters = "0.12"
RecursiveArrayTools = "1,2"
Reexport = "0.2, 1.0"
Requires = "0.5, 1.0"
StructArrays = "0.4, 0.5"
TransformVariables = "0.3, 0.4"
Turing = "0.12, 0.13, 0.14, 0.15"
Turing = "0.12, 0.13, 0.14, 0.15, 0.16"
julia = "1.3"

[extras]
Expand Down
16 changes: 5 additions & 11 deletions src/DiffEqBayes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,23 @@ using DocStringExtensions
using DiffEqBase, Distributions, Turing, MacroTools
using RecursiveArrayTools, ModelingToolkit
using Parameters, Distributions, Optim, Requires
using Distances, ApproxBayes, DocStringExtensions, Random
using Distances, DocStringExtensions, Random, StanSample

STANDARD_PROB_GENERATOR(prob,p) = remake(prob;u0=eltype(p).(prob.u0),p=p)
STANDARD_PROB_GENERATOR(prob::EnsembleProblem,p) = EnsembleProblem(remake(prob.prob;u0=eltype(p).(prob.prob.u0),p=p))

include("turing_inference.jl")
include("abc_inference.jl")
# include("abc_inference.jl")
include("stan_string.jl")
include("stan_inference.jl")

function __init__()
@require CmdStan="593b3428-ca2f-500c-ae53-031589ec8ddd" begin
using .CmdStan
include("stan_inference.jl")
include("stan_string.jl")
export stan_inference, stan_string
end

@require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" begin
using .DynamicHMC, TransformVariables, LogDensityProblems
include("dynamichmc_inference.jl")
export dynamichmc_inference
end
end

export turing_inference, abc_inference

export turing_inference, stan_inference ,abc_inference
end # module
2 changes: 1 addition & 1 deletion src/dynamichmc_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,5 @@ function dynamichmc_inference(problem::DiffEqBase.DEProblem, algorithm, t, data,
ℓ = TransformedLogDensity(trans, P)
∇ℓ = LogDensityProblems.ADgradient(AD_gradient_kind, ℓ)
results = mcmc_with_warmup(rng, ∇ℓ, num_samples; mcmc_kwargs...)
merge((posterior = transform.(Ref(trans), results.chain), ), results)
merge((posterior = TransformVariables.transform.(Ref(trans), results.chain), ), results)
end
78 changes: 40 additions & 38 deletions src/stan_inference.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
struct StanModel{M,R,C,N}
struct StanResult{M,R,C}
model::M
return_code::R
chains::C
cnames::N
end

function Base.show(io::IO, mime::MIME"text/plain", res::StanResult)
show(io, mime, res.chains)
end

struct StanODEData
end

function generate_priors(n,priors)
priors_string = ""
if priors==nothing
if priors===nothing
for i in 1:n
priors_string = string(priors_string,"theta[$i] ~ normal(0, 1)", " ; ")
priors_string = string(priors_string,"theta_$i ~ normal(0, 1)", " ; ")
end
else
for i in 1:n
priors_string = string(priors_string,"theta[$i] ~",stan_string(priors[i]),";")
priors_string = string(priors_string,"theta_$i ~ ",stan_string(priors[i]),";")
end
end
priors_string
Expand All @@ -34,13 +37,13 @@ function generate_theta(n,priors)
lower_bound = string("lower=",minimum(priors[i]))
end
if lower_bound != "" && upper_bound != ""
theta = string(theta,"real","<$lower_bound",",","$upper_bound>"," theta$i",";")
theta = string(theta,"real","<$lower_bound",",","$upper_bound>"," theta_$i",";")
elseif lower_bound != ""
theta = string(theta,"real","<$lower_bound",">"," theta$i",";")
theta = string(theta,"real","<$lower_bound",">"," theta_$i",";")
elseif upper_bound != ""
theta = string(theta,"real","<","$upper_bound>"," theta$i",";")
theta = string(theta,"real","<","$upper_bound>"," theta_$i",";")
else
theta = string(theta,"real"," theta$i",";")
theta = string(theta,"real"," theta_$i",";")
end
end
return theta
Expand All @@ -50,9 +53,10 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing,
stanmodel = nothing;alg=:rk45,
num_samples=1000, num_warmup=1000, reltol=1e-3,
abstol=1e-6, maxiter=Int(1e5),likelihood=Normal,
vars=(StanODEData(),InverseGamma(3,3)),nchains=1,
sample_u0 = false, save_idxs = nothing, diffeq_string = nothing, printsummary = true)

vars=(StanODEData(),InverseGamma(3,3)),nchains=[1],
sample_u0 = false, save_idxs = nothing, diffeq_string = nothing,
printsummary = true, output_format = :mcmcchains)

save_idxs !== nothing && length(save_idxs) == 1 ? save_idxs = save_idxs[1] : save_idxs = save_idxs
length_of_y = length(prob.u0)
save_idxs = something(save_idxs, 1:length_of_y)
Expand All @@ -63,24 +67,26 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing,
else
length_of_parameter = length(prob.p) + sample_u0 * length(save_idxs)
end

if stanmodel === nothing
if alg ==:adams
algorithm = "integrate_ode_adams"
algorithm = "ode_adams_tol"
elseif alg ==:rk45
algorithm = "integrate_ode_rk45"
algorithm = "ode_rk45_tol"
elseif alg == :bdf
algorithm = "integrate_ode_bdf"
algorithm = "ode_bdf_tol"
else
error("The choices for alg are :adams, :rk45, or :bdf")
end
hyper_params = ""
tuple_hyper_params = ""
setup_params = ""
thetas = ""
theta_names = ""
theta_string = generate_theta(length_of_parameter,priors)
for i in 1:length_of_parameter
thetas = string(thetas,"theta[$i] = theta$i",";")
thetas = string(thetas,"real theta_$i",";")
theta_names = string(theta_names,"theta_$i",",")
end
for i in 1:length_of_params
if isa(vars[i],StanODEData)
Expand All @@ -97,18 +103,18 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing,
stan_likelihood = stan_string(likelihood)
if sample_u0
nu = length(save_idxs)
dv_names_ind = findfirst("$nu", theta_names)[1]
if nu < length(prob.u0)
u0 = "{"
u0 = ""
for u_ in prob.u0[nu+1:length(prob.u0)]
u0 = u0*string(u_)
end
u0 = u0*"}"
integral_string = "u_hat = $algorithm(sho, append_array(theta[1:$nu],$u0), t0, ts, theta[$(nu+1):$length_of_parameter], x_r, x_i, $reltol, $abstol, $maxiter);"
else
integral_string = "u_hat = $algorithm(sho, theta[1:$nu], t0, ts, theta[$(nu+1):$length_of_parameter], x_r, x_i, $reltol, $abstol, $maxiter);"
integral_string = "u_hat = $algorithm(sho, [$(theta_names[1:dv_names_ind]),$u0]', t0, ts, $reltol, $abstol, $maxiter, $(rstrip(theta_names[dv_names_ind+2:end],',')));"
else
integral_string = "u_hat = $algorithm(sho, [$(theta_names[1:dv_names_ind])]', t0, ts, $reltol, $abstol, $maxiter, $(rstrip(theta_names[dv_names_ind+2:end],',')));"
end
else
integral_string = "u_hat = $algorithm(sho, u0, t0, ts, theta, x_r, x_i, $reltol, $abstol, $maxiter);"
integral_string = "u_hat = $algorithm(sho, u0, t0, ts, $reltol, $abstol, $maxiter, $(rstrip(theta_names,',')));"
end
binsearch_string = """
int bin_search(real x, int min_val, int max_val){
Expand All @@ -120,8 +126,8 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing,
out = mid_pt;
range = 0;
} else {
range = (range + 1) / 2;
mid_pt = x > mid_pt ? mid_pt + range: mid_pt - range;
range = (range + 1) / 2;
mid_pt = x > mid_pt ? mid_pt + range: mid_pt - range;
}
}
return out;
Expand All @@ -141,26 +147,18 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing,
$diffeq_string
}
data {
real u0[$length_of_y];
vector[$length_of_y] u0;
int<lower=1> T;
real internal_var___u[T,$(length(save_idxs))];
real t0;
real ts[T];
}
transformed data {
real x_r[0];
int x_i[0];
}
parameters {
$setup_params
$theta_string
}
transformed parameters{
real theta[$length_of_parameter];
$thetas
}
model{
real u_hat[T,$length_of_y];
vector[$length_of_y] u_hat[T];
$hyper_params
$priors_string
$integral_string
Expand All @@ -169,9 +167,13 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing,
}
}
"
stanmodel = CmdStan.Stanmodel(num_samples=num_samples, num_warmup=num_warmup, name="parameter_estimation_model", model=parameter_estimation_model, nchains=nchains, printsummary = printsummary)
stanmodel = StanSample.SampleModel("parameter_estimation_model", parameter_estimation_model, nchains; printsummary = printsummary, method = StanSample.Sample(;num_samples = num_samples, num_warmup = num_warmup))
end
parameter_estimation_data = Dict("u0"=>prob.u0, "T" => length(t), "internal_var___u" => view(data, :, 1:length(t))', "t0" => prob.tspan[1], "ts" => t)
return_code, chains, cnames = CmdStan.stan(stanmodel, [parameter_estimation_data])
return StanModel(stanmodel, return_code, chains, cnames)
rc = stan_sample(stanmodel; data = parameter_estimation_data)
if success(rc)
return StanResult(stanmodel, rc, read_samples(stanmodel; output_format=output_format))
else
rc.err
end
end
4 changes: 1 addition & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@ const GROUP = get(ENV, "GROUP", "All")
if GROUP == "All" || GROUP == "Core"
@time @safetestset "DynamicHMC" begin include("dynamicHMC.jl") end
@time @safetestset "Turing" begin include("turing.jl") end
@time @safetestset "ABC" begin include("abc.jl") end
# @time @safetestset "ABC" begin include("abc.jl") end
end

if GROUP == "Stan" || GROUP == "All"
using Pkg
Pkg.add("CmdStan")
@time @safetestset "Stan_String" begin include("stan_string.jl") end
@time @safetestset "Stan" begin include("stan.jl") end
end
32 changes: 13 additions & 19 deletions test/stan.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using CmdStan, DiffEqBayes, OrdinaryDiffEq, ParameterizedFunctions,
using DiffEqBayes, OrdinaryDiffEq, ParameterizedFunctions,
RecursiveArrayTools, Distributions, Test

println("One parameter case")
Expand All @@ -19,24 +19,21 @@ priors = [truncated(Normal(1.5,0.1),1.0,1.8)]
bayesian_result = stan_inference(prob1,t,data,priors;num_samples=300,
num_warmup=500,likelihood=Normal)

sdf = CmdStan.read_summary(bayesian_result.model)
@test sdf[sdf.parameters .== :theta1, :mean][1] ≈ 1.5 atol=3e-1
@test mean(get(bayesian_result.chains,:theta_1)[1]) ≈ 1.5 atol=3e-1

# Test norecompile
bayesian_result2 = stan_inference(prob1,t,data,priors,bayesian_result.model;
num_samples=300,num_warmup=500,likelihood=Normal)

sdf = CmdStan.read_summary(bayesian_result.model)
@test sdf[sdf.parameters .== :theta1, :mean][1] ≈ 1.5 atol=3e-1
@test mean(get(bayesian_result2.chains,:theta_1)[1]) ≈ 1.5 atol=3e-1

priors = [truncated(Normal(1.,0.01),0.5,2.0),truncated(Normal(1.,0.01),0.5,2.0),truncated(Normal(1.5,0.01),1.0,2.0)]
bayesian_result = stan_inference(prob1,t,data,priors;num_samples=300,
num_warmup=500,likelihood=Normal,sample_u0=true)

sdf = CmdStan.read_summary(bayesian_result.model)
@test sdf[sdf.parameters .== :theta1, :mean][1] ≈ 1. atol=3e-1
@test sdf[sdf.parameters .== :theta2, :mean][1] ≈ 1. atol=3e-1
@test sdf[sdf.parameters .== :theta3, :mean][1] ≈ 1.5 atol=3e-1
@test mean(get(bayesian_result.chains,:theta_1)[1]) ≈ 1. atol=3e-1
@test mean(get(bayesian_result.chains,:theta_2)[1]) ≈ 1. atol=3e-1
@test mean(get(bayesian_result.chains,:theta_3)[1]) ≈ 1.5 atol=3e-1

sol = solve(prob1,Tsit5(),save_idxs=[1])
randomized = VectorOfArray([(sol(t[i]) + .01 * randn(1)) for i in 1:length(t)])
Expand All @@ -45,17 +42,15 @@ priors = [truncated(Normal(1.5,0.1),0.5,2)]
bayesian_result = stan_inference(prob1,t,data,priors;num_samples=300,
num_warmup=500,likelihood=Normal,save_idxs=[1])

sdf = CmdStan.read_summary(bayesian_result.model)
@test sdf[sdf.parameters .== :theta1, :mean][1] ≈ 1.5 atol=3e-1
@test mean(get(bayesian_result.chains,:theta_1)[1]) ≈ 1.5 atol=3e-1


priors = [truncated(Normal(1.,0.01),0.5,2),truncated(Normal(1.5,0.01),0.5,2)]
bayesian_result = stan_inference(prob1,t,data,priors;num_samples=300,
num_warmup=500,likelihood=Normal,save_idxs=[1],sample_u0=true)

sdf = CmdStan.read_summary(bayesian_result.model)
@test sdf[sdf.parameters .== :theta1, :mean][1] ≈ 1. atol=3e-1
@test sdf[sdf.parameters .== :theta2, :mean][1] ≈ 1.5 atol=3e-1
@test mean(get(bayesian_result.chains,:theta_1)[1]) ≈ 1. atol=3e-1
@test mean(get(bayesian_result.chains,:theta_2)[1]) ≈ 1.5 atol=3e-1

println("Four parameter case")
f1 = @ode_def begin
Expand All @@ -74,8 +69,7 @@ priors = [truncated(Normal(1.5,0.01),0.5,2),truncated(Normal(1.0,0.01),0.5,1.5),
truncated(Normal(3.0,0.01),0.5,4),truncated(Normal(1.0,0.01),0.5,2)]

bayesian_result = stan_inference(prob1,t,data,priors;num_samples=100,num_warmup=500,vars =(DiffEqBayes.StanODEData(),InverseGamma(4,1)))
sdf = CmdStan.read_summary(bayesian_result.model)
@test sdf[sdf.parameters .== :theta1, :mean][1] ≈ 1.5 atol=1e-1
@test sdf[sdf.parameters .== :theta2, :mean][1] ≈ 1.0 atol=1e-1
@test sdf[sdf.parameters .== :theta3, :mean][1] ≈ 3.0 atol=1e-1
@test sdf[sdf.parameters .== :theta4, :mean][1] ≈ 1.0 atol=1e-1
@test mean(get(bayesian_result.chains,:theta_1)[1]) ≈ 1.5 atol=1e-1
@test mean(get(bayesian_result.chains,:theta_2)[1]) ≈ 1.0 atol=1e-1
@test mean(get(bayesian_result.chains,:theta_3)[1]) ≈ 3.0 atol=1e-1
@test mean(get(bayesian_result.chains,:theta_4)[1]) ≈ 1.0 atol=1e-1
Loading