-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsimulationTimings.jl
135 lines (117 loc) · 4.53 KB
/
simulationTimings.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
using Plots, BenchmarkPlots, StatsPlots
using BifurcationKit, DifferentialEquations
using BenchmarkTools
const BK = BifurcationKit
include("./model.jl")
using .Model
include("./tools.jl")
using .Tools
# Define BenchmarkGroup
bg = BenchmarkGroup()
bg["Small"] = BenchmarkGroup()
bg["Large"] = BenchmarkGroup()
bg["Small"]["ODE"] = BenchmarkGroup()
bg["Small"]["Cont"] = BenchmarkGroup()
bg["Large"]["ODE"] = BenchmarkGroup()
bg["Large"]["Cont"] = BenchmarkGroup()
# Plot parameters
plot_params = (linewidth=2., dpi=300, size=(450,300), legend=false)
# ODE Convergence
tmp = Model.params
pSmall = @set tmp.g_Na_sf = 1.1
pLarge = @set tmp.g_Na_sf = 1.5
pLarge = @set pLarge.g_K_sf = 1.2
pLarge = @set pLarge.g_L_sf = 0.8
prob = ODEProblem(Model.ode!, Model.ic, (0.0, 10000.0), tmp, abstol=1e-10, reltol=1e-8)
# ODE Convergence - Standard
params = [pSmall, pLarge]
println("Standard Approach")
for i in eachindex(params)
condition(u, _, _) = u[1]+20
STATE::Vector{Float64} = zeros(size(Model.ic))
function affect!(integrator)
error = STATE .- integrator.u
if sum(abs.(error)) < 1e-6
terminate!(integrator)
end
STATE .= integrator.u
end
cb = ContinuousCallback(condition, affect!, nothing;
save_positions = (false, false))
p = params[i]
prob_de = remake(prob, p=p)
sol = DifferentialEquations.solve(prob_de, Tsit5(), maxiters=1e9, save_everystep=false, save_start=false, save_end=true, callback=cb)
if i==1
println("Simulation time to convergence for small perturbation")
else
println("Simulation time to convergence for large perturbation")
end
display(sol.t)
b = @benchmarkable DifferentialEquations.solve($prob_de, $Tsit5(), maxiters=1e9, save_everystep=false, save_start=false, save_end=true, callback=$cb)
bg[i==1 ? "Small" : "Large"]["ODE"]["ODE - Standard"] = b
end
# ODE Convergence - Tracking
println("Tracking Approach")
for i in eachindex(params)
condition(u, _, _) = u[1]+20
STATE::Vector{Float64} = zeros(size(Model.ic))
function affect!(integrator)
error = STATE .- integrator.u
if sum(abs.(error)) < 1e-6
terminate!(integrator)
end
STATE .= integrator.u
end
cb = ContinuousCallback(condition, affect!, nothing;
save_positions = (false, false))
p = params[i]
prob_de = remake(prob, p=p)
sol = DifferentialEquations.solve(prob_de, Tsit5(), maxiters=1e9, u0=Model.ic_conv, save_everystep=false, save_start=false, save_end=true, callback=cb)
if i==1
println("Simulation time to convergence for small perturbation")
else
println("Simulation time to convergence for large perturbation")
end
display(sol.t)
b = @benchmarkable DifferentialEquations.solve($prob_de, $Tsit5(), maxiters=1e9, u0=Model.ic_conv, save_everystep=false, save_start=false, save_end=true, callback=$cb)
bg[i==1 ? "Small" : "Large"]["ODE"]["ODE - Tracking"] = b
end
# Continuation Convergence
function early_abort((x, f, J, res, iteration, itlinear, options); kwargs...)
if res < 5e2
return true
else
return false
end
end
lens = @optic _.step
tmp = Model.params_cont
pSmall = @set tmp.na_step = 0.1
pLarge = @set tmp.na_step = 0.5
pLarge = @set pLarge.k_step = 0.2
pLarge = @set pLarge.l_step = -0.2
params = [pSmall, pLarge]
ds = [1.0, 0.3]
for i in eachindex(params)
p = params[i]
bp = BifurcationProblem(Model.ode_cont!, Model.ic_conv, p, lens;
record_from_solution = (x, p) -> (V = x[Model.plot_idx]),)
# 1 pulse solution
prob_cont = ODEProblem(Model.ode_cont!, Model.ic_conv, (0.0, 0.5216), p, abstol=1e-10, reltol=1e-8)
sol_pulse = DifferentialEquations.solve(prob_cont, Tsit5())
opts_br = ContinuationPar(p_min = 0.0, p_max = 1.0, max_steps = 50, tol_stability = 1e-8, ds=ds[i], dsmax=1.0,
detect_bifurcation=0, detect_fold=false, newton_options=NewtonPar(verbose=true, tol=1e-10))
# Shooting method
bpsh, cish = BK.generate_ci_problem(ShootingProblem(M=1, update_section_every_step=0), #update_section_every_step=0 avoids bpsh being perturbed between benchmark runs
bp, prob_cont, sol_pulse, 0.5216; alg = Tsit5(), abstol=1e-10, reltol=1e-8)
brpo_sh = continuation(bpsh, cish, PALC(), opts_br;
verbosity = 3, callback_newton = early_abort
)
reducedOpts = ContinuationPar(p_min = 0.0, p_max = 1.0, max_steps = 50, tol_stability = 1e-8,
ds=1.0, dsmax=1.0, detect_bifurcation=0, detect_fold=false, newton_options=NewtonPar(tol=1e-10))
b = @benchmarkable continuation($bpsh, $cish, $PALC(), $reducedOpts; callback_newton = early_abort)
bg[i==1 ? "Small" : "Large"]["Cont"]["Cont - Shooting"] = b
end
println("Reached the end of the script. Just running benchmark now.")
t = run(bg, seconds=100)
BenchmarkTools.save("results/simTimings/data.json", t)