diff --git a/.github/workflows/Downgrade.yml b/.github/workflows/Downgrade.yml index d5e3d732e..dd233f538 100644 --- a/.github/workflows/Downgrade.yml +++ b/.github/workflows/Downgrade.yml @@ -25,6 +25,7 @@ jobs: - QA - ODEBPINN - PDEBPINN + - NNSDE - NNPDE1 - NNPDE2 - AdaptiveLoss diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index 1a405daf4..e5bec03f0 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -31,6 +31,7 @@ jobs: - "QA" - "ODEBPINN" - "PDEBPINN" + - "NNSDE" - "NNPDE1" - "NNPDE2" - "AdaptiveLoss" diff --git a/src/NN_SDE_solve.jl b/src/NN_SDE_solve.jl new file mode 100644 index 000000000..9686d6a1c --- /dev/null +++ b/src/NN_SDE_solve.jl @@ -0,0 +1,481 @@ +@concrete struct NNSDE + chain <: AbstractLuxLayer + opt + init_params + autodiff::Bool + batch::Bool + strategy <: Union{Nothing, AbstractTrainingStrategy} + param_estim::Bool + additional_loss <: Union{Nothing, Function} + sub_batch::Int64 + strong_loss::Bool + numensemble::Number + kwargs +end + +function NNSDE(chain, opt, init_params = nothing; strategy = nothing, autodiff = false, + batch = true, param_estim = false, additional_loss = nothing, + sub_batch = 1, strong_loss = false, numensemble = 10, kwargs...) + chain isa AbstractLuxLayer || (chain = FromFluxAdaptor()(chain)) + return NNSDE(chain, opt, init_params, autodiff, batch, + strategy, param_estim, additional_loss, sub_batch, strong_loss, numensemble, kwargs) +end + +""" + SDEPhi(chain::Lux.AbstractLuxLayer, t, u0, st) + +Internal struct, used for representing the SDE solution as a neural network in a form that +respects boundary conditions, i.e. `phi(inp) = u0 + inp[1]*NN(inp)`. +""" +@concrete struct SDEPhi + u0 + t0 + smodel <: StatefulLuxLayer +end + +function SDEPhi(model::AbstractLuxLayer, t0::Number, u0, st) + return SDEPhi(u0, t0, StatefulLuxLayer{true}(model, nothing, st)) +end + +function (f::SDEPhi)(inp, θ) + dev = safe_get_device(θ) + return f(dev, safe_expand(dev, inp), θ) +end + +function (f::SDEPhi{<:Number})(dev, inp::Array{<:Number, 1}, θ) + res = only(cdev(f.smodel(dev(inp), θ.depvar))) + return f.u0 + (inp[1] - f.t0) * res +end + +function (f::SDEPhi)(dev, inp::Array{<:Number, 1}, θ) + return dev(f.u0) .+ (inp[1] - f.t0) .* f.smodel(dev(inp), θ.depvar) +end + +function generate_phi(chain::AbstractLuxLayer, t, u0, ::Nothing) + θ, st = LuxCore.setup(Random.default_rng(), chain) + return SDEPhi(chain, t, u0, st), θ +end + +function generate_phi(chain::AbstractLuxLayer, t, u0, init_params) + st = LuxCore.initialstates(Random.default_rng(), chain) + return SDEPhi(chain, t, u0, st), init_params +end + +""" + ∂u_∂t(phi, inp, θ, autodiff) + +Computes for sde's solution u, u' using either forward-mode automatic differentiation or numerical differentiation. +""" +function ∂u_∂t end + +# earlier input as number or abstract vector, now becomes array1 or matrix +# input first col must be time, rest all z_i's +# vector of t_i's sub_batch input vectors = inputs +# returns a vector of gradients for each sub_batch, for each method call over all its sub_batches +function ∂u_∂t(phi::SDEPhi, inputs::Array{<:Array{<:Number, 1}}, θ, autodiff::Bool) + autodiff && + return [ForwardDiff.gradient( + t -> (phi(vcat(t, input[2:end]), θ)), input[1]) + for input in inputs] + ϵ = sqrt(eps(eltype(inputs[1]))) + return [(phi(vcat(input[1] + ϵ, input[2:end]), θ) .- phi(input, θ)) ./ ϵ + for input in inputs] +end + +""" + inner_sde_loss(phi, f, autodiff, t, θ, p, param_estim) + +Simple L2 inner loss for the SDE at a time `t` and random variables z_i with parameters `θ` of the neural network. + +For NNSE instead of a matrix, input is a N x n_sub x (t+n_z) dim array N -> n timepoints, n_sub -> n_sub_batch, t+n_z -> chain/phi input dims. +Inner Sde loss enforces weak/strong solution across sub_batches (strong sol convergence implies weak sol convergence but not vice versa) +Note: test file GBM SDE case, weak sol training gives better results for more sub_samples. strong sol training gives opposite results. +Note: NNODE, NNSDE take only a single Neural Network which is multioutput or singleoutput + +""" +function inner_sde_loss end + +# no batching +function inner_sde_loss( + phi::SDEPhi, f, g, autodiff::Bool, inputs::Array{<:Array{<:Number, 1}}, + θ, p, param_estim::Bool, train_type) + p_ = param_estim ? θ.p : p + + # phi's many outputs for the same timepoint's many sub_batches + # must also cover multioutput case + u = [phi(sub_batch_input, θ) for sub_batch_input in inputs] + + # for t's sub_batch we consider the i_th batch in sub_batch as indvar, z_i... inputs for the ith phi output (also covers sub_batch=1 case) + fs = if phi.u0 isa Number + [f(u[i][1], p_, inp[1]) + + g(u[i][1], p_, inp[1]) * √2 * + sum(inp[1 + j] * cos((j - 1 / 2)pi * inp[1]) for j in 1:(length(inp) - 1)) + for (i, inp) in enumerate(inputs)] + else + # will be vector in multioutput case + [f(u[i], p_, inp[1]) + + g(u[i], p_, inp[1]) * √2 * + sum(inp[1 + j] * cos((j - 1 / 2)pi * inp[1]) for j in 1:(length(inp) - 1)) + for (i, inp) in enumerate(inputs)] + end + + # gradient at t is not affected by sub_batch z_i values beyond use in phi(t+ϵ) + dudt = ∂u_∂t(phi, inputs, θ, autodiff) + + # initially .- broadcasting over NN multiple/single outputs and subbatches simultaneously + # fs and dudt size is (n_sub_batch, NN_output_size), loss is for one timepoint + # broadcasted sum(abs2,) over losses for each sub_batch, where rows - sub_batch's and cols - NN multiple/single outputs + # finally mean over vector of L2 errors (all the sub_batches) (direct sum(strong sol)/mean(weak sol) across sub_batches) + return train_type(sum.(abs2, fs .- dudt)) +end + +# batching case +function inner_sde_loss( + phi::SDEPhi, f, g, autodiff::Bool, inputs::Array{<:Array{<:Array{<:Number, 1}}}, + θ, p, param_estim::Bool, train_type) + p_ = param_estim ? θ.p : p + + # phi's many outputs for each timepoint's many sub_batches, input for each t_i, sub_batch_input for each input's sub_batches + # must also cover multioutput case + u = [[phi(sub_batch_input, θ) for sub_batch_input in input] for input in inputs] + + # for each t_i's sub_batch we consider the i_th batch in sub_batch as indvar, z_i... inputs for the i_th phi output (covers subbatch=1 case for each t_i) + fs = if phi.u0 isa Number + [[f(u[i][j][1], p_, inpi[1]) + + g(u[i][j][1], p_, inpi[1]) * 2^(1 / 2) * + sum(inpi[1 + k] * cos((k - 1 / 2)pi * inpi[1]) for k in 1:(length(inpi) - 1)) + for (j, inpi) in enumerate(inp)] for (i, inp) in enumerate(inputs)] + else + [[f(u[i][j], p_, inpi[1]) + + g(u[i][j], p_, inpi[1]) * 2^(1 / 2) * + sum(inpi[1 + k] * cos((k - 1 / 2)pi * inpi[1]) for k in 1:(length(inpi) - 1)) + for (j, inpi) in enumerate(inp)] for (i, inp) in enumerate(inputs)] + end + + # fs[i] is made of n=n_sub_batch, vectors of n=n_output_dims dim each + # gradient at t_i's is not affected by their sub_batch's z_i values beyond use in phi(ti+ϵ) + dudt = [∂u_∂t(phi, inpi, θ, autodiff) for inpi in inputs] + + # Taking MSE across Z, each fs and du/dt has n_sub_batch elements in them + # mean used for each timepoint's sub_batch as weak solution enforced for each WienerProcess realization (better results for same n iterations) + # similar explanation as non batching additionally final sum aggregated over all timepoints. + return sum(train_type(sum.(abs2, fs[i] .- dudt[i])) for i in eachindex(inputs)) / + length(inputs) +end + +""" + add_rand_coeff(times, n_z) +n_z is the number of orthogonal basis (probability space) of Random variables taken in the KKl expansion for an SDE. +n_z can also be a list of sampled values +returns a list appending n_z or n = n_z sampled (Uniform Gaussian) random variables values to a fixed time's value or a list of times. +""" +function add_rand_coeff(times, n_z::Int64) + times isa Number && return vcat(times, rand(Normal(0, 1), n_z)) + return [vcat(time, rand(Normal(0, 1), n_z)) + for time in times] +end + +""" + generate_loss(strategy, phi, f, autodiff, tspan, p, batch, param_estim) + +Representation of the loss function, parametric on the training strategy `strategy`. +""" +function generate_loss( + strategy::QuadratureTraining, phi, f, g, autodiff::Bool, tspan, n_z::Int64, n_sub_batch::Int64, train_type, p, + batch::Bool, param_estim::Bool) + inputs = AbstractVector{Any}[] + function integrand(t::Number, θ) + inputs = [[add_rand_coeff(t, n_z) for i in 1:n_sub_batch]] + return abs2(inner_sde_loss( + phi, f, g, autodiff, inputs, θ, p, param_estim, train_type)) + end + + # when ts is a 1D Array + function integrand(ts, θ) + inputs = [[add_rand_coeff(t, n_z) for i in 1:n_sub_batch] for t in ts] + return [abs2(inner_sde_loss( + phi, f, g, autodiff, input, θ, p, param_estim, train_type)) + for input in inputs] + end + + function loss(θ, _) + intf = BatchIntegralFunction(integrand, max_batch = strategy.batch) + intprob = IntegralProblem(intf, (tspan[1], tspan[2]), θ) + sol = solve(intprob, strategy.quadrature_alg; strategy.abstol, + strategy.reltol, strategy.maxiters) + return sol.u + end + + return loss, inputs +end + +function generate_loss( + strategy::GridTraining, phi, f, g, autodiff::Bool, tspan, n_z::Int64, n_sub_batch::Int64, + train_type, p, batch::Bool, param_estim::Bool) + ts = tspan[1]:(strategy.dx):tspan[2] + # in (t,n_i,..) space we solve at one point, NN(input) can also represent only this point if subbatch=1 + # inp = add_rand_coeff(ts, n_z) + # for each ti in t we have n=n_sub_batch phi onput possibilities + inputs = [[add_rand_coeff(t, n_z) for i in 1:n_sub_batch] for t in ts] + + autodiff && throw(ArgumentError("autodiff not supported for GridTraining.")) + batch && + return (θ, _) -> inner_sde_loss( + phi, f, g, autodiff, inputs, θ, p, param_estim, train_type), + inputs + return (θ, _) -> sum([inner_sde_loss(phi, f, g, autodiff, input, θ, p, + param_estim, train_type) + for input in inputs]), + inputs +end + +function generate_loss(strategy::StochasticTraining, phi, f, g, autodiff::Bool, + tspan, n_z::Int64, n_sub_batch::Int64, train_type, p, batch::Bool, param_estim::Bool) + autodiff && throw(ArgumentError("autodiff not supported for StochasticTraining.")) + inputs = AbstractVector{Any}[] + + return (θ, _) -> begin + T = promote_type(eltype(tspan[1]), eltype(tspan[2])) + ts = ((tspan[2] - tspan[1]) .* rand(T, strategy.points) .+ tspan[1]) + inputs = [[add_rand_coeff(t, n_z) for i in 1:n_sub_batch] for t in ts] + + if batch + inner_sde_loss( + phi, f, g, autodiff, inputs, θ, p, param_estim, train_type) + else + sum([inner_sde_loss(phi, f, g, autodiff, input, θ, p, + param_estim, train_type) + for input in inputs]) + end + end, + inputs +end + +function generate_loss( + strategy::WeightedIntervalTraining, phi, f, g, autodiff::Bool, tspan, n_z::Int64, n_sub_batch::Int64, train_type, p, + batch::Bool, param_estim::Bool) + autodiff && throw(ArgumentError("autodiff not supported for WeightedIntervalTraining.")) + minT, maxT = tspan + weights = strategy.weights ./ sum(strategy.weights) + N = length(weights) + difference = (maxT - minT) / N + + ts = eltype(difference)[] + for (index, item) in enumerate(weights) + temp_data = rand(1, trunc(Int, strategy.points * item)) .* difference .+ minT .+ + ((index - 1) * difference) + append!(ts, temp_data) + end + inputs = [[add_rand_coeff(t, n_z) for i in 1:n_sub_batch] for t in ts] + + batch && + return (θ, _) -> inner_sde_loss( + phi, f, g, autodiff, inputs, θ, p, param_estim, train_type), + inputs + return (θ, _) -> sum([inner_sde_loss(phi, f, g, autodiff, input, θ, p, + param_estim, train_type) + for input in inputs]), + inputs +end + +function evaluate_tstops_loss( + phi, f, g, autodiff::Bool, tstops, n_z::Int64, n_sub_batch::Int64, + train_type, p, batch::Bool, param_estim::Bool) + inputs = [[add_rand_coeff(t, n_z) for i in 1:n_sub_batch] for t in tstops] + batch && + return (θ, _) -> inner_sde_loss( + phi, f, g, autodiff, inputs, θ, p, param_estim, train_type), + inputs + return (θ, _) -> sum([inner_sde_loss(phi, f, g, autodiff, input, θ, p, + param_estim, train_type) + for input in inputs]), + inputs +end + +function generate_loss(::QuasiRandomTraining, phi, f, g, autodiff::Bool, + tspan, n_z::Int64, n_sub_batch::Int64, train_type, p, batch::Bool, param_estim::Bool) + error("QuasiRandomTraining is not supported by NNODE since it's for high dimensional \ + spaces only. Use StochasticTraining instead.") +end + +@concrete struct NNSDEInterpolation + phi <: SDEPhi + θ +end + +(f::NNSDEInterpolation)(inp, ::Nothing, ::Type{Val{0}}, p, continuity) = f.phi(inp, f.θ) +(f::NNSDEInterpolation)(inp, idxs, ::Type{Val{0}}, p, continuity) = f.phi(inp, f.θ)[idxs] + +function (f::NNSDEInterpolation)( + inp::Array{<:Number, 1}, ::Nothing, ::Type{Val{0}}, p, continuity) + out = f.phi(inp, f.θ) + return DiffEqArray([out[:, i] for i in axes(out, 2)], inp) +end + +function (f::NNSDEInterpolation)( + inp::Array{<:Number, 1}, idxs, ::Type{Val{0}}, p, continuity) + out = f.phi(inp, f.θ) + return DiffEqArray([out[idxs, i] for i in axes(out, 2)], inp) +end + +SciMLBase.interp_summary(::NNSDEInterpolation) = "Trained neural network interpolation" +SciMLBase.allowscomplex(::NNSDE) = true + +@concrete struct SDEsol + solution + strong_sol::AbstractVector{<:Particles} + timepoints::AbstractVector{<:Number} + ensemble_fits::AbstractVector + ensemble_inputs::AbstractVector + numensemble::Number + training_sets::AbstractVector +end + +function SciMLBase.__solve( + prob::SciMLBase.AbstractSDEProblem, + alg::NNSDE, + args...; + dt = nothing, + timeseries_errors = true, + save_everystep = true, + adaptive = false, + abstol = 1.0f-6, + reltol = 1.0f-3, + verbose = false, + saveat = nothing, + maxiters = nothing, + tstops = nothing +) + (; u0, tspan, f, g, p) = prob + # rescaling tspan, discretization so KKL expansion can be applied for loss formulation + tspan_scale = tspan ./ tspan[end] + if dt !== nothing + dt = dt / abs(tspan_scale[2] - tspan_scale[1]) + end + + t0 = tspan_scale[1] + (; param_estim, chain, opt, autodiff, init_params, batch, additional_loss, sub_batch, strong_loss, numensemble) = alg + n_z = chain[1].in_dims - 1 + sde_phi, init_params = generate_phi(chain, t0, u0, init_params) + + (recursive_eltype(init_params) <: Complex && alg.strategy isa QuadratureTraining) && + error("QuadratureTraining cannot be used with complex parameters. Use other strategies.") + + init_params = if alg.param_estim + ComponentArray(; depvar = init_params, p) + else + ComponentArray(; depvar = init_params) + end + + @assert !isinplace(prob) "The NNSDE solver only supports out-of-place SDE definitions, i.e. du=f(u,p,t) + g(u,p,t)*dW(t)" + + strategy = if alg.strategy === nothing + if dt !== nothing + GridTraining(dt) + else + QuadratureTraining(; quadrature_alg = QuadGKJL(), + reltol = convert(eltype(u0), reltol), abstol = convert(eltype(u0), abstol), + maxiters, batch = 0) + end + else + alg.strategy + end + + # train_type is weak training (expectation based loss) by default, use strong_loss = true for strong loss (pathwise loss) + train_type = strong_loss ? sum : mean + inner_f, training_sets = generate_loss( + strategy, sde_phi, f, g, autodiff, tspan_scale, n_z, + sub_batch, train_type, p, batch, param_estim) + + (param_estim && additional_loss === nothing) && + throw(ArgumentError("Please provide `additional_loss` in `NNSDE` for parameter estimation (`param_estim` is true).")) + + # Creates OptimizationFunction Object from total_loss + function total_loss(θ, _) + L2_loss = inner_f(θ, sde_phi) + if additional_loss !== nothing + L2_loss = L2_loss + additional_loss(sde_phi, θ) + end + if tstops !== nothing + num_tstops_points = length(tstops) + tstops_loss_func = evaluate_tstops_loss( + sde_phi, f, g, autodiff, tstops, n_z, sub_batch, + train_type, p, batch, param_estim) + tstops_loss = tstops_loss_func(θ, sde_phi) + if strategy isa GridTraining + num_original_points = length(tspan_scale[1]:(strategy.dx):tspan_scale[2]) + elseif strategy isa Union{WeightedIntervalTraining, StochasticTraining} + num_original_points = strategy.points + else + return L2_loss + tstops_loss + end + total_original_loss = L2_loss * num_original_points + total_tstops_loss = tstops_loss * num_tstops_points + total_points = num_original_points + num_tstops_points + L2_loss = (total_original_loss + total_tstops_loss) / total_points + end + return L2_loss + end + + opt_algo = ifelse(strategy isa QuadratureTraining, AutoForwardDiff(), AutoZygote()) + optf = OptimizationFunction(total_loss, opt_algo) + + plen = maxiters === nothing ? 6 : ndigits(maxiters) + callback = function (p, l) + if verbose + if maxiters === nothing + @printf("[NNSDE]\tIter: [%*d]\tLoss: %g\n", plen, p.iter, l) + else + @printf("[NNSDE]\tIter: [%*d/%d]\tLoss: %g\n", plen, p.iter, maxiters, l) + end + end + return l < abstol + end + + optprob = OptimizationProblem(optf, init_params) + res = solve(optprob, opt; callback, maxiters, alg.kwargs...) + + #solutions at timepoints + if saveat isa Number + ts = tspan_scale[1]:saveat:tspan_scale[2] + elseif saveat isa AbstractArray + ts = saveat + elseif dt !== nothing + ts = tspan_scale[1]:dt:tspan_scale[2] + elseif save_everystep + ts = range(tspan_scale[1], tspan_scale[2], length = 100) + else + ts = [tspan_scale[1], tspan_scale[2]] + end + ts = collect(ts) + + ensembles = [] + ensemble_inputs = [] + for i in 1:numensemble + inputs = add_rand_coeff(ts, n_z) + + if u0 isa Number + u = [first(sde_phi(input, res.u)) for input in inputs] + else + u = [sde_phi(input, res.u) for input in inputs] + end + push!(ensembles, u) + push!(ensemble_inputs, inputs) + end + sde_sols = hcat(ensembles...) + strong_sde_sol = [Particles(sde_sols[i, :]) for i in eachindex(ts)] + + # SDEsol.solution contains the weak solution only + # Strong solution can be accessed via SDEsol.strong_sol + sol = SciMLBase.build_solution(prob, alg, ts, strong_sde_sol; k = res, dense = true, + interp = NNSDEInterpolation(sde_phi, res.u), calculate_error = false, + retcode = ReturnCode.Success, original = res, resid = res.objective) + + SciMLBase.has_analytic(prob.f) && + SciMLBase.calculate_solution_errors!( + sol; timeseries_errors = true, dense_errors = false) + + # separate Wernier process realisations and their solutions can be accessed via ensembles, ensemble_inputs + return SDEsol( + sol, strong_sde_sol, ts, ensembles, ensemble_inputs, numensemble, training_sets) +end diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index 55227371e..54c1ac091 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -49,7 +49,7 @@ using AdvancedHMC: AdvancedHMC, DiagEuclideanMetric, HMC, HMCDA, Hamiltonian, using Distributions: Distributions, Distribution, MvNormal, Normal, dim, logpdf using LogDensityProblems: LogDensityProblems using MCMCChains: MCMCChains, Chains, sample -using MonteCarloMeasurements: Particles +using MonteCarloMeasurements: Particles, pmean import LuxCore: initialparameters, initialstates, parameterlength @@ -90,10 +90,12 @@ include("BPINN_ode.jl") include("PDE_BPINN.jl") include("dgm.jl") +include("NN_SDE_solve.jl") export PINOODE export NNODE, NNDAE export BNNODE, ahmc_bayesian_pinn_ode, ahmc_bayesian_pinn_pde +export NNSDE export PhysicsInformedNN, discretize export BPINNsolution, BayesianPINN export DeepGalerkin diff --git a/test/NNODE_tests.jl b/test/NNODE_tests.jl index d8a833ca7..8dea6fde0 100644 --- a/test/NNODE_tests.jl +++ b/test/NNODE_tests.jl @@ -1,10 +1,3 @@ -# using Test, Random, NeuralPDE, OrdinaryDiffEq, Statistics, Lux, OptimizationOptimisers, -# OptimizationOptimJL, WeightInitializers, LineSearches -# import Flux - -# rng = Random.default_rng() -# Random.seed!(100) - @testitem "Scalar" tags=[:nnode] begin using OrdinaryDiffEq, Random, Lux, Optimisers using OptimizationOptimJL: BFGS diff --git a/test/NN_SDE_tests.jl b/test/NN_SDE_tests.jl new file mode 100644 index 000000000..8f97421e8 --- /dev/null +++ b/test/NN_SDE_tests.jl @@ -0,0 +1,301 @@ +@testitem "Test-1" tags=[:nnsde] begin + using OrdinaryDiffEq, Random, Lux, Optimisers + using OptimizationOptimJL: BFGS + Random.seed!(100) + + α = 1.2 + β = 1.1 + u₀ = 0.5 + f(u, p, t) = α * u + g(u, p, t) = β * u + tspan = (0.0, 1.0) + prob = SDEProblem(f, g, u₀, tspan) + dim = 1 + 3 + luxchain = Chain(Dense(dim, 16, σ), Dense(16, 16, σ), Dense(16, 1)) + + @testset "$(nameof(typeof(opt))) -- $(autodiff)" for opt in [BFGS(), Adam(0.1)], + autodiff in [false, true] + + if autodiff + @test_throws ArgumentError solve( + prob, NNSDE(luxchain, opt; autodiff); maxiters = 200, dt = 1 / 20.0f0) + continue + end + + @testset for (dt, abstol) in [(1 / 20.0f0, 1e-10), (nothing, 1e-6)] + kwargs = (; verbose = false, dt, abstol, maxiters = 200) + sol = solve(prob, NNSDE(luxchain, opt; autodiff); kwargs...) + end + end +end + +@testitem "Test - GBM SDE" tags=[:nnsde] begin + using OrdinaryDiffEq, Random, Lux, Optimisers, DiffEqNoiseProcess, Distributions + using OptimizationOptimJL: BFGS + using MonteCarloMeasurements: Particles, pmean + Random.seed!(100) + + α = 1.2 + β = 1.1 + u₀ = 0.5 + f(u, p, t) = α * u + g(u, p, t) = β * u + tspan = (0.0, 1.0) + prob = SDEProblem(f, g, u₀, tspan) + dim = 1 + 3 + luxchain = Chain(Dense(dim, 16, σ), Dense(16, 16, σ), Dense(16, 1)) + + dt = 1 / 50.0f0 + abstol = 1e-6 + autodiff = false + kwargs = (; verbose = true, dt = dt, abstol, maxiters = 300) + opt = BFGS() + numensemble = 2000 + + sol_2 = solve( + prob, NNSDE( + luxchain, opt; autodiff, numensemble = numensemble, sub_batch = 10, batch = true); + kwargs...) + + sol_1 = solve( + prob, NNSDE( + luxchain, opt; autodiff, numensemble = numensemble, sub_batch = 1, batch = true); + kwargs...) + + # sol_1 and sol_2 have same timespan + ts = sol_1.timepoints + u1 = sol_1.strong_sol + u2 = sol_2.strong_sol + + analytic_sol(u0, p, t, W) = u0 * exp((α - β^2 / 2) * t + β * W) + function W_kkl(t, z1, z2, z3) + √2 * (z1 * sin((1 - 1 / 2) * π * t) / ((1 - 1 / 2) * π) + + z2 * sin((2 - 1 / 2) * π * t) / ((2 - 1 / 2) * π) + + z3 * sin((3 - 1 / 2) * π * t) / ((3 - 1 / 2) * π)) + end + truncated_sol(u0, t, z1, z2, z3) = u0 * + exp((α - β^2 / 2) * t + β * W_kkl(t, z1, z2, z3)) + + num_samples = 3000 + num_time_steps = dt + z1_samples = rand(Normal(0, 1), num_samples) + z2_samples = rand(Normal(0, 1), num_samples) + z3_samples = rand(Normal(0, 1), num_samples) + + num_time_steps = length(ts) + W_samples = Array{Float64}(undef, num_time_steps, num_samples) + for i in 1:num_samples + W = WienerProcess(0.0, 0.0) + probtemp = NoiseProblem(W, (0.0, 1.0)) + Np_sol = solve(probtemp; dt = dt) + W_samples[:, i] = Np_sol.u + end + + temp_rands = hcat(z1_samples, z2_samples, z3_samples)' + phi_inputs = [hcat([vcat(ts[j], temp_rands[:, i]) for j in eachindex(ts)]...) + for i in 1:num_samples] + + analytic_solution_samples = Array{Float64}(undef, num_time_steps, num_samples) + truncated_solution_samples = Array{Float64}(undef, num_time_steps, num_samples) + predicted_solution_samples_1 = Array{Float64}(undef, num_time_steps, num_samples) + predicted_solution_samples_2 = Array{Float64}(undef, num_time_steps, num_samples) + + for j in 1:num_samples + for i in 1:num_time_steps + # for each sample, pass each timepoints and get output + analytic_solution_samples[i, j] = analytic_sol(u₀, 0, ts[i], W_samples[i, j]) + + predicted_solution_samples_1[i, j] = sol_1.solution.interp.phi( + phi_inputs[j][:, i], sol_1.solution.interp.θ) + predicted_solution_samples_2[i, j] = sol_2.solution.interp.phi( + phi_inputs[j][:, i], sol_2.solution.interp.θ) + + truncated_solution_samples[i, j] = truncated_sol( + u₀, ts[i], z1_samples[j], z2_samples[j], z3_samples[j]) + end + end + + # strong solution tests + strong_analytic_solution = [Particles(analytic_solution_samples[i, :]) + for i in eachindex(ts)] + strong_truncated_solution = [Particles(truncated_solution_samples[i, :]) + for i in eachindex(ts)] + strong_predicted_solution_1 = [Particles(predicted_solution_samples_1[i, :]) + for i in eachindex(ts)] + strong_predicted_solution_2 = [Particles(predicted_solution_samples_2[i, :]) + for i in eachindex(ts)] + + error_1 = sum(abs2, strong_analytic_solution .- strong_predicted_solution_1) + error_2 = sum(abs2, strong_analytic_solution .- strong_predicted_solution_2) + @test pmean(error_1) > pmean(error_2) + + @test pmean(sum(abs2.(strong_predicted_solution_1 .- strong_truncated_solution))) > + pmean(sum(abs2.(strong_predicted_solution_2 .- strong_truncated_solution))) + + # weak solution tests + mean_analytic_solution = mean(analytic_solution_samples, dims = 2) + mean_truncated_solution = mean(truncated_solution_samples, dims = 2) + mean_predicted_solution_1 = mean(predicted_solution_samples_1, dims = 2) + mean_predicted_solution_2 = mean(predicted_solution_samples_2, dims = 2) + + # testing over different Z_i sample sizes + error_1 = sum(abs2, mean_analytic_solution .- pmean(u1)) + error_2 = sum(abs2, mean_analytic_solution .- pmean(u2)) + @test error_1 > error_2 + + MSE_1 = mean(abs2.(mean_analytic_solution .- pmean(u1))) + MSE_2 = mean(abs2.(mean_analytic_solution .- pmean(u2))) + @test MSE_2 < MSE_1 + @test MSE_2 < 5e-2 + + error_1 = sum(abs2, mean_analytic_solution .- mean_predicted_solution_1) + error_2 = sum(abs2, mean_analytic_solution .- mean_predicted_solution_2) + @test error_1 > error_2 + + MSE_1 = mean(abs2.(mean_analytic_solution .- mean_predicted_solution_1)) + MSE_2 = mean(abs2.(mean_analytic_solution .- mean_predicted_solution_2)) + @test MSE_2 < MSE_1 + @test MSE_2 < 5e-2 + + @test mean(abs2.(mean_predicted_solution_1 .- mean_truncated_solution)) > + mean(abs2.(mean_predicted_solution_2 .- mean_truncated_solution)) + @test mean(abs2.(mean_predicted_solution_1 .- mean_truncated_solution)) < 3e-1 + @test mean(abs2.(mean_predicted_solution_2 .- mean_truncated_solution)) < 4e-2 +end + +# Equation 65 from https://arxiv.org/abs/1804.04344 +@testitem "Test-3 Additive Noise Test Equation" tags=[:nnsde] begin + using OrdinaryDiffEq, Random, Lux, Optimisers, DiffEqNoiseProcess, Distributions + using OptimizationOptimJL: BFGS + using MonteCarloMeasurements: Particles, pmean + Random.seed!(100) + + α = 0.1 + β = 0.05 + u₀ = 0.5 + f(u, p, t) = (β / sqrt(1 + t)) - (u[1] / ((1 + t) * 2)) + g(u, p, t) = β * α / sqrt(1 + t) + tspan = (0.0, 1.0) + prob = SDEProblem(f, g, u₀, tspan) + dim = 1 + 6 + luxchain = Chain(Dense(dim, 16, σ), Dense(16, 16, tanh), Dense(16, 16, σ), Dense(16, 1)) + + dt = 1 / 50.0f0 + abstol = 1e-7 + autodiff = false + kwargs = (; verbose = true, dt = dt, abstol, maxiters = 300) + opt = BFGS() + numensemble = 2000 + + sol_1 = solve( + prob, NNSDE( + luxchain, opt; autodiff, numensemble = numensemble, sub_batch = 1, batch = true); + kwargs...) + + sol_2 = solve( + prob, NNSDE( + luxchain, opt; autodiff, numensemble = numensemble, sub_batch = 10, batch = true); + kwargs...) + + # sol_1 and sol_2 have same timespan + ts = sol_1.timepoints + u1 = sol_1.strong_sol + u2 = sol_2.strong_sol + + analytic_sol(u0, p, t, W) = (u0 / sqrt(1 + t)) + (β * (t + α * W) / sqrt(1 + t)) + function W_kkl(t, z1, z2, z3, z4, z5, z6) + √2 * ((z1 * sin((1 - 1 / 2) * π * t) / ((1 - 1 / 2) * π)) + + (z2 * sin((2 - 1 / 2) * π * t) / ((2 - 1 / 2) * π)) + + (z3 * sin((3 - 1 / 2) * π * t) / ((3 - 1 / 2) * π)) + + (z4 * sin((4 - 1 / 2) * π * t) / ((4 - 1 / 2) * π)) + + (z5 * sin((5 - 1 / 2) * π * t) / ((5 - 1 / 2) * π)) + + (z6 * sin((6 - 1 / 2) * π * t) / ((6 - 1 / 2) * π))) + end + function truncated_sol(u0, t, z1, z2, z3, z4, z5, z6) + (u0 / sqrt(1 + t)) + (β * (t + α * W_kkl(t, z1, z2, z3, z4, z5, z6)) / sqrt(1 + t)) + end + + num_samples = 3000 + num_time_steps = dt + z1_samples = rand(Normal(0, 1), num_samples) + z2_samples = rand(Normal(0, 1), num_samples) + z3_samples = rand(Normal(0, 1), num_samples) + z4_samples = rand(Normal(0, 1), num_samples) + z5_samples = rand(Normal(0, 1), num_samples) + z6_samples = rand(Normal(0, 1), num_samples) + + num_time_steps = length(ts) + W_samples = Array{Float64}(undef, num_time_steps, num_samples) + for i in 1:num_samples + W = WienerProcess(0.0, 1.0) + probtemp = NoiseProblem(W, (0.0, 1.0)) + Np_sol = solve(probtemp; dt = dt) + W_samples[:, i] = Np_sol.u + end + + temp_rands = hcat( + z1_samples, z2_samples, z3_samples, z4_samples, z5_samples, z6_samples)' + phi_inputs = [hcat([vcat(ts[j], temp_rands[:, i]) for j in eachindex(ts)]...) + for i in 1:num_samples] + + analytic_solution_samples = Array{Float64}(undef, num_time_steps, num_samples) + truncated_solution_samples = Array{Float64}(undef, num_time_steps, num_samples) + predicted_solution_samples_1 = Array{Float64}(undef, num_time_steps, num_samples) + predicted_solution_samples_2 = Array{Float64}(undef, num_time_steps, num_samples) + + for j in 1:num_samples + for i in 1:num_time_steps + # for each sample, pass each timepoints and get output + analytic_solution_samples[i, j] = analytic_sol(u₀, 0, ts[i], W_samples[i, j]) + + predicted_solution_samples_1[i, j] = sol_1.solution.interp.phi( + phi_inputs[j][:, i], sol_1.solution.interp.θ) + predicted_solution_samples_2[i, j] = sol_2.solution.interp.phi( + phi_inputs[j][:, i], sol_2.solution.interp.θ) + + truncated_solution_samples[i, j] = truncated_sol( + u₀, ts[i], z1_samples[j], z2_samples[j], z3_samples[j], + z4_samples[j], z5_samples[j], z6_samples[j]) + end + end + + # strong solution tests + strong_analytic_solution = [Particles(analytic_solution_samples[i, :]) + for i in eachindex(ts)] + strong_truncated_solution = [Particles(truncated_solution_samples[i, :]) + for i in eachindex(ts)] + strong_predicted_solution_1 = [Particles(predicted_solution_samples_1[i, :]) + for i in eachindex(ts)] + strong_predicted_solution_2 = [Particles(predicted_solution_samples_2[i, :]) + for i in eachindex(ts)] + + error_1 = sum(abs2, strong_analytic_solution .- strong_predicted_solution_1) + error_2 = sum(abs2, strong_analytic_solution .- strong_predicted_solution_2) + @test pmean(error_1) > pmean(error_2) + + error1 = sum(abs2.(strong_predicted_solution_1 .- strong_truncated_solution)) + error2 = sum(abs2.(strong_predicted_solution_2 .- strong_truncated_solution)) + @test pmean(error1) > pmean(error2) + + # weak solution tests + mean_analytic_solution = mean(analytic_solution_samples, dims = 2) + mean_truncated_solution = mean(truncated_solution_samples, dims = 2) + mean_predicted_solution_1 = mean(predicted_solution_samples_1, dims = 2) + mean_predicted_solution_2 = mean(predicted_solution_samples_2, dims = 2) + + # testing over different Z_i sample sizes + MSE_1 = mean(abs2.(mean_analytic_solution .- pmean(u1))) + MSE_2 = mean(abs2.(mean_analytic_solution .- pmean(u2))) + @test MSE_1 < 5e-5 + @test MSE_2 < 3e-5 + + error_1 = sum(abs2, mean_truncated_solution .- mean_predicted_solution_1) + error_2 = sum(abs2, mean_truncated_solution .- mean_predicted_solution_2) + @test error_1 > error_2 + @test error_2 < 3e-3 + + MSE_1 = mean(abs2.(mean_truncated_solution .- mean_predicted_solution_1)) + MSE_2 = mean(abs2.(mean_truncated_solution .- mean_predicted_solution_2)) + @test MSE_2 < MSE_1 + @test MSE_2 < 3e-5 +end