Skip to content

Commit

Permalink
Merge pull request #150 from JuliaDiffEq/nystrom
Browse files Browse the repository at this point in the history
Add IRKN3 and tableaus for DPRKN6
  • Loading branch information
ChrisRackauckas authored Aug 13, 2017
2 parents a112113 + f0736c6 commit 08bd93c
Show file tree
Hide file tree
Showing 7 changed files with 304 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ module OrdinaryDiffEq
include("tableaus/feagin_tableaus.jl")
include("tableaus/rosenbrock_tableaus.jl")
include("tableaus/sdirk_tableaus.jl")
include("tableaus/rkn_tableaus.jl")

include("integrators/type.jl")
include("integrators/integrator_utils.jl")
Expand Down Expand Up @@ -150,5 +151,5 @@ module OrdinaryDiffEq
export SplitEuler

export Nystrom4, Nystrom4VelocityIndependent, Nystrom5VelocityIndependent,
IRKN4
IRKN3, IRKN4, DPRKN6
end # module
5 changes: 5 additions & 0 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,12 @@ isfsal(alg::McAte8) = true
isfsal(alg::KahanLi8) = true
isfsal(alg::SofSpa10) = true

isfsal(alg::IRKN3) = true
isfsal(alg::Nystrom4) = true
isfsal(alg::Nystrom4VelocityIndependent) = true
isfsal(alg::IRKN4) = true
isfsal(alg::Nystrom5VelocityIndependent) = true
isfsal(alg::DPRKN6) = true

fsal_typeof(alg::OrdinaryDiffEqAlgorithm,rate_prototype) = typeof(rate_prototype)
#fsal_typeof(alg::LawsonEuler,rate_prototype) = Vector{typeof(rate_prototype)}
Expand Down Expand Up @@ -178,6 +180,7 @@ alg_extrapolates(alg::Cash4) = true
alg_extrapolates(alg::Hairer4) = true
alg_extrapolates(alg::Hairer42) = true
alg_extrapolates(alg::IRKN4) = true
alg_extrapolates(alg::IRKN3) = true

alg_autodiff(alg::OrdinaryDiffEqAlgorithm) = error("This algorithm does not have an autodifferentiation option defined.")
alg_autodiff{CS,AD}(alg::ImplicitEuler{CS,AD}) = AD
Expand Down Expand Up @@ -238,10 +241,12 @@ alg_order(alg::McAte8) = 8
alg_order(alg::KahanLi8) = 8
alg_order(alg::SofSpa10) = 10

alg_order(alg::IRKN3) = 3
alg_order(alg::Nystrom4) = 4
alg_order(alg::Nystrom4VelocityIndependent) = 4
alg_order(alg::IRKN4) = 4
alg_order(alg::Nystrom5VelocityIndependent) = 5
alg_order(alg::DPRKN6) = 6

alg_order(alg::Midpoint) = 2
alg_order(alg::IIF1) = 1
Expand Down
2 changes: 2 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ struct SofSpa10 <: OrdinaryDiffEqAlgorithm end

# Nyström methods

struct IRKN3 <: OrdinaryDiffEqAlgorithm end
struct Nystrom4 <: OrdinaryDiffEqAlgorithm end
struct Nystrom4VelocityIndependent <: OrdinaryDiffEqAlgorithm end
struct IRKN4 <: OrdinaryDiffEqAlgorithm end
struct Nystrom5VelocityIndependent <: OrdinaryDiffEqAlgorithm end
struct DPRKN6 <: OrdinaryDiffEqAlgorithm end

################################################################################

