Skip to content

Commit

Permalink
Remove update_coefficients_for_test!, test L(v, u, p, t)
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav-arya committed Mar 28, 2023
1 parent c887b5f commit b072b13
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 20 deletions.
44 changes: 32 additions & 12 deletions test/test_jaches_products.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using SparseDiffTools, ForwardDiff, FiniteDiff, Zygote, IterativeSolvers
using LinearAlgebra, Test
import SciMLOperators: update_coefficients, update_coefficients!

using Random
Random.seed!(123)
Expand Down Expand Up @@ -85,22 +84,28 @@ update_coefficients!(f, x, 1.0, 1.0)
@test L * v auto_jacvec(f, x, v)
@test mul!(dy, L, v) auto_jacvec(f, x, v)
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*auto_jacvec(f,x,v) + b*_dy
update_coefficients_for_test!(L, v, 3.0, 4.0)
update_coefficients!(L, v, 3.0, 4.0)
update_coefficients!(f, v, 3.0, 4.0)
@test mul!(dy, L, v) auto_jacvec(f, v, v)
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*auto_jacvec(f,x,v) + b*_dy
update_coefficients!(f, v, 5.0, 6.0)
@test L(dy, v, 5.0, 6.0) auto_jacvec(f,x,v)

L = JacVec(f, x, 1.0, 1.0; autodiff = AutoFiniteDiff())
update_coefficients!(f, x, 1.0, 1.0)
@test L * x num_jacvec(f, x, x)
@test L * v num_jacvec(f, x, v)
@test mul!(dy, L, v)num_jacvec(f, x, v) rtol=1e-6
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_jacvec(f,x,v) + b*_dy rtol=1e-6
update_coefficients_for_test!(L, v, 3.0, 4.0)
update_coefficients!(L, v, 3.0, 4.0)
update_coefficients!(f, v, 3.0, 4.0)
@test mul!(dy, L, v)num_jacvec(f, v, v) rtol=1e-6
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_jacvec(f,x,v) + b*_dy rtol=1e-6
update_coefficients!(f, v, 5.0, 6.0)
@test L(dy, v, 5.0, 6.0) num_jacvec(f,x,v) rtol=1e-6

out = similar(v)
gmres!(out, L, v)
@test_nowarn gmres!(out, L, v)

@info "HesVec"

Expand All @@ -113,9 +118,12 @@ num_hesvec(g, x, x)
@test L * v num_hesvec(g, x, v) rtol=1e-2
@test mul!(dy, L, v)num_hesvec(g, x, v) rtol=1e-2
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_hesvec(g,x,v) + b*_dy rtol=1e-2
update_coefficients_for_test!(L, v, 3.0, 4.0)
update_coefficients!(L, v, 3.0, 4.0)
update_coefficients!(g, v, 3.0, 4.0)
@test mul!(dy, L, v)num_hesvec(g, v, v) rtol=1e-2
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_hesvec(g,x,v) + b*_dy rtol=1e-2
update_coefficients!(g, v, 5.0, 6.0)
@test L(dy, v, 5.0, 6.0) num_hesvec(g,x,v) rtol=1e-2

L = HesVec(g, x, 1.0, 1.0)
update_coefficients!(g, x, 1.0, 1.0)
Expand All @@ -125,9 +133,12 @@ num_hesvec(g, x, x)
@test L * v numauto_hesvec(g, x, v)
@test mul!(dy, L, v)numauto_hesvec(g, x, v) rtol=1e-8
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
update_coefficients_for_test!(L, v, 3.0, 4.0)
update_coefficients!(L, v, 3.0, 4.0)
update_coefficients!(g, v, 3.0, 4.0)
@test mul!(dy, L, v)numauto_hesvec(g, v, v) rtol=1e-8
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
update_coefficients!(g, v, 5.0, 6.0)
@test L(dy, v, 5.0, 6.0) numauto_hesvec(g,x,v) rtol=1e-2

out = similar(v)
gmres!(out, L, v)
Expand All @@ -141,9 +152,12 @@ update_coefficients!(g, x, 1.0, 1.0)
@test L * v autoback_hesvec(g, x, v)
@test mul!(dy, L, v)autoback_hesvec(g, x, v) rtol=1e-8
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*autoback_hesvec(g,x,v)+b*_dy rtol=1e-8
update_coefficients_for_test!(L, v, 3.0, 4.0)
update_coefficients!(L, v, 3.0, 4.0)
update_coefficients!(g, v, 3.0, 4.0)
@test mul!(dy, L, v)autoback_hesvec(g, v, v) rtol=1e-8
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*autoback_hesvec(g,x,v)+b*_dy rtol=1e-8
update_coefficients!(g, v, 5.0, 6.0)
@test L(dy, v, 5.0, 6.0) autoback_hesvec(g,x,v) rtol=1e-2

