Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Run BPINN experiments #761

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ include("rode_solve.jl")
include("transform_inf_integral.jl")
include("discretize.jl")
include("neural_adapter.jl")
include("advancedHMC_MCMC.jl")
include("BPINN_ode.jl")
include("bayesian/advancedHMC_MCMC.jl")
include("bayesian/BPINN_ode.jl")
include("bayesian/collocated_estim.jl")

export NNODE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE,
KolmogorovPDEProblem, NNKolmogorov, NNStopping, ParamKolmogorovPDEProblem,
Expand Down
12 changes: 9 additions & 3 deletions src/BPINN_ode.jl → src/bayesian/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem,
verbose = false,
saveat = 1 / 50.0,
maxiters = nothing,
numensemble = floor(Int, alg.draw_samples / 3))
numensemble = floor(Int, alg.draw_samples / 3),
estim_collocate = false)
@unpack chain, l2std, phystd, param, priorsNNw, Kernel, strategy,
draw_samples, dataset, init_params,
nchains, physdt, Adaptorkwargs, Integratorkwargs,
Expand Down Expand Up @@ -207,16 +208,21 @@ function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem,
Integratorkwargs = Integratorkwargs,
MCMCkwargs = MCMCkwargs,
progress = progress,
verbose = verbose)
verbose = verbose,
estim_collocate = estim_collocate)

fullsolution = BPINNstats(mcmcchain, samples, statistics)
ninv = length(param)
t = collect(eltype(saveat), prob.tspan[1]:saveat:prob.tspan[2])

if chain isa Lux.AbstractExplicitLayer
θinit, st = Lux.setup(Random.default_rng(), chain)
println(length(θinit))
println(length(samples[1]))
println(draw_samples)
θ = [vector_to_parameters(samples[i][1:(end - ninv)], θinit)
for i in (draw_samples - numensemble):draw_samples]
for i in 1:max(draw_samples - draw_samples ÷ 10, draw_samples - 1000)]

luxar = [chain(t', θ[i], st)[1] for i in 1:numensemble]
# only need for size
θinit = collect(ComponentArrays.ComponentArray(θinit))
Expand Down
31 changes: 19 additions & 12 deletions src/advancedHMC_MCMC.jl → src/bayesian/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
physdt::Float64
extraparams::Int
init_params::I
estim_collocate::Bool

function LogTargetDensity(dim, prob, chain::Optimisers.Restructure, st, strategy,
dataset,
priors, phystd, l2std, autodiff, physdt, extraparams,
init_params::AbstractVector)
init_params::AbstractVector, estim_collocate)
new{
typeof(chain),
Nothing,
Expand All @@ -39,12 +40,13 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
autodiff,
physdt,
extraparams,
init_params)
init_params,
estim_collocate)
end
function LogTargetDensity(dim, prob, chain::Lux.AbstractExplicitLayer, st, strategy,
dataset,
priors, phystd, l2std, autodiff, physdt, extraparams,
init_params::NamedTuple)
init_params::NamedTuple, estim_collocate)
new{
typeof(chain),
typeof(st),
Expand All @@ -60,7 +62,8 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
autodiff,
physdt,
extraparams,
init_params)
init_params,
estim_collocate)
end
end

Expand All @@ -79,7 +82,11 @@ function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
end

function LogDensityProblems.logdensity(Tar::LogTargetDensity, θ)
return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ)
if Tar.estim_collocate
return physloglikelihood(Tar, θ)/length(Tar.dataset[1]) + priorweights(Tar, θ) + L2LossData(Tar, θ)/length(Tar.dataset[1]) + L2loss2(Tar, θ)/length(Tar.dataset[1])
else
return physloglikelihood(Tar, θ)/length(Tar.dataset[1]) + priorweights(Tar, θ) + L2LossData(Tar, θ)/length(Tar.dataset[1])
end
end

LogDensityProblems.dimension(Tar::LogTargetDensity) = Tar.dim
Expand Down Expand Up @@ -481,7 +488,8 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain;
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
MCMCkwargs = (n_leapfrog = 30,),
progress = false, verbose = false)
progress = false, verbose = false,
estim_collocate = false)

