From b6aa2327161900aaa13f67955bc1107b5889753a Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Mon, 27 Mar 2023 23:12:37 -0400 Subject: [PATCH] Remove update_coefficients_for_test!, test L(v, u, p, t) --- test/test_jaches_products.jl | 44 ++++++++++++++++++++++++--------- test/test_vecjac_products.jl | 8 ++++++ test/update_coeffs_testutils.jl | 11 +++------ 3 files changed, 43 insertions(+), 20 deletions(-) diff --git a/test/test_jaches_products.jl b/test/test_jaches_products.jl index cbff19fc..f89ec418 100644 --- a/test/test_jaches_products.jl +++ b/test/test_jaches_products.jl @@ -1,6 +1,5 @@ using SparseDiffTools, ForwardDiff, FiniteDiff, Zygote, IterativeSolvers using LinearAlgebra, Test -import SciMLOperators: update_coefficients, update_coefficients! using Random Random.seed!(123) @@ -85,9 +84,12 @@ update_coefficients!(f, x, 1.0, 1.0) @test L * v ≈ auto_jacvec(f, x, v) @test mul!(dy, L, v) ≈ auto_jacvec(f, x, v) dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*auto_jacvec(f,x,v) + b*_dy -update_coefficients_for_test!(L, v, 3.0, 4.0) +update_coefficients!(L, v, 3.0, 4.0) +update_coefficients!(f, v, 3.0, 4.0) @test mul!(dy, L, v) ≈ auto_jacvec(f, v, v) dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*auto_jacvec(f,x,v) + b*_dy +update_coefficients!(f, v, 5.0, 6.0) +@test L(dy, v, 5.0, 6.0) ≈ auto_jacvec(f,x,v) L = JacVec(f, x, 1.0, 1.0; autodiff = AutoFiniteDiff()) update_coefficients!(f, x, 1.0, 1.0) @@ -95,12 +97,15 @@ update_coefficients!(f, x, 1.0, 1.0) @test L * v ≈ num_jacvec(f, x, v) @test mul!(dy, L, v)≈num_jacvec(f, x, v) rtol=1e-6 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*num_jacvec(f,x,v) + b*_dy rtol=1e-6 -update_coefficients_for_test!(L, v, 3.0, 4.0) +update_coefficients!(L, v, 3.0, 4.0) +update_coefficients!(f, v, 3.0, 4.0) @test mul!(dy, L, v)≈num_jacvec(f, v, v) rtol=1e-6 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*num_jacvec(f,x,v) + b*_dy rtol=1e-6 +update_coefficients!(f, v, 5.0, 6.0) +@test L(dy, v, 5.0, 6.0) ≈ num_jacvec(f,x,v) rtol=1e-6 out = similar(v) -gmres!(out, L, v) +@test_nowarn gmres!(out, L, v) @info "HesVec" @@ -113,9 +118,12 @@ num_hesvec(g, x, x) @test L * v ≈ num_hesvec(g, x, v) rtol=1e-2 @test mul!(dy, L, v)≈num_hesvec(g, x, v) rtol=1e-2 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*num_hesvec(g,x,v) + b*_dy rtol=1e-2 -update_coefficients_for_test!(L, v, 3.0, 4.0) +update_coefficients!(L, v, 3.0, 4.0) +update_coefficients!(g, v, 3.0, 4.0) @test mul!(dy, L, v)≈num_hesvec(g, v, v) rtol=1e-2 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*num_hesvec(g,x,v) + b*_dy rtol=1e-2 +update_coefficients!(g, v, 5.0, 6.0) +@test L(dy, v, 5.0, 6.0) ≈ num_hesvec(g,x,v) rtol=1e-2 L = HesVec(g, x, 1.0, 1.0) update_coefficients!(g, x, 1.0, 1.0) @@ -125,9 +133,12 @@ num_hesvec(g, x, x) @test L * v ≈ numauto_hesvec(g, x, v) @test mul!(dy, L, v)≈numauto_hesvec(g, x, v) rtol=1e-8 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8 -update_coefficients_for_test!(L, v, 3.0, 4.0) +update_coefficients!(L, v, 3.0, 4.0) +update_coefficients!(g, v, 3.0, 4.0) @test mul!(dy, L, v)≈numauto_hesvec(g, v, v) rtol=1e-8 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8 +update_coefficients!(g, v, 5.0, 6.0) +@test L(dy, v, 5.0, 6.0) ≈ numauto_hesvec(g,x,v) rtol=1e-2 out = similar(v) gmres!(out, L, v) @@ -141,9 +152,12 @@ update_coefficients!(g, x, 1.0, 1.0) @test L * v ≈ autoback_hesvec(g, x, v) @test mul!(dy, L, v)≈autoback_hesvec(g, x, v) rtol=1e-8 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*autoback_hesvec(g,x,v)+b*_dy rtol=1e-8 -update_coefficients_for_test!(L, v, 3.0, 4.0) +update_coefficients!(L, v, 3.0, 4.0) +update_coefficients!(g, v, 3.0, 4.0) @test mul!(dy, L, v)≈autoback_hesvec(g, v, v) rtol=1e-8 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*autoback_hesvec(g,x,v)+b*_dy rtol=1e-8 +update_coefficients!(g, v, 5.0, 6.0) +@test L(dy, v, 5.0, 6.0) ≈ autoback_hesvec(g,x,v) rtol=1e-2 out = similar(v) gmres!(out, L, v) @@ -159,20 +173,26 @@ update_coefficients!(g, x, 1.0, 1.0) @test L * v ≈ num_hesvec(g, x, v) rtol=1e-2 @test mul!(dy, L, v)≈num_hesvec(g, x, v) rtol=1e-2 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*num_hesvec(g,x,v)+b*_dy rtol=1e-2 -update_coefficients_for_test!(L, v, 3.0, 4.0) -update_coefficients!(g, x, 3.0, 4.0) +for op in (L, g, h) update_coefficients!(op, v, 3.0, 4.0) end @test mul!(dy, L, v)≈num_hesvec(g, v, v) rtol=1e-2 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*num_hesvec(g,x,v)+b*_dy rtol=1e-2 +update_coefficients!(g, v, 5.0, 6.0) +@test L(dy, v, 5.0, 6.0) ≈ num_hesvec(g,x,v) rtol=1e-2 L = HesVecGrad(h, x, 1.0, 1.0) -update_coefficients!(h, x, 1.0, 1.0) update_coefficients!(g, x, 1.0, 1.0) +update_coefficients!(h, x, 1.0, 1.0) @test L * x ≈ autonum_hesvec(g, x, x) @test L * v ≈ numauto_hesvec(g, x, v) @test mul!(dy, L, v)≈numauto_hesvec(g, x, v) rtol=1e-8 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8 -update_coefficients_for_test!(L, v, 3.0, 4.0) -update_coefficients!(g, x, 3.0, 4.0) +for op in (L, g, h) update_coefficients!(op, v, 3.0, 4.0) end @test mul!(dy, L, v)≈numauto_hesvec(g, v, v) rtol=1e-8 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8 +update_coefficients!(g, v, 5.0, 6.0) +update_coefficients!(h, v, 5.0, 6.0) +@test L(dy, v, 5.0, 6.0) ≈ num_hesvec(g,x,v) rtol=1e-2 + +out = similar(v) +gmres!(out, L, v) # diff --git a/test/test_vecjac_products.jl b/test/test_vecjac_products.jl index 6aea35de..eb6511e5 100644 --- a/test/test_vecjac_products.jl +++ b/test/test_vecjac_products.jl @@ -20,9 +20,17 @@ f = WrapFunc(_f, 1.0, 1.0) L = VecJac(f, x, 1.0, 1.0) update_coefficients!(L, v, 3.0, 4.0) +update_coefficients!(f, v, 3.0, 4.0) actual_vjp = Zygote.jacobian(f, x)[1]' * v @test L * v ≈ actual_vjp +update_coefficients!(f, v, 5.0, 6.0) +actual_vjp2 = Zygote.jacobian(f, x)[1]' * v +@test L(copy(v), v, 5.0, 6.0) ≈ actual_vjp2 + L = VecJac(f, x, 1.0, 1.0; autodiff = AutoFiniteDiff()) update_coefficients!(L, v, 3.0, 4.0) +update_coefficients!(f, v, 3.0, 4.0) @test L * v ≈ actual_vjp +update_coefficients!(f, v, 5.0, 6.0) +@test L(copy(v), v, 5.0, 6.0) ≈ actual_vjp2 # diff --git a/test/update_coeffs_testutils.jl b/test/update_coeffs_testutils.jl index 9cb75150..8e82c394 100644 --- a/test/update_coeffs_testutils.jl +++ b/test/update_coeffs_testutils.jl @@ -1,4 +1,6 @@ -# Utilities for testing update coefficient behaviour with state-dependent (i.e. dependent on u/p/t) functions +import SciMLOperators: update_coefficients, update_coefficients! + +# Wrapper function for testing update coefficient behaviour with state-dependent (i.e. dependent on u/p/t) functions mutable struct WrapFunc{F,P,T} func::F @@ -16,11 +18,4 @@ update_coefficients(w::WrapFunc, u, p, t) = WrapFunc(w.func, p, t) function update_coefficients!(w::WrapFunc, u, p, t) w.p = p w.t = t -end - -# Helper function for testing correct update coefficients behaviour of operators -function update_coefficients_for_test!(L, u, p, t) - update_coefficients!(L, u, p, t) - # Force function hiding inside L to update. Should be a null-op if previous line works correctly - update_coefficients!(L.op.f, u, p, t) end \ No newline at end of file