Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adaptive Reweighting of BPINN Loglikelihood #798

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
begin again
AstitvaAggarwal committed Feb 3, 2024
commit 9ccd327c6c661952b2934f3b0bd1e8d9542761ad
139 changes: 73 additions & 66 deletions src/adaptive_losses.jl
Original file line number Diff line number Diff line change
@@ -24,27 +24,28 @@ mutable struct NonAdaptiveLoss{T <: Real} <: AbstractAdaptiveLoss
bc_loss_weights::Vector{T}
additional_loss_weights::Vector{T}
SciMLBase.@add_kwonly function NonAdaptiveLoss{T}(; pde_loss_weights = 1.0,
bc_loss_weights = 1.0,
additional_loss_weights = 1.0) where {
T <:
Real
}
bc_loss_weights = 1.0,
additional_loss_weights = 1.0) where {
T <:
Real,
}
new(vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, T),
vectorify(additional_loss_weights, T))
end
end

# default to Float64
SciMLBase.@add_kwonly function NonAdaptiveLoss(; pde_loss_weights = 1.0, bc_loss_weights = 1.0,
additional_loss_weights = 1.0)
SciMLBase.@add_kwonly function NonAdaptiveLoss(; pde_loss_weights = 1.0,
bc_loss_weights = 1.0,
additional_loss_weights = 1.0)
NonAdaptiveLoss{Float64}(; pde_loss_weights = pde_loss_weights,
bc_loss_weights = bc_loss_weights,
additional_loss_weights = additional_loss_weights)
bc_loss_weights = bc_loss_weights,
additional_loss_weights = additional_loss_weights)
end

function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
adaloss::NonAdaptiveLoss,
pde_loss_functions, bc_loss_functions)
adaloss::NonAdaptiveLoss,
pde_loss_functions, bc_loss_functions)
function null_nonadaptive_loss(θ, pde_losses, bc_losses)
nothing
end
@@ -87,31 +88,31 @@ mutable struct GradientScaleAdaptiveLoss{T <: Real, R <: Real} <: AbstractAdapti
bc_loss_weights::Vector{R}
additional_loss_weights::Vector{T}
SciMLBase.@add_kwonly function GradientScaleAdaptiveLoss{T, R}(reweight_every;
weight_change_inertia = 0.9,
pde_loss_weights = 1.0,
bc_loss_weights = 1.0,
additional_loss_weights = 1.0) where {
T <:
Real,
R <:
Real
}
weight_change_inertia = 0.9,
pde_loss_weights = 1.0,
bc_loss_weights = 1.0,
additional_loss_weights = 1.0) where {
T <:
Real,
R <:
Real,
}
new(convert(Int64, reweight_every), convert(T, weight_change_inertia),
vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, R),
vectorify(additional_loss_weights, T))
end
end
# default to Float64
SciMLBase.@add_kwonly function GradientScaleAdaptiveLoss(reweight_every;
weight_change_inertia = 0.9,
pde_loss_weights = 1,
bc_loss_weights = 1,
additional_loss_weights = 1)
weight_change_inertia = 0.9,
pde_loss_weights = 1,
bc_loss_weights = 1,
additional_loss_weights = 1)
GradientScaleAdaptiveLoss{Float64, Float64}(reweight_every;
weight_change_inertia = weight_change_inertia,
pde_loss_weights = pde_loss_weights,
bc_loss_weights = bc_loss_weights,
additional_loss_weights = additional_loss_weights)
weight_change_inertia = weight_change_inertia,
pde_loss_weights = pde_loss_weights,
bc_loss_weights = bc_loss_weights,
additional_loss_weights = additional_loss_weights)
end

# function GradientScaleAdaptiveLoss(reweight_every;
@@ -127,12 +128,12 @@ end
# end

function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
adaloss::GradientScaleAdaptiveLoss,
pde_loss_functions, bc_loss_functions)
adaloss::GradientScaleAdaptiveLoss,
pde_loss_functions, bc_loss_functions)
weight_change_inertia = adaloss.weight_change_inertia
iteration = pinnrep.iteration
adaloss_T = eltype(adaloss.pde_loss_weights)

function run_loss_gradients_adaptive_loss(θ, pde_losses, bc_losses)
if iteration[1] % adaloss.reweight_every == 0
# the paper assumes a single pde loss function, so here we grab the maximum of the maximums of each pde loss function
@@ -145,13 +146,13 @@ function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
convert(adaloss_T, 1e-7)
bc_loss_weights_proposed = pde_grads_max ./
(bc_grads_mean .+ nonzero_divisor_eps)

# println("adaloss.bc_loss_weights :", adaloss.bc_loss_weights)
if bc_loss_weights_proposed[1] isa ForwardDiff.Dual
bc_loss_weights_proposed = [bc_loss_weights_propose.value
for bc_loss_weights_propose in bc_loss_weights_proposed]
end

