Skip to content

Commit

Permalink
refactor: bayesian PINN ODEs
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 15, 2024
1 parent b76d8a0 commit 22c316b
Show file tree
Hide file tree
Showing 9 changed files with 373 additions and 601 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Expand Down Expand Up @@ -78,6 +79,7 @@ OptimizationOptimJL = "0.4"
OptimizationOptimisers = "0.3"
OrdinaryDiffEq = "6.87"
Pkg = "1.10"
Printf = "1.10"
QuasiMonteCarlo = "0.3.2"
Random = "1"
RecursiveArrayTools = "3.27.0"
Expand Down
180 changes: 70 additions & 110 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# HIGH level API for BPINN ODE solver

"""
BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
priorsNNw = (0.0, 2.0), param = [nothing], l2std = [0.05],
phystd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
MCMCargs = (n_leapfrog=30), nchains = 1, init_params = nothing,
Adaptorkwargs = (Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8, Metric = DiagEuclideanMetric),
Integratorkwargs = (Integrator = Leapfrog,), autodiff = false,
progress = false, verbose = false)
Algorithm for solving ordinary differential equations using a Bayesian neural network. This is a specialization
of the physics-informed neural network which is used as a solver for a standard `ODEProblem`.
BNNODE(chain, kernel = HMC; strategy = nothing, draw_samples = 2000,
priorsNNw = (0.0, 2.0), param = [nothing], l2std = [0.05],
phystd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
MCMCargs = (; n_leapfrog=30), nchains = 1, init_params = nothing,
Adaptorkwargs = (; Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8,
Metric = DiagEuclideanMetric),
Integratorkwargs = (Integrator = Leapfrog,), autodiff = false,
progress = false, verbose = false)
Algorithm for solving ordinary differential equations using a Bayesian neural network. This
is a specialization of the physics-informed neural network which is used as a solver for a
standard `ODEProblem`.
!!! warn
Expand All @@ -21,9 +23,10 @@ of the physics-informed neural network which is used as a solver for a standard
## Positional Arguments
* `chain`: A neural network architecture, defined as a `Lux.AbstractLuxLayer`.
* `Kernel`: Choice of MCMC Sampling Algorithm. Defaults to `AdvancedHMC.HMC`
* `kernel`: Choice of MCMC Sampling Algorithm. Defaults to `AdvancedHMC.HMC`
## Keyword Arguments
(refer `NeuralPDE.ahmc_bayesian_pinn_ode` keyword arguments.)
## Example
Expand All @@ -44,18 +47,15 @@ dataset = [x̂, time]
chainlux = Lux.Chain(Lux.Dense(1, 6, tanh), Lux.Dense(6, 6, tanh), Lux.Dense(6, 1))
alg = BNNODE(chainlux, draw_samples = 2000,
l2std = [0.05], phystd = [0.05],
priorsNNw = (0.0, 3.0), progress = true)
alg = BNNODE(chainlux; draw_samples = 2000, l2std = [0.05], phystd = [0.05],
priorsNNw = (0.0, 3.0), progress = true)
sol_lux = solve(prob, alg)
# with parameter estimation
alg = BNNODE(chainlux,dataset = dataset,
draw_samples = 2000,l2std = [0.05],
phystd = [0.05],priorsNNw = (0.0, 10.0),
param = [Normal(6.5, 0.5), Normal(-3, 0.5)],
progress = true)
alg = BNNODE(chainlux; dataset, draw_samples = 2000, l2std = [0.05], phystd = [0.05],
priorsNNw = (0.0, 10.0), param = [Normal(6.5, 0.5), Normal(-3, 0.5)],
progress = true)
sol_lux_pestim = solve(prob, alg)
```
Expand All @@ -71,60 +71,48 @@ is an accurate interpolation (up to the neural network training result). In addi
## References
Liu Yanga, Xuhui Menga, George Em Karniadakis. "B-PINNs: Bayesian Physics-Informed Neural Networks for
Forward and Inverse PDE Problems with Noisy Data".
Liu Yanga, Xuhui Menga, George Em Karniadakis. "B-PINNs: Bayesian Physics-Informed Neural
Networks for Forward and Inverse PDE Problems with Noisy Data".
Kevin Linka, Amelie Schäfer, Xuhui Meng, Zongren Zou, George Em Karniadakis, Ellen Kuhl
"Bayesian Physics Informed Neural Networks for real-world nonlinear dynamical systems".
"""
struct BNNODE{C, K, IT <: NamedTuple,
A <: NamedTuple, H <: NamedTuple,
ST <: Union{Nothing, AbstractTrainingStrategy},
I <: Union{Nothing, <:NamedTuple, Vector{<:AbstractFloat}},
P <: Union{Nothing, Vector{<:Distribution}},
D <:
Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}}} <:
NeuralPDEAlgorithm
chain::C
Kernel::K
strategy::ST
draw_samples::Int64
@concrete struct BNNODE <: NeuralPDEAlgorithm
chain <: AbstractLuxLayer
kernel
strategy <: Union{Nothing, AbstractTrainingStrategy}
draw_samples::Int
priorsNNw::Tuple{Float64, Float64}
param::P
param <: Union{Nothing, Vector{<:Distribution}}
l2std::Vector{Float64}
phystd::Vector{Float64}
dataset::D
dataset <: Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}}
physdt::Float64
MCMCkwargs::H
nchains::Int64
init_params::I
Adaptorkwargs::A
Integratorkwargs::IT
numensemble::Int64
MCMCkwargs <: NamedTuple
nchains::Int
init_params <: Union{Nothing, <:NamedTuple, Vector{<:AbstractFloat}}
Adaptorkwargs <: NamedTuple
Integratorkwargs <: NamedTuple
numensemble::Int
estim_collocate::Bool
autodiff::Bool
progress::Bool
verbose::Bool
end
function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,

