Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improved BPINN solvers #905

Merged
merged 19 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
BNNODE(chain, kernel = HMC; strategy = nothing, draw_samples = 2000,
priorsNNw = (0.0, 2.0), param = [nothing], l2std = [0.05],
phystd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
phystd = [0.05], phynewstd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
MCMCargs = (; n_leapfrog=30), nchains = 1, init_params = nothing,
Adaptorkwargs = (; Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8,
Metric = DiagEuclideanMetric),
Expand Down Expand Up @@ -86,6 +86,7 @@ Kevin Linka, Amelie Schäfer, Xuhui Meng, Zongren Zou, George Em Karniadakis, El
param <: Union{Nothing, Vector{<:Distribution}}
l2std::Vector{Float64}
phystd::Vector{Float64}
phynewstd::Vector{Float64}
dataset <: Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}}
physdt::Float64
MCMCkwargs <: NamedTuple
Expand All @@ -100,18 +101,18 @@ Kevin Linka, Amelie Schäfer, Xuhui Meng, Zongren Zou, George Em Karniadakis, El
verbose::Bool
end

function BNNODE(chain, kernel = HMC; strategy = nothing, draw_samples = 2000,
function BNNODE(chain, kernel = HMC; strategy = nothing, draw_samples = 1000,
priorsNNw = (0.0, 2.0), param = nothing, l2std = [0.05], phystd = [0.05],
dataset = [nothing], physdt = 1 / 20.0, MCMCkwargs = (n_leapfrog = 30,),
nchains = 1, init_params = nothing,
phynewstd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
MCMCkwargs = (n_leapfrog = 30,), nchains = 1, init_params = nothing,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
numensemble = floor(Int, draw_samples / 3),
estim_collocate = false, autodiff = false, progress = false, verbose = false)
chain isa AbstractLuxLayer || (chain = FromFluxAdaptor()(chain))
return BNNODE(chain, kernel, strategy, draw_samples, priorsNNw, param, l2std, phystd,
dataset, physdt, MCMCkwargs, nchains, init_params, Adaptorkwargs,
phynewstd, dataset, physdt, MCMCkwargs, nchains, init_params, Adaptorkwargs,
Integratorkwargs, numensemble, estim_collocate, autodiff, progress, verbose)
end

Expand Down Expand Up @@ -157,7 +158,7 @@ end
function SciMLBase.__solve(prob::SciMLBase.ODEProblem, alg::BNNODE, args...; dt = nothing,
timeseries_errors = true, save_everystep = true, adaptive = false,
abstol = 1.0f-6, reltol = 1.0f-3, verbose = false, saveat = 1 / 50.0,
maxiters = nothing, numensemble = floor(Int, alg.draw_samples / 3))
maxiters = nothing)
(; chain, param, strategy, draw_samples, numensemble, verbose) = alg

# ahmc_bayesian_pinn_ode needs param=[] for easier vcat operation for full vector of parameters
Expand All @@ -168,7 +169,8 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem, alg::BNNODE, args...; dt

mcmcchain, samples, statistics = ahmc_bayesian_pinn_ode(
prob, chain; strategy, alg.dataset, alg.draw_samples, alg.init_params,
alg.physdt, alg.l2std, alg.phystd, alg.priorsNNw, param, alg.nchains, alg.autodiff,
alg.physdt, alg.l2std, alg.phystd, alg.phynewstd,
alg.priorsNNw, param, alg.nchains, alg.autodiff,
Kernel = alg.kernel, alg.Adaptorkwargs, alg.Integratorkwargs,
alg.MCMCkwargs, alg.progress, alg.verbose, alg.estim_collocate)

Expand All @@ -178,7 +180,7 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem, alg::BNNODE, args...; dt

θinit, st = LuxCore.setup(Random.default_rng(), chain)
θ = [vector_to_parameters(samples[i][1:(end - ninv)], θinit)
for i in 1:max(draw_samples - draw_samples ÷ 10, draw_samples - 1000)]
for i in (draw_samples - numensemble):draw_samples]