Expand Down
60 changes: 59 additions & 1 deletion src/caches/rkn_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,29 @@ function alg_cache(alg::Nystrom4VelocityIndependent,u,rate_prototype,uEltypeNoUn
Nystrom4VelocityIndependentCache(u,uprev,k₁,k₂,k₃,k,tmp)
end

struct IRKN3Cache{uType,rateType} <: OrdinaryDiffEqMutableCache
u::uType
uprev::uType
uprev2::uType
fsalfirst::rateType
k₂::rateType
k::rateType
tmp::uType
onestep_cache::Nystrom4VelocityIndependentCache
end

u_cache(c::IRKN3Cache) = ()
du_cache(c::IRKN3Cache) = (c.fsalfirst,c.k₂,c.k)

function alg_cache(alg::IRKN3,u,rate_prototype,uEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,reltol,::Type{Val{true}})
k₁ = zeros(rate_prototype)
k₂ = zeros(rate_prototype)
k₃ = zeros(rate_prototype)
k = zeros(rate_prototype)
tmp = similar(u)
IRKN3Cache(u,uprev,uprev2,k₁,k₂,k,tmp,Nystrom4VelocityIndependentCache(u,uprev,k₁,k₂,k₃,k,tmp))
end

struct IRKN4Cache{uType,rateType} <: OrdinaryDiffEqMutableCache
u::uType
uprev::uType
Expand All @@ -61,7 +84,7 @@ struct IRKN4Cache{uType,rateType} <: OrdinaryDiffEqMutableCache
end

u_cache(c::IRKN4Cache) = ()
du_cache(c::IRKN4Cache) = (c.fsalfirst,c.k₂,c.k₃,c.k,c.k_₁,c.k_₂,c.k_₃)
du_cache(c::IRKN4Cache) = (c.fsalfirst,c.k₂,c.k₃,c.k)

function alg_cache(alg::IRKN4,u,rate_prototype,uEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,reltol,::Type{Val{true}})
k₁ = zeros(rate_prototype)
Expand Down Expand Up @@ -95,3 +118,38 @@ function alg_cache(alg::Nystrom5VelocityIndependent,u,rate_prototype,uEltypeNoUn
tmp = similar(u)
Nystrom5VelocityIndependentCache(u,uprev,k₁,k₂,k₃,k₄,k,tmp)
end

struct DPRKN6Cache{uType,uArrayType,rateType,uEltypeNoUnits,TabType} <: OrdinaryDiffEqMutableCache
u::uType
uprev::uType
fsalfirst::rateType
k2::rateType
k3::rateType
k4::rateType
k5::rateType
k6::rateType
k::rateType
utilde::uArrayType
tmp::uType
atmp::uEltypeNoUnits
tab::TabType
end

u_cache(c::DPRKN6Cache) = (c.atmp,c.utilde)
du_cache(c::DPRKN6Cache) = (c.fsalfirst,c.k2,c.k3,c.k4,c.k5,c.k6)

function alg_cache(alg::DPRKN6,u,rate_prototype,uEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,reltol,::Type{Val{true}})
tab = DPRKN6ConstantCache(real(uEltypeNoUnits),real(tTypeNoUnits))
k1 = zeros(rate_prototype)
k2 = zeros(rate_prototype)
k3 = zeros(rate_prototype)
k4 = zeros(rate_prototype)
k5 = zeros(rate_prototype)
k6 = zeros(rate_prototype)
k = zeros(rate_prototype)
utilde = similar(u,indices(u))
atmp = similar(u,uEltypeNoUnits)
tmp = similar(u)
DPRKN6Cache(u,uprev,k1,k2,k3,k4,k5,k6,k,utilde,tmp,atmp,tab)
end

94 changes: 94 additions & 0 deletions src/integrators/rkn_integrators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,53 @@ end
f.f2(t+dt,u,du,k.x[2])
end

function initialize!(integrator,cache::IRKN3Cache,f=integrator.f)
@unpack tmp,fsalfirst,k₂,k = cache
uprev,duprev = integrator.uprev.x

integrator.fsalfirst = fsalfirst
integrator.fsallast = k
integrator.kshortsize = 2
integrator.k = eltype(integrator.sol.k)(integrator.kshortsize)
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
f.f1(integrator.t,uprev,duprev,integrator.k[2].x[1])
f.f2(integrator.t,uprev,duprev,integrator.k[2].x[2])
end

@muladd function perform_step!(integrator,cache::IRKN3Cache,f=integrator.f)
# if there's a discontinuity or the solver is in the first step
if integrator.iter < 2 && !integrator.u_modified
perform_step!(integrator,integrator.cache.onestep_cache)
else
@unpack t,dt,k,tprev = integrator
u,du = integrator.u.x
uprev, duprev = integrator.uprev.x
uprev2,duprev2 = integrator.uprev2.x
uidx = eachindex(integrator.uprev.x[1])
@unpack tmp,fsalfirst,k₂,k = cache
ku, kdu = integrator.cache.tmp.x[1], integrator.cache.tmp.x[2]
k₁ = fsalfirst
dtsq = dt^2

f.f2(t+1//2*dt, uprev, duprev, k.x[1])
f.f2(tprev+1//2*dt,uprev2,duprev2,k.x[2])
@tight_loop_macros for i in uidx
@inbounds ku[i] = uprev[i] + (1//2*dt)*duprev[i] + (1//8*dtsq)*k.x[1][i]
@inbounds kdu[i] = uprev2[i] + (1//2*dt)*duprev2[i] + (1//8*dtsq)*k.x[2][i]
end

f.f2(t+1//2*dt, ku, duprev, k₂.x[1])
f.f2(tprev+1//2*dt,kdu,duprev2,k₂.x[2])
@tight_loop_macros for i in uidx
@inbounds u[i] = uprev[i] + (3//2*dt)*duprev[i] + (1//2*-dt)*duprev2[i] + (5//12*dtsq)*(k₂.x[1][i]-k₂.x[2][i])
@inbounds du[i] = duprev[i] + dt*(2//3*k.x[1][i] + 1//3*k.x[2][i] + 5//6*(k₂.x[1][i]-k₂.x[2][i]))
end
f.f1(t+dt,u,du,k.x[1])
f.f2(t+dt,u,du,k.x[2])
end # end if
end

@inline function initialize!(integrator,cache::IRKN4Cache,f=integrator.f)
@unpack tmp,fsalfirst,k₂,k = cache
uprev,duprev = integrator.uprev.x
Expand Down Expand Up @@ -219,3 +266,50 @@ end
f.f1(t+dt,u,du,k.x[1])
f.f2(t+dt,u,du,k.x[2])
end

function initialize!(integrator,cache::DPRKN6Cache,f=integrator.f)
integrator.fsalfirst = cache.fsalfirst
integrator.fsallast = cache.k

integrator.kshortsize = 2
integrator.k = eltype(integrator.sol.k)(integrator.kshortsize)
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast

uprev,duprev = integrator.uprev.x
f.f1(integrator.t,uprev,duprev,integrator.k[2].x[1])
f.f2(integrator.t,uprev,duprev,integrator.k[2].x[2])
end

@muladd function perform_step!(integrator,cache::DPRKN6Cache,f=integrator.f)
@unpack t,dt = integrator
u,du = integrator.u.x
uprev,duprev = integrator.uprev.x
@unpack tmp,atmp,fsalfirst,k2,k3,k4,k5,k6,k,utilde = cache
@unpack c1, c2, c3, c4, c5, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a63, a64, a65, b1, b3, b4, b5, bp1, bp3, bp4, bp5, bp6, btilde1, btilde2, btilde3, bptilde1, bptilde3, bptilde4, bptilde5, bptilde6 = cache.tab
ku, kdu = integrator.cache.tmp.x[1], integrator.cache.tmp.x[2]
k1 = fsalfirst

@. ku = uprev + dt*(duprev + dt*a21*k1.x[2])

f.f2(t+dt*c1,ku,du,k2.x[2])
@. ku = uprev + dt*(c1*duprev + dt*(a31*k1.x[2] + a32*k2.x[2]))

f.f2(t+dt*c2,ku,du,k3.x[2])
@. ku = uprev + dt*(c2*duprev + dt*(a41*k1.x[2] + a42*k2.x[2] + a43*k3.x[2]))

f.f2(t+dt*c3,ku,du,k4.x[2])
@. ku = uprev + dt*(c3*duprev + dt*(a51*k1.x[2] + a52*k2.x[2] + a53*k3.x[2] + a54*k4.x[2]))

f.f2(t+dt*c4,ku,du,k5.x[2])
@. ku = uprev + dt*(c4*duprev + dt*(a61*k1.x[2] + a63*k3.x[2] + a64*k4.x[2] + a65*k5.x[2])) # no a62

f.f2(t+dt*c5,ku,du,k6.x[2])

@. u = uprev + dt*(duprev + dt*(b1 *k1.x[2] + b3 *k3.x[2] + b4 *k4.x[2] + b5 *k5.x[2])) # b1 -- b5, no b2
@. du = duprev + dt*(bp1*k1.x[2] + bp3*k3.x[2] + bp4*k4.x[2] + bp5*k5.x[2] + bp6*k6.x[2]) # bp1 -- bp6, no bp2

f.f1(t+dt,u,du,k.x[1])
f.f2(t+dt,u,du,k.x[2])
end

136 changes: 136 additions & 0 deletions src/tableaus/rkn_tableaus.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
struct DPRKN6ConstantCache{T,T2} <: OrdinaryDiffEqConstantCache
c1::T2
c2::T2
c3::T2
c4::T2
c5::T2
a21::T
a31::T
a32::T
a41::T
a42::T
a43::T
a51::T
a52::T
a53::T
a54::T
a61::T
# a62::T
a63::T
a64::T
a65::T
b1::T
# b2::T
b3::T
b4::T
b5::T
# b6::T
bp1::T # bp denotes bprime
# bp2::T
bp3::T
bp4::T
bp5::T
bp6::T
btilde1::T
btilde2::T
btilde3::T
# btilde4::T
# btilde5::T
# btilde6::T
bptilde1::T
# bptilde2::T
bptilde3::T
bptilde4::T
bptilde5::T
bptilde6::T
end

Base.@pure function DPRKN6ConstantCache{T<:CompiledFloats,T2<:CompiledFloats}(::Type{T},::Type{T2})
c1 = T2(0.12929590313670442)
c2 = T2(0.25859180627340883)
c3 = T2(0.67029708261548)
c4 = T2(0.9)
c5 = T2(1.0)
a21 = T(0.008358715283968025)
a31 = T(0.011144953711957367)
a32 = T(0.022289907423914734)
a41 = T(0.1454747428010918)
a42 = T(-0.22986064052264749)
a43 = T(0.3090349872029675)
a51 = T(-0.20766826295078997)
a52 = T(0.6863667842925143)
a53 = T(-0.19954927787234925)
a54 = T(0.12585075653062489)
a61 = T(0.07811016144349478)
a63 = T(0.2882917411897668)
a64 = T(0.12242553717457041)
a65 = T(0.011172560192168035)
b1 = T(0.07811016144349478)
b3 = T(0.2882917411897668)
b4 = T(0.12242553717457041)
b5 = T(0.011172560192168035)
bp1 = T(0.07811016144349478)
bp3 = T(3.220718367176496)
bp4 = T(3.203195646399356)
bp5 = T(0.11172560192168035)
bp6 = T(0.05)
btilde1 = T(1.0588592603704183)
btilde2 = T(-2.406751371924452)
btilde3 = T(1.8478921115540339)
bptilde1 = T(0.054605887939221276)
bptilde3 = T(0.46126678590362685)
bptilde4 = T(0.19588085947931266)
bptilde5 = T(0.38824646667783924)
bptilde6 = T(-0.1)
DPRKN6ConstantCache(c1, c2, c3, c4, c5, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a63, a64, a65, b1, b3, b4, b5, bp1, bp3, bp4, bp5, bp6, btilde1, btilde2, btilde3, bptilde1, bptilde3, bptilde4, bptilde5, bptilde6)
end

function DPRKN6ConstantCache(T::Type,T2::Type)
R = sqrt(big(8581))
c1 = T2((209-R)/900)
c2 = T2((209-R)/450)
c3 = T2((209+R)/450)
c4 = T2(9//10)
c5 = T2(1)
a21 = T((26131-209R)/81_0000)
a31 = T((26131-209R)/60_7500)
a32 = T((26131-209R)/30_3750)
a41 = T((980403512254+7781688431R)/116944_6992_1875)
a42 = T(-(126288_4486208+153854_81287R)/116944_6992_1875)
a43 = T((7166_233_891_441+786_945_632_99R)/46_777_879_687_500)
a51 = T(-9(329260+3181R)/2704_0000)
a52 = T(27(35129+3331R)/1352_0000)
a53 = T(-27(554358343+31040327R)/46406048_0000)
a54 = T(153(8555_257-67973R)/274592_0000)
a61 = T(329//4212)
# a62 = T(0)
a63 = T((8411_9543+366_727R)/4096_22616)
a64 = T((8411_9543-366_727R)/4096_22616)
a65 = T(200//17901)
b1 = T(329//4212)
# b2 = T(0)
b3 = a63
b4 = a64
b5 = T(200//17901)
# b6 = T(0)
bp1 = b1
# bp2 = b2
bp3 = T((32_8922_5579+96856R)/10_2405_6540)
bp4 = T((32_8922_5579-96856R)/10_2405_6540)
bp5 = T(2000//17901)
bp6 = T(1//20)
btilde1 = T((2701+23R)/4563)
btilde2 = T(-(9829+131R)/9126)
btilde3 = T(5(1798+17R)/9126)
# btilde4 = T(0)
# btilde5 = T(0)
# btilde6 = T(0)
bptilde1 = T(115//2106)
# btildep2 = T(0)
bptilde3 = T((8411_9543+366_727R)/2560_14135)
bptilde4 = T((8411_9543-366_727R)/2560_14135)
bptilde5 = T(6950//17901)
bptilde6 = T(-1//10)
DPRKN6ConstantCache(c1, c2, c3, c4, c5, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a63, a64, a65, b1, b3, b4, b5, bp1, bp3, bp4, bp5, bp6, btilde1, btilde2, btilde3, bptilde1, bptilde3, bptilde4, bptilde5, bptilde6)
end

Loading

0 comments on commit 08bd93c

Please sign in to comment.