Skip to content

Commit 33d4574

Browse files
More examples
1 parent 3fabc62 commit 33d4574

File tree

2 files changed

+81
-34
lines changed

2 files changed

+81
-34
lines changed

examples/kuramoto_example.jl

+80-33
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ Pkg.activate(".")
66

77
using BAC
88

9+
using Random
910
Random.seed!(1);
1011
using Pipe: @pipe
1112
using LaTeXStrings
1213
using Statistics
1314
using DiffEqFlux
15+
using Plots
1416

1517
##
1618
const t_steps = 0.:0.1:4pi
@@ -26,47 +28,41 @@ p_spec_init = rand(dim_p_spec) .+ 1.
2628

2729
p_initial = vcat(p_sys_init, repeat(p_spec_init, N_samples))
2830

29-
i = BAC.rand_fourier_input_generator(1)
31+
@views begin
32+
p_syss = p_initial[1:dim_p]
33+
p_specs = [p_initial[(dim_p + 1 + (n - 1) * dim_p_spec):(dim_p + n * dim_p_spec)] for n in 1:N_samples]
34+
end
35+
36+
3037
K_av = 1.
3138

3239
##
40+
41+
i = BAC.rand_fourier_input_generator(1)
3342
plot(i, 0., 4pi)
3443

35-
##
44+
## Start with a small frequency spread
3645

3746
omega = 1. * randn(N_osc);
3847
omega .-= mean(omega)
3948

4049
##
4150

42-
@views begin
43-
p_syss = p_initial[1:dim_p]
44-
p_specs = [p_initial[(dim_p + 1 + (n - 1) * dim_p_spec):(dim_p + n * dim_p_spec)] for n in 1:N_samples]
45-
end
46-
4751
kur = BAC.create_kuramoto_example(omega, N_osc, dim_p_spec, K_av, t_steps, N_samples) # specify modes = 0 for no input
4852

49-
solve_sys_spec(kur, i, p_syss, p_specs[1])
50-
5153
scen = 1:5
52-
sol1, sol2 = BAC.solve_bl_n(kur, 3, p_initial, scenario_nums = scen) # n and scen_nums???
53-
kur.output_metric(sol1, sol2)
5454

5555
## Plot where we start
5656
p_initial = BAC.bac_spec_only(kur, p_initial; optimizer_options=(:maxiters => 1000,), solver_options = (abstol = 1e-4, reltol=1e-4))
5757

5858
l = kur(p_initial, abstol=1e-4, reltol=1e-4)
5959
plot_callback(kur, p_initial, l, scenario_nums = scen, xlims = (kur.t_span[2]/2, kur.t_span[2]))
60-
##
6160

