Skip to content

Commit

Permalink
managing conflicts 2
Browse files Browse the repository at this point in the history
  • Loading branch information
AstitvaAggarwal committed Oct 18, 2024
1 parent 50a36f7 commit be7c3d4
Show file tree
Hide file tree
Showing 3 changed files with 362 additions and 477 deletions.
185 changes: 119 additions & 66 deletions src/PDE_BPINN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,91 @@
dataset <: Union{Nothing, Vector{<:Matrix{<:Real}}}
priors <: Vector{<:Distribution}
allstd::Vector{Vector{Float64}}
phynewstd::Vector{Float64}
names::Tuple
extraparams::Int
init_params <: Union{AbstractVector, NamedTuple, ComponentArray}
full_loglikelihood
Φ
full_loglikelihood::Any
L2_loss2::Any
Φ::Any
end

function LogDensityProblems.logdensity(ltd::PDELogTargetDensity, θ)
# for parameter estimation neccesarry to use multioutput case
return ltd.full_loglikelihood(setparameters(ltd, θ), ltd.allstd) + priorlogpdf(ltd, θ) +
L2LossData(ltd, θ)
if Tar.L2_loss2 === nothing
return Tar.full_loglikelihood(setparameters(Tar, θ), Tar.allstd) +
priorlogpdf(Tar, θ) + L2LossData(Tar, θ)
else
return Tar.full_loglikelihood(setparameters(Tar, θ), Tar.allstd) +
priorlogpdf(Tar, θ) + L2LossData(Tar, θ) +
Tar.L2_loss2(setparameters(Tar, θ), Tar.phynewstd)
end
end

# you get a vector of losses
function get_lossy(pinnrep, dataset, Dict_differentials)
eqs = pinnrep.eqs
depvars = pinnrep.depvars #depvar order is same as dataset

# Dict_differentials is filled with Differential operator => diff_i key-value pairs
# masking operation
eqs_new = substitute.(eqs, Ref(Dict_differentials))

to_subs, tobe_subs = get_symbols(dataset, depvars, eqs)

# for values of all depvars at corresponding indvar values in dataset, create dictionaries {Dict(x(t) => 1.0496435863173237, y(t) => 1.9227770685615337)}
# In each Dict, num form of depvar is key to its value at certain coords of indvars, n_dicts = n_rows_dataset(or n_indvar_coords_dataset)
eq_subs = [Dict(tobe_subs[depvar] => to_subs[depvar][i] for depvar in depvars)
for i in 1:size(dataset[1][:, 1])[1]]

# for each dataset point(eq_sub dictionary), substitute in masked equations
# n_collocated_equations = n_rows_dataset(or n_indvar_coords_dataset)
masked_colloc_equations = [[substitute(eq, eq_sub) for eq in eqs_new]
for eq_sub in eq_subs]
# now we have vector of dataset depvar's collocated equations

# reverse dict for re-substituting values of Differential(t)(u(t)) etc
rev_Dict_differentials = Dict(value => key for (key, value) in Dict_differentials)

# unmask Differential terms in masked_colloc_equations
colloc_equations = [substitute.(masked_colloc_equation, Ref(rev_Dict_differentials))
for masked_colloc_equation in masked_colloc_equations]

# nested vector of datafree_pde_loss_functions (as in discretize.jl)
# each sub vector has dataset's indvar coord's datafree_colloc_loss_function, n_subvectors = n_rows_dataset(or n_indvar_coords_dataset)
# zip each colloc equation with args for each build_loss call per equation vector
datafree_colloc_loss_functions = [[build_loss_function(pinnrep, eq, pde_indvar)
for (eq, pde_indvar, integration_indvar) in zip(
colloc_equation,
pinnrep.pde_indvars,
pinnrep.pde_integration_vars)]
for colloc_equation in colloc_equations]

return datafree_colloc_loss_functions
end

function get_symbols(dataset, depvars, eqs)
# take only values of depvars from dataset
depvar_vals = [dataset_i[:, 1] for dataset_i in dataset]
# order of pinnrep.depvars, depvar_vals, BayesianPINN.dataset must be same
to_subs = Dict(depvars .=> depvar_vals)

numform_vars = Symbolics.get_variables.(eqs)
Eq_vars = unique(reduce(vcat, numform_vars))
# got equation's depvar num format {x(t)} for use in substitute()