# NN parameter prior mean and variance(PriorsNN must be a tuple)
if isinplace(prob)
Expand Down Expand Up @@ -542,7 +550,7 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain;
t0 = prob.tspan[1]
# dimensions would be total no of params,initial_nnθ for Lux namedTuples
ℓπ = LogTargetDensity(nparameters, prob, recon, st, strategy, dataset, priors,
phystd, l2std, autodiff, physdt, ninv, initial_nnθ)
phystd, l2std, autodiff, physdt, ninv, initial_nnθ, estim_collocate)

try
ℓπ(t0, initial_θ[1:(nparameters - ninv)])
Expand Down Expand Up @@ -580,7 +588,7 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain;
MCMC_alg = kernelchoice(Kernel, MCMCkwargs)
Kernel = AdvancedHMC.make_kernel(MCMC_alg, integrator)
samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, adaptor;
progress = progress, verbose = verbose)
progress = progress, verbose = verbose, drop_warmup = true)

samplesc[i] = samples
statsc[i] = stats
Expand All @@ -598,11 +606,10 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain;
MCMC_alg = kernelchoice(Kernel, MCMCkwargs)
Kernel = AdvancedHMC.make_kernel(MCMC_alg, integrator)
samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples,
adaptor; progress = progress, verbose = verbose)