62-
## For some reason using kur(p) inside an optimization loop results in an error. I have not been able to find the reason yet
63-
# The error occurs in the differential equation system (line 36), as if p is a 4-element vector instead of 2x2 matrix.
64-
# All the functions run normally without DiffEqFlux
61+
##
6562
res_1 = DiffEqFlux.sciml_train(
6663
p -> kur(p, abstol=1e-4, reltol=1e-4),
6764
p_initial,
6865
DiffEqFlux.ADAM(0.1),
69-
# DiffEqFlux.BFGS(),
7066
maxiters=25,
7167
cb=basic_bac_callback
7268
# cb = (p, l) -> plot_callback(kur, p, l, scenario_nums=scenarios)
@@ -81,7 +77,6 @@ plot_callback(kur, res_1.u, res_1.minimum, scenario_nums = scen, xlims = (kur.t_
8177
res_2 = DiffEqFlux.sciml_train(
8278
p -> kur(p, abstol=1e-4, reltol=1e-4),
8379
res_1.u,
84-
# DiffEqFlux.ADAM(0.1),
8580
DiffEqFlux.BFGS(),
8681
maxiters=25,
8782
cb=basic_bac_callback
@@ -97,8 +92,6 @@ plot_callback(kur, res_2.u, res_2.minimum, scenario_nums = scen, xlims = (kur.t_
9792
res_3 = DiffEqFlux.sciml_train(
9893
p -> kur(p, abstol=1e-4, reltol=1e-4),
9994
res_2.u,
100-
# DiffEqFlux.BFGS(),
101-
# DiffEqFlux.ADAM(0.1),
10295
DiffEqFlux.AMSGrad(0.01),
10396
maxiters=100,
10497
cb=basic_bac_callback
@@ -114,7 +107,6 @@ plot_callback(kur, res_3.u, res_3.minimum, scenario_nums = scen, xlims = (kur.t_
114107
res_4 = DiffEqFlux.sciml_train(
115108
p -> kur(p, abstol=1e-4, reltol=1e-4),
116109
res_3.u,
117-
# DiffEqFlux.ADAM(0.1),
118110
DiffEqFlux.BFGS(),
119111
maxiters=25,
120112
cb=basic_bac_callback
@@ -127,26 +119,21 @@ plot_callback(kur, res_4.u, res_4.minimum, scenario_nums = scen, xlims = (kur.t_
127119

128120
##
129121

130-
plot_callback(kur, res_1.u, res_1.minimum, scenario_nums = scen, xlims = (kur.t_span[2]/2, kur.t_span[2]))
131-
plot_callback(kur, res_2.u, res_2.minimum, scenario_nums = scen, xlims = (kur.t_span[2]/2, kur.t_span[2]))
132-
plot_callback(kur, res_3.u, res_3.minimum, scenario_nums = scen, xlims = (kur.t_span[2]/2, kur.t_span[2]))
133-
plot_callback(kur, res_4.u, res_4.minimum, scenario_nums = scen, xlims = (kur.t_span[2]/2, kur.t_span[2]))
134-
135-
##
136-
137122
res_5 = DiffEqFlux.sciml_train(
138-
p -> kur(p, abstol=1e-4, reltol=1e-4),
123+
p -> kur(p, abstol=1e-5, reltol=1e-5),
139124
res_4.u,
140-
# DiffEqFlux.ADAM(0.1),
141125
DiffEqFlux.AMSGrad(0.01),
142-
# DiffEqFlux.BFGS(),
143-
maxiters=225,
126+
maxiters=100,
144127
cb=basic_bac_callback
145128
# cb = (p, l) -> plot_callback(kur, p, l, scenario_nums=scenarios)
146129
)
147130

148131
##
149-
132+
plot_callback(kur, p_initial, l, scenario_nums = scen, xlims = (kur.t_span[2]/2, kur.t_span[2]))
133+
plot_callback(kur, res_1.u, res_1.minimum, scenario_nums = scen, xlims = (kur.t_span[2]/2, kur.t_span[2]))
134+
plot_callback(kur, res_2.u, res_2.minimum, scenario_nums = scen, xlims = (kur.t_span[2]/2, kur.t_span[2]))
135+
plot_callback(kur, res_3.u, res_3.minimum, scenario_nums = scen, xlims = (kur.t_span[2]/2, kur.t_span[2]))
136+
plot_callback(kur, res_4.u, res_4.minimum, scenario_nums = scen, xlims = (kur.t_span[2]/2, kur.t_span[2]))
150137
plot_callback(kur, res_5.u, res_5.minimum, scenario_nums = scen, xlims = (kur.t_span[2]/2, kur.t_span[2]))
151138

152139
##
@@ -156,8 +143,37 @@ plot_callback(kur, res_5.u, res_5.minimum, scenario_nums = scen, xlims = (kur.t_
156143
p_final_specs = [res_5.u[(dim_p + 1 + (n - 1) * dim_p_spec):(dim_p + n * dim_p_spec)] for n in 1:N_samples]
157144
end
158145

146+
## Try with a larger spread
147+
148+
omega = 20. * randn(N_osc);
149+
omega .-= mean(omega)
150+
151+
##
152+
153+
kur2 = BAC.create_kuramoto_example(omega, N_osc, dim_p_spec, K_av, t_steps, N_samples) # specify modes = 0 for no input
154+
155+
##
156+
p_initial2 = BAC.bac_spec_only(kur2, p_initial; optimizer_options=(:maxiters => 1000,), solver_options = (abstol = 1e-4, reltol=1e-4))
157+
158+
l2 = kur2(p_initial2, abstol=1e-4, reltol=1e-4)
159+
plot_callback(kur2, p_initial2, l, scenario_nums = scen, xlims = (kur.t_span[2]/2, kur.t_span[2]))
160+
161+
##
162+
res2_1 = DiffEqFlux.sciml_train(
163+
p -> kur2(p, abstol=1e-4, reltol=1e-4),
164+
p_initial2,
165+
DiffEqFlux.ADAM(0.1),
166+
maxiters=50,
167+
cb=basic_bac_callback
168+
# cb = (p, l) -> plot_callback(kur, p, l, scenario_nums=scenarios)
169+
)
170+
159171
##
160172

173+
plot_callback(kur2, res2_1.u, res2_1.minimum, scenario_nums = scen, xlims = (kur.t_span[2]/2, kur.t_span[2]))
174+
175+
## Increase number of nodes to 100
176+
161177
kur_100 = resample(BAC.rand_fourier_input_generator, kur; n = 100);
162178

163179
p_sys_init_100 = 6. * rand(dim_p) .+ 1.
@@ -168,4 +184,35 @@ p_100 = vcat(p_sys_init, repeat(p_spec_init, 100))
168184
p_100[1:dim_p] .= relu.(p_final)
169185
p_100[dim_p+1:end] .= repeat(relu.(p_final_specs[1]), 100)
170186

171-
p_initial_100 = BAC.bac_spec_only(kur_100, p_100)
187+
##
188+
p_initial_100 = BAC.bac_spec_only(kur_100, p_100)
189+
190+
##
191+
res_100 = DiffEqFlux.sciml_train(
192+
p -> kur_100(p, abstol=1e-4, reltol=1e-4),
193+
p_initial_100,
194+
DiffEqFlux.ADAM(0.1),
195+
maxiters=50,
196+
cb=basic_bac_callback
197+
# cb = (p, l) -> plot_callback(kur, p, l, scenario_nums=scenarios)
198+
)
199+
200+
##
201+
202+
plot_callback(kur_100, res_100.u, res_100.minimum, scenario_nums = scen, xlims = (kur.t_span[2]/2, kur.t_span[2]))
203+
204+
##
205+
res_100 = DiffEqFlux.sciml_train(
206+
p -> kur_100(p, abstol=1e-4, reltol=1e-4),
207+
res_100.u,
208+
DiffEqFlux.AMSGrad(0.01),
209+
maxiters=50,
210+
cb=basic_bac_callback
211+
# cb = (p, l) -> plot_callback(kur, p, l, scenario_nums=scenarios)
212+
)
213+
214+
##
215+
216+
plot_callback(kur_100, res_100.u, res_100.minimum, scenario_nums = scen, xlims = (kur.t_span[2]/2, kur.t_span[2]))
217+
218+
##

src/Core.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ function resample(sampler, bac::BAC_Loss; n = 0)
172172
bac.N_samples = n
173173
end
174174
new_input_sample = [sampler(n) for n in 1:bac.N_samples]
175-
BAC_Loss(bac.f_spec, bac.f_sys, bac.tsteps, bac.t_span, new_input_sample, bac.output_metric, bac.N_samples, bac.size_p_spec, bac.size_p_sys, bac.y0_spec, bac.y0_sys, bac.solver)
175+
return BAC_Loss(bac.f_spec, bac.f_sys, bac.tsteps, bac.t_span, new_input_sample, bac.output_metric, bac.N_samples, bac.size_p_spec, bac.size_p_sys, bac.y0_spec, bac.y0_sys, bac.solver)
176176
end
177177

178178
# Basic callback

0 commit comments

Comments
 (0)