diff --git a/test/test_jaches_products.jl b/test/test_jaches_products.jl index e1b826dd..cbff19fc 100644 --- a/test/test_jaches_products.jl +++ b/test/test_jaches_products.jl @@ -25,9 +25,9 @@ end # Define state-dependent functions for operator tests include("update_coeffs_testutils.jl") -f = WrapFunc(_f, ones(N), 1.0, 1.0) -g = WrapFunc(_g, ones(N), 1.0, 1.0) -h = WrapFunc(_h, ones(N), 1.0, 1.0) +f = WrapFunc(_f, 1.0, 1.0) +g = WrapFunc(_g, 1.0, 1.0) +h = WrapFunc(_h, 1.0, 1.0) ### @@ -80,6 +80,7 @@ cache4 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing, eltype(x))), eltype(x) @info "JacVec" L = JacVec(f, x, 1.0, 1.0) +update_coefficients!(f, x, 1.0, 1.0) @test L * x ≈ auto_jacvec(f, x, x) @test L * v ≈ auto_jacvec(f, x, v) @test mul!(dy, L, v) ≈ auto_jacvec(f, x, v) @@ -89,6 +90,7 @@ update_coefficients_for_test!(L, v, 3.0, 4.0) dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*auto_jacvec(f,x,v) + b*_dy L = JacVec(f, x, 1.0, 1.0; autodiff = AutoFiniteDiff()) +update_coefficients!(f, x, 1.0, 1.0) @test L * x ≈ num_jacvec(f, x, x) @test L * v ≈ num_jacvec(f, x, v) @test mul!(dy, L, v)≈num_jacvec(f, x, v) rtol=1e-6 @@ -105,8 +107,10 @@ gmres!(out, L, v) x = rand(N) v = rand(N) L = HesVec(g, x, 1.0, 1.0, autodiff = AutoFiniteDiff()) -@test L * x ≈ num_hesvec(g, x, x) -@test L * v ≈ num_hesvec(g, x, v) +update_coefficients!(g, x, 1.0, 1.0) +@test L * x ≈ num_hesvec(g, x, x) rtol=1e-2 +num_hesvec(g, x, x) +@test L * v ≈ num_hesvec(g, x, v) rtol=1e-2 @test mul!(dy, L, v)≈num_hesvec(g, x, v) rtol=1e-2 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*num_hesvec(g,x,v) + b*_dy rtol=1e-2 update_coefficients_for_test!(L, v, 3.0, 4.0) @@ -114,6 +118,9 @@ update_coefficients_for_test!(L, v, 3.0, 4.0) dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*num_hesvec(g,x,v) + b*_dy rtol=1e-2 L = HesVec(g, x, 1.0, 1.0) +update_coefficients!(g, x, 1.0, 1.0) +numauto_hesvec(g, x, x) +num_hesvec(g, x, x) @test L * x ≈ numauto_hesvec(g, x, x) @test L * v ≈ numauto_hesvec(g, x, v) @test mul!(dy, L, v)≈numauto_hesvec(g, x, v) rtol=1e-8 @@ -129,6 +136,7 @@ x = rand(N) v = rand(N) L = HesVec(g, x, 1.0, 1.0; autodiff = AutoZygote()) +update_coefficients!(g, x, 1.0, 1.0) @test L * x ≈ autoback_hesvec(g, x, x) @test L * v ≈ autoback_hesvec(g, x, v) @test mul!(dy, L, v)≈autoback_hesvec(g, x, v) rtol=1e-8 @@ -145,23 +153,26 @@ gmres!(out, L, v) x = rand(N) v = rand(N) L = HesVecGrad(h, x, 1.0, 1.0; autodiff = AutoFiniteDiff()) +update_coefficients!(h, x, 1.0, 1.0) +update_coefficients!(g, x, 1.0, 1.0) @test L * x ≈ num_hesvec(g, x, x) rtol=1e-2 @test L * v ≈ num_hesvec(g, x, v) rtol=1e-2 @test mul!(dy, L, v)≈num_hesvec(g, x, v) rtol=1e-2 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*num_hesvec(g,x,v)+b*_dy rtol=1e-2 update_coefficients_for_test!(L, v, 3.0, 4.0) +update_coefficients!(g, x, 3.0, 4.0) @test mul!(dy, L, v)≈num_hesvec(g, v, v) rtol=1e-2 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*num_hesvec(g,x,v)+b*_dy rtol=1e-2 L = HesVecGrad(h, x, 1.0, 1.0) +update_coefficients!(h, x, 1.0, 1.0) +update_coefficients!(g, x, 1.0, 1.0) @test L * x ≈ autonum_hesvec(g, x, x) @test L * v ≈ numauto_hesvec(g, x, v) @test mul!(dy, L, v)≈numauto_hesvec(g, x, v) rtol=1e-8 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8 update_coefficients_for_test!(L, v, 3.0, 4.0) +update_coefficients!(g, x, 3.0, 4.0) @test mul!(dy, L, v)≈numauto_hesvec(g, v, v) rtol=1e-8 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8 - -out = similar(v) -gmres!(out, L, v) # diff --git a/test/update_coeffs_testutils.jl b/test/update_coeffs_testutils.jl index 669ec291..9cb75150 100644 --- a/test/update_coeffs_testutils.jl +++ b/test/update_coeffs_testutils.jl @@ -1,8 +1,7 @@ # Utilities for testing update coefficient behaviour with state-dependent (i.e. dependent on u/p/t) functions -mutable struct WrapFunc{F,U,P,T} +mutable struct WrapFunc{F,P,T} func::F - u::U p::P t::T end @@ -13,9 +12,8 @@ function (w::WrapFunc)(v, u) lmul!(w.p * w.t, v) end -update_coefficients(w::WrapFunc, u, p, t) = WrapFunc(w.func, u, p, t) +update_coefficients(w::WrapFunc, u, p, t) = WrapFunc(w.func, p, t) function update_coefficients!(w::WrapFunc, u, p, t) - w.u = u w.p = p w.t = t end