adaloss.bc_loss_weights .= weight_change_inertia .*
adaloss.bc_loss_weights .+
(1 .- weight_change_inertia) .*
@@ -205,8 +206,8 @@ Levi McClenny, Ulisses Braga-Neto
https://arxiv.org/abs/2009.04544
"""
mutable struct MiniMaxAdaptiveLoss{T <: Real,
PDE_OPT,
BC_OPT} <:
PDE_OPT,
BC_OPT} <:
AbstractAdaptiveLoss
reweight_every::Int64
pde_max_optimiser::PDE_OPT
@@ -215,17 +216,17 @@ mutable struct MiniMaxAdaptiveLoss{T <: Real,
bc_loss_weights::Vector{T}
additional_loss_weights::Vector{T}
SciMLBase.@add_kwonly function MiniMaxAdaptiveLoss{T,
PDE_OPT, BC_OPT}(reweight_every;
pde_max_optimiser = OptimizationOptimisers.Adam(1e-4),
bc_max_optimiser = OptimizationOptimisers.Adam(0.5),
pde_loss_weights = 1.0,
bc_loss_weights = 1.0,
additional_loss_weights = 1.0) where {
T <:
Real,
PDE_OPT,
BC_OPT
}
PDE_OPT, BC_OPT}(reweight_every;
pde_max_optimiser = OptimizationOptimisers.Adam(1e-4),
bc_max_optimiser = OptimizationOptimisers.Adam(0.5),
pde_loss_weights = 1.0,
bc_loss_weights = 1.0,
additional_loss_weights = 1.0) where {
T <:
Real,
PDE_OPT,
BC_OPT,
}
new(convert(Int64, reweight_every), convert(PDE_OPT, pde_max_optimiser),
convert(BC_OPT, bc_max_optimiser),
vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, T),
@@ -235,38 +236,44 @@ end

# default to Float64, ADAM, ADAM
SciMLBase.@add_kwonly function MiniMaxAdaptiveLoss(reweight_every;
pde_max_optimiser = OptimizationOptimisers.Adam(1e-4),
bc_max_optimiser = OptimizationOptimisers.Adam(0.5),
pde_loss_weights = 1.0,
bc_loss_weights = 1.0,
additional_loss_weights = 1.0)
pde_max_optimiser = OptimizationOptimisers.Adam(1e-4),
bc_max_optimiser = OptimizationOptimisers.Adam(0.5),
pde_loss_weights = 1.0,
bc_loss_weights = 1.0,
additional_loss_weights = 1.0)
MiniMaxAdaptiveLoss{Float64, typeof(pde_max_optimiser),
typeof(bc_max_optimiser)}(reweight_every;
pde_max_optimiser = pde_max_optimiser,
bc_max_optimiser = bc_max_optimiser,
pde_loss_weights = pde_loss_weights,
bc_loss_weights = bc_loss_weights,
additional_loss_weights = additional_loss_weights)
typeof(bc_max_optimiser)}(reweight_every;
pde_max_optimiser = pde_max_optimiser,
bc_max_optimiser = bc_max_optimiser,
pde_loss_weights = pde_loss_weights,
bc_loss_weights = bc_loss_weights,
additional_loss_weights = additional_loss_weights)
end

function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
adaloss::MiniMaxAdaptiveLoss,
pde_loss_functions, bc_loss_functions)
adaloss::MiniMaxAdaptiveLoss,
pde_loss_functions, bc_loss_functions)
pde_max_optimiser = adaloss.pde_max_optimiser
pde_max_optimiser_setup = OptimizationOptimisers.Optimisers.setup(pde_max_optimiser, adaloss.pde_loss_weights)
pde_max_optimiser_setup = OptimizationOptimisers.Optimisers.setup(pde_max_optimiser,
adaloss.pde_loss_weights)
bc_max_optimiser = adaloss.bc_max_optimiser
bc_max_optimiser_setup = OptimizationOptimisers.Optimisers.setup(bc_max_optimiser, adaloss.bc_loss_weights)
bc_max_optimiser_setup = OptimizationOptimisers.Optimisers.setup(bc_max_optimiser,
adaloss.bc_loss_weights)
iteration = pinnrep.iteration

function run_minimax_adaptive_loss(θ, pde_losses, bc_losses)
if iteration[1] % adaloss.reweight_every == 0
OptimizationOptimisers.Optimisers.update!(pde_max_optimiser_setup, adaloss.pde_loss_weights, -pde_losses)
OptimizationOptimisers.Optimisers.update!(bc_max_optimiser_setup, adaloss.bc_loss_weights, -bc_losses)
OptimizationOptimisers.Optimisers.update!(pde_max_optimiser_setup,
adaloss.pde_loss_weights,
-pde_losses)
OptimizationOptimisers.Optimisers.update!(bc_max_optimiser_setup,
adaloss.bc_loss_weights,
-bc_losses)
logvector(pinnrep.logger, adaloss.pde_loss_weights,
"adaptive_loss/pde_loss_weights", iteration[1])
"adaptive_loss/pde_loss_weights", iteration[1])
logvector(pinnrep.logger, adaloss.bc_loss_weights,
"adaptive_loss/bc_loss_weights",
iteration[1])
"adaptive_loss/bc_loss_weights",
iteration[1])
end
nothing
end
Loading