Skip to content

Commit 0687aaf

Browse files
Merge pull request #745 from AstitvaAggarwal/Bpinn_pde
BPINN PDE solver
2 parents 9f191f8 + 8008f3a commit 0687aaf

13 files changed

+1267
-111
lines changed

Diff for: .github/workflows/CI.yml

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ jobs:
1818
matrix:
1919
group:
2020
- ODEBPINN
21+
- PDEBPINN
2122
- NNPDE1
2223
- NNPDE2
2324
- AdaptiveLoss

Diff for: docs/src/tutorials/Lotka_Volterra_BPINNs.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ plot!(solution, labels = ["true x" "true y"])
108108
We can see the estimated ODE parameters by -
109109

110110
```@example bpinn
111-
sol_pestim.estimated_ode_params
111+
sol_pestim.estimated_de_params
112112
```
113113

114114
We can see it is close to the true values of the parameters.

Diff for: src/BPINN_ode.jl

+16-13
Original file line numberDiff line numberDiff line change
@@ -148,21 +148,24 @@ end
148148
BPINN Solution contains the original solution from AdvancedHMC.jl sampling(BPINNstats contains fields related to that)
149149
> ensemblesol is the Probabilistic Estimate(MonteCarloMeasurements.jl Particles type) of Ensemble solution from All Neural Network's(made using all sampled parameters) output's.
150150
> estimated_nn_params - Probabilistic Estimate of NN params from sampled weights,biases
151-
> estimated_ode_params - Probabilistic Estimate of ODE params from sampled unknown ode paramters
151+
> estimated_de_params - Probabilistic Estimate of DE params from sampled unknown DE paramters
152152
"""
153-
struct BPINNsolution{O <: BPINNstats, E,
154-
NP <: Vector{<:MonteCarloMeasurements.Particles{<:Float64}},
155-
OP <: Union{Vector{Nothing},
156-
Vector{<:MonteCarloMeasurements.Particles{<:Float64}}}}
153+
154+
struct BPINNsolution{O <: BPINNstats, E, NP, OP, P}
157155
original::O
158156
ensemblesol::E
159157
estimated_nn_params::NP
160-
estimated_ode_params::OP
161-
162-
function BPINNsolution(original, ensemblesol, estimated_nn_params, estimated_ode_params)
158+
estimated_de_params::OP
159+
timepoints::P
160+
161+
function BPINNsolution(original,
162+
ensemblesol,
163+
estimated_nn_params,
164+
estimated_de_params,
165+
timepoints)
163166
new{typeof(original), typeof(ensemblesol), typeof(estimated_nn_params),
164-
typeof(estimated_ode_params)}(original, ensemblesol, estimated_nn_params,
165-
estimated_ode_params)
167+
typeof(estimated_de_params), typeof(timepoints)}(original, ensemblesol, estimated_nn_params,
168+
estimated_de_params, timepoints)
166169
end
167170
end
168171

@@ -260,14 +263,14 @@ function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem,
260263
end
261264

262265
nnparams = length(θinit)
263-
estimnnparams = [Particles(reduce(hcat, samples)[i, :]) for i in 1:nnparams]
266+
estimnnparams = [Particles(reduce(hcat, samples[(end - numensemble):end])[i, :]) for i in 1:nnparams]
264267

265268
if ninv == 0
266269
estimated_params = [nothing]
267270
else
268-
estimated_params = [Particles(reduce(hcat, samples[(end - ninv + 1):end])[i, :])
271+
estimated_params = [Particles(reduce(hcat, samples[(end - numensemble):end])[i, :])
269272
for i in (nnparams + 1):(nnparams + ninv)]
270273
end
271274

272-
BPINNsolution(fullsolution, ensemblecurves, estimnnparams, estimated_params)
275+
BPINNsolution(fullsolution, ensemblecurves, estimnnparams, estimated_params, t)
273276
end

Diff for: src/NeuralPDE.jl

+14-12
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,21 @@ include("discretize.jl")
5252
include("neural_adapter.jl")
5353
include("advancedHMC_MCMC.jl")
5454
include("BPINN_ode.jl")
55+
include("PDE_BPINN.jl")
5556

5657
export NNODE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE,
57-
KolmogorovPDEProblem, NNKolmogorov, NNStopping, ParamKolmogorovPDEProblem,
58-
KolmogorovParamDomain, NNParamKolmogorov,
59-
PhysicsInformedNN, discretize,
60-
GridTraining, StochasticTraining, QuadratureTraining, QuasiRandomTraining,
61-
WeightedIntervalTraining,
62-
build_loss_function, get_loss_function,
63-
generate_training_sets, get_variables, get_argument, get_bounds,
64-
get_phi, get_numeric_derivative, get_numeric_integral,
65-
build_symbolic_equation, build_symbolic_loss_function, symbolic_discretize,
66-
AbstractAdaptiveLoss, NonAdaptiveLoss, GradientScaleAdaptiveLoss,
67-
MiniMaxAdaptiveLoss,
68-
LogOptions, ahmc_bayesian_pinn_ode, BNNODE
58+
KolmogorovPDEProblem, NNKolmogorov, NNStopping, ParamKolmogorovPDEProblem,
59+
KolmogorovParamDomain, NNParamKolmogorov,
60+
PhysicsInformedNN, discretize,
61+
GridTraining, StochasticTraining, QuadratureTraining, QuasiRandomTraining,
62+
WeightedIntervalTraining,
63+
build_loss_function, get_loss_function,
64+
generate_training_sets, get_variables, get_argument, get_bounds,
65+
get_phi, get_numeric_derivative, get_numeric_integral,
66+
build_symbolic_equation, build_symbolic_loss_function, symbolic_discretize,
67+
AbstractAdaptiveLoss, NonAdaptiveLoss, GradientScaleAdaptiveLoss,
68+
MiniMaxAdaptiveLoss, LogOptions,
69+
ahmc_bayesian_pinn_ode, BNNODE, ahmc_bayesian_pinn_pde, vector_to_parameters,
70+
BPINNsolution
6971

7072
end # module

0 commit comments

Comments
 (0)