diff --git a/Project.toml b/Project.toml index ab2a6ae2f..2544e1a82 100644 --- a/Project.toml +++ b/Project.toml @@ -115,27 +115,19 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503" -ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" -InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" -LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" -LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" MethodOfLines = "94925ecb-adb7-4558-8ed8-f975c56a0bf4" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f" -ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" -StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" -TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] diff --git a/src/training_strategies.jl b/src/training_strategies.jl index f7493af8a..de65647e6 100644 --- a/src/training_strategies.jl +++ b/src/training_strategies.jl @@ -46,6 +46,38 @@ function get_dataset_train_points(eqs, train_sets, pinnrep) return points end +# dataset must have depvar values for same values of indvars +function get_dataset_train_points(eqs, train_sets, pinnrep) + dict_depvar_input = pinnrep.dict_depvar_input + depvars = pinnrep.depvars + dict_depvars = pinnrep.dict_depvars + dict_indvars = pinnrep.dict_indvars + + symbols_input = [(i, dict_depvar_input[i]) for i in depvars] + # [(:u, [:t])] + eq_args = NeuralPDE.get_argument(eqs, dict_indvars, dict_depvars) + # equation wise indvar presence ~ [[:t]] + # in each equation atleast one depvars must be a function of all indvars(to cover heterogenous/not case) + + # train_sets follows order of depvars + # take dataset indvar values if for equations depvar's indvar matches input symbol indvar + points = [] + for eq_arg in eq_args + eq_points = [] + for i in eachindex(symbols_input) + if symbols_input[i][2] == eq_arg + push!(eq_points, train_sets[i][:, 2:end]') + # Terminate to avoid repetitive ind var points inclusion + break + end + end + # Concatenate points for this equation argument + push!(points, vcat(eq_points...)) + end + + return points +end + # include dataset points in pde_residual loglikelihood (BayesianPINN) function merge_strategy_with_loglikelihood_function(pinnrep::PINNRepresentation, strategy::GridTraining, datafree_pde_loss_function, @@ -105,6 +137,18 @@ function get_points_loss_functions(loss_function, train_set, eltypeθ, strategy: end end +function get_points_loss_functions(loss_function, train_set, eltypeθ, strategy::GridTraining; + τ = nothing) + # loss_function length is number of all points loss is being evaluated upon + # train sets rows are for each indvar, cols are coordinates (row_1,row_2,..row_n) at which loss evaluated + function loss(θ, std) + logpdf( + MvNormal(loss_function(train_set, θ)[1, :], + LinearAlgebra.Diagonal(abs2.(std .* ones(size(train_set)[2])))), + zeros(size(train_set)[2])) + end +end + function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, strategy::GridTraining, datafree_pde_loss_function, datafree_bc_loss_function) (; domains, eqs, bcs, dict_indvars, dict_depvars) = pinnrep @@ -127,6 +171,8 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, bc_loss_functions = [get_loss_function(_loss, _set, eltypeθ, strategy) for (_loss, _set) in zip(datafree_bc_loss_function, bcs_train_sets)] + for (_loss, _set) in zip(datafree_bc_loss_function, + bcs_train_sets)] pde_loss_functions, bc_loss_functions end