diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index 945093ea04..edfaf9664a 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -50,8 +50,9 @@ include("rode_solve.jl") include("transform_inf_integral.jl") include("discretize.jl") include("neural_adapter.jl") -include("advancedHMC_MCMC.jl") -include("BPINN_ode.jl") +include("bayesian/advancedHMC_MCMC.jl") +include("bayesian/BPINN_ode.jl") +include("bayesian/collocated_estim.jl") export NNODE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE, KolmogorovPDEProblem, NNKolmogorov, NNStopping, ParamKolmogorovPDEProblem, diff --git a/src/BPINN_ode.jl b/src/bayesian/BPINN_ode.jl similarity index 96% rename from src/BPINN_ode.jl rename to src/bayesian/BPINN_ode.jl index da49640314..a2cce9db34 100644 --- a/src/BPINN_ode.jl +++ b/src/bayesian/BPINN_ode.jl @@ -178,7 +178,8 @@ function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem, verbose = false, saveat = 1 / 50.0, maxiters = nothing, - numensemble = floor(Int, alg.draw_samples / 3)) + numensemble = floor(Int, alg.draw_samples / 3), + estim_collocate = false) @unpack chain, l2std, phystd, param, priorsNNw, Kernel, strategy, draw_samples, dataset, init_params, nchains, physdt, Adaptorkwargs, Integratorkwargs, @@ -207,7 +208,8 @@ function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem, Integratorkwargs = Integratorkwargs, MCMCkwargs = MCMCkwargs, progress = progress, - verbose = verbose) + verbose = verbose, + estim_collocate = estim_collocate) fullsolution = BPINNstats(mcmcchain, samples, statistics) ninv = length(param) @@ -215,8 +217,12 @@ function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem, if chain isa Lux.AbstractExplicitLayer θinit, st = Lux.setup(Random.default_rng(), chain) + println(length(θinit)) + println(length(samples[1])) + println(draw_samples) θ = [vector_to_parameters(samples[i][1:(end - ninv)], θinit) - for i in (draw_samples - numensemble):draw_samples] + for i in 1:max(draw_samples - draw_samples ÷ 10, draw_samples - 1000)] + luxar = [chain(t', θ[i], st)[1] for i in 1:numensemble] # only need for size θinit = collect(ComponentArrays.ComponentArray(θinit)) diff --git a/src/advancedHMC_MCMC.jl b/src/bayesian/advancedHMC_MCMC.jl similarity index 95% rename from src/advancedHMC_MCMC.jl rename to src/bayesian/advancedHMC_MCMC.jl index 6032c7ca21..5e995ebfdb 100644 --- a/src/advancedHMC_MCMC.jl +++ b/src/bayesian/advancedHMC_MCMC.jl @@ -16,11 +16,12 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I, physdt::Float64 extraparams::Int init_params::I + estim_collocate::Bool function LogTargetDensity(dim, prob, chain::Optimisers.Restructure, st, strategy, dataset, priors, phystd, l2std, autodiff, physdt, extraparams, - init_params::AbstractVector) + init_params::AbstractVector, estim_collocate) new{ typeof(chain), Nothing, @@ -39,12 +40,13 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I, autodiff, physdt, extraparams, - init_params) + init_params, + estim_collocate) end function LogTargetDensity(dim, prob, chain::Lux.AbstractExplicitLayer, st, strategy, dataset, priors, phystd, l2std, autodiff, physdt, extraparams, - init_params::NamedTuple) + init_params::NamedTuple, estim_collocate) new{ typeof(chain), typeof(st), @@ -60,7 +62,8 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I, autodiff, physdt, extraparams, - init_params) + init_params, + estim_collocate) end end @@ -79,7 +82,11 @@ function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple) end function LogDensityProblems.logdensity(Tar::LogTargetDensity, θ) - return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ) + if Tar.estim_collocate + return physloglikelihood(Tar, θ)/length(Tar.dataset[1]) + priorweights(Tar, θ) + L2LossData(Tar, θ)/length(Tar.dataset[1]) + L2loss2(Tar, θ)/length(Tar.dataset[1]) + else + return physloglikelihood(Tar, θ)/length(Tar.dataset[1]) + priorweights(Tar, θ) + L2LossData(Tar, θ)/length(Tar.dataset[1]) + end end LogDensityProblems.dimension(Tar::LogTargetDensity) = Tar.dim @@ -481,7 +488,8 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain; Metric = DiagEuclideanMetric, targetacceptancerate = 0.8), Integratorkwargs = (Integrator = Leapfrog,), MCMCkwargs = (n_leapfrog = 30,), - progress = false, verbose = false) + progress = false, verbose = false, + estim_collocate = false) # NN parameter prior mean and variance(PriorsNN must be a tuple) if isinplace(prob) @@ -542,7 +550,7 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain; t0 = prob.tspan[1] # dimensions would be total no of params,initial_nnθ for Lux namedTuples ℓπ = LogTargetDensity(nparameters, prob, recon, st, strategy, dataset, priors, - phystd, l2std, autodiff, physdt, ninv, initial_nnθ) + phystd, l2std, autodiff, physdt, ninv, initial_nnθ, estim_collocate) try ℓπ(t0, initial_θ[1:(nparameters - ninv)]) @@ -580,7 +588,7 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain; MCMC_alg = kernelchoice(Kernel, MCMCkwargs) Kernel = AdvancedHMC.make_kernel(MCMC_alg, integrator) samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, adaptor; - progress = progress, verbose = verbose) + progress = progress, verbose = verbose, drop_warmup = true) samplesc[i] = samples statsc[i] = stats @@ -598,11 +606,10 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain; MCMC_alg = kernelchoice(Kernel, MCMCkwargs) Kernel = AdvancedHMC.make_kernel(MCMC_alg, integrator) samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, - adaptor; progress = progress, verbose = verbose) - + adaptor; progress = progress, verbose = verbose, drop_warmup = true) # return a chain(basic chain),samples and stats - matrix_samples = hcat(samples...) - mcmc_chain = MCMCChains.Chains(matrix_samples') + matrix_samples = reshape(hcat(samples...), (length(samples[1]), length(samples), 1)) + mcmc_chain = MCMCChains.Chains(matrix_samples) return mcmc_chain, samples, stats end end \ No newline at end of file diff --git a/src/bayesian/collocated_estim.jl b/src/bayesian/collocated_estim.jl new file mode 100644 index 0000000000..b113b76f12 --- /dev/null +++ b/src/bayesian/collocated_estim.jl @@ -0,0 +1,192 @@ +# suggested extra loss function +function L2loss2(Tar::LogTargetDensity, θ) + f = Tar.prob.f + + # parameter estimation chosen or not + if Tar.extraparams > 0 + # deri_sol = deri_sol' + autodiff = Tar.autodiff + # # Timepoints to enforce Physics + # dataset = Array(reduce(hcat, dataset)') + # t = dataset[end, :] + # û = dataset[1:(end - 1), :] + + # ode_params = Tar.extraparams == 1 ? + # θ[((length(θ) - Tar.extraparams) + 1):length(θ)][1] : + # θ[((length(θ) - Tar.extraparams) + 1):length(θ)] + + # if length(û[:, 1]) == 1 + # physsol = [f(û[:, i][1], + # ode_params, + # t[i]) + # for i in 1:length(û[1, :])] + # else + # physsol = [f(û[:, i], + # ode_params, + # t[i]) + # for i in 1:length(û[1, :])] + # end + # #form of NN output matrix output dim x n + # deri_physsol = reduce(hcat, physsol) + + # > for perfect deriv(basically gradient matching in case of an ODEFunction) + # in case of PDE or general ODE we would want to reduce residue of f(du,u,p,t) + # if length(û[:, 1]) == 1 + # deri_sol = [f(û[:, i][1], + # Tar.prob.p, + # t[i]) + # for i in 1:length(û[1, :])] + # else + # deri_sol = [f(û[:, i], + # Tar.prob.p, + # t[i]) + # for i in 1:length(û[1, :])] + # end + # deri_sol = reduce(hcat, deri_sol) + # deri_sol = reduce(hcat, derivatives) + + # Timepoints to enforce Physics + t = Tar.dataset[end] + u1 = Tar.dataset[2] + û = Tar.dataset[1] + # Tar(t, θ[1:(length(θ) - Tar.extraparams)])' + # + + nnsol = NNodederi(Tar, t, θ[1:(length(θ) - Tar.extraparams)], autodiff) + + ode_params = Tar.extraparams == 1 ? + θ[((length(θ) - Tar.extraparams) + 1):length(θ)][1] : + θ[((length(θ) - Tar.extraparams) + 1):length(θ)] + + if length(Tar.prob.u0) == 1 + physsol = [f(û[i], + ode_params, + t[i]) + for i in 1:length(û[:, 1])] + else + physsol = [f([û[i], u1[i]], + ode_params, + t[i]) + for i in 1:length(û)] + end + #form of NN output matrix output dim x n + deri_physsol = reduce(hcat, physsol) + + # if length(Tar.prob.u0) == 1 + # nnsol = [f(û[i], + # Tar.prob.p, + # t[i]) + # for i in 1:length(û[:, 1])] + # else + # nnsol = [f([û[i], u1[i]], + # Tar.prob.p, + # t[i]) + # for i in 1:length(û[:, 1])] + # end + # form of NN output matrix output dim x n + # nnsol = reduce(hcat, nnsol) + + # > Instead of dataset gradients trying NN derivatives with dataset collocation + # # convert to matrix as nnsol + + physlogprob = 0 + for i in 1:length(Tar.prob.u0) + # can add phystd[i] for u[i] + physlogprob += logpdf(MvNormal(deri_physsol[i, :], + LinearAlgebra.Diagonal(map(abs2, + (Tar.l2std[i] * 4.0) .* + ones(length(nnsol[i, :]))))), + nnsol[i, :]) + end + return physlogprob + else + return 0 + end +end + +# PDE(DU,U,P,T)=0 + +# Derivated via Central Diff +# function calculate_derivatives2(dataset) +# x̂, time = dataset +# num_points = length(x̂) +# # Initialize an array to store the derivative values. +# derivatives = similar(x̂) + +# for i in 2:(num_points - 1) +# # Calculate the first-order derivative using central differences. +# Δt_forward = time[i + 1] - time[i] +# Δt_backward = time[i] - time[i - 1] + +# derivative = (x̂[i + 1] - x̂[i - 1]) / (Δt_forward + Δt_backward) + +# derivatives[i] = derivative +# end + +# # Derivatives at the endpoints can be calculated using forward or backward differences. +# derivatives[1] = (x̂[2] - x̂[1]) / (time[2] - time[1]) +# derivatives[end] = (x̂[end] - x̂[end - 1]) / (time[end] - time[end - 1]) +# return derivatives +# end + +function calderivatives(prob, dataset) + chainflux = Flux.Chain(Flux.Dense(1, 8, tanh), Flux.Dense(8, 8, tanh), + Flux.Dense(8, 2)) |> Flux.f64 + # chainflux = Flux.Chain(Flux.Dense(1, 7, tanh), Flux.Dense(7, 1)) |> Flux.f64 + function loss(x, y) + # sum(Flux.mse.(prob.u0[1] .+ (prob.tspan[2] .- x)' .* chainflux(x)[1, :], y[1]) + + # Flux.mse.(prob.u0[2] .+ (prob.tspan[2] .- x)' .* chainflux(x)[2, :], y[2])) + # sum(Flux.mse.(prob.u0[1] .+ (prob.tspan[2] .- x)' .* chainflux(x)[1, :], y[1])) + sum(Flux.mse.(chainflux(x), y)) + end + optimizer = Flux.Optimise.ADAM(0.01) + epochs = 3000 + for epoch in 1:epochs + Flux.train!(loss, + Flux.params(chainflux), + [(dataset[end]', dataset[1:(end - 1)])], + optimizer) + end + + # A1 = (prob.u0' .+ + # (prob.tspan[2] .- (dataset[end]' .+ sqrt(eps(eltype(Float64)))))' .* + # chainflux(dataset[end]' .+ sqrt(eps(eltype(Float64))))') + + # A2 = (prob.u0' .+ + # (prob.tspan[2] .- (dataset[end]'))' .* + # chainflux(dataset[end]')') + + A1 = chainflux(dataset[end]' .+ sqrt(eps(eltype(dataset[end][1])))) + A2 = chainflux(dataset[end]') + + gradients = (A2 .- A1) ./ sqrt(eps(eltype(dataset[end][1]))) + + return gradients +end + +function calculate_derivatives(dataset) + + # u = dataset[1] + # u1 = dataset[2] + # t = dataset[end] + # # control points + # n = Int(floor(length(t) / 10)) + # # spline for datasetvalues(solution) + # # interp = BSplineApprox(u, t, 4, 10, :Uniform, :Uniform) + # interp = CubicSpline(u, t) + # interp1 = CubicSpline(u1, t) + # # derrivatives interpolation + # dx = t[2] - t[1] + # time = collect(t[1]:dx:t[end]) + # smoothu = [interp(i) for i in time] + # smoothu1 = [interp1(i) for i in time] + # # derivative of the spline (must match function derivative) + # û = tvdiff(smoothu, 20, 0.5, dx = dx, ε = 1) + # û1 = tvdiff(smoothu1, 20, 0.5, dx = dx, ε = 1) + # # tvdiff(smoothu, 100, 0.035, dx = dx, ε = 1) + # # FDM + # # û1 = diff(u) / dx + # # dataset[1] and smoothu are almost equal(rounding errors) + # return [û, û1] + +end \ No newline at end of file diff --git a/test/bpinnexperimental.jl b/test/bpinnexperimental.jl new file mode 100644 index 0000000000..3de049bf58 --- /dev/null +++ b/test/bpinnexperimental.jl @@ -0,0 +1,118 @@ +using Test, MCMCChains +using ForwardDiff, Distributions, OrdinaryDiffEq +using Flux, OptimizationOptimisers, AdvancedHMC, Lux +using Statistics, Random, Functors, ComponentArrays +using NeuralPDE, MonteCarloMeasurements + +Random.seed!(110) + +using NeuralPDE, Lux, Plots, OrdinaryDiffEq, Distributions, Random + +function lotka_volterra(u, p, t) + # Model parameters. + α, β, γ, δ = p + # Current state. + x, y = u + + # Evaluate differential equations. + dx = (α - β * y) * x # prey + dy = (δ * x - γ) * y # predator + + return [dx, dy] +end + +# initial-value problem. +u0 = [1.0, 1.0] +p = [1.5, 1.0, 3.0, 1.0] +tspan = (0.0, 4.0) +prob = ODEProblem(lotka_volterra, u0, tspan, p) + +# Solve using OrdinaryDiffEq.jl solver +dt = 0.2 +solution = solve(prob, Tsit5(); saveat = dt) + +times = solution.t +u = hcat(solution.u...) +x = u[1, :] + (u[1, :]) .* (0.3 .* randn(length(u[1, :]))) +y = u[2, :] + (u[2, :]) .* (0.3 .* randn(length(u[2, :]))) +dataset = [x, y, times] + +plot(times, x, label = "noisy x") +plot!(times, y, label = "noisy y") +plot!(solution, labels = ["x" "y"]) + +chain = Lux.Chain(Lux.Dense(1, 6, tanh), Lux.Dense(6, 6, tanh), + Lux.Dense(6, 2)) + +alg = BNNODE(chain; +dataset = dataset, +draw_samples = 1000, +l2std = [0.1, 0.1], +phystd = [0.1, 0.1], +priorsNNw = (0.0, 3.0), +param = [ + Normal(1, 2), + Normal(2, 2), + Normal(2, 2), + Normal(0, 2)], progress = true) + +@time sol_pestim1 = solve(prob, alg; saveat = dt,) +@time sol_pestim2 = solve(prob, alg; estim_collocate = true, saveat = dt) +plot(times, sol_pestim1.ensemblesol[1], label = "estimated x1") +plot!(times, sol_pestim2.ensemblesol[1], label = "estimated x2") +plot!(times, sol_pestim1.ensemblesol[2], label = "estimated y1") +plot!(times, sol_pestim2.ensemblesol[2], label = "estimated y2") + +# comparing it with the original solution +plot!(solution, labels = ["true x" "true y"]) + +@show sol_pestim1.estimated_ode_params +@show sol_pestim2.estimated_ode_params + +function fitz(u, p , t) + v, w = u[1], u[2] + a,b,τinv,l = p[1], p[2], p[3], p[4] + + dv = v - 0.33*v^3 -w + l + dw = τinv*(v + a - b*w) + + return [dv, dw] +end + +prob_ode_fitzhughnagumo = ODEProblem(fitz, [1.0,1.0], (0.0,10.0), [0.7,0.8,1/12.5,0.5]) +dt = 0.5 +sol = solve(prob_ode_fitzhughnagumo, Tsit5(), saveat = dt) + +sig = 0.20 +data = Array(sol) +dataset = [data[1,:] .+ (sig .* rand(length(sol.t))), data[2, :] .+ (sig .* rand(length(sol.t))), sol.t] +priors = [Normal(0.5,1.0), Normal(0.5,1.0), Normal(0.0,0.5), Normal(0.5,1.0)] + + +plot(sol.t, dataset[1], label = "noisy x") +plot!(sol.t, dataset[2], label = "noisy y") +plot!(sol, labels = ["x" "y"]) + +chain = Lux.Chain(Lux.Dense(1, 10, tanh), Lux.Dense(10, 10, tanh), + Lux.Dense(10, 2)) + +Adaptorkwargs = (Adaptor = AdvancedHMC.StanHMCAdaptor, + Metric = AdvancedHMC.DiagEuclideanMetric, targetacceptancerate = 0.8) +alg = BNNODE(chain; +dataset = dataset, +draw_samples = 1000, +l2std = [0.1, 0.1], +phystd = [0.1, 0.1], +priorsNNw = (0.01, 3.0), +Adaptorkwargs = Adaptorkwargs, +param = priors, progress = true) + +@time sol_pestim3 = solve(prob_ode_fitzhughnagumo, alg; saveat = dt) +@time sol_pestim4 = solve(prob_ode_fitzhughnagumo, alg; estim_collocate = true, saveat = dt) +plot!(sol.t, sol_pestim3.ensemblesol[1], label = "estimated x1") +plot!(sol.t, sol_pestim4.ensemblesol[1], label = "estimated x2") +plot!(sol.t, sol_pestim3.ensemblesol[2], label = "estimated y1") +plot!(sol.t, sol_pestim4.ensemblesol[2], label = "estimated y2") + +@show sol_pestim3.estimated_ode_params +@show sol_pestim4.estimated_ode_params