Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

better BPINN formulation, improvements to include dataset domain points #842

Open
wants to merge 63 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
90dfd76
Better Posterior Formulation
AstitvaAggarwal Sep 27, 2023
b2f3ac1
Put new loglikelihood behind a conditional
Vaibhavdixit02 Oct 19, 2023
058aa05
fitzhughnagumo experiment and some edits
Vaibhavdixit02 Oct 27, 2023
103e1fe
Scale logpdfs and fix chain creation
Vaibhavdixit02 Oct 28, 2023
f5b4f1c
trying to sync
AstitvaAggarwal Jan 20, 2024
a999ccf
removed new files
AstitvaAggarwal Jan 20, 2024
637006b
Merge branch 'research' of https://github.com/AstitvaAggarwal/NeuralP…
AstitvaAggarwal Jan 20, 2024
a5c3148
update advancedHMC_MCMC.jl
AstitvaAggarwal Jan 20, 2024
f8427a3
update advancedHMC_MCMC.jl
AstitvaAggarwal Jan 20, 2024
005c906
Merge branch 'SciML:master' into research
AstitvaAggarwal Jan 22, 2024
237fb45
most of logic done
AstitvaAggarwal Feb 3, 2024
f1e0315
removed duplicate methods
AstitvaAggarwal Feb 3, 2024
d077ec7
Merge branch 'master' into research
AstitvaAggarwal Feb 3, 2024
4c88dd4
update BPINN_PDE_tests.jl
AstitvaAggarwal Feb 3, 2024
f7836fd
update BPINN_PDE_tests.jl
AstitvaAggarwal Feb 3, 2024
f334168
Merge remote-tracking branch 'upstream/bpinnexperimental' into research
AstitvaAggarwal Feb 3, 2024
a3a0cb5
keeping bayesian directory files in sync with master
AstitvaAggarwal Feb 3, 2024
e0803b1
changes to sync with master
AstitvaAggarwal Feb 3, 2024
3498ddc
keep new dir
AstitvaAggarwal Feb 3, 2024
8ba64b2
having problems with eval() call in recursive Dict creation
AstitvaAggarwal Feb 12, 2024
aa15410
removed bayesian folder
AstitvaAggarwal Feb 12, 2024
9475d27
cleaned files, removed DataInterpolations
AstitvaAggarwal Feb 12, 2024
20e14b7
Merge branch 'SciML:master' into research
AstitvaAggarwal Feb 12, 2024
b50989a
done with implementation
AstitvaAggarwal Feb 14, 2024
2fbe4a9
update BPINN_PDEinvsol_tests.jl
AstitvaAggarwal Feb 14, 2024
e002802
spellings, newloss now optional
AstitvaAggarwal Feb 15, 2024
b26a75b
update PDE_BPINN.jl
AstitvaAggarwal Feb 15, 2024
f4b1bfb
removed length reweighing in BPINN ode, testset for recur..
AstitvaAggarwal Feb 15, 2024
cd01cee
corrected tests, datasetnew format
AstitvaAggarwal Feb 16, 2024
78cadf1
changes from reviews
AstitvaAggarwal Feb 21, 2024
49dd7cb
refactor code, Corrected PDE_BPINN Logphys calc.
AstitvaAggarwal Feb 26, 2024
4e8b23d
Merge branch 'SciML:master' into research
AstitvaAggarwal Feb 26, 2024
014a11d
corrected original and new implementation, comments
AstitvaAggarwal Feb 28, 2024
11bbba7
update BPINN_ode, BPINN_PDE_tests
AstitvaAggarwal Feb 28, 2024
908cb5b
update BPINN_PDE_tests.jl
AstitvaAggarwal Feb 28, 2024
585a4f5
update BPINN_PDE_tests.jl
AstitvaAggarwal Feb 29, 2024
cf77408
done for now
AstitvaAggarwal Mar 26, 2024
16e1b56
merge conflict resolution
AstitvaAggarwal Mar 26, 2024
9f694f8
update NeuralPDE.jl, advancedHMC_MCMC.jl
AstitvaAggarwal Mar 26, 2024
a28d12f
update NeuralPDE.jl
AstitvaAggarwal Mar 26, 2024
03ee7b4
update NeuralPDE.jl
AstitvaAggarwal Mar 26, 2024
29eff11
Merge branch 'SciML:master' into research
AstitvaAggarwal Mar 26, 2024
27cfb56
update NeuralPDE.jl
AstitvaAggarwal Mar 26, 2024
68c73b2
Merge branch 'research' of https://github.com/AstitvaAggarwal/NeuralP…
AstitvaAggarwal Mar 26, 2024
6ef9d48
pmean for tests
AstitvaAggarwal Mar 29, 2024
f8cf2da
.
AstitvaAggarwal Mar 29, 2024
9b535a9
Merge branch 'research' of https://github.com/AstitvaAggarwal/NeuralP…
AstitvaAggarwal Mar 29, 2024
3e96e3d
update BPINN_PDEinvsol_tests.jl
AstitvaAggarwal Mar 29, 2024
efbccda
update training_strategies.jl
AstitvaAggarwal Mar 29, 2024
39ed5f6
update BPINN_PDEinvsol_tests.jl
AstitvaAggarwal Mar 29, 2024
fad53ff
Merge branch 'SciML:master' into research
AstitvaAggarwal Apr 14, 2024
3bb93dd
Merge branch 'SciML:master' into research
AstitvaAggarwal Apr 29, 2024
2731063
changes from reviews
AstitvaAggarwal May 4, 2024
cfca4a7
Merge branch 'SciML:master' into research
AstitvaAggarwal May 4, 2024
8b980d2
Merge branch 'research' of https://github.com/AstitvaAggarwal/NeuralP…
AstitvaAggarwal May 4, 2024
a90c730
Testing code for BPINN PDEs
AstitvaAggarwal May 8, 2024
2331614
spelling corrections, cleared test space, seperated pr
AstitvaAggarwal May 8, 2024
29d270d
Merge branch 'SciML:master' into research
AstitvaAggarwal May 9, 2024
2b42acb
Merge branch 'research' of https://github.com/AstitvaAggarwal/NeuralP…
AstitvaAggarwal May 12, 2024
94770d4
need PDE exp file to be concise
AstitvaAggarwal May 12, 2024
caa727c
Merge branch 'SciML:master' into research
AstitvaAggarwal Jun 11, 2024
dda2176
Merge branch 'SciML:master' into research
AstitvaAggarwal Jul 5, 2024
7ffadbf
Merge branch 'SciML:master' into research
AstitvaAggarwal Aug 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ struct BNNODE{C, K, IT <: NamedTuple,
init_params::I
Adaptorkwargs::A
Integratorkwargs::IT
numensemble::Int64
estim_collocate::Bool
autodiff::Bool
progress::Bool
verbose::Bool
Expand All @@ -112,6 +114,8 @@ function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
Metric = DiagEuclideanMetric,
targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
numensemble = floor(Int, draw_samples / 3),
estim_collocate = false,
autodiff = false, progress = false, verbose = false)
!(chain isa Lux.AbstractExplicitLayer) &&
(chain = adapt(FromFluxAdaptor(false, false), chain))
Expand All @@ -120,6 +124,7 @@ function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
phystd, dataset, physdt, MCMCkwargs,
nchains, init_params,
Adaptorkwargs, Integratorkwargs,
numensemble, estim_collocate,
autodiff, progress, verbose)
end

