forked from FluxML/model-zoo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sde.jl
28 lines (26 loc) · 830 Bytes
/
sde.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
using Flux, DiffEqFlux, DifferentialEquations, Plots
function lotka_volterra(du,u,p,t)
x, y = u
α, β, δ, γ = p
du[1] = dx = α*x - β*x*y
du[2] = dy = -δ*y + γ*x*y
end
function lotka_volterra_noise(du,u,p,t)
du[1] = 0.1u[1]
du[2] = 0.1u[2]
end
prob = SDEProblem(lotka_volterra,lotka_volterra_noise,[1.0,1.0],(0.0,10.0))
p = param([2.2, 1.0, 2.0, 0.4])
function predict_fd_sde()
diffeq_fd(p,sol->sol[1,:],101,prob,SOSRI(),saveat=0.1)
end
loss_fd_sde() = sum(abs2,x-1 for x in predict_fd_sde())
data = Iterators.repeated((), 100)
opt = ADAM(0.1)
cb = function ()
display(loss_fd_sde())
display(plot(solve(remake(prob,p=Flux.data(p)),SOSRI(),saveat=0.1),ylim=(0,6)))
end
# Display the SDE with the current parameter values.
cb()
Flux.train!(loss_fd_sde, [p], data, opt, cb = cb)