adaptor; progress = progress, verbose = verbose, drop_warmup = true)
# return a chain(basic chain),samples and stats
matrix_samples = hcat(samples...)
mcmc_chain = MCMCChains.Chains(matrix_samples')
matrix_samples = reshape(hcat(samples...), (length(samples[1]), length(samples), 1))
mcmc_chain = MCMCChains.Chains(matrix_samples)
return mcmc_chain, samples, stats
end
end
192 changes: 192 additions & 0 deletions src/bayesian/collocated_estim.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# suggested extra loss function
function L2loss2(Tar::LogTargetDensity, θ)
f = Tar.prob.f

# parameter estimation chosen or not
if Tar.extraparams > 0
# deri_sol = deri_sol'
autodiff = Tar.autodiff
# # Timepoints to enforce Physics
# dataset = Array(reduce(hcat, dataset)')
# t = dataset[end, :]
# û = dataset[1:(end - 1), :]

# ode_params = Tar.extraparams == 1 ?
# θ[((length(θ) - Tar.extraparams) + 1):length(θ)][1] :
# θ[((length(θ) - Tar.extraparams) + 1):length(θ)]

# if length(û[:, 1]) == 1
# physsol = [f(û[:, i][1],
# ode_params,
# t[i])
# for i in 1:length(û[1, :])]
# else
# physsol = [f(û[:, i],
# ode_params,
# t[i])
# for i in 1:length(û[1, :])]
# end
# #form of NN output matrix output dim x n
# deri_physsol = reduce(hcat, physsol)

# > for perfect deriv(basically gradient matching in case of an ODEFunction)
# in case of PDE or general ODE we would want to reduce residue of f(du,u,p,t)
# if length(û[:, 1]) == 1
# deri_sol = [f(û[:, i][1],
# Tar.prob.p,
# t[i])
# for i in 1:length(û[1, :])]
# else
# deri_sol = [f(û[:, i],
# Tar.prob.p,
# t[i])
# for i in 1:length(û[1, :])]
# end
# deri_sol = reduce(hcat, deri_sol)
# deri_sol = reduce(hcat, derivatives)

# Timepoints to enforce Physics
t = Tar.dataset[end]
u1 = Tar.dataset[2]
û = Tar.dataset[1]
# Tar(t, θ[1:(length(θ) - Tar.extraparams)])'
#

nnsol = NNodederi(Tar, t, θ[1:(length(θ) - Tar.extraparams)], autodiff)

ode_params = Tar.extraparams == 1 ?
θ[((length(θ) - Tar.extraparams) + 1):length(θ)][1] :
θ[((length(θ) - Tar.extraparams) + 1):length(θ)]

if length(Tar.prob.u0) == 1
physsol = [f(û[i],
ode_params,
t[i])
for i in 1:length(û[:, 1])]
else
physsol = [f([û[i], u1[i]],
ode_params,
t[i])
for i in 1:length(û)]
end
#form of NN output matrix output dim x n
deri_physsol = reduce(hcat, physsol)

# if length(Tar.prob.u0) == 1
# nnsol = [f(û[i],
# Tar.prob.p,
# t[i])
# for i in 1:length(û[:, 1])]
# else
# nnsol = [f([û[i], u1[i]],
# Tar.prob.p,
# t[i])
# for i in 1:length(û[:, 1])]
# end
# form of NN output matrix output dim x n
# nnsol = reduce(hcat, nnsol)

# > Instead of dataset gradients trying NN derivatives with dataset collocation
# # convert to matrix as nnsol

physlogprob = 0
for i in 1:length(Tar.prob.u0)
# can add phystd[i] for u[i]
physlogprob += logpdf(MvNormal(deri_physsol[i, :],
LinearAlgebra.Diagonal(map(abs2,
(Tar.l2std[i] * 4.0) .*
ones(length(nnsol[i, :]))))),
nnsol[i, :])
end
return physlogprob
else
return 0
end
end

# PDE(DU,U,P,T)=0

# Derivated via Central Diff
# function calculate_derivatives2(dataset)
# x̂, time = dataset
# num_points = length(x̂)
# # Initialize an array to store the derivative values.
# derivatives = similar(x̂)

# for i in 2:(num_points - 1)
# # Calculate the first-order derivative using central differences.
# Δt_forward = time[i + 1] - time[i]
# Δt_backward = time[i] - time[i - 1]

# derivative = (x̂[i + 1] - x̂[i - 1]) / (Δt_forward + Δt_backward)

# derivatives[i] = derivative
# end

# # Derivatives at the endpoints can be calculated using forward or backward differences.
# derivatives[1] = (x̂[2] - x̂[1]) / (time[2] - time[1])
# derivatives[end] = (x̂[end] - x̂[end - 1]) / (time[end] - time[end - 1])
# return derivatives
# end

function calderivatives(prob, dataset)
chainflux = Flux.Chain(Flux.Dense(1, 8, tanh), Flux.Dense(8, 8, tanh),
Flux.Dense(8, 2)) |> Flux.f64
# chainflux = Flux.Chain(Flux.Dense(1, 7, tanh), Flux.Dense(7, 1)) |> Flux.f64
function loss(x, y)
# sum(Flux.mse.(prob.u0[1] .+ (prob.tspan[2] .- x)' .* chainflux(x)[1, :], y[1]) +
# Flux.mse.(prob.u0[2] .+ (prob.tspan[2] .- x)' .* chainflux(x)[2, :], y[2]))
# sum(Flux.mse.(prob.u0[1] .+ (prob.tspan[2] .- x)' .* chainflux(x)[1, :], y[1]))
sum(Flux.mse.(chainflux(x), y))
end
optimizer = Flux.Optimise.ADAM(0.01)
epochs = 3000
for epoch in 1:epochs
Flux.train!(loss,
Flux.params(chainflux),
[(dataset[end]', dataset[1:(end - 1)])],
optimizer)
end

# A1 = (prob.u0' .+
# (prob.tspan[2] .- (dataset[end]' .+ sqrt(eps(eltype(Float64)))))' .*
# chainflux(dataset[end]' .+ sqrt(eps(eltype(Float64))))')

# A2 = (prob.u0' .+
# (prob.tspan[2] .- (dataset[end]'))' .*
# chainflux(dataset[end]')')

A1 = chainflux(dataset[end]' .+ sqrt(eps(eltype(dataset[end][1]))))
A2 = chainflux(dataset[end]')

gradients = (A2 .- A1) ./ sqrt(eps(eltype(dataset[end][1])))

return gradients
end

function calculate_derivatives(dataset)

# u = dataset[1]
# u1 = dataset[2]
# t = dataset[end]
# # control points
# n = Int(floor(length(t) / 10))
# # spline for datasetvalues(solution)
# # interp = BSplineApprox(u, t, 4, 10, :Uniform, :Uniform)
# interp = CubicSpline(u, t)
# interp1 = CubicSpline(u1, t)
# # derrivatives interpolation
# dx = t[2] - t[1]
# time = collect(t[1]:dx:t[end])
# smoothu = [interp(i) for i in time]
# smoothu1 = [interp1(i) for i in time]
# # derivative of the spline (must match function derivative)
# û = tvdiff(smoothu, 20, 0.5, dx = dx, ε = 1)
# û1 = tvdiff(smoothu1, 20, 0.5, dx = dx, ε = 1)
# # tvdiff(smoothu, 100, 0.035, dx = dx, ε = 1)
# # FDM
# # û1 = diff(u) / dx
# # dataset[1] and smoothu are almost equal(rounding errors)
# return [û, û1]

end
Loading