tobe_subs = Dict()
for a in depvars
for i in Eq_vars
expr = toexpr(i)
if (expr isa Expr) && (expr.args[1] == a)
tobe_subs[a] = i
end
end
end
# depvar symbolic and num format got, tobe_subs : Dict{Any, Any}(:y => y(t), :x => x(t))

return to_subs, tobe_subs
end

@views function setparameters(ltd::PDELogTargetDensity, θ)
Expand Down Expand Up @@ -55,8 +129,6 @@ function L2LossData(ltd::PDELogTargetDensity, θ)

# dataset of form Vector[matrix_x, matrix_y, matrix_z]
# matrix_i is of form [i,indvar1,indvar2,..] (needed in case if heterogenous domains)
# note that indvar1,indvar2.. cols can be different values for different depvar matrices
# dataset,phi order follows pinnrep.depvars orders of variables (order of declaration in @variables macro)

# Phi is the trial solution for each NN in chain array
# Creating logpdf( MvNormal(Phi(t,θ),std), dataset[i] )
Expand Down Expand Up @@ -90,6 +162,8 @@ function priorlogpdf(ltd::PDELogTargetDensity, θ)
invlogpdf = sum((length(θ) - ltd.extraparams + 1):length(θ)) do i
logpdf(invpriors[length(θ) - i + 1], θ[i])
end

return invlogpdf + logpdf(nnwparams, θ[1:(length(θ) - ltd.extraparams)])
end

function integratorchoice(Integratorkwargs, initial_ϵ)
Expand Down Expand Up @@ -177,27 +251,6 @@ function inference(samples, pinnrep, saveats, numensemble, ℓπ)
return ensemblecurves, estimatedLuxparams, estimated_params, timepoints
end

function integratorchoice(Integratorkwargs, initial_ϵ)
Integrator = Integratorkwargs[:Integrator]
if Integrator == JitteredLeapfrog
jitter_rate = Integratorkwargs[:jitter_rate]
Integrator(initial_ϵ, jitter_rate)
elseif Integrator == TemperedLeapfrog
tempering_rate = Integratorkwargs[:tempering_rate]
Integrator(initial_ϵ, tempering_rate)
else
Integrator(initial_ϵ)
end
end

function adaptorchoice(Adaptor, mma, ssa)
if Adaptor != AdvancedHMC.NoAdaptation()
Adaptor(mma, ssa)
else
AdvancedHMC.NoAdaptation()
end
end

"""
ahmc_bayesian_pinn_pde(pde_system, discretization;
draw_samples = 1000, bcstd = [0.01], l2std = [0.05], phystd = [0.05],
Expand Down Expand Up @@ -255,15 +308,12 @@ end
releases.
"""
function ahmc_bayesian_pinn_pde(pde_system, discretization;
draw_samples = 1000,
bcstd = [0.01], l2std = [0.05],
phystd = [0.05], phystdnew = [0.05], priorsNNw = (0.0, 2.0),
param = [], nchains = 1, Kernel = HMC(0.1, 30),
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
draw_samples = 1000, bcstd = [0.01], l2std = [0.05], phystd = [0.05],
phynewstd = [0.05], priorsNNw = (0.0, 2.0), param = [], nchains = 1,
Kernel = HMC(0.1, 30), Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,), saveats = [1 / 10.0],
numensemble = floor(Int, draw_samples / 3), Dict_differentials = nothing,
progress = false, verbose = false)
numensemble = floor(Int, draw_samples / 3), progress = false, verbose = false)
pinnrep = symbolic_discretize(pde_system, discretization)
dataset_pde, dataset_bc = discretization.dataset

Expand All @@ -275,31 +325,31 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
# add case for if parameters present in bcs?

train_sets_pde = get_dataset_train_points(pde_system.eqs,
dataset_pde,
pinnrep)
colloc_train_sets = [[hcat(train_sets_pde[i][:, j]...)' for i in eachindex(datafree_colloc_loss_functions[1])] for j in eachindex(datafree_colloc_loss_functions)]
dataset_pde,
pinnrep)
colloc_train_sets = [[hcat(train_sets_pde[i][:, j]...)'
for i in eachindex(datafree_colloc_loss_functions[1])]
for j in eachindex(datafree_colloc_loss_functions)]

# for each datafree_colloc_loss_function create loss_functions by passing dataset's indvar coords as train_sets_pde.
# placeholder strategy = GridTraining(0.1), datafree_bc_loss_function and train_sets_bc must be nothing
# order of indvar coords will be same as corresponding depvar coords values in dataset provided in get_lossy() call.
pde_loss_function_points = [merge_strategy_with_loglikelihood_function(
pinnrep,
GridTraining(0.1),
datafree_colloc_loss_functions[i],
nothing;
train_sets_pde = colloc_train_sets[i],
train_sets_bc = nothing)[1]
for i in eachindex(datafree_colloc_loss_functions)]

function L2_loss2(θ, allstd)
stdpdesnew = allstd[4]

pinnrep,
GridTraining(0.1),
datafree_colloc_loss_functions[i],
nothing;
train_sets_pde = colloc_train_sets[i],
train_sets_bc = nothing)[1]
for i in eachindex(datafree_colloc_loss_functions)]

