Skip to content

Commit

Permalink
include dataset points into physicsloglikelhiood
Browse files Browse the repository at this point in the history
  • Loading branch information
AstitvaAggarwal committed Nov 23, 2023
1 parent 9069028 commit 98e0a94
Show file tree
Hide file tree
Showing 4 changed files with 499 additions and 162 deletions.
114 changes: 91 additions & 23 deletions src/PDE_BPINN.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
mutable struct PDELogTargetDensity{
ST <: AbstractTrainingStrategy,
D <: Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}},
D <: Union{Vector{Nothing}, Vector{<:Matrix{<:Real}}},
P <: Vector{<:Distribution},
I,
F,
Expand All @@ -12,7 +12,7 @@ mutable struct PDELogTargetDensity{
priors::P
allstd::Vector{Vector{Float64}}
names::Tuple
physdt::Float64
physdt::Vector{Float64}
extraparams::Int
init_params::I
full_loglikelihood::F
Expand Down Expand Up @@ -42,7 +42,8 @@ mutable struct PDELogTargetDensity{
end
function PDELogTargetDensity(dim, strategy, dataset,
priors, allstd, names, physdt, extraparams,
init_params::NamedTuple, full_loglikelihood, Phi)
init_params::Union{NamedTuple, ComponentArrays.ComponentVector},
full_loglikelihood, Phi)
new{
typeof(strategy),
typeof(dataset),
Expand Down Expand Up @@ -126,22 +127,48 @@ end
function L2loss2(Tar::PDELogTargetDensity, θ)
return logpdf(MvNormal(pde(phi, Tar.dataset[end], θ)), zeros(length(pde_eqs)))
end

# L2 losses loglikelihood(needed mainly for ODE parameter estimation)
function L2LossData(Tar::PDELogTargetDensity, θ)
Phi = Tar.Phi
init_params = Tar.init_params
dataset = Tar.dataset
sumt = 0
L2stds = Tar.allstd[3]
# each dep var has a diff dataset depending on its indep var and thier domains
# these datasets are matrices of first col-dep var and remaining cols-all indep var
# Tar.init_params is needed to contruct a vector of parameters into a ComponentVector

# dataset of form Vector[matrix_x, matrix_y, matrix_z]
# matrix_i is of form [i,indvar1,indvar2,..] (needed in case if heterogenous domains)

# Phi is the trial solution for each NN in chain array
# Creating logpdf( MvNormal(Phi(t,θ),std), dataset[i] )
# dataset[i][:, 2:end] -> indepvar cols of a particular depvar's dataset
# dataset[i][:, 1] -> depvar col of depvar's dataset

if Tar.extraparams > 0
if Tar.init_params isa ComponentArrays.ComponentVector
return sum([logpdf(MvNormal(Tar.Phi[i](Tar.dataset[end]',
vector_to_parameters(θ[1:(end - Tar.extraparams)],
Tar.init_params)[Tar.names[i]])[1,
:], ones(length(Tar.dataset[end])) .* Tar.allstd[3][i]), Tar.dataset[i])
for i in eachindex(Tar.Phi)])
for i in eachindex(Phi)
sumt += logpdf(MvNormal(Phi[i](dataset[i][:, 2:end]',
vector_to_parameters(θ[1:(end - Tar.extraparams)],
init_params)[Tar.names[i]])[1,
:],
ones(size(dataset[i])[1]) .* L2stds[i]),
dataset[i][:, 1])
end
sumt
else
# Flux case needs subindexing wrt Tar.names indices(hence stored in Tar.names)
return sum([logpdf(MvNormal(Tar.Phi[i](Tar.dataset[end]',
vector_to_parameters(θ[1:(end - Tar.extraparams)],
Tar.init_params)[Tar.names[2][i]])[1,
:], ones(length(Tar.dataset[end])) .* Tar.allstd[3][i]), Tar.dataset[i])
for i in eachindex(Tar.Phi)])
for i in eachindex(Phi)
sumt += logpdf(MvNormal(Phi[i](dataset[i][:, 2:end]',
vector_to_parameters(θ[1:(end - Tar.extraparams)],
init_params)[Tar.names[2][i]])[1,
:],
ones(size(dataset[i])[1]) .* L2stds[i]),
dataset[i][:, 1])
end
sumt
end
else
return 0
Expand Down Expand Up @@ -204,21 +231,58 @@ function adaptorchoice(Adaptor, mma, ssa)
end
end

# dataset would be (x̂,t)
# function inference(samples, discretization, saveat, numensemble, ℓπ)
# ranges = []
# for i in eachindex(domains)
# push!(ranges, [infimum(domains[i].domain), supremum(infimum(domains[i].domain))])
# end
# ranges = map(ranges) do x
# collect(x[1]:saveat:x[2])
# end
# samples = samples[(end - numensemble):end]
# chain = discretization.chain

# if discretization.multioutput && chain[1] isa Lux.AbstractExplicitLayer
# temp = [setparameters(ℓπ, samples[i]) for i in eachindex(samples)]

# luxar = map(temp) do x
# chain(t', x, st[i])
# end

# elseif discretization.multioutput && chain[1] isa Flux.chain

# elseif chain isa Flux.Chain
# re = Flux.destructure(chain)[2]
# out1 = re.([sample for sample in samples])
# luxar = [collect(out1[i](t') for t in ranges)]
# fluxmean = map(luxar) do x
# mean(vcat(x...)[:, i]) for i in eachindex(x)
# end
# else
# transsamples = [vector_to_parameters(sample, initl) for sample in samples]
# luxar2 = [chainl(t1', transsamples[i], st)[1] for i in 800:1000]
# luxmean = [mean(vcat(luxar2...)[:, i]) for i in eachindex(t1)]
# end
# end

# priors: pdf for W,b + pdf for ODE params
# lotka specific kwargs here
function ahmc_bayesian_pinn_pde(pde_system, discretization;
strategy = GridTraining, dataset = [nothing],
init_params = nothing, draw_samples = 1000,
physdt = 1 / 20.0, bcstd = [0.01], l2std = [0.05],
draw_samples = 1000, physdt = [1 / 20.0],
bcstd = [0.01], l2std = [0.05],
phystd = [0.05], priorsNNw = (0.0, 2.0),
param = [], nchains = 1, Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
MCMCkwargs = (n_leapfrog = 30,),
MCMCkwargs = (n_leapfrog = 30,), saveat = 1 / 50.0,
numensemble = 100,
# floor(Int, alg.draw_samples / 3),
progress = false, verbose = false)
pinnrep = symbolic_discretize(pde_system, discretization, bayesian = true)
pinnrep = symbolic_discretize(pde_system,
discretization,
bayesian = true,
dataset_given = dataset)

# for physics loglikelihood
full_weighted_loglikelihood = pinnrep.loss_functions.full_loss_function
Expand All @@ -245,7 +309,7 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
# converting vector of parameters to ComponentArray for runtimegenerated functions
names = ntuple(i -> pinnrep.depvars[i], length(chain))
else
# this case is for Flux multioutput
# Flux multioutput
i = 0
temp = []
for j in eachindex(initial_nnθ)
Expand All @@ -270,7 +334,8 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
nparameters += ninv
end

strategy = strategy(physdt)
# physdt vector in case of N-dimensional domains
strategy = discretization.strategy

# dimensions would be total no of params,initial_nnθ for Lux namedTuples
ℓπ = PDELogTargetDensity(nparameters,
Expand Down Expand Up @@ -318,10 +383,13 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
Kernel = AdvancedHMC.make_kernel(MCMC_alg, integrator)
samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, adaptor;
progress = progress, verbose = verbose)
mcmc_chain = Chains(hcat(samples...)')

fullsolution = BPINNstats(mcmcchain, samples, statistics)
estimsol = inference(samples, discretization, saveat, numensemble, ℓπ)

samplesc[i] = samples
statsc[i] = stats
mcmc_chain = Chains(hcat(samples...)')
chains[i] = mcmc_chain
end

Expand All @@ -348,7 +416,7 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
println("Current Prior Log-likelihood : ", priorlogpdf(ℓπ, samples[end]))
println("Current MSE against dataset Log-likelihood : ",
L2LossData(ℓπ, samples[end]))

fullsolution = BPINNstats(mcmc_chain, samples, stats)
return mcmc_chain, samples, stats
end
end
33 changes: 26 additions & 7 deletions src/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ to the PDE.
For more information, see `discretize` and `PINNRepresentation`.
"""
function SciMLBase.symbolic_discretize(pde_system::PDESystem,
discretization::PhysicsInformedNN; bayesian::Bool = false)
discretization::PhysicsInformedNN; bayesian::Bool = false,dataset_given=[nothing])
eqs = pde_system.eqs
bcs = pde_system.bcs
chain = discretization.chain
Expand Down Expand Up @@ -567,7 +567,6 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
strategy,
datafree_pde_loss_functions,
datafree_bc_loss_functions)

# setup for all adaptive losses
num_pde_losses = length(pde_loss_functions)
num_bc_losses = length(bc_loss_functions)
Expand All @@ -587,14 +586,34 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
bc_loss_functions)

if bayesian
# required as Physics loss also needed on dataset domain points
pde_loss_functions1, bc_loss_functions1 = if !(dataset_given[1] isa Nothing)
if !(strategy isa GridTraining)
println("only GridTraining strategy allowed")
else
merge_strategy_with_loglikelihood_function(pinnrep,
strategy,
datafree_pde_loss_functions,
datafree_bc_loss_functions, train_sets_L2loss2 = dataset_given)
end
end

function full_likelihood_function(θ, allstd)
stdpdes, stdbcs, stdextra = allstd
# the aggregation happens on cpu even if the losses are gpu, probably fine since it's only a few of them
pde_loglikelihoods = [logpdf(Normal(0, stdpdes[i]), pde_loss_function(θ))
for (i, pde_loss_function) in enumerate(pde_loss_functions)]
for (i, pde_loss_function) in enumerate(pde_loss_functions)]

bc_loglikelihoods = [logpdf(Normal(0, stdbcs[j]), bc_loss_function(θ))
for (j, bc_loss_function) in enumerate(bc_loss_functions)]
for (j, bc_loss_function) in enumerate(bc_loss_functions)]

if !(dataset_given[1] isa Nothing)
pde_loglikelihoods += [logpdf(Normal(0, stdpdes[j]), pde_loss_function1(θ))
for (j, pde_loss_function1) in enumerate(pde_loss_functions1)]

bc_loglikelihoods += [logpdf(Normal(0, stdbcs[j]), bc_loss_function1(θ))
for (j, bc_loss_function1) in enumerate(bc_loss_functions1)]
end

# this is kind of a hack, and means that whenever the outer function is evaluated the increment goes up, even if it's not being optimized
# that's why we prefer the user to maintain the increment in the outer loop callback during optimization
Expand Down Expand Up @@ -634,8 +653,8 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
return additional_loss(phi, θ_, p_)
end

_additional_loglikelihood = logpdf(Normal(0, stdextra),
_additional_loss(phi, θ))
_additional_loglikelihood = logpdf(Normal(0, stdextra) _additional_loss(phi, θ))

weighted_additional_loglikelihood = adaloss.additional_loss_weights[1] *
_additional_loglikelihood

Expand All @@ -645,7 +664,7 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
end

pinnrep.loss_functions = PINNLossFunctions(bc_loss_functions, pde_loss_functions,
full_likelihood_function, additional_loss,
full_likelihood_function, additional_loss,
datafree_pde_loss_functions,
datafree_bc_loss_functions)
else
Expand Down
29 changes: 29 additions & 0 deletions src/training_strategies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,35 @@ struct GridTraining{T} <: AbstractTrainingStrategy
dx::T
end

# include dataset points in pde_residual loglikelihood
function merge_strategy_with_loglikelihood_function(pinnrep::PINNRepresentation,
strategy::GridTraining,
datafree_pde_loss_function,
datafree_bc_loss_function; train_sets_L2loss2 = nothing)
@unpack domains, eqs, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep
dx = strategy.dx
eltypeθ = eltype(pinnrep.flat_init_params)

train_sets = generate_training_sets(domains, dx, eqs, bcs, eltypeθ,
dict_indvars, dict_depvars)

bcs_train_sets = train_sets[2]
pde_train_sets = [train_set[:, 2:end] for train_set in train_sets_L2loss2]
# the points in the domain and on the boundary
pde_train_sets = adapt.(parameterless_type(ComponentArrays.getdata(flat_init_params)),
pde_train_sets)
bcs_train_sets = adapt.(parameterless_type(ComponentArrays.getdata(flat_init_params)),
bcs_train_sets)
pde_loss_functions = [get_loss_function(_loss, _set, eltypeθ, strategy)
for (_loss, _set) in zip(datafree_pde_loss_function,
pde_train_sets)]

bc_loss_functions = [get_loss_function(_loss, _set, eltypeθ, strategy)
for (_loss, _set) in zip(datafree_bc_loss_function, bcs_train_sets)]

pde_loss_functions, bc_loss_functions
end

function merge_strategy_with_loss_function(pinnrep::PINNRepresentation,
strategy::GridTraining,
datafree_pde_loss_function,
Expand Down
Loading

0 comments on commit 98e0a94

Please sign in to comment.