luxar = [chain(t', θ[i], st)[1] for i in 1:numensemble]
# only need for size
Expand Down
149 changes: 136 additions & 13 deletions src/PDE_BPINN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,90 @@
dataset <: Union{Nothing, Vector{<:Matrix{<:Real}}}
priors <: Vector{<:Distribution}
allstd::Vector{Vector{Float64}}
phynewstd::Vector{Float64}
names::Tuple
extraparams::Int
init_params <: Union{AbstractVector, NamedTuple, ComponentArray}
full_loglikelihood
L2_loss2
Φ
end

function LogDensityProblems.logdensity(ltd::PDELogTargetDensity, θ)
# for parameter estimation neccesarry to use multioutput case
return ltd.full_loglikelihood(setparameters(ltd, θ), ltd.allstd) + priorlogpdf(ltd, θ) +
L2LossData(ltd, θ)
if ltd.L2_loss2 === nothing
return ltd.full_loglikelihood(setparameters(ltd, θ), ltd.allstd) +
priorlogpdf(ltd, θ) + L2LossData(ltd, θ)
else
return ltd.full_loglikelihood(setparameters(ltd, θ), ltd.allstd) +
priorlogpdf(ltd, θ) + L2LossData(ltd, θ) + ltd.L2_loss2(setparameters(ltd, θ), ltd.phynewstd)
end
end

# you get a vector of losses
function get_lossy(pinnrep, dataset, Dict_differentials)
eqs = pinnrep.eqs
depvars = pinnrep.depvars #depvar order is same as dataset

# Dict_differentials is filled with Differential operator => diff_i key-value pairs
# masking operation
eqs_new = SymbolicUtils.substitute.(eqs, Ref(Dict_differentials))

to_subs, tobe_subs = get_symbols(dataset, depvars, eqs)

# for values of all depvars at corresponding indvar values in dataset, create dictionaries {Dict(x(t) => 1.0496435863173237, y(t) => 1.9227770685615337)}
# In each Dict, num form of depvar is key to its value at certain coords of indvars, n_dicts = n_rows_dataset(or n_indvar_coords_dataset)
eq_subs = [Dict(tobe_subs[depvar] => to_subs[depvar][i] for depvar in depvars)
for i in 1:size(dataset[1][:, 1])[1]]

# for each dataset point(eq_sub dictionary), substitute in masked equations
# n_collocated_equations = n_rows_dataset(or n_indvar_coords_dataset)
masked_colloc_equations = [[Symbolics.substitute(eq, eq_sub) for eq in eqs_new]
for eq_sub in eq_subs]
# now we have vector of dataset depvar's collocated equations

# reverse dict for re-substituting values of Differential(t)(u(t)) etc
rev_Dict_differentials = Dict(value => key for (key, value) in Dict_differentials)

# unmask Differential terms in masked_colloc_equations
colloc_equations = [Symbolics.substitute.(
masked_colloc_equation, Ref(rev_Dict_differentials))
for masked_colloc_equation in masked_colloc_equations]
# nested vector of data_pde_loss_functions (as in discretize.jl)
# each sub vector has dataset's indvar coord's datafree_colloc_loss_function, n_subvectors = n_rows_dataset(or n_indvar_coords_dataset)
# zip each colloc equation with args for each build_loss call per equation vector
data_colloc_loss_functions = [[build_loss_function(pinnrep, eq, pde_indvar)
for (eq, pde_indvar, integration_indvar) in zip(
colloc_equation,
pinnrep.pde_indvars,
pinnrep.pde_integration_vars)]
for colloc_equation in colloc_equations]

return data_colloc_loss_functions
end

function get_symbols(dataset, depvars, eqs)
# take only values of depvars from dataset
depvar_vals = [dataset_i[:, 1] for dataset_i in dataset]
# order of pinnrep.depvars, depvar_vals, BayesianPINN.dataset must be same
to_subs = Dict(depvars .=> depvar_vals)

numform_vars = Symbolics.get_variables.(eqs)
Eq_vars = unique(reduce(vcat, numform_vars))
# got equation's depvar num format {x(t)} for use in substitute()

tobe_subs = Dict()
for a in depvars
for i in Eq_vars
expr = toexpr(i)
if (expr isa Expr) && (expr.args[1] == a)
tobe_subs[a] = i
end
end
end
# depvar symbolic and num format got, tobe_subs : Dict{Any, Any}(:y => y(t), :x => x(t))

return to_subs, tobe_subs
end

@views function setparameters(ltd::PDELogTargetDensity, θ)
Expand Down Expand Up @@ -180,8 +253,8 @@ end
"""
ahmc_bayesian_pinn_pde(pde_system, discretization;
draw_samples = 1000, bcstd = [0.01], l2std = [0.05], phystd = [0.05],
priorsNNw = (0.0, 2.0), param = [], nchains = 1, Kernel = HMC(0.1, 30),
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
phynewstd = [0.05], priorsNNw = (0.0, 2.0), param = [], nchains = 1,
Kernel = HMC(0.1, 30), Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,), saveats = [1 / 10.0],
numensemble = floor(Int, draw_samples / 3), progress = false, verbose = false)
Expand Down Expand Up @@ -210,6 +283,7 @@ end
each dependant variable of interest.
* `phystd`: Vector of standard deviations of BPINN prediction against Chosen Underlying PDE
equations.
* `phynewstd`: Vector of standard deviations of new loss term.
* `priorsNNw`: Tuple of (mean, std) for BPINN Network parameters. Weights and Biases of
BPINN are Normal Distributions by default.
* `param`: Vector of chosen PDE's parameter's Distributions in case of Inverse problems.
Expand All @@ -235,14 +309,53 @@ end
"""
function ahmc_bayesian_pinn_pde(pde_system, discretization;
draw_samples = 1000, bcstd = [0.01], l2std = [0.05], phystd = [0.05],
priorsNNw = (0.0, 2.0), param = [], nchains = 1, Kernel = HMC(0.1, 30),
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
phynewstd = [0.05], priorsNNw = (0.0, 2.0), param = [], nchains = 1,
Kernel = HMC(0.1, 30), Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,), saveats = [1 / 10.0],
numensemble = floor(Int, draw_samples / 3), progress = false, verbose = false)
numensemble = floor(Int, draw_samples / 3), Dict_differentials = nothing, progress = false, verbose = false)
pinnrep = symbolic_discretize(pde_system, discretization)
dataset_pde, dataset_bc = discretization.dataset

