-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmcmcSetup.jl
139 lines (116 loc) · 4.57 KB
/
mcmcSetup.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
using Distributions, LinearAlgebra, Random
using BifurcationKit, DifferentialEquations, BenchmarkTools
using CSV, Tables, Plots
using Optim, DataFrames
include("./model.jl")
using .Model
include("./tools.jl")
using .Tools
# Set seed
Random.seed!(1)
"""
get_period(lc::Vector{Float64}, prob::ODEProblem)::Number
Get the period of the limit cycle.
# Arguments
- `lc::Vector{Float64}`: The limit cycle.
- `prob::ODEProblem`: The ODEProblem for the model.
# Returns
- `period::Number`: The period of the limit cycle.
"""
function get_period(lc::Vector{Float64}, prob::ODEProblem)::Number
# Use callbacks to find state at start and end of a period (using upcrossings of V=0mV)
condition(u, _, _) = u[1]+20
NUM_TIMES_EFFECT_HIT::Int = 0
function affect!(integrator)
NUM_TIMES_EFFECT_HIT += 1
if NUM_TIMES_EFFECT_HIT >= 2
terminate!(integrator)
end
end
cb = ContinuousCallback(condition, affect!, nothing;
save_positions = (true, false))
sol = DifferentialEquations.solve(prob, Tsit5(), u0=lc, tspan=(0.0, 10.0), maxiters=1e9,
save_everystep=false, save_start=false, save_end=false, callback=cb)
period = sol.t[end]-sol.t[1]
return period
end
function saveData()
# Define the method specific settings and functions for MCMC
prob = ODEProblem(Model.ode!, Model.ic, (0.0, 1000.0), Model.params, abstol=1e-10, reltol=1e-8, maxiters=1e7)
# Create the true data
# True parameters
pTrue = Tools.param_map([1.0, 1.0, 1.0])
prob = remake(prob, p=pTrue)::ODEProblem
# Run ODE to converged limit cycle
sol = DifferentialEquations.solve(prob, Tsit5(), maxiters=1e9, save_everystep=false)::ODESolution
if Tools.auto_converge_check(prob, sol[end], pTrue)
println("Data is appropriately converged")
else
println("Data was NOT generated from a converged limit cycle")
end
# Generate aligned data
period = floor(get_period(sol[end], prob)/0.001)*0.001
sol_pulse = Tools.aligned_sol(sol[end], prob, period)
# Add noise and plot
odedata = Array(sol_pulse.u) + 2.0 * randn(size(sol_pulse))
# Save the data
CSV.write("results/mcmc/data.csv", Tables.table([sol_pulse.t odedata]), writeheader=false)
return sol_pulse, odedata, period
end
function ℓ(data, sol)
σ = 2.0
n = length(data)
return -n*log(2π)/2 - n*log(σ^2)/2 - 1/(2σ^2)*sum((data - sol.u).^2)
end
function plotData(sol, data, mle, period)
plot(sol, title="True data"; label="True solution - ℓ: "*string(round(ℓ(data, sol),sigdigits=4)))
prob = ODEProblem(Model.ode!, Model.ic, (0.0, 1000.0), abstol=1e-10, reltol=1e-8, maxiters=1e7)
prob = remake(prob, p=Tools.param_map(mle))::ODEProblem
solMLE = DifferentialEquations.solve(prob, Tsit5(), maxiters=1e9)::ODESolution
solMLE = Tools.aligned_sol(solMLE[end], prob, period)
plot!(solMLE, label="MLE - ℓ: "*string(round(ℓ(data, solMLE), sigdigits=4)))
plot!(sol_pulse.t, data, label="Data")
savefig("results/mcmc/data.pdf")
return solMLE
end
function optimiseParameters()
# Load data
data = CSV.read("results/mcmc/data.csv", DataFrame, header=false)
t = data[:, 1]
data = data[:, 2]
period = t[end]
# Define optimisation variables and functions
prob = ODEProblem(Model.ode!, Model.ic, (0.0, 1000.0); p=Model.params, abstol=1e-10, reltol=1e-8, maxiters=1e9)
# Define model simulation function
function model_simulator(p)
prob = remake(prob, p=Tools.param_map(p))::ODEProblem
# Converge
condition(u, _, _) = u[1]+20
STATE::Vector{Float64} = zeros(size(Model.ic))
function affect!(integrator)
error = STATE .- integrator.u
if sum(abs.(error)) < 1e-6
terminate!(integrator)
end
STATE .= integrator.u
end
cb = ContinuousCallback(condition, affect!, nothing;
save_positions = (false, false))
sol = DifferentialEquations.solve(prob, Tsit5(), save_everystep=false, save_start=false, save_end=true, callback=cb)
sol_pulse = Tools.aligned_sol(sol[end], prob, period)
return sol_pulse
end
# Define the cost function
function cost(p)
sim = model_simulator(p)
return -ℓ(data, sim)
end
# Optimise
p0 = [1.0, 1.0, 1.0]
res = Optim.optimize(cost, p0, NelderMead(; initial_simplex=Optim.AffineSimplexer(b=0.0)), Optim.Options(show_trace=true))
println(res.minimizer)
return res
end
sol_pulse, odedata, period = saveData()
res = optimiseParameters()
solMLE = plotData(sol_pulse, odedata, res.minimizer, period)