out = similar(v)
gmres!(out, L, v)
Expand All @@ -159,20 +173,26 @@ update_coefficients!(g, x, 1.0, 1.0)
@test L * v num_hesvec(g, x, v) rtol=1e-2
@test mul!(dy, L, v)num_hesvec(g, x, v) rtol=1e-2
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*num_hesvec(g,x,v)+b*_dy rtol=1e-2
update_coefficients_for_test!(L, v, 3.0, 4.0)
update_coefficients!(g, x, 3.0, 4.0)
for op in (L, g, h) update_coefficients!(op, v, 3.0, 4.0) end
@test mul!(dy, L, v)num_hesvec(g, v, v) rtol=1e-2
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*num_hesvec(g,x,v)+b*_dy rtol=1e-2
update_coefficients!(g, v, 5.0, 6.0)
@test L(dy, v, 5.0, 6.0) num_hesvec(g,x,v) rtol=1e-2

L = HesVecGrad(h, x, 1.0, 1.0)
update_coefficients!(h, x, 1.0, 1.0)
update_coefficients!(g, x, 1.0, 1.0)
update_coefficients!(h, x, 1.0, 1.0)
@test L * x autonum_hesvec(g, x, x)
@test L * v numauto_hesvec(g, x, v)
@test mul!(dy, L, v)numauto_hesvec(g, x, v) rtol=1e-8
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
update_coefficients_for_test!(L, v, 3.0, 4.0)
update_coefficients!(g, x, 3.0, 4.0)
for op in (L, g, h) update_coefficients!(op, v, 3.0, 4.0) end
@test mul!(dy, L, v)numauto_hesvec(g, v, v) rtol=1e-8
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
update_coefficients!(g, v, 5.0, 6.0)
update_coefficients!(h, v, 5.0, 6.0)
@test L(dy, v, 5.0, 6.0) num_hesvec(g,x,v) rtol=1e-2

out = similar(v)
gmres!(out, L, v)
#
8 changes: 8 additions & 0 deletions test/test_vecjac_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,17 @@ f = WrapFunc(_f, 1.0, 1.0)

L = VecJac(f, x, 1.0, 1.0)
update_coefficients!(L, v, 3.0, 4.0)
update_coefficients!(f, v, 3.0, 4.0)
actual_vjp = Zygote.jacobian(f, x)[1]' * v
@test L * v actual_vjp
update_coefficients!(f, v, 5.0, 6.0)
actual_vjp2 = Zygote.jacobian(f, x)[1]' * v
@test L(copy(v), v, 5.0, 6.0) actual_vjp2

L = VecJac(f, x, 1.0, 1.0; autodiff = AutoFiniteDiff())
update_coefficients!(L, v, 3.0, 4.0)
update_coefficients!(f, v, 3.0, 4.0)
@test L * v actual_vjp
update_coefficients!(f, v, 5.0, 6.0)
@test L(copy(v), v, 5.0, 6.0) actual_vjp2
#
11 changes: 3 additions & 8 deletions test/update_coeffs_testutils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Utilities for testing update coefficient behaviour with state-dependent (i.e. dependent on u/p/t) functions
import SciMLOperators: update_coefficients, update_coefficients!

# Wrapper function for testing update coefficient behaviour with state-dependent (i.e. dependent on u/p/t) functions

mutable struct WrapFunc{F,P,T}
func::F
Expand All @@ -16,11 +18,4 @@ update_coefficients(w::WrapFunc, u, p, t) = WrapFunc(w.func, p, t)
function update_coefficients!(w::WrapFunc, u, p, t)
w.p = p
w.t = t
end

# Helper function for testing correct update coefficients behaviour of operators
function update_coefficients_for_test!(L, u, p, t)
update_coefficients!(L, u, p, t)
# Force function hiding inside L to update. Should be a null-op if previous line works correctly
update_coefficients!(L.op.f, u, p, t)
end

0 comments on commit b072b13

Please sign in to comment.