newloss = if Dict_differentials isa Nothing
nothing
else
data_colloc_loss_functions = get_lossy(pinnrep, dataset_pde, Dict_differentials)
# size = number of indvar coords in dataset
# add case for if parameters present in bcs?

train_sets_pde = get_dataset_train_points(pde_system.eqs,
dataset_pde,
pinnrep)
# j is number of indvar coords in dataset, i is number of PDE equations in system
# -1 is placeholder, removed in merge_strategy_with_loglikelihood_function function call (train_sets[:, 2:end]())
colloc_train_sets = [[hcat([-1], train_sets_pde[i][:, j]...)
for i in eachindex(data_colloc_loss_functions[1])]
for j in eachindex(data_colloc_loss_functions)]

# using dataset's indvar coords as train_sets_pde and indvar coord's datafree_colloc_loss_function, create loss functions
# placeholder strategy = GridTraining(0.1), datafree_bc_loss_function and train_sets_bc must be nothing
# order of indvar coords will be same as corresponding depvar coords values in dataset provided in get_lossy() call.
pde_loss_function_points = [merge_strategy_with_loglikelihood_function(
pinnrep,
GridTraining(0.1),
data_colloc_loss_functions[i],
nothing;
train_sets_pde = colloc_train_sets[i],
train_sets_bc = nothing)[1]
for i in eachindex(data_colloc_loss_functions)]