function L2_loss2(θ, phynewstd)
# first vector of losses,from tuple -> pde losses, first[1] pde loss
pde_loglikelihoods = [sum([pde_loss_function(θ, stdpdesnew[i])
pde_loglikelihoods = [sum([pde_loss_function(θ, phynewstd[i])
for (i, pde_loss_function) in enumerate(pde_loss_functions)])
for pde_loss_functions in pde_loss_function_points]

# bc_loglikelihoods = [sum([bc_loss_function(θ, stdpdesnew[i]) for (i, bc_loss_function) in enumerate(pde_loss_function_points[1])]) for pde_loss_function_points in pde_loss_functions]
# bc_loglikelihoods = [sum([bc_loss_function(θ, phynewstd[i]) for (i, bc_loss_function) in enumerate(pde_loss_function_points[1])]) for pde_loss_function_points in pde_loss_functions]
# for (j, bc_loss_function) in enumerate(bc_loss_functions)]

return sum(pde_loglikelihoods)
Expand Down Expand Up @@ -368,18 +418,10 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
# vector in case of N-dimensional domains
strategy = discretization.strategy

# dimensions would be total no of params,initial_nnθ for Lux namedTuples
ℓπ = PDELogTargetDensity(nparameters,
strategy,
dataset,
priors,
[phystd, bcstd, l2std, phystdnew],
names,
ninv,
initial_nnθ,
full_weighted_loglikelihood,
newloss,
Φ)
# dimensions would be total no of params,initial_nnθ for Lux namedTuples
ℓπ = PDELogTargetDensity(
nparameters, strategy, dataset, priors, [phystd, bcstd, l2std], phynewstd,
names, ninv, initial_nnθ, full_weighted_loglikelihood, newloss, Φ)

Adaptor, Metric, targetacceptancerate = Adaptorkwargs[:Adaptor],
Adaptorkwargs[:Metric], Adaptorkwargs[:targetacceptancerate]
Expand All @@ -394,10 +436,16 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
@printf("Current Prior Log-likelihood : %g\n", priorlogpdf(ℓπ, initial_θ))
@printf("Current MSE against dataset Log-likelihood : %g\n",
L2LossData(ℓπ, initial_θ))
if !(newloss isa Nothing)
@printf("Current new loss : %g\n",
ℓπ.L2_loss2(setparameters(ℓπ, initial_θ),
ℓπ.phynewstd))
end
end

# parallel sampling option
if nchains != 1

# Cache to store the chains
bpinnsols = Vector{Any}(undef, nchains)

Expand Down Expand Up @@ -441,11 +489,16 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;

if verbose
@printf("Sampling Complete.\n")
@printf("Current Physics Log-likelihood : %g\n",
@printf("Final Physics Log-likelihood : %g\n",
ℓπ.full_loglikelihood(setparameters(ℓπ, samples[end]), ℓπ.allstd))
@printf("Current Prior Log-likelihood : %g\n", priorlogpdf(ℓπ, samples[end]))
@printf("Current MSE against dataset Log-likelihood : %g\n",
@printf("Final Prior Log-likelihood : %g\n", priorlogpdf(ℓπ, samples[end]))
@printf("Final MSE against dataset Log-likelihood : %g\n",
L2LossData(ℓπ, samples[end]))
if !(newloss isa Nothing)
@printf("Final L2_LOSSY : %g\n",
ℓπ.L2_loss2(setparameters(ℓπ, samples[end]),
ℓπ.phynewstd))
end
end

fullsolution = BPINNstats(mcmc_chain, samples, stats)
Expand All @@ -455,4 +508,4 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
return BPINNsolution(
fullsolution, ensemblecurves, estimnnparams, estimated_params, timepoints)
end
end
end
Loading

0 comments on commit be7c3d4

Please sign in to comment.