function BNNODE(chain, kernel = HMC; strategy = nothing, draw_samples = 2000,
priorsNNw = (0.0, 2.0), param = nothing, l2std = [0.05], phystd = [0.05],
dataset = [nothing], physdt = 1 / 20.0, MCMCkwargs = (n_leapfrog = 30,), nchains = 1,
init_params = nothing,
dataset = [nothing], physdt = 1 / 20.0, MCMCkwargs = (n_leapfrog = 30,),
nchains = 1, init_params = nothing,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric,
targetacceptancerate = 0.8),
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
numensemble = floor(Int, draw_samples / 3),
estim_collocate = false,
autodiff = false, progress = false, verbose = false)
estim_collocate = false, autodiff = false, progress = false, verbose = false)
chain isa AbstractLuxLayer || (chain = FromFluxAdaptor()(chain))
BNNODE(chain, Kernel, strategy,
draw_samples, priorsNNw, param, l2std,
phystd, dataset, physdt, MCMCkwargs,
nchains, init_params,
Adaptorkwargs, Integratorkwargs,
numensemble, estim_collocate,
autodiff, progress, verbose)
return BNNODE(chain, kernel, strategy, draw_samples, priorsNNw, param, l2std, phystd,
dataset, physdt, MCMCkwargs, nchains, init_params, Adaptorkwargs,
Integratorkwargs, numensemble, estim_collocate, autodiff, progress, verbose)
end