Expand Down Expand Up @@ -186,7 +191,7 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
@unpack chain, l2std, phystd, param, priorsNNw, Kernel, strategy,
draw_samples, dataset, init_params,
nchains, physdt, Adaptorkwargs, Integratorkwargs,
MCMCkwargs, autodiff, progress, verbose = alg
MCMCkwargs, numensemble, estim_collocate, autodiff, progress, verbose = alg

# ahmc_bayesian_pinn_ode needs param=[] for easier vcat operation for full vector of parameters
param = param === nothing ? [] : param
Expand All @@ -211,7 +216,8 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
Integratorkwargs = Integratorkwargs,
MCMCkwargs = MCMCkwargs,
progress = progress,
verbose = verbose)
verbose = verbose,
estim_collocate = estim_collocate)

fullsolution = BPINNstats(mcmcchain, samples, statistics)
ninv = length(param)
Expand All @@ -220,7 +226,8 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
if chain isa Lux.AbstractExplicitLayer
θinit, st = Lux.setup(Random.default_rng(), chain)
θ = [vector_to_parameters(samples[i][1:(end - ninv)], θinit)
for i in (draw_samples - numensemble):draw_samples]
for i in 1:max(draw_samples - draw_samples ÷ 10, draw_samples - 1000)]

luxar = [chain(t', θ[i], st)[1] for i in 1:numensemble]
# only need for size
θinit = collect(ComponentArrays.ComponentArray(θinit))
Expand Down
196 changes: 158 additions & 38 deletions src/PDE_BPINN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mutable struct PDELogTargetDensity{
P <: Vector{<:Distribution},
I,
F,
FF,
PH
}
dim::Int64
Expand All @@ -15,17 +16,19 @@ mutable struct PDELogTargetDensity{
extraparams::Int
init_params::I
full_loglikelihood::F
L2_loss2::FF
Φ::PH

function PDELogTargetDensity(dim, strategy, dataset,
priors, allstd, names, extraparams,
init_params::AbstractVector, full_loglikelihood, Φ)
init_params::AbstractVector, full_loglikelihood, L2_loss2, Φ)
new{
typeof(strategy),
typeof(dataset),
typeof(priors),
typeof(init_params),
typeof(full_loglikelihood),
typeof(L2_loss2),
typeof(Φ)
}(dim,
strategy,
Expand All @@ -36,18 +39,20 @@ mutable struct PDELogTargetDensity{
extraparams,
init_params,
full_loglikelihood,
L2_loss2,
Φ)
end
function PDELogTargetDensity(dim, strategy, dataset,
priors, allstd, names, extraparams,
init_params::Union{NamedTuple, ComponentArrays.ComponentVector},
full_loglikelihood, Φ)
full_loglikelihood, L2_loss2, Φ)
new{
typeof(strategy),
typeof(dataset),
typeof(priors),
typeof(init_params),
typeof(full_loglikelihood),
typeof(L2_loss2),
typeof(Φ)
}(dim,
strategy,
Expand All @@ -58,22 +63,87 @@ mutable struct PDELogTargetDensity{
extraparams,
init_params,
full_loglikelihood,
L2_loss2,
Φ)
end
end

# you get a vector of losses
function get_lossy(pinnrep, dataset, Dict_differentials)
eqs = pinnrep.eqs
depvars = pinnrep.depvars #depvar order is same as dataset

# Dict_differentials is filled with Differential operator => diff_i key-value pairs
# masking operation
eqs_new = substitute.(eqs, Ref(Dict_differentials))

to_subs, tobe_subs = get_symbols(dataset, depvars, eqs)

# for values of all depvars at corresponding indvar values in dataset, create dictionaries {Dict(x(t) => 1.0496435863173237, y(t) => 1.9227770685615337)}
# In each Dict, num form of depvar is key to its value at certain coords of indvars, n_dicts = n_rows_dataset(or n_indvar_coords_dataset)
eq_subs = [Dict(tobe_subs[depvar] => to_subs[depvar][i] for depvar in depvars)
for i in 1:size(dataset[1][:, 1])[1]]

# for each dataset point(eq_sub dictionary), substitute in masked equations
# n_collocated_equations = n_rows_dataset(or n_indvar_coords_dataset)
masked_colloc_equations = [[substitute(eq, eq_sub) for eq in eqs_new]
for eq_sub in eq_subs]
# now we have vector of dataset depvar's collocated equations

# reverse dict for re-substituting values of Differential(t)(u(t)) etc
rev_Dict_differentials = Dict(value => key for (key, value) in Dict_differentials)

# unmask Differential terms in masked_colloc_equations
colloc_equations = [substitute.(masked_colloc_equation, Ref(rev_Dict_differentials))
for masked_colloc_equation in masked_colloc_equations]

# nested vector of datafree_pde_loss_functions (as in discretize.jl)
# each sub vector has dataset's indvar coord's datafree_colloc_loss_function, n_subvectors = n_rows_dataset(or n_indvar_coords_dataset)
# zip each colloc equation with args for each build_loss call per equation vector
datafree_colloc_loss_functions = [[build_loss_function(pinnrep, eq, pde_indvar)
for (eq, pde_indvar, integration_indvar) in zip(colloc_equation,
pinnrep.pde_indvars,
pinnrep.pde_integration_vars)] for colloc_equation in colloc_equations]

return datafree_colloc_loss_functions
end

function get_symbols(dataset, depvars, eqs)
# take only values of depvars from dataset
depvar_vals = [dataset_i[:, 1] for dataset_i in dataset]
# order of pinnrep.depvars, depvar_vals, BayesianPINN.dataset must be same
to_subs = Dict(depvars .=> depvar_vals)

numform_vars = Symbolics.get_variables.(eqs)
Eq_vars = unique(reduce(vcat, numform_vars))
# got equation's depvar num format {x(t)} for use in substitute()

tobe_subs = Dict()
for a in depvars
for i in Eq_vars
expr = toexpr(i)
if (expr isa Expr) && (expr.args[1] == a)
tobe_subs[a] = i
end
end
end
# depvar symbolic and num format got, tobe_subs : Dict{Any, Any}(:y => y(t), :x => x(t))

return to_subs, tobe_subs
end

function LogDensityProblems.logdensity(Tar::PDELogTargetDensity, θ)
# for parameter estimation neccesarry to use multioutput case
return Tar.full_loglikelihood(setparameters(Tar, θ),
Tar.allstd) + priorlogpdf(Tar, θ) + L2LossData(Tar, θ)
# + L2loss2(Tar, θ)
if Tar.L2_loss2 isa Nothing
return Tar.full_loglikelihood(setparameters(Tar, θ), Tar.allstd) +
priorlogpdf(Tar, θ) + L2LossData(Tar, θ)
else
return Tar.full_loglikelihood(setparameters(Tar, θ), Tar.allstd) +
priorlogpdf(Tar, θ) + L2LossData(Tar, θ) +
Tar.L2_loss2(setparameters(Tar, θ), Tar.allstd)
end
end

# function L2loss2(Tar::PDELogTargetDensity, θ)
# return Tar.full_loglikelihood(setparameters(Tar, θ),
# Tar.allstd)
# end

function setparameters(Tar::PDELogTargetDensity, θ)
names = Tar.names
ps_new = θ[1:(end - Tar.extraparams)]
Expand Down Expand Up @@ -117,6 +187,8 @@ function L2LossData(Tar::PDELogTargetDensity, θ)

# dataset of form Vector[matrix_x, matrix_y, matrix_z]
# matrix_i is of form [i,indvar1,indvar2,..] (needed in case if heterogenous domains)
# note that indvar1,indvar2.. cols can be different values for different depvar matrices
# dataset,phi order follows pinnrep.depvars orders of variables (order of declaration in @variables macro)

# Phi is the trial solution for each NN in chain array
# Creating logpdf( MvNormal(Phi(t,θ),std), dataset[i] )
Expand Down Expand Up @@ -162,27 +234,6 @@ function priorlogpdf(Tar::PDELogTargetDensity, θ)
return logpdf(nnwparams, θ)
end

function integratorchoice(Integratorkwargs, initial_ϵ)
Integrator = Integratorkwargs[:Integrator]
if Integrator == JitteredLeapfrog
jitter_rate = Integratorkwargs[:jitter_rate]
Integrator(initial_ϵ, jitter_rate)
elseif Integrator == TemperedLeapfrog
tempering_rate = Integratorkwargs[:tempering_rate]
Integrator(initial_ϵ, tempering_rate)
else
Integrator(initial_ϵ)
end
end

function adaptorchoice(Adaptor, mma, ssa)
if Adaptor != AdvancedHMC.NoAdaptation()
Adaptor(mma, ssa)
else
AdvancedHMC.NoAdaptation()
end
end

function inference(samples, pinnrep, saveats, numensemble, ℓπ)
domains = pinnrep.domains
phi = pinnrep.phi
Expand Down Expand Up @@ -247,6 +298,27 @@ function inference(samples, pinnrep, saveats, numensemble, ℓπ)
return ensemblecurves, estimatedLuxparams, estimated_params, timepoints
end

function integratorchoice(Integratorkwargs, initial_ϵ)
Integrator = Integratorkwargs[:Integrator]
if Integrator == JitteredLeapfrog
jitter_rate = Integratorkwargs[:jitter_rate]
Integrator(initial_ϵ, jitter_rate)
elseif Integrator == TemperedLeapfrog
tempering_rate = Integratorkwargs[:tempering_rate]
Integrator(initial_ϵ, tempering_rate)
else
Integrator(initial_ϵ)
end
end

function adaptorchoice(Adaptor, mma, ssa)
if Adaptor != AdvancedHMC.NoAdaptation()
Adaptor(mma, ssa)
else
AdvancedHMC.NoAdaptation()
end
end

"""
ahmc_bayesian_pinn_pde(pde_system, discretization;
draw_samples = 1000,
Expand Down Expand Up @@ -295,15 +367,56 @@ end
function ahmc_bayesian_pinn_pde(pde_system, discretization;
draw_samples = 1000,
bcstd = [0.01], l2std = [0.05],
phystd = [0.05], priorsNNw = (0.0, 2.0),
phystd = [0.05], phystdnew = [0.05], priorsNNw = (0.0, 2.0),
param = [], nchains = 1, Kernel = HMC(0.1, 30),
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,), saveats = [1 / 10.0],
numensemble = floor(Int, draw_samples / 3), progress = false, verbose = false)
numensemble = floor(Int, draw_samples / 3), Dict_differentials = nothing,
progress = false, verbose = false)
pinnrep = symbolic_discretize(pde_system, discretization)
dataset_pde, dataset_bc = discretization.dataset

newloss = if Dict_differentials isa Nothing
nothing
else
datafree_colloc_loss_functions = get_lossy(pinnrep, dataset_pde, Dict_differentials)
# equals number of indvar coords in dataset
# add case for if parameters present in bcs?

train_sets_pde = get_dataset_train_points(pde_system.eqs,
dataset_pde,
pinnrep)
colloc_train_sets = [[hcat(train_sets_pde[i][:, j]...)' for i in eachindex(datafree_colloc_loss_functions[1])] for j in eachindex(datafree_colloc_loss_functions)]

# for each datafree_colloc_loss_function create loss_functions by passing dataset's indvar coords as train_sets_pde.
# placeholder strategy = GridTraining(0.1), datafree_bc_loss_function and train_sets_bc must be nothing
# order of indvar coords will be same as corresponding depvar coords values in dataset provided in get_lossy() call.
pde_loss_function_points = [merge_strategy_with_loglikelihood_function(
pinnrep,
GridTraining(0.1),
datafree_colloc_loss_functions[i],
nothing;
train_sets_pde = colloc_train_sets[i],
train_sets_bc = nothing)[1]
for i in eachindex(datafree_colloc_loss_functions)]

function L2_loss2(θ, allstd)
stdpdesnew = allstd[4]

# first vector of losses,from tuple -> pde losses, first[1] pde loss
pde_loglikelihoods = [sum([pde_loss_function(θ, stdpdesnew[i])
for (i, pde_loss_function) in enumerate(pde_loss_functions)])
for pde_loss_functions in pde_loss_function_points]

# bc_loglikelihoods = [sum([bc_loss_function(θ, stdpdesnew[i]) for (i, bc_loss_function) in enumerate(pde_loss_function_points[1])]) for pde_loss_function_points in pde_loss_functions]
# for (j, bc_loss_function) in enumerate(bc_loss_functions)]

return sum(pde_loglikelihoods)
end
end

# [WIP] add overall functionality for BC dataset points
if ((dataset_bc isa Nothing) && (dataset_pde isa Nothing))
dataset = nothing
elseif dataset_bc isa Nothing
Expand Down Expand Up @@ -333,9 +446,6 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
# NN solutions for loglikelihood which is used for L2lossdata
Φ = pinnrep.phi

# for new L2 loss
# discretization.additional_loss =

if nchains < 1
throw(error("number of chains must be greater than or equal to 1"))
end
Expand Down Expand Up @@ -375,11 +485,12 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
strategy,
dataset,
priors,
[phystd, bcstd, l2std],
[phystd, bcstd, l2std, phystdnew],
names,
ninv,
initial_nnθ,
full_weighted_loglikelihood,
newloss,
Φ)

Adaptor, Metric, targetacceptancerate = Adaptorkwargs[:Adaptor],
Expand All @@ -394,10 +505,14 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
ℓπ.allstd))
@info("Current Prior Log-likelihood : ", priorlogpdf(ℓπ, initial_θ))
@info("Current MSE against dataset Log-likelihood : ", L2LossData(ℓπ, initial_θ))
if !(newloss isa Nothing)
@info("Current L2_LOSSY : ",
ℓπ.L2_loss2(setparameters(ℓπ, initial_θ),
ℓπ.allstd))
end

# parallel sampling option
if nchains != 1

# Cache to store the chains
bpinnsols = Vector{Any}(undef, nchains)

Expand Down Expand Up @@ -453,6 +568,11 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
@info("Current Prior Log-likelihood : ", priorlogpdf(ℓπ, samples[end]))
@info("Current MSE against dataset Log-likelihood : ",
L2LossData(ℓπ, samples[end]))
if !(newloss isa Nothing)
@info("Current L2_LOSSY : ",
ℓπ.L2_loss2(setparameters(ℓπ, samples[end]),
ℓπ.allstd))
end

fullsolution = BPINNstats(mcmc_chain, samples, stats)
ensemblecurves, estimnnparams, estimated_params, timepoints = inference(samples,
Expand Down
Loading
Loading