Skip to content

Commit

Permalink
Make sure to reset state of stateful functions when reconstructing op…
Browse files Browse the repository at this point in the history
…erator
  • Loading branch information
gaurav-arya committed Mar 28, 2023
1 parent 7e4ac73 commit c71b336
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
27 changes: 19 additions & 8 deletions test/test_jaches_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ end
# Define state-dependent functions for operator tests

include("update_coeffs_testutils.jl")
f = WrapFunc(_f, ones(N), 1.0, 1.0)
g = WrapFunc(_g, ones(N), 1.0, 1.0)
h = WrapFunc(_h, ones(N), 1.0, 1.0)
f = WrapFunc(_f, 1.0, 1.0)
g = WrapFunc(_g, 1.0, 1.0)
h = WrapFunc(_h, 1.0, 1.0)

###

Expand Down Expand Up @@ -80,6 +80,7 @@ cache4 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing, eltype(x))), eltype(x)
@info "JacVec"

L = JacVec(f, x, 1.0, 1.0)
update_coefficients!(f, x, 1.0, 1.0)
@test L * x auto_jacvec(f, x, x)
@test L * v auto_jacvec(f, x, v)
@test mul!(dy, L, v) auto_jacvec(f, x, v)
Expand All @@ -89,6 +90,7 @@ update_coefficients_for_test!(L, v, 3.0, 4.0)
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*auto_jacvec(f,x,v) + b*_dy

L = JacVec(f, x, 1.0, 1.0; autodiff = AutoFiniteDiff())
update_coefficients!(f, x, 1.0, 1.0)
@test L * x num_jacvec(f, x, x)
@test L * v num_jacvec(f, x, v)
@test mul!(dy, L, v)num_jacvec(f, x, v) rtol=1e-6
Expand All @@ -105,15 +107,20 @@ gmres!(out, L, v)
x = rand(N)
v = rand(N)
L = HesVec(g, x, 1.0, 1.0, autodiff = AutoFiniteDiff())
@test L * x num_hesvec(g, x, x)
@test L * v num_hesvec(g, x, v)
update_coefficients!(g, x, 1.0, 1.0)
@test L * x num_hesvec(g, x, x) rtol=1e-2
num_hesvec(g, x, x)
@test L * v num_hesvec(g, x, v) rtol=1e-2
@test mul!(dy, L, v)num_hesvec(g, x, v) rtol=1e-2
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_hesvec(g,x,v) + b*_dy rtol=1e-2
update_coefficients_for_test!(L, v, 3.0, 4.0)
@test mul!(dy, L, v)num_hesvec(g, v, v) rtol=1e-2
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_hesvec(g,x,v) + b*_dy rtol=1e-2

L = HesVec(g, x, 1.0, 1.0)
update_coefficients!(g, x, 1.0, 1.0)
numauto_hesvec(g, x, x)
num_hesvec(g, x, x)
@test L * x numauto_hesvec(g, x, x)
@test L * v numauto_hesvec(g, x, v)
@test mul!(dy, L, v)numauto_hesvec(g, x, v) rtol=1e-8
Expand All @@ -129,6 +136,7 @@ x = rand(N)
v = rand(N)

L = HesVec(g, x, 1.0, 1.0; autodiff = AutoZygote())
update_coefficients!(g, x, 1.0, 1.0)
@test L * x autoback_hesvec(g, x, x)
@test L * v autoback_hesvec(g, x, v)
@test mul!(dy, L, v)autoback_hesvec(g, x, v) rtol=1e-8
Expand All @@ -145,23 +153,26 @@ gmres!(out, L, v)
x = rand(N)
v = rand(N)
L = HesVecGrad(h, x, 1.0, 1.0; autodiff = AutoFiniteDiff())
update_coefficients!(h, x, 1.0, 1.0)
update_coefficients!(g, x, 1.0, 1.0)
@test L * x num_hesvec(g, x, x) rtol=1e-2
@test L * v num_hesvec(g, x, v) rtol=1e-2
@test mul!(dy, L, v)num_hesvec(g, x, v) rtol=1e-2
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*num_hesvec(g,x,v)+b*_dy rtol=1e-2
update_coefficients_for_test!(L, v, 3.0, 4.0)
update_coefficients!(g, x, 3.0, 4.0)
@test mul!(dy, L, v)num_hesvec(g, v, v) rtol=1e-2
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*num_hesvec(g,x,v)+b*_dy rtol=1e-2

L = HesVecGrad(h, x, 1.0, 1.0)
update_coefficients!(h, x, 1.0, 1.0)
update_coefficients!(g, x, 1.0, 1.0)
@test L * x autonum_hesvec(g, x, x)
@test L * v numauto_hesvec(g, x, v)
@test mul!(dy, L, v)numauto_hesvec(g, x, v) rtol=1e-8
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
update_coefficients_for_test!(L, v, 3.0, 4.0)
update_coefficients!(g, x, 3.0, 4.0)
@test mul!(dy, L, v)numauto_hesvec(g, v, v) rtol=1e-8
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8

out = similar(v)
gmres!(out, L, v)
#
6 changes: 2 additions & 4 deletions test/update_coeffs_testutils.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# Utilities for testing update coefficient behaviour with state-dependent (i.e. dependent on u/p/t) functions

mutable struct WrapFunc{F,U,P,T}
mutable struct WrapFunc{F,P,T}
func::F
u::U
p::P
t::T
end
Expand All @@ -13,9 +12,8 @@ function (w::WrapFunc)(v, u)
lmul!(w.p * w.t, v)
end

update_coefficients(w::WrapFunc, u, p, t) = WrapFunc(w.func, u, p, t)
update_coefficients(w::WrapFunc, u, p, t) = WrapFunc(w.func, p, t)
function update_coefficients!(w::WrapFunc, u, p, t)
w.u = u
w.p = p
w.t = t
end
Expand Down

0 comments on commit c71b336

Please sign in to comment.