From 6bb601cf6c1351c913a2024c05f8b8d93d95a185 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Mon, 27 Mar 2023 15:31:04 -0400 Subject: [PATCH] Add tests for recursive updates of f in JacVec etc. --- test/test_jaches_products.jl | 82 ++++++++++++++++++++++++++---------- 1 file changed, 59 insertions(+), 23 deletions(-) diff --git a/test/test_jaches_products.jl b/test/test_jaches_products.jl index 0dd34a92..26d1508b 100644 --- a/test/test_jaches_products.jl +++ b/test/test_jaches_products.jl @@ -1,24 +1,62 @@ using SparseDiffTools, ForwardDiff, FiniteDiff, Zygote, IterativeSolvers using LinearAlgebra, Test +import SciMLOperators: update_coefficients, update_coefficients! using Random Random.seed!(123) N = 300 const A = rand(N, N) -f(y, x) = mul!(y, A, x) -f(x) = A * x + +_f(y, x) = mul!(y, A, x) +_f(x) = A * x + x = rand(N) v = rand(N) a, b = rand(2) dy = similar(x) -g(x) = sum(abs2, x) -function h(x) - FiniteDiff.finite_difference_gradient(g, x) +_g(x) = sum(abs2, x) +function _h(x) + FiniteDiff.finite_difference_gradient(_g, x) +end +function _h(dy, x) + FiniteDiff.finite_difference_gradient!(dy, _g, x) +end + +# Define state-dependent (i.e. dependent on u/p/t) functions for tests of operators + +mutable struct WrapFunc{F,U,P,T} + func::F + u::U + p::P + t::T +end + +(w::WrapFunc)(u) = sum(w.u) * w.p * w.t * w.func(u) +function (w::WrapFunc)(v, u) + w.func(v, u) + lmul!(sum(w.u) * w.p * w.t, v) +end + +update_coefficients(w::WrapFunc, u, p, t) = WrapFunc(w.func, u, p, t) +function update_coefficients!(w::WrapFunc, u, p, t) + w.u = u + w.p = p + w.t = t end -function h(dy, x) - FiniteDiff.finite_difference_gradient!(dy, g, x) + +# Helper function for testing correct update coefficients behaviour of operators +function update_coefficients_for_test!(L, u, p, t) + update_coefficients!(L, u, p, t) + # Force function hiding inside L to update. Should be a null-op if previous line works correctly + update_coefficients!(L.op.f, u, p, t) end +f = WrapFunc(_f, ones(N) * 2, 1.0, 1.0) +g = WrapFunc(_g, ones(N), 1.0, 1.0) +h = WrapFunc(_h, ones(N), 1.0, 1.0) + +### + cache1 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))), eltype(x), 1}.(x, ForwardDiff.Partials.(Tuple.(v))) cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))), eltype(x), 1}.(x, ForwardDiff.Partials.(Tuple.(v))) @@ -67,21 +105,21 @@ cache4 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing, eltype(x))), eltype(x) @info "JacVec" -L = JacVec(f, x) +L = JacVec(f, x, 1.0, 1.0) @test L * x ≈ auto_jacvec(f, x, x) @test L * v ≈ auto_jacvec(f, x, v) @test mul!(dy, L, v) ≈ auto_jacvec(f, x, v) dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*auto_jacvec(f,x,v) + b*_dy -update_coefficients!(L, v, nothing, 0.0) +update_coefficients_for_test!(L, v, 3.0, 4.0) @test mul!(dy, L, v) ≈ auto_jacvec(f, v, v) dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*auto_jacvec(f,x,v) + b*_dy -L = JacVec(f, x, autodiff = AutoFiniteDiff()) +L = JacVec(f, x, 1.0, 1.0; autodiff = AutoFiniteDiff()) @test L * x ≈ num_jacvec(f, x, x) @test L * v ≈ num_jacvec(f, x, v) @test mul!(dy, L, v)≈num_jacvec(f, x, v) rtol=1e-6 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*num_jacvec(f,x,v) + b*_dy rtol=1e-6 -update_coefficients!(L, v, nothing, 0.0) +update_coefficients_for_test!(L, v, 3.0, 4.0) @test mul!(dy, L, v)≈num_jacvec(f, v, v) rtol=1e-6 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*num_jacvec(f,x,v) + b*_dy rtol=1e-6 @@ -92,38 +130,36 @@ gmres!(out, L, v) x = rand(N) v = rand(N) -L = HesVec(g, x, autodiff = AutoFiniteDiff()) +L = HesVec(g, x, 1.0, 1.0, autodiff = AutoFiniteDiff()) @test L * x ≈ num_hesvec(g, x, x) @test L * v ≈ num_hesvec(g, x, v) @test mul!(dy, L, v)≈num_hesvec(g, x, v) rtol=1e-2 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*num_hesvec(g,x,v) + b*_dy rtol=1e-2 -update_coefficients!(L, v, nothing, 0.0) +update_coefficients_for_test!(L, v, 3.0, 4.0) @test mul!(dy, L, v)≈num_hesvec(g, v, v) rtol=1e-2 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*num_hesvec(g,x,v) + b*_dy rtol=1e-2 -L = HesVec(g, x) +L = HesVec(g, x, 1.0, 1.0) @test L * x ≈ numauto_hesvec(g, x, x) @test L * v ≈ numauto_hesvec(g, x, v) @test mul!(dy, L, v)≈numauto_hesvec(g, x, v) rtol=1e-8 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8 -update_coefficients!(L, v, nothing, 0.0) +update_coefficients_for_test!(L, v, 3.0, 4.0) @test mul!(dy, L, v)≈numauto_hesvec(g, v, v) rtol=1e-8 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8 out = similar(v) gmres!(out, L, v) -using Zygote - x = rand(N) v = rand(N) -L = HesVec(g, x, autodiff = AutoZygote()) +L = HesVec(g, x, 1.0, 1.0; autodiff = AutoZygote()) @test L * x ≈ autoback_hesvec(g, x, x) @test L * v ≈ autoback_hesvec(g, x, v) @test mul!(dy, L, v)≈autoback_hesvec(g, x, v) rtol=1e-8 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*autoback_hesvec(g,x,v)+b*_dy rtol=1e-8 -update_coefficients!(L, v, nothing, 0.0) +update_coefficients_for_test!(L, v, 3.0, 4.0) @test mul!(dy, L, v)≈autoback_hesvec(g, v, v) rtol=1e-8 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*autoback_hesvec(g,x,v)+b*_dy rtol=1e-8 @@ -134,21 +170,21 @@ gmres!(out, L, v) x = rand(N) v = rand(N) -L = HesVecGrad(h, x, autodiff = AutoFiniteDiff()) +L = HesVecGrad(h, x, 1.0, 1.0, autodiff = AutoFiniteDiff()) @test L * x ≈ num_hesvec(g, x, x) @test L * v ≈ num_hesvec(g, x, v) @test mul!(dy, L, v)≈num_hesvec(g, x, v) rtol=1e-2 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*num_hesvec(g,x,v)+b*_dy rtol=1e-2 -update_coefficients!(L, v, nothing, 0.0) +update_coefficients_for_test!(L, v, 3.0, 4.0) @test mul!(dy, L, v)≈num_hesvec(g, v, v) rtol=1e-2 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*num_hesvec(g,x,v)+b*_dy rtol=1e-2 -L = HesVecGrad(h, x) +L = HesVecGrad(h, x, 1.0, 1.0) @test L * x ≈ autonum_hesvec(g, x, x) @test L * v ≈ numauto_hesvec(g, x, v) @test mul!(dy, L, v)≈numauto_hesvec(g, x, v) rtol=1e-8 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8 -update_coefficients!(L, v, nothing, 0.0) +update_coefficients_for_test!(L, v, 3.0, 4.0) @test mul!(dy, L, v)≈numauto_hesvec(g, v, v) rtol=1e-8 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8