function L2_loss2(θ, phynewstd)
# first sum is over points losses over many equations for the same points
# second sum is over all points
pde_loglikelihoods = sum([sum([pde_loss_function(θ,
phynewstd[i])
for (i, pde_loss_function) in enumerate(pde_loss_functions)])
for pde_loss_functions in pde_loss_function_points])
end
end

# add overall functionality for BC dataset points (case of parametric BC) ?
if ((dataset_bc isa Nothing) && (dataset_pde isa Nothing))
dataset = nothing
elseif dataset_bc isa Nothing
Expand Down Expand Up @@ -306,8 +419,8 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;

# dimensions would be total no of params,initial_nnθ for Lux namedTuples
ℓπ = PDELogTargetDensity(
nparameters, strategy, dataset, priors, [phystd, bcstd, l2std],
names, ninv, initial_nnθ, full_weighted_loglikelihood, Φ)
nparameters, strategy, dataset, priors, [phystd, bcstd, l2std], phynewstd,
names, ninv, initial_nnθ, full_weighted_loglikelihood, newloss, Φ)

Adaptor, Metric, targetacceptancerate = Adaptorkwargs[:Adaptor],
Adaptorkwargs[:Metric], Adaptorkwargs[:targetacceptancerate]
Expand All @@ -320,8 +433,13 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
@printf("Current Physics Log-likelihood : %g\n",
ℓπ.full_loglikelihood(setparameters(ℓπ, initial_θ), ℓπ.allstd))
@printf("Current Prior Log-likelihood : %g\n", priorlogpdf(ℓπ, initial_θ))
@printf("Current MSE against dataset Log-likelihood : %g\n",
@printf("Current SSE against dataset Log-likelihood : %g\n",
L2LossData(ℓπ, initial_θ))
if !(newloss isa Nothing)
@printf("Current new loss : %g\n",
ℓπ.L2_loss2(setparameters(ℓπ, initial_θ),
ℓπ.phynewstd))
end
end

# parallel sampling option
Expand Down Expand Up @@ -370,11 +488,16 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;

if verbose
@printf("Sampling Complete.\n")
@printf("Current Physics Log-likelihood : %g\n",
@printf("Final Physics Log-likelihood : %g\n",
ℓπ.full_loglikelihood(setparameters(ℓπ, samples[end]), ℓπ.allstd))
@printf("Current Prior Log-likelihood : %g\n", priorlogpdf(ℓπ, samples[end]))
@printf("Current MSE against dataset Log-likelihood : %g\n",
@printf("Final Prior Log-likelihood : %g\n", priorlogpdf(ℓπ, samples[end]))
@printf("Final SSE against dataset Log-likelihood : %g\n",
L2LossData(ℓπ, samples[end]))
if !(newloss isa Nothing)
@printf("Final new loss : %g\n",
ℓπ.L2_loss2(setparameters(ℓπ, samples[end]),
ℓπ.phynewstd))
end
end

fullsolution = BPINNstats(mcmc_chain, samples, stats)
Expand Down
27 changes: 14 additions & 13 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
dataset <: Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}}
priors <: Vector{<:Distribution}
phystd::Vector{Float64}
phynewstd::Vector{Float64}
l2std::Vector{Float64}
autodiff::Bool
physdt::Float64
Expand Down Expand Up @@ -97,7 +98,7 @@ suggested extra loss function for ODE solver case
for i in 1:length(ltd.prob.u0)
physlogprob += logpdf(
MvNormal(deri_physsol[i, :],
Diagonal(abs2.(T(ltd.phystd[i]) .* ones(T, length(nnsol[i, :]))))),
Diagonal(abs2.(T(ltd.phynewstd[i]) .* ones(T, length(nnsol[i, :]))))),
nnsol[i, :]
)
end
Expand Down Expand Up @@ -263,7 +264,7 @@ end
"""
ahmc_bayesian_pinn_ode(prob, chain; strategy = GridTraining, dataset = [nothing],
init_params = nothing, draw_samples = 1000, physdt = 1 / 20.0f0,
l2std = [0.05], phystd = [0.05], priorsNNw = (0.0, 2.0),
l2std = [0.05], phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0),
param = [], nchains = 1, autodiff = false, Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Expand Down Expand Up @@ -336,6 +337,7 @@ Incase you are only solving the Equations for solution, do not provide dataset
~2/3 of draw samples)
* `l2std`: standard deviation of BPINN prediction against L2 losses/Dataset
* `phystd`: standard deviation of BPINN prediction against Chosen Underlying ODE System
* `phynewstd`: standard deviation of new loss func term
* `priorsNNw`: Tuple of (mean, std) for BPINN Network parameters. Weights and Biases of
BPINN are Normal Distributions by default.
* `param`: Vector of chosen ODE parameters Distributions in case of Inverse problems.
Expand Down Expand Up @@ -366,10 +368,10 @@ Incase you are only solving the Equations for solution, do not provide dataset
function ahmc_bayesian_pinn_ode(
prob::SciMLBase.ODEProblem, chain; strategy = GridTraining, dataset = [nothing],
init_params = nothing, draw_samples = 1000, physdt = 1 / 20.0, l2std = [0.05],
phystd = [0.05], priorsNNw = (0.0, 2.0), param = [], nchains = 1, autodiff = false,
Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor, Metric = DiagEuclideanMetric,
targetacceptancerate = 0.8),
phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0), param = [], nchains = 1,
autodiff = false, Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,), MCMCkwargs = (n_leapfrog = 30,),
progress = false, verbose = false, estim_collocate = false)
@assert !isinplace(prob) "The BPINN ODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t)."
Expand Down Expand Up @@ -415,16 +417,15 @@ function ahmc_bayesian_pinn_ode(
nparameters += ninv
end

t0 = prob.tspan[1]
smodel = StatefulLuxLayer{true}(chain, nothing, st)
# dimensions would be total no of params,initial_nnθ for Lux namedTuples
ℓπ = LogTargetDensity(nparameters, prob, smodel, strategy, dataset, priors,
phystd, l2std, autodiff, physdt, ninv, initial_nnθ, estim_collocate)
phystd, phynewstd, l2std, autodiff, physdt, ninv, initial_nnθ, estim_collocate)

if verbose
@printf("Current Physics Log-likelihood: %g\n", physloglikelihood(ℓπ, initial_θ))
@printf("Current Prior Log-likelihood: %g\n", priorweights(ℓπ, initial_θ))
@printf("Current MSE against dataset Log-likelihood: %g\n",
@printf("Current SSE against dataset Log-likelihood: %g\n",
L2LossData(ℓπ, initial_θ))
if estim_collocate
@printf("Current gradient loss against dataset Log-likelihood: %g\n",
Expand Down Expand Up @@ -483,13 +484,13 @@ function ahmc_bayesian_pinn_ode(

if verbose
println("Sampling Complete.")
@printf("Current Physics Log-likelihood: %g\n",
@printf("Final Physics Log-likelihood: %g\n",
physloglikelihood(ℓπ, samples[end]))
@printf("Current Prior Log-likelihood: %g\n", priorweights(ℓπ, samples[end]))
@printf("Current MSE against dataset Log-likelihood: %g\n",
@printf("Final Prior Log-likelihood: %g\n", priorweights(ℓπ, samples[end]))
@printf("Final SSE against dataset Log-likelihood: %g\n",
L2LossData(ℓπ, samples[end]))
if estim_collocate
@printf("Current gradient loss against dataset Log-likelihood: %g\n",
@printf("Final gradient loss against dataset Log-likelihood: %g\n",
L2loss2(ℓπ, samples[end]))
end
end
Expand Down
Loading
Loading