"""
Expand All @@ -149,11 +137,14 @@ struct BPINNstats{MC, S, ST}
end

"""
BPINN Solution contains the original solution from AdvancedHMC.jl sampling (BPINNstats contains fields related to that).
BPINN Solution contains the original solution from AdvancedHMC.jl sampling (BPINNstats
contains fields related to that).
1. `ensemblesol` is the Probabilistic Estimate (MonteCarloMeasurements.jl Particles type) of Ensemble solution from All Neural Network's (made using all sampled parameters) output's.
1. `ensemblesol` is the Probabilistic Estimate (MonteCarloMeasurements.jl Particles type) of
Ensemble solution from All Neural Network's (made using all sampled parameters) output's.
2. `estimated_nn_params` - Probabilistic Estimate of NN params from sampled weights, biases.
3. `estimated_de_params` - Probabilistic Estimate of DE params from sampled unknown DE parameters.
3. `estimated_de_params` - Probabilistic Estimate of DE params from sampled unknown DE
parameters.
"""
struct BPINNsolution{O <: BPINNstats, E, NP, OP, P}
original::O
Expand All @@ -162,74 +153,43 @@ struct BPINNsolution{O <: BPINNstats, E, NP, OP, P}
estimated_de_params::OP
timepoints::P

function BPINNsolution(original,
ensemblesol,
estimated_nn_params,
estimated_de_params,
timepoints)
function BPINNsolution(
original, ensemblesol, estimated_nn_params, estimated_de_params, timepoints)
new{typeof(original), typeof(ensemblesol), typeof(estimated_nn_params),
typeof(estimated_de_params), typeof(timepoints)}(
original, ensemblesol, estimated_nn_params,
estimated_de_params, timepoints)
original, ensemblesol, estimated_nn_params, estimated_de_params, timepoints)
end
end

function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
alg::BNNODE,
args...;
dt = nothing,
timeseries_errors = true,
save_everystep = true,
adaptive = false,
abstol = 1.0f-6,
reltol = 1.0f-3,
verbose = false,
saveat = 1 / 50.0,
maxiters = nothing,
numensemble = floor(Int, alg.draw_samples / 3))
(; chain, l2std, phystd, param, priorsNNw, Kernel, strategy, draw_samples, dataset, init_params, nchains, physdt, Adaptorkwargs, Integratorkwargs, MCMCkwargs, numensemble, estim_collocate, autodiff, progress, verbose) = alg
function SciMLBase.__solve(prob::SciMLBase.ODEProblem, alg::BNNODE, args...; dt = nothing,
timeseries_errors = true, save_everystep = true, adaptive = false,
abstol = 1.0f-6, reltol = 1.0f-3, verbose = false, saveat = 1 / 50.0,
maxiters = nothing, numensemble = floor(Int, alg.draw_samples / 3))
(; chain, param, strategy, draw_samples, numensemble, verbose) = alg

# ahmc_bayesian_pinn_ode needs param=[] for easier vcat operation for full vector of parameters
param = param === nothing ? [] : param
strategy = strategy === nothing ? GridTraining : strategy

if draw_samples < 0
error("Number of samples to be drawn has to be >=0.")
end
@assert alg.draw_samples0 "Number of samples to be drawn has to be >=0."

mcmcchain, samples, statistics = ahmc_bayesian_pinn_ode(prob, chain,
strategy = strategy, dataset = dataset,
draw_samples = draw_samples,
init_params = init_params,
physdt = physdt, l2std = l2std,
phystd = phystd,
priorsNNw = priorsNNw,
param = param,
nchains = nchains,
autodiff = autodiff,
Kernel = Kernel,
Adaptorkwargs = Adaptorkwargs,
Integratorkwargs = Integratorkwargs,
MCMCkwargs = MCMCkwargs,
progress = progress,
verbose = verbose,
estim_collocate = estim_collocate)
mcmcchain, samples, statistics = ahmc_bayesian_pinn_ode(
prob, chain; strategy, alg.dataset, alg.draw_samples, alg.init_params,
alg.physdt, alg.l2std, alg.phystd, alg.priorsNNw, param, alg.nchains, alg.autodiff,
Kernel = alg.kernel, alg.Adaptorkwargs, alg.Integratorkwargs,
alg.MCMCkwargs, alg.progress, alg.verbose, alg.estim_collocate)

fullsolution = BPINNstats(mcmcchain, samples, statistics)
ninv = length(param)
t = collect(eltype(saveat), prob.tspan[1]:saveat:prob.tspan[2])

if chain isa AbstractLuxLayer
θinit, st = LuxCore.setup(Random.default_rng(), chain)
θ = [vector_to_parameters(samples[i][1:(end - ninv)], θinit)
for i in 1:max(draw_samples - draw_samples ÷ 10, draw_samples - 1000)]
θinit, st = LuxCore.setup(Random.default_rng(), chain)
θ = [vector_to_parameters(samples[i][1:(end - ninv)], θinit)
for i in 1:max(draw_samples - draw_samples ÷ 10, draw_samples - 1000)]

luxar = [chain(t', θ[i], st)[1] for i in 1:numensemble]
# only need for size
θinit = collect(ComponentArray(θinit))
else
error("Only Lux.AbstractLuxLayer neural networks are supported")
end
luxar = [chain(t', θ[i], st)[1] for i in 1:numensemble]
# only need for size
θinit = collect(ComponentArray(θinit))

# constructing ensemble predictions
ensemblecurves = Vector{}[]
Expand Down Expand Up @@ -272,5 +232,5 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
for i in (nnparams + 1):(nnparams + ninv)]
end

BPINNsolution(fullsolution, ensemblecurves, estimnnparams, estimated_params, t)
return BPINNsolution(fullsolution, ensemblecurves, estimnnparams, estimated_params, t)
end
3 changes: 2 additions & 1 deletion src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ $(DocStringExtensions.README)
module NeuralPDE

using ADTypes: ADTypes, AutoForwardDiff, AutoZygote
using Adapt: Adapt, adapt
using Adapt: Adapt
using AdvancedHMC: AdvancedHMC, DiagEuclideanMetric, HMC, HMCDA, Hamiltonian,
JitteredLeapfrog, Leapfrog, MassMatrixAdaptor, NUTS, StanHMCAdaptor,
StepSizeAdaptor, TemperedLeapfrog, find_good_stepsize
Expand All @@ -31,6 +31,7 @@ using MonteCarloMeasurements: Particles
using Optimisers: Optimisers, Adam
using Optimization: Optimization
using OptimizationOptimisers: OptimizationOptimisers
using Printf: @printf
using Random: Random, AbstractRNG
using RecursiveArrayTools: DiffEqArray
using Reexport: @reexport
Expand Down
Loading

0 comments on commit 22c316b

Please sign in to comment.