Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…DE.jl into research
  • Loading branch information
AstitvaAggarwal committed Oct 18, 2024
2 parents be7c3d4 + f5eca91 commit 7b88e7f
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 8 deletions.
8 changes: 0 additions & 8 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -115,27 +115,19 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
MethodOfLines = "94925ecb-adb7-4558-8ed8-f975c56a0bf4"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
Expand Down
46 changes: 46 additions & 0 deletions src/training_strategies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,38 @@ function get_dataset_train_points(eqs, train_sets, pinnrep)
return points
end

# dataset must have depvar values for same values of indvars
function get_dataset_train_points(eqs, train_sets, pinnrep)
dict_depvar_input = pinnrep.dict_depvar_input
depvars = pinnrep.depvars
dict_depvars = pinnrep.dict_depvars
dict_indvars = pinnrep.dict_indvars

symbols_input = [(i, dict_depvar_input[i]) for i in depvars]
# [(:u, [:t])]
eq_args = NeuralPDE.get_argument(eqs, dict_indvars, dict_depvars)
# equation wise indvar presence ~ [[:t]]
# in each equation atleast one depvars must be a function of all indvars(to cover heterogenous/not case)

# train_sets follows order of depvars
# take dataset indvar values if for equations depvar's indvar matches input symbol indvar
points = []
for eq_arg in eq_args
eq_points = []
for i in eachindex(symbols_input)
if symbols_input[i][2] == eq_arg
push!(eq_points, train_sets[i][:, 2:end]')
# Terminate to avoid repetitive ind var points inclusion
break
end
end
# Concatenate points for this equation argument
push!(points, vcat(eq_points...))
end

return points
end

# include dataset points in pde_residual loglikelihood (BayesianPINN)
function merge_strategy_with_loglikelihood_function(pinnrep::PINNRepresentation,
strategy::GridTraining, datafree_pde_loss_function,
Expand Down Expand Up @@ -105,6 +137,18 @@ function get_points_loss_functions(loss_function, train_set, eltypeθ, strategy:
end
end

function get_points_loss_functions(loss_function, train_set, eltypeθ, strategy::GridTraining;
τ = nothing)
# loss_function length is number of all points loss is being evaluated upon
# train sets rows are for each indvar, cols are coordinates (row_1,row_2,..row_n) at which loss evaluated
function loss(θ, std)
logpdf(
MvNormal(loss_function(train_set, θ)[1, :],
LinearAlgebra.Diagonal(abs2.(std .* ones(size(train_set)[2])))),
zeros(size(train_set)[2]))
end
end

function merge_strategy_with_loss_function(pinnrep::PINNRepresentation,
strategy::GridTraining, datafree_pde_loss_function, datafree_bc_loss_function)
(; domains, eqs, bcs, dict_indvars, dict_depvars) = pinnrep
Expand All @@ -127,6 +171,8 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation,
bc_loss_functions = [get_loss_function(_loss, _set, eltypeθ, strategy)
for (_loss, _set) in zip(datafree_bc_loss_function,
bcs_train_sets)]
for (_loss, _set) in zip(datafree_bc_loss_function,
bcs_train_sets)]
pde_loss_functions, bc_loss_functions
end

Expand Down

0 comments on commit 7b88e7f

Please sign in to comment.