From be7c3d4e088769ad10da0622c1945ec6fca70676 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Fri, 18 Oct 2024 21:55:11 +0530 Subject: [PATCH] managing conflicts 2 --- src/PDE_BPINN.jl | 185 +++++++----- src/advancedHMC_MCMC.jl | 603 +++++++++++++++------------------------- src/discretize.jl | 51 +--- 3 files changed, 362 insertions(+), 477 deletions(-) diff --git a/src/PDE_BPINN.jl b/src/PDE_BPINN.jl index 9c283f1e00..4f7e51b3a0 100644 --- a/src/PDE_BPINN.jl +++ b/src/PDE_BPINN.jl @@ -4,17 +4,91 @@ dataset <: Union{Nothing, Vector{<:Matrix{<:Real}}} priors <: Vector{<:Distribution} allstd::Vector{Vector{Float64}} + phynewstd::Vector{Float64} names::Tuple extraparams::Int init_params <: Union{AbstractVector, NamedTuple, ComponentArray} - full_loglikelihood - Φ + full_loglikelihood::Any + L2_loss2::Any + Φ::Any end function LogDensityProblems.logdensity(ltd::PDELogTargetDensity, θ) # for parameter estimation neccesarry to use multioutput case - return ltd.full_loglikelihood(setparameters(ltd, θ), ltd.allstd) + priorlogpdf(ltd, θ) + - L2LossData(ltd, θ) + if Tar.L2_loss2 === nothing + return Tar.full_loglikelihood(setparameters(Tar, θ), Tar.allstd) + + priorlogpdf(Tar, θ) + L2LossData(Tar, θ) + else + return Tar.full_loglikelihood(setparameters(Tar, θ), Tar.allstd) + + priorlogpdf(Tar, θ) + L2LossData(Tar, θ) + + Tar.L2_loss2(setparameters(Tar, θ), Tar.phynewstd) + end +end + +# you get a vector of losses +function get_lossy(pinnrep, dataset, Dict_differentials) + eqs = pinnrep.eqs + depvars = pinnrep.depvars #depvar order is same as dataset + + # Dict_differentials is filled with Differential operator => diff_i key-value pairs + # masking operation + eqs_new = substitute.(eqs, Ref(Dict_differentials)) + + to_subs, tobe_subs = get_symbols(dataset, depvars, eqs) + + # for values of all depvars at corresponding indvar values in dataset, create dictionaries {Dict(x(t) => 1.0496435863173237, y(t) => 1.9227770685615337)} + # In each Dict, num form of depvar is key to its value at certain coords of indvars, n_dicts = n_rows_dataset(or n_indvar_coords_dataset) + eq_subs = [Dict(tobe_subs[depvar] => to_subs[depvar][i] for depvar in depvars) + for i in 1:size(dataset[1][:, 1])[1]] + + # for each dataset point(eq_sub dictionary), substitute in masked equations + # n_collocated_equations = n_rows_dataset(or n_indvar_coords_dataset) + masked_colloc_equations = [[substitute(eq, eq_sub) for eq in eqs_new] + for eq_sub in eq_subs] + # now we have vector of dataset depvar's collocated equations + + # reverse dict for re-substituting values of Differential(t)(u(t)) etc + rev_Dict_differentials = Dict(value => key for (key, value) in Dict_differentials) + + # unmask Differential terms in masked_colloc_equations + colloc_equations = [substitute.(masked_colloc_equation, Ref(rev_Dict_differentials)) + for masked_colloc_equation in masked_colloc_equations] + + # nested vector of datafree_pde_loss_functions (as in discretize.jl) + # each sub vector has dataset's indvar coord's datafree_colloc_loss_function, n_subvectors = n_rows_dataset(or n_indvar_coords_dataset) + # zip each colloc equation with args for each build_loss call per equation vector + datafree_colloc_loss_functions = [[build_loss_function(pinnrep, eq, pde_indvar) + for (eq, pde_indvar, integration_indvar) in zip( + colloc_equation, + pinnrep.pde_indvars, + pinnrep.pde_integration_vars)] + for colloc_equation in colloc_equations] + + return datafree_colloc_loss_functions +end + +function get_symbols(dataset, depvars, eqs) + # take only values of depvars from dataset + depvar_vals = [dataset_i[:, 1] for dataset_i in dataset] + # order of pinnrep.depvars, depvar_vals, BayesianPINN.dataset must be same + to_subs = Dict(depvars .=> depvar_vals) + + numform_vars = Symbolics.get_variables.(eqs) + Eq_vars = unique(reduce(vcat, numform_vars)) + # got equation's depvar num format {x(t)} for use in substitute() + + tobe_subs = Dict() + for a in depvars + for i in Eq_vars + expr = toexpr(i) + if (expr isa Expr) && (expr.args[1] == a) + tobe_subs[a] = i + end + end + end + # depvar symbolic and num format got, tobe_subs : Dict{Any, Any}(:y => y(t), :x => x(t)) + + return to_subs, tobe_subs end @views function setparameters(ltd::PDELogTargetDensity, θ) @@ -55,8 +129,6 @@ function L2LossData(ltd::PDELogTargetDensity, θ) # dataset of form Vector[matrix_x, matrix_y, matrix_z] # matrix_i is of form [i,indvar1,indvar2,..] (needed in case if heterogenous domains) - # note that indvar1,indvar2.. cols can be different values for different depvar matrices - # dataset,phi order follows pinnrep.depvars orders of variables (order of declaration in @variables macro) # Phi is the trial solution for each NN in chain array # Creating logpdf( MvNormal(Phi(t,θ),std), dataset[i] ) @@ -90,6 +162,8 @@ function priorlogpdf(ltd::PDELogTargetDensity, θ) invlogpdf = sum((length(θ) - ltd.extraparams + 1):length(θ)) do i logpdf(invpriors[length(θ) - i + 1], θ[i]) end + + return invlogpdf + logpdf(nnwparams, θ[1:(length(θ) - ltd.extraparams)]) end function integratorchoice(Integratorkwargs, initial_ϵ) @@ -177,27 +251,6 @@ function inference(samples, pinnrep, saveats, numensemble, ℓπ) return ensemblecurves, estimatedLuxparams, estimated_params, timepoints end -function integratorchoice(Integratorkwargs, initial_ϵ) - Integrator = Integratorkwargs[:Integrator] - if Integrator == JitteredLeapfrog - jitter_rate = Integratorkwargs[:jitter_rate] - Integrator(initial_ϵ, jitter_rate) - elseif Integrator == TemperedLeapfrog - tempering_rate = Integratorkwargs[:tempering_rate] - Integrator(initial_ϵ, tempering_rate) - else - Integrator(initial_ϵ) - end -end - -function adaptorchoice(Adaptor, mma, ssa) - if Adaptor != AdvancedHMC.NoAdaptation() - Adaptor(mma, ssa) - else - AdvancedHMC.NoAdaptation() - end -end - """ ahmc_bayesian_pinn_pde(pde_system, discretization; draw_samples = 1000, bcstd = [0.01], l2std = [0.05], phystd = [0.05], @@ -255,15 +308,12 @@ end releases. """ function ahmc_bayesian_pinn_pde(pde_system, discretization; - draw_samples = 1000, - bcstd = [0.01], l2std = [0.05], - phystd = [0.05], phystdnew = [0.05], priorsNNw = (0.0, 2.0), - param = [], nchains = 1, Kernel = HMC(0.1, 30), - Adaptorkwargs = (Adaptor = StanHMCAdaptor, + draw_samples = 1000, bcstd = [0.01], l2std = [0.05], phystd = [0.05], + phynewstd = [0.05], priorsNNw = (0.0, 2.0), param = [], nchains = 1, + Kernel = HMC(0.1, 30), Adaptorkwargs = (Adaptor = StanHMCAdaptor, Metric = DiagEuclideanMetric, targetacceptancerate = 0.8), Integratorkwargs = (Integrator = Leapfrog,), saveats = [1 / 10.0], - numensemble = floor(Int, draw_samples / 3), Dict_differentials = nothing, - progress = false, verbose = false) + numensemble = floor(Int, draw_samples / 3), progress = false, verbose = false) pinnrep = symbolic_discretize(pde_system, discretization) dataset_pde, dataset_bc = discretization.dataset @@ -275,31 +325,31 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization; # add case for if parameters present in bcs? train_sets_pde = get_dataset_train_points(pde_system.eqs, - dataset_pde, - pinnrep) - colloc_train_sets = [[hcat(train_sets_pde[i][:, j]...)' for i in eachindex(datafree_colloc_loss_functions[1])] for j in eachindex(datafree_colloc_loss_functions)] + dataset_pde, + pinnrep) + colloc_train_sets = [[hcat(train_sets_pde[i][:, j]...)' + for i in eachindex(datafree_colloc_loss_functions[1])] + for j in eachindex(datafree_colloc_loss_functions)] # for each datafree_colloc_loss_function create loss_functions by passing dataset's indvar coords as train_sets_pde. # placeholder strategy = GridTraining(0.1), datafree_bc_loss_function and train_sets_bc must be nothing # order of indvar coords will be same as corresponding depvar coords values in dataset provided in get_lossy() call. pde_loss_function_points = [merge_strategy_with_loglikelihood_function( - pinnrep, - GridTraining(0.1), - datafree_colloc_loss_functions[i], - nothing; - train_sets_pde = colloc_train_sets[i], - train_sets_bc = nothing)[1] - for i in eachindex(datafree_colloc_loss_functions)] - - function L2_loss2(θ, allstd) - stdpdesnew = allstd[4] - + pinnrep, + GridTraining(0.1), + datafree_colloc_loss_functions[i], + nothing; + train_sets_pde = colloc_train_sets[i], + train_sets_bc = nothing)[1] + for i in eachindex(datafree_colloc_loss_functions)] + + function L2_loss2(θ, phynewstd) # first vector of losses,from tuple -> pde losses, first[1] pde loss - pde_loglikelihoods = [sum([pde_loss_function(θ, stdpdesnew[i]) + pde_loglikelihoods = [sum([pde_loss_function(θ, phynewstd[i]) for (i, pde_loss_function) in enumerate(pde_loss_functions)]) for pde_loss_functions in pde_loss_function_points] - # bc_loglikelihoods = [sum([bc_loss_function(θ, stdpdesnew[i]) for (i, bc_loss_function) in enumerate(pde_loss_function_points[1])]) for pde_loss_function_points in pde_loss_functions] + # bc_loglikelihoods = [sum([bc_loss_function(θ, phynewstd[i]) for (i, bc_loss_function) in enumerate(pde_loss_function_points[1])]) for pde_loss_function_points in pde_loss_functions] # for (j, bc_loss_function) in enumerate(bc_loss_functions)] return sum(pde_loglikelihoods) @@ -368,18 +418,10 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization; # vector in case of N-dimensional domains strategy = discretization.strategy - # dimensions would be total no of params,initial_nnθ for Lux namedTuples - ℓπ = PDELogTargetDensity(nparameters, - strategy, - dataset, - priors, - [phystd, bcstd, l2std, phystdnew], - names, - ninv, - initial_nnθ, - full_weighted_loglikelihood, - newloss, - Φ) + # dimensions would be total no of params,initial_nnθ for Lux namedTuples + ℓπ = PDELogTargetDensity( + nparameters, strategy, dataset, priors, [phystd, bcstd, l2std], phynewstd, + names, ninv, initial_nnθ, full_weighted_loglikelihood, newloss, Φ) Adaptor, Metric, targetacceptancerate = Adaptorkwargs[:Adaptor], Adaptorkwargs[:Metric], Adaptorkwargs[:targetacceptancerate] @@ -394,10 +436,16 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization; @printf("Current Prior Log-likelihood : %g\n", priorlogpdf(ℓπ, initial_θ)) @printf("Current MSE against dataset Log-likelihood : %g\n", L2LossData(ℓπ, initial_θ)) + if !(newloss isa Nothing) + @printf("Current new loss : %g\n", + ℓπ.L2_loss2(setparameters(ℓπ, initial_θ), + ℓπ.phynewstd)) + end end # parallel sampling option if nchains != 1 + # Cache to store the chains bpinnsols = Vector{Any}(undef, nchains) @@ -441,11 +489,16 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization; if verbose @printf("Sampling Complete.\n") - @printf("Current Physics Log-likelihood : %g\n", + @printf("Final Physics Log-likelihood : %g\n", ℓπ.full_loglikelihood(setparameters(ℓπ, samples[end]), ℓπ.allstd)) - @printf("Current Prior Log-likelihood : %g\n", priorlogpdf(ℓπ, samples[end])) - @printf("Current MSE against dataset Log-likelihood : %g\n", + @printf("Final Prior Log-likelihood : %g\n", priorlogpdf(ℓπ, samples[end])) + @printf("Final MSE against dataset Log-likelihood : %g\n", L2LossData(ℓπ, samples[end])) + if !(newloss isa Nothing) + @printf("Final L2_LOSSY : %g\n", + ℓπ.L2_loss2(setparameters(ℓπ, samples[end]), + ℓπ.phynewstd)) + end end fullsolution = BPINNstats(mcmc_chain, samples, stats) @@ -455,4 +508,4 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization; return BPINNsolution( fullsolution, ensemblecurves, estimnnparams, estimated_params, timepoints) end -end +end \ No newline at end of file diff --git a/src/advancedHMC_MCMC.jl b/src/advancedHMC_MCMC.jl index 8b996fce5c..5ac4213c92 100644 --- a/src/advancedHMC_MCMC.jl +++ b/src/advancedHMC_MCMC.jl @@ -1,72 +1,42 @@ -mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I, - P <: Vector{<:Distribution}, - D <: - Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}} -} +@concrete struct LogTargetDensity dim::Int - prob::SciMLBase.ODEProblem - chain::C - st::S - strategy::ST - dataset::D - priors::P + prob <: SciMLBase.ODEProblem + smodel <: StatefulLuxLayer + strategy <: AbstractTrainingStrategy + dataset <: Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}} + priors <: Vector{<:Distribution} phystd::Vector{Float64} phynewstd::Vector{Float64} l2std::Vector{Float64} autodiff::Bool physdt::Float64 extraparams::Int - init_params::I + init_params <: Union{NamedTuple, ComponentArray} estim_collocate::Bool +end - function LogTargetDensity(dim, prob, chain::Optimisers.Restructure, st, strategy, - dataset, - priors, phystd, phynewstd, l2std, autodiff, physdt, extraparams, - init_params::AbstractVector, estim_collocate) - new{ - typeof(chain), - Nothing, - typeof(strategy), - typeof(init_params), - typeof(priors), - typeof(dataset) - }(dim, - prob, - chain, - nothing, strategy, - dataset, - priors, - phystd, - phynewstd, - l2std, - autodiff, - physdt, - extraparams, - init_params, - estim_collocate) - end - function LogTargetDensity(dim, prob, chain::Lux.AbstractExplicitLayer, st, strategy, - dataset, - priors, phystd, phynewstd, l2std, autodiff, physdt, extraparams, - init_params::NamedTuple, estim_collocate) - new{ - typeof(chain), - typeof(st), - typeof(strategy), - typeof(init_params), - typeof(priors), - typeof(dataset) - }(dim, - prob, - chain, st, strategy, - dataset, priors, - phystd, phynewstd, - l2std, - autodiff, - physdt, - extraparams, - init_params, - estim_collocate) +""" +NN OUTPUT AT t,θ ~ phi(t,θ). +""" +function (f::LogTargetDensity)(t::AbstractVector, θ) + θ = vector_to_parameters(θ, f.init_params) + dev = safe_get_device(θ) + t = safe_expand(dev, t) + u0 = f.prob.u0 |> dev + return u0 .+ (t' .- f.prob.tspan[1]) .* f.smodel(t', θ) +end + +(f::LogTargetDensity)(t::Number, θ) = f([t], θ)[:, 1] + +""" +Similar to ode_dfdx() in NNODE. +""" +function ode_dfdx(phi::LogTargetDensity, t::AbstractVector, θ, autodiff::Bool) + if autodiff + return ForwardDiff.jacobian(Base.Fix2(phi, θ), t) + else + ϵ = sqrt(eps(eltype(t))) + return (phi(t .+ ϵ, θ) .- phi(t, θ)) ./ ϵ end end @@ -74,344 +44,239 @@ end Function needed for converting vector of sampled parameters into ComponentVector in case of Lux chain output, derivatives the sampled parameters are of exotic type `Dual` due to ForwardDiff's autodiff tagging. """ -function vector_to_parameters(ps_new::AbstractVector, - ps::Union{NamedTuple, ComponentArrays.ComponentVector}) - @assert length(ps_new) == Lux.parameterlength(ps) +function vector_to_parameters(ps_new::AbstractVector, ps::Union{NamedTuple, ComponentArray}) + @assert length(ps_new) == LuxCore.parameterlength(ps) i = 1 function get_ps(x) z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x)) i += length(x) return z end - return Functors.fmap(get_ps, ps) + return fmap(get_ps, ps) end -vector_to_parameters(ps_new::AbstractVector, ps::AbstractVector) = ps_new +vector_to_parameters(ps_new::AbstractVector, _::AbstractVector) = ps_new -function LogDensityProblems.logdensity(Tar::LogTargetDensity, θ) - if Tar.estim_collocate - return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ) + - L2loss2(Tar, θ) - else - return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ) - end +function LogDensityProblems.logdensity(ltd::LogTargetDensity, θ) + ldensity = physloglikelihood(ltd, θ) + priorweights(ltd, θ) + L2LossData(ltd, θ) + ltd.estim_collocate && return ldensity + L2loss2(ltd, θ) + return ldensity end -LogDensityProblems.dimension(Tar::LogTargetDensity) = Tar.dim +LogDensityProblems.dimension(ltd::LogTargetDensity) = ltd.dim function LogDensityProblems.capabilities(::LogTargetDensity) - LogDensityProblems.LogDensityOrder{1}() + return LogDensityProblems.LogDensityOrder{1}() end """ suggested extra loss function for ODE solver case """ -function L2loss2(Tar::LogTargetDensity, θ) - f = Tar.prob.f +@views function L2loss2(ltd::LogTargetDensity, θ) + ltd.extraparams ≤ 0 && return false # XXX: type-stability? - # parameter estimation chosen or not - if Tar.extraparams > 0 - autodiff = Tar.autodiff - # Timepoints to enforce Physics - t = Tar.dataset[end] - u1 = Tar.dataset[2] - û = Tar.dataset[1] - - nnsol = NNodederi(Tar, t, θ[1:(length(θ) - Tar.extraparams)], autodiff) - - ode_params = Tar.extraparams == 1 ? - θ[((length(θ) - Tar.extraparams) + 1):length(θ)][1] : - θ[((length(θ) - Tar.extraparams) + 1):length(θ)] - - if length(Tar.prob.u0) == 1 - physsol = [f(û[i], - ode_params, - t[i]) - for i in 1:length(û[:, 1])] - else - physsol = [f([û[i], u1[i]], - ode_params, - t[i]) - for i in 1:length(û)] - end - #form of NN output matrix output dim x n - deri_physsol = reduce(hcat, physsol) - - physlogprob = 0 - for i in 1:length(Tar.prob.u0) - # can add phystdnew[i] for u[i] - physlogprob += logpdf(MvNormal(deri_physsol[i, :], - LinearAlgebra.Diagonal(map(abs2, - (Tar.phynewstd[i]) .* - ones(length(nnsol[i, :]))))), - nnsol[i, :]) - end - return physlogprob + f = ltd.prob.f + t = ltd.dataset[end] + u1 = ltd.dataset[2] + û = ltd.dataset[1] + + nnsol = ode_dfdx(ltd, t, θ[1:(length(θ) - ltd.extraparams)], ltd.autodiff) + + ode_params = ltd.extraparams == 1 ? θ[((length(θ) - ltd.extraparams) + 1)] : + θ[((length(θ) - ltd.extraparams) + 1):length(θ)] + + physsol = if length(ltd.prob.u0) == 1 + [f(û[i], ode_params, tᵢ) for (i, tᵢ) in enumerate(t)] else - return 0 + [f([û[i], u1[i]], ode_params, tᵢ) for (i, tᵢ) in enumerate(t)] + end + # form of NN output matrix output dim x n + deri_physsol = reduce(hcat, physsol) + T = promote_type(eltype(deri_physsol), eltype(nnsol)) + + physlogprob = T(0) + for i in 1:length(ltd.prob.u0) + physlogprob += logpdf( + MvNormal(deri_physsol[i, :], + Diagonal(abs2.(T(ltd.phynewstd[i]) .* ones(T, length(nnsol[i, :]))))), + nnsol[i, :] + ) end + return physlogprob end """ L2 loss loglikelihood(needed for ODE parameter estimation). """ -function L2LossData(Tar::LogTargetDensity, θ) - # check if dataset is provided - if Tar.dataset isa Vector{Nothing} || Tar.extraparams == 0 - return 0 - else - # matrix(each row corresponds to vector u's rows) - nn = Tar(Tar.dataset[end], θ[1:(length(θ) - Tar.extraparams)]) - - L2logprob = 0 - for i in 1:length(Tar.prob.u0) - # for u[i] ith vector must be added to dataset, nn[1,:] is the dx in lotka_volterra - L2logprob += logpdf( - MvNormal(nn[i, :], - LinearAlgebra.Diagonal(abs2.(Tar.l2std[i] .* - ones(length(Tar.dataset[i]))))), - Tar.dataset[i]) - end - return L2logprob +@views function L2LossData(ltd::LogTargetDensity, θ) + (ltd.dataset isa Vector{Nothing} || ltd.extraparams == 0) && return 0 + + # matrix(each row corresponds to vector u's rows) + nn = ltd(ltd.dataset[end], θ[1:(length(θ) - ltd.extraparams)]) + T = eltype(nn) + + L2logprob = zero(T) + for i in 1:length(ltd.prob.u0) + # for u[i] ith vector must be added to dataset,nn[1, :] is the dx in lotka_volterra + L2logprob += logpdf( + MvNormal( + nn[i, :], + Diagonal(abs2.(T(ltd.l2std[i]) .* ones(T, length(ltd.dataset[i])))) + ), + ltd.dataset[i] + ) end + return L2logprob end """ Physics loglikelihood over problem timespan + dataset timepoints. """ -function physloglikelihood(Tar::LogTargetDensity, θ) - f = Tar.prob.f - p = Tar.prob.p - tspan = Tar.prob.tspan - autodiff = Tar.autodiff - strategy = Tar.strategy +function physloglikelihood(ltd::LogTargetDensity, θ) + (; f, p, tspan) = ltd.prob + (; autodiff, strategy) = ltd # parameter estimation chosen or not - if Tar.extraparams > 0 - ode_params = Tar.extraparams == 1 ? - θ[((length(θ) - Tar.extraparams) + 1):length(θ)][1] : - θ[((length(θ) - Tar.extraparams) + 1):length(θ)] + if ltd.extraparams > 0 + ode_params = ltd.extraparams == 1 ? θ[((length(θ) - ltd.extraparams) + 1)] : + θ[((length(θ) - ltd.extraparams) + 1):length(θ)] else - ode_params = p == SciMLBase.NullParameters() ? [] : p + ode_params = p isa SciMLBase.NullParameters ? Float64[] : p end - return getlogpdf(strategy, Tar, f, autodiff, tspan, ode_params, θ) + return getlogpdf(strategy, ltd, f, autodiff, tspan, ode_params, θ) end -function getlogpdf(strategy::GridTraining, Tar::LogTargetDensity, f, autodiff::Bool, - tspan, - ode_params, θ) - if Tar.dataset isa Vector{Nothing} - t = collect(eltype(strategy.dx), tspan[1]:(strategy.dx):tspan[2]) - else - t = vcat(collect(eltype(strategy.dx), tspan[1]:(strategy.dx):tspan[2]), - Tar.dataset[end]) - end - - sum(innerdiff(Tar, f, autodiff, t, θ, - ode_params)) +function getlogpdf(strategy::GridTraining, ltd::LogTargetDensity, f, autodiff::Bool, + tspan, ode_params, θ) + ts = collect(eltype(strategy.dx), tspan[1]:(strategy.dx):tspan[2]) + t = ltd.dataset isa Vector{Nothing} ? ts : vcat(ts, ltd.dataset[end]) + return sum(innerdiff(ltd, f, autodiff, t, θ, ode_params)) end -function getlogpdf(strategy::StochasticTraining, - Tar::LogTargetDensity, - f, - autodiff::Bool, - tspan, - ode_params, - θ) - if Tar.dataset isa Vector{Nothing} - t = [(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)] - else - t = vcat([(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)], - Tar.dataset[end]) - end - - sum(innerdiff(Tar, f, autodiff, t, θ, - ode_params)) +function getlogpdf(strategy::StochasticTraining, ltd::LogTargetDensity, + f, autodiff::Bool, tspan, ode_params, θ) + T = promote_type(eltype(tspan[1]), eltype(tspan[2])) + samples = (tspan[2] - tspan[1]) .* rand(T, strategy.points) .+ tspan[1] + t = ltd.dataset isa Vector{Nothing} ? samples : vcat(samples, ltd.dataset[end]) + return sum(innerdiff(ltd, f, autodiff, t, θ, ode_params)) end -function getlogpdf(strategy::QuadratureTraining, Tar::LogTargetDensity, f, - autodiff::Bool, - tspan, - ode_params, θ) - function integrand(t::Number, θ) - innerdiff(Tar, f, autodiff, [t], θ, ode_params) - end +function getlogpdf(strategy::QuadratureTraining, ltd::LogTargetDensity, f, autodiff::Bool, + tspan, ode_params, θ) + integrand(t::Number, θ) = innerdiff(ltd, f, autodiff, [t], θ, ode_params) intprob = IntegralProblem( - integrand, (tspan[1], tspan[2]), θ; nout = length(Tar.prob.u0)) - sol = solve(intprob, QuadGKJL(); abstol = strategy.abstol, reltol = strategy.reltol) - sum(sol.u) + integrand, (tspan[1], tspan[2]), θ; nout = length(ltd.prob.u0)) + sol = solve(intprob, QuadGKJL(); strategy.abstol, strategy.reltol) + return sum(sol.u) end -function getlogpdf(strategy::WeightedIntervalTraining, Tar::LogTargetDensity, f, - autodiff::Bool, - tspan, - ode_params, θ) - minT = tspan[1] - maxT = tspan[2] - +function getlogpdf(strategy::WeightedIntervalTraining, ltd::LogTargetDensity, f, + autodiff::Bool, tspan, ode_params, θ) + minT, maxT = tspan weights = strategy.weights ./ sum(strategy.weights) - N = length(weights) - points = strategy.points - difference = (maxT - minT) / N - data = Float64[] + ts = eltype(difference)[] for (index, item) in enumerate(weights) - temp_data = rand(1, trunc(Int, points * item)) .* difference .+ minT .+ + temp_data = rand(1, trunc(Int, strategy.points * item)) .* difference .+ minT .+ ((index - 1) * difference) - data = append!(data, temp_data) + append!(ts, temp_data) end - if Tar.dataset isa Vector{Nothing} - t = data - else - t = vcat(data, - Tar.dataset[end]) - end - - sum(innerdiff(Tar, f, autodiff, t, θ, - ode_params)) + t = ltd.dataset isa Vector{Nothing} ? ts : vcat(ts, ltd.dataset[end]) + return sum(innerdiff(ltd, f, autodiff, t, θ, ode_params)) end """ MvNormal likelihood at each `ti` in time `t` for ODE collocation residue with NN with parameters θ. """ -function innerdiff(Tar::LogTargetDensity, f, autodiff::Bool, t::AbstractVector, θ, +@views function innerdiff(ltd::LogTargetDensity, f, autodiff::Bool, t::AbstractVector, θ, ode_params) + # ltd used for phi and LogTargetDensity object attributes access + out = ltd(t, θ[1:(length(θ) - ltd.extraparams)]) - # Tar used for phi and LogTargetDensity object attributes access - out = Tar(t, θ[1:(length(θ) - Tar.extraparams)]) - - # # reject samples case(write clear reason why) - if any(isinf, out[:, 1]) || any(isinf, ode_params) - return -Inf - end + # reject samples case(write clear reason why) + (any(isinf, out[:, 1]) || any(isinf, ode_params)) && return convert(eltype(out), -Inf) # this is a vector{vector{dx,dy}}(handle case single u(float passed)) if length(out[:, 1]) == 1 - physsol = [f(out[:, i][1], - ode_params, - t[i]) - for i in 1:length(out[1, :])] + physsol = [f(out[:, i][1], ode_params, t[i]) for i in 1:length(out[1, :])] else - physsol = [f(out[:, i], - ode_params, - t[i]) - for i in 1:length(out[1, :])] + physsol = [f(out[:, i], ode_params, t[i]) for i in 1:length(out[1, :])] end physsol = reduce(hcat, physsol) - nnsol = NNodederi(Tar, t, θ[1:(length(θ) - Tar.extraparams)], autodiff) + nnsol = ode_dfdx(ltd, t, θ[1:(length(θ) - ltd.extraparams)], autodiff) vals = nnsol .- physsol + T = eltype(vals) - # N dimensional vector if N outputs for NN(each row has logpdf of u[i] where u is vector of dependant variables) + # N dimensional vector if N outputs for NN(each row has logpdf of u[i] where u is vector + # of dependant variables) return [logpdf( MvNormal(vals[i, :], - LinearAlgebra.Diagonal(abs2.(Tar.phystd[i] .* - ones(length(vals[i, :]))))), - zeros(length(vals[i, :]))) for i in 1:length(Tar.prob.u0)] + Diagonal(abs2.(T(ltd.phystd[i]) .* ones(T, length(vals[i, :]))))), + zeros(T, length(vals[i, :])) + ) for i in 1:length(ltd.prob.u0)] end """ Prior logpdf for NN parameters + ODE constants. """ -function priorweights(Tar::LogTargetDensity, θ) - allparams = Tar.priors - # nn weights - nnwparams = allparams[1] - - if Tar.extraparams > 0 - # Vector of ode parameters priors - invpriors = allparams[2:end] - - invlogpdf = sum( - logpdf(invpriors[length(θ) - i + 1], θ[i]) - for i in (length(θ) - Tar.extraparams + 1):length(θ); - init = 0.0) - - return (invlogpdf - + - logpdf(nnwparams, θ[1:(length(θ) - Tar.extraparams)])) - else - return logpdf(nnwparams, θ) - end -end +@views function priorweights(ltd::LogTargetDensity, θ) + allparams = ltd.priors + nnwparams = allparams[1] # nn weights -function generate_Tar(chain::Lux.AbstractExplicitLayer, init_params) - θ, st = Lux.setup(Random.default_rng(), chain) - return init_params, chain, st -end + ltd.extraparams ≤ 0 && return logpdf(nnwparams, θ) -function generate_Tar(chain::Lux.AbstractExplicitLayer, init_params::Nothing) - θ, st = Lux.setup(Random.default_rng(), chain) - return θ, chain, st -end + # Vector of ode parameters priors + invpriors = allparams[2:end] -""" -NN OUTPUT AT t,θ ~ phi(t,θ). -""" -function (f::LogTargetDensity{C, S})(t::AbstractVector, - θ) where {C <: Lux.AbstractExplicitLayer, S} - θ = vector_to_parameters(θ, f.init_params) - y, st = f.chain(adapt(parameterless_type(ComponentArrays.getdata(θ)), t'), θ, f.st) - ChainRulesCore.@ignore_derivatives f.st = st - f.prob.u0 .+ (t' .- f.prob.tspan[1]) .* y + invlogpdf = sum( + logpdf(invpriors[length(θ) - i + 1], θ[i]) + for i in (length(θ) - ltd.extraparams + 1):length(θ)) + + return invlogpdf + logpdf(nnwparams, θ[1:(length(θ) - ltd.extraparams)]) end -function (f::LogTargetDensity{C, S})(t::Number, - θ) where {C <: Lux.AbstractExplicitLayer, S} - θ = vector_to_parameters(θ, f.init_params) - y, st = f.chain(adapt(parameterless_type(ComponentArrays.getdata(θ)), [t]), θ, f.st) - ChainRulesCore.@ignore_derivatives f.st = st - f.prob.u0 .+ (t .- f.prob.tspan[1]) .* y +function generate_ltd(chain::AbstractLuxLayer, init_params) + return init_params, chain, LuxCore.initialstates(Random.default_rng(), chain) end -""" -Similar to ode_dfdx() in NNODE. -""" -function NNodederi(phi::LogTargetDensity, t::AbstractVector, θ, autodiff::Bool) - if autodiff - hcat(ForwardDiff.derivative.(ti -> phi(ti, θ), t)...) - else - (phi(t .+ sqrt(eps(eltype(t))), θ) - phi(t, θ)) ./ sqrt(eps(eltype(t))) - end +function generate_ltd(chain::AbstractLuxLayer, ::Nothing) + θ, st = LuxCore.setup(Random.default_rng(), chain) + return θ, chain, st end function kernelchoice(Kernel, MCMCkwargs) if Kernel == HMCDA - δ, λ = MCMCkwargs[:δ], MCMCkwargs[:λ] - Kernel(δ, λ) + Kernel(MCMCkwargs[:δ], MCMCkwargs[:λ]) elseif Kernel == NUTS δ, max_depth, Δ_max = MCMCkwargs[:δ], MCMCkwargs[:max_depth], MCMCkwargs[:Δ_max] - Kernel(δ, max_depth = max_depth, Δ_max = Δ_max) - else - # HMC - n_leapfrog = MCMCkwargs[:n_leapfrog] - Kernel(n_leapfrog) + Kernel(δ; max_depth, Δ_max) + else # HMC + Kernel(MCMCkwargs[:n_leapfrog]) end end """ - ahmc_bayesian_pinn_ode(prob, chain; strategy = GridTraining, - dataset = [nothing],init_params = nothing, - draw_samples = 1000, physdt = 1 / 20.0f0,l2std = [0.05], - phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0), - param = [], nchains = 1, autodiff = false, Kernel = HMC, - Adaptorkwargs = (Adaptor = StanHMCAdaptor, - Metric = DiagEuclideanMetric, - targetacceptancerate = 0.8), - Integratorkwargs = (Integrator = Leapfrog,), - MCMCkwargs = (n_leapfrog = 30,), - progress = false, verbose = false) + ahmc_bayesian_pinn_ode(prob, chain; strategy = GridTraining, dataset = [nothing], + init_params = nothing, draw_samples = 1000, physdt = 1 / 20.0f0, + l2std = [0.05], phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0), + param = [], nchains = 1, autodiff = false, Kernel = HMC, + Adaptorkwargs = (Adaptor = StanHMCAdaptor, + Metric = DiagEuclideanMetric, targetacceptancerate = 0.8), + Integratorkwargs = (Integrator = Leapfrog,), + MCMCkwargs = (n_leapfrog = 30,), progress = false, + verbose = false) !!! warn - Note that `ahmc_bayesian_pinn_ode()` only supports ODEs which are written in the out-of-place form, i.e. - `du = f(u,p,t)`, and not `f(du,u,p,t)`. If not declared out-of-place, then the `ahmc_bayesian_pinn_ode()` - will exit with an error. + Note that `ahmc_bayesian_pinn_ode()` only supports ODEs which are written in the + out-of-place form, i.e. `du = f(u,p,t)`, and not `f(du,u,p,t)`. If not declared + out-of-place, then `ahmc_bayesian_pinn_ode()` will exit with an error. ## Example @@ -463,22 +328,29 @@ Incase you are only solving the Equations for solution, do not provide dataset ## Keyword Arguments -* `strategy`: The training strategy used to choose the points for the evaluations. By default GridTraining is used with given physdt discretization. -* `init_params`: initial parameter values for BPINN (ideally for multiple chains different initializations preferred) +* `strategy`: The training strategy used to choose the points for the evaluations. By + default GridTraining is used with given physdt discretization. +* `init_params`: initial parameter values for BPINN (ideally for multiple chains different + initializations preferred) * `nchains`: number of chains you want to sample -* `draw_samples`: number of samples to be drawn in the MCMC algorithms (warmup samples are ~2/3 of draw samples) +* `draw_samples`: number of samples to be drawn in the MCMC algorithms (warmup samples are + ~2/3 of draw samples) * `l2std`: standard deviation of BPINN prediction against L2 losses/Dataset * `phystd`: standard deviation of BPINN prediction against Chosen Underlying ODE System * `phynewstd`: standard deviation of new loss func term -* `priorsNNw`: Tuple of (mean, std) for BPINN Network parameters. Weights and Biases of BPINN are Normal Distributions by default. +* `priorsNNw`: Tuple of (mean, std) for BPINN Network parameters. Weights and Biases of + BPINN are Normal Distributions by default. * `param`: Vector of chosen ODE parameters Distributions in case of Inverse problems. * `autodiff`: Boolean Value for choice of Derivative Backend(default is numerical) * `physdt`: Timestep for approximating ODE in it's Time domain. (1/20.0 by default) * `Kernel`: Choice of MCMC Sampling Algorithm (AdvancedHMC.jl implementations HMC/NUTS/HMCDA) -* `Integratorkwargs`: `Integrator`, `jitter_rate`, `tempering_rate`. Refer: https://turinglang.org/AdvancedHMC.jl/stable/ -* `Adaptorkwargs`: `Adaptor`, `Metric`, `targetacceptancerate`. Refer: https://turinglang.org/AdvancedHMC.jl/stable/ - Note: Target percentage(in decimal) of iterations in which the proposals are accepted (0.8 by default) -* `MCMCargs`: A NamedTuple containing all the chosen MCMC kernel's(HMC/NUTS/HMCDA) Arguments, as follows : +* `Integratorkwargs`: `Integrator`, `jitter_rate`, `tempering_rate`. + Refer: https://turinglang.org/AdvancedHMC.jl/stable/ +* `Adaptorkwargs`: `Adaptor`, `Metric`, `targetacceptancerate`. + Refer: https://turinglang.org/AdvancedHMC.jl/stable/ Note: Target percentage (in decimal) + of iterations in which the proposals are accepted (0.8 by default) +* `MCMCargs`: A NamedTuple containing all the chosen MCMC kernel's (HMC/NUTS/HMCDA) + Arguments, as follows : * `n_leapfrog`: number of leapfrog steps for HMC * `δ`: target acceptance probability for NUTS and HMCDA * `λ`: target trajectory length for HMCDA @@ -488,67 +360,53 @@ Incase you are only solving the Equations for solution, do not provide dataset * `progress`: controls whether to show the progress meter or not. * `verbose`: controls the verbosity. (Sample call args in AHMC) -## Warnings +!!! warning -* AdvancedHMC.jl is still developing convenience structs so might need changes on new releases. + AdvancedHMC.jl is still developing convenience structs so might need changes on new + releases. """ -function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain; - strategy = GridTraining, dataset = [nothing], - init_params = nothing, draw_samples = 1000, - physdt = 1 / 20.0, l2std = [0.05], - phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0), - param = [], nchains = 1, autodiff = false, - Kernel = HMC, +function ahmc_bayesian_pinn_ode( + prob::SciMLBase.ODEProblem, chain; strategy = GridTraining, dataset = [nothing], + init_params = nothing, draw_samples = 1000, physdt = 1 / 20.0, l2std = [0.05], + phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0), param = [], nchains = 1, + autodiff = false, Kernel = HMC, Adaptorkwargs = (Adaptor = StanHMCAdaptor, Metric = DiagEuclideanMetric, targetacceptancerate = 0.8), - Integratorkwargs = (Integrator = Leapfrog,), - MCMCkwargs = (n_leapfrog = 30,), - progress = false, verbose = false, - estim_collocate = false) - !(chain isa Lux.AbstractExplicitLayer) && - (chain = adapt(FromFluxAdaptor(false, false), chain)) - # NN parameter prior mean and variance(PriorsNN must be a tuple) - if isinplace(prob) - throw(error("The BPINN ODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t).")) - end + Integratorkwargs = (Integrator = Leapfrog,), MCMCkwargs = (n_leapfrog = 30,), + progress = false, verbose = false, estim_collocate = false) + @assert !isinplace(prob) "The BPINN ODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t)." + + chain isa AbstractLuxLayer || (chain = FromFluxAdaptor()(chain)) strategy = strategy == GridTraining ? strategy(physdt) : strategy if dataset != [nothing] && (length(dataset) < 2 || !(dataset isa Vector{<:Vector{<:AbstractFloat}})) - throw(error("Invalid dataset. dataset would be timeseries (x̂,t) where type: Vector{Vector{AbstractFloat}")) + error("Invalid dataset. dataset would be timeseries (x̂,t) where type: Vector{Vector{AbstractFloat}") end if dataset != [nothing] && param == [] println("Dataset is only needed for Parameter Estimation + Forward Problem, not in only Forward Problem case.") elseif dataset == [nothing] && param != [] - throw(error("Dataset Required for Parameter Estimation.")) + error("Dataset Required for Parameter Estimation.") end - if chain isa Lux.AbstractExplicitLayer - # Lux-Named Tuple - initial_nnθ, recon, st = generate_Tar(chain, init_params) - else - error("Only Lux.AbstractExplicitLayer Neural networks are supported") - end + initial_nnθ, chain, st = generate_ltd(chain, init_params) - if nchains > Threads.nthreads() - throw(error("number of chains is greater than available threads")) - elseif nchains < 1 - throw(error("number of chains must be greater than 1")) - end + @assert nchains≤Threads.nthreads() "number of chains is greater than available threads" + @assert nchains≥1 "number of chains must be greater than 1" # eltype(physdt) cause needs Float64 for find_good_stepsize # Lux chain(using component array later as vector_to_parameter need namedtuple) - initial_θ = collect(eltype(physdt), - vcat(ComponentArrays.ComponentArray(initial_nnθ))) + T = eltype(physdt) + initial_θ = getdata(ComponentArray{T}(initial_nnθ)) # adding ode parameter estimation nparameters = length(initial_θ) ninv = length(param) priors = [ - MvNormal(priorsNNw[1] * ones(nparameters), - LinearAlgebra.Diagonal(abs2.(priorsNNw[2] .* ones(nparameters)))) + MvNormal(T(priorsNNw[1]) * ones(T, nparameters), + Diagonal(abs2.(T(priorsNNw[2]) .* ones(T, nparameters)))) ] # append Ode params to all paramvector @@ -560,30 +418,25 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain; end t0 = prob.tspan[1] + smodel = StatefulLuxLayer{true}(chain, nothing, st) # dimensions would be total no of params,initial_nnθ for Lux namedTuples - ℓπ = LogTargetDensity(nparameters, prob, recon, st, strategy, dataset, priors, + ℓπ = LogTargetDensity(nparameters, prob, smodel, strategy, dataset, priors, phystd, phynewstd, l2std, autodiff, physdt, ninv, initial_nnθ, estim_collocate) - try - ℓπ(t0, initial_θ[1:(nparameters - ninv)]) - catch err - if isa(err, DimensionMismatch) - throw(DimensionMismatch("Dimensions of the initial u0 and chain should match")) - else - throw(err) + if verbose + @printf("Current Physics Log-likelihood: %g\n", physloglikelihood(ℓπ, initial_θ)) + @printf("Current Prior Log-likelihood: %g\n", priorweights(ℓπ, initial_θ)) + @printf("Current MSE against dataset Log-likelihood: %g\n", + L2LossData(ℓπ, initial_θ)) + if estim_collocate + @printf("Current gradient loss against dataset Log-likelihood: %g\n", + L2loss2(ℓπ, initial_θ)) end end - @info("Current Physics Log-likelihood : ", physloglikelihood(ℓπ, initial_θ)) - @info("Current Prior Log-likelihood : ", priorweights(ℓπ, initial_θ)) - @info("Current MSE against dataset Log-likelihood : ", L2LossData(ℓπ, initial_θ)) - if estim_collocate - @info("Current gradient loss against dataset Log-likelihood : ", - L2loss2(ℓπ, initial_θ)) - end - - Adaptor, Metric, targetacceptancerate = Adaptorkwargs[:Adaptor], - Adaptorkwargs[:Metric], Adaptorkwargs[:targetacceptancerate] + Adaptor = Adaptorkwargs[:Adaptor] + Metric = Adaptorkwargs[:Metric] + targetacceptancerate = Adaptorkwargs[:targetacceptancerate] # Define Hamiltonian system (nparameters ~ dimensionality of the sampling space) metric = Metric(nparameters) @@ -598,8 +451,10 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain; Threads.@threads for i in 1:nchains # each chain has different initial NNparameter values(better posterior exploration) - initial_θ = vcat(randn(nparameters - ninv), - initial_θ[(nparameters - ninv + 1):end]) + initial_θ = vcat( + randn(eltype(initial_θ), nparameters - ninv), + initial_θ[(nparameters - ninv + 1):end] + ) initial_ϵ = find_good_stepsize(hamiltonian, initial_θ) integrator = integratorchoice(Integratorkwargs, initial_ϵ) adaptor = adaptorchoice(Adaptor, MassMatrixAdaptor(metric), @@ -612,7 +467,7 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain; samplesc[i] = samples statsc[i] = stats - mcmc_chain = Chains(hcat(samples...)') + mcmc_chain = Chains(reduce(hcat, samples)') chains[i] = mcmc_chain end @@ -628,13 +483,17 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain; samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, adaptor; progress = progress, verbose = verbose) - @info("Sampling Complete.") - @info("Final Physics Log-likelihood : ", physloglikelihood(ℓπ, samples[end])) - @info("Final Prior Log-likelihood : ", priorweights(ℓπ, samples[end])) - @info("Final MSE against dataset Log-likelihood : ", L2LossData(ℓπ, samples[end])) - if estim_collocate - @info("Final gradient loss against dataset Log-likelihood : ", - L2loss2(ℓπ, samples[end])) + if verbose + println("Sampling Complete.") + @printf("Final Physics Log-likelihood: %g\n", + physloglikelihood(ℓπ, samples[end])) + @printf("Final Prior Log-likelihood: %g\n", priorweights(ℓπ, samples[end])) + @printf("Final MSE against dataset Log-likelihood: %g\n", + L2LossData(ℓπ, samples[end])) + if estim_collocate + @printf("Final gradient loss against dataset Log-likelihood: %g\n", + L2loss2(ℓπ, samples[end])) + end end # return a chain(basic chain),samples and stats diff --git a/src/discretize.jl b/src/discretize.jl index 757c1f8b8f..5187a0638a 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -445,14 +445,13 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, discretization::Ab # assume one single additional loss function if there is one. this means that the user needs to lump all their functions into a single one, num_additional_loss = convert(Int, additional_loss !== nothing) - adaloss_T = eltype(adaloss.pde_loss_weights) + adaloss_T = eltype(adaloss.pde_loss_weights) - # this will error if the user has provided a number of initial weights that is more than 1 and doesn't match the number of loss functions - adaloss.pde_loss_weights = ones(adaloss_T, num_pde_losses) .* - adaloss.pde_loss_weights - adaloss.bc_loss_weights = ones(adaloss_T, num_bc_losses) .* adaloss.bc_loss_weights - adaloss.additional_loss_weights = ones(adaloss_T, num_additional_loss) .* - adaloss.additional_loss_weights + # this will error if the user has provided a number of initial weights that is more than 1 and doesn't match the number of loss functions + adaloss.pde_loss_weights = ones(adaloss_T, num_pde_losses) .* adaloss.pde_loss_weights + adaloss.bc_loss_weights = ones(adaloss_T, num_bc_losses) .* adaloss.bc_loss_weights + adaloss.additional_loss_weights = ones(adaloss_T, num_additional_loss) .* + adaloss.additional_loss_weights reweight_losses_func = generate_adaptive_loss_function(pinnrep, adaloss, pde_loss_functions, bc_loss_functions) @@ -521,36 +520,10 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, discretization::Ab return full_weighted_loss end - return bc_loss_functions, pde_loss_functions, full_loss_function + return full_loss_function end function get_likelihood_estimate_function(discretization::BayesianPINN) - # Because separate reweighting code section needed and loglikelihood is pointwise independent - pde_loss_functions, bc_loss_functions = merge_strategy_with_loglikelihood_function( - pinnrep, - strategy, - datafree_pde_loss_functions, - datafree_bc_loss_functions) - - # setup for all adaptive losses - num_pde_losses = length(pde_loss_functions) - num_bc_losses = length(bc_loss_functions) - # assume one single additional loss function if there is one. this means that the user needs to lump all their functions into a single one, - num_additional_loss = additional_loss isa Nothing ? 0 : 1 - - adaloss_T = eltype(adaloss.pde_loss_weights) - - # this will error if the user has provided a number of initial weights that is more than 1 and doesn't match the number of loss functions - adaloss.pde_loss_weights = ones(adaloss_T, num_pde_losses) .* - adaloss.pde_loss_weights - adaloss.bc_loss_weights = ones(adaloss_T, num_bc_losses) .* adaloss.bc_loss_weights - adaloss.additional_loss_weights = ones(adaloss_T, num_additional_loss) .* - adaloss.additional_loss_weights - - reweight_losses_func = generate_adaptive_loss_function(pinnrep, adaloss, - pde_loss_functions, - bc_loss_functions) - dataset_pde, dataset_bc = discretization.dataset # required as Physics loss also needed on the discrete dataset domain points @@ -566,7 +539,7 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, discretization::Ab end function full_loss_function(θ, allstd::Vector{Vector{Float64}}) - stdpdes, stdbcs, stdextra, stdpdesnew = allstd + stdpdes, stdbcs, stdextra = allstd # the aggregation happens on cpu even if the losses are gpu, probably fine since it's only a few of them pde_loglikelihoods = [logpdf(Normal(0, stdpdes[i]), pde_loss_function(θ)) for (i, pde_loss_function) in enumerate(pde_loss_functions)] @@ -578,6 +551,7 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, discretization::Ab pde_loglikelihoods += [logpdf(Normal(0, stdpdes[j]), pde_loss_function(θ)) for (j, pde_loss_function) in enumerate(datapde_loss_functions)] end + if !(databc_loss_functions isa Nothing) bc_loglikelihoods += [logpdf(Normal(0, stdbcs[j]), bc_loss_function(θ)) for (j, bc_loss_function) in enumerate(databc_loss_functions)] @@ -618,11 +592,10 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, discretization::Ab return full_weighted_loglikelihood end - return bc_loss_functions, pde_loss_functions, full_loss_function + return full_loss_function end - bc_loss_functions, pde_loss_functions, full_loss_function = get_likelihood_estimate_function(discretization) - + full_loss_function = get_likelihood_estimate_function(discretization) pinnrep.loss_functions = PINNLossFunctions(bc_loss_functions, pde_loss_functions, full_loss_function, additional_loss, datafree_pde_loss_functions, datafree_bc_loss_functions) @@ -641,4 +614,4 @@ function SciMLBase.discretize(pde_system::PDESystem, discretization::PhysicsInfo pinnrep = symbolic_discretize(pde_system, discretization) f = OptimizationFunction(pinnrep.loss_functions.full_loss_function, AutoZygote()) return Optimization.OptimizationProblem(f, pinnrep.flat_init_params) -end +end \ No newline at end of file