diff --git a/ext/SparseDiffToolsZygote.jl b/ext/SparseDiffToolsZygote.jl index b91e7d31..eb6df440 100644 --- a/ext/SparseDiffToolsZygote.jl +++ b/ext/SparseDiffToolsZygote.jl @@ -25,6 +25,7 @@ function SparseDiffTools.numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache g(cache1, x) @. x -= 2ϵ * v g(cache2, x) + @. x += ϵ * v @. dy = (cache1 - cache2) / (2ϵ) end diff --git a/src/differentiation/jaches_products.jl b/src/differentiation/jaches_products.jl index 6ce3c4f8..846ba795 100644 --- a/src/differentiation/jaches_products.jl +++ b/src/differentiation/jaches_products.jl @@ -78,6 +78,7 @@ function num_hesvec!(dy, g(cache2, x) @. x -= 2ϵ * v g(cache3, x) + @. x += ϵ * v @. dy = (cache2 - cache3) / (2ϵ) end @@ -110,6 +111,7 @@ function numauto_hesvec!(dy, g(cache1, x) @. x -= 2ϵ * v g(cache2, x) + @. x += ϵ * v @. dy = (cache1 - cache2) / (2ϵ) end @@ -158,6 +160,7 @@ function num_hesvecgrad!(dy, g, x, v, cache2 = similar(v), cache3 = similar(v)) g(cache2, x) @. x -= 2ϵ * v g(cache3, x) + @. x += ϵ * v @. dy = (cache2 - cache3) / (2ϵ) end @@ -207,10 +210,12 @@ struct FwdModeAutoDiffVecProd{F,U,C,V,V!} <: AbstractAutoDiffVecProd end function update_coefficients(L::FwdModeAutoDiffVecProd, u, p, t) - FwdModeAutoDiffVecProd(L.f, u, L.vecprod, L.vecprod!, L.cache) + f = update_coefficients(L.f, u, p, t) + FwdModeAutoDiffVecProd(f, u, L.cache, L.vecprod, L.vecprod!) end function update_coefficients!(L::FwdModeAutoDiffVecProd, u, p, t) + update_coefficients!(L.f, u, p, t) copy!(L.u, u) L end diff --git a/src/differentiation/vecjac_products.jl b/src/differentiation/vecjac_products.jl index 003e7671..7893d890 100644 --- a/src/differentiation/vecjac_products.jl +++ b/src/differentiation/vecjac_products.jl @@ -65,26 +65,28 @@ struct RevModeAutoDiffVecProd{ad,iip,oop,F,U,C,V,V!} <: AbstractAutoDiffVecProd end function update_coefficients(L::RevModeAutoDiffVecProd, u, p, t) - RevModeAutoDiffVecProd(L.f, u, L.vecprod, L.vecprod!, L.cache) + f = update_coefficients(L.f, u, p, t) + RevModeAutoDiffVecProd(f, u, L.vecprod, L.vecprod!, L.cache) end function update_coefficients!(L::RevModeAutoDiffVecProd, u, p, t) + update_coefficients!(L.f, u, p, t) copy!(L.u, u) L end # Interpret the call as df/du' * u function (L::RevModeAutoDiffVecProd)(v, p, t) - L.vecprod(_u -> L.f(_u, p, t), L.u, v) + L.vecprod(L.f, L.u, v) end # prefer non in-place method function (L::RevModeAutoDiffVecProd{ad,iip,true})(dv, v, p, t) where{ad,iip} - L.vecprod!(dv, _u -> L.f(_u, p, t), L.u, v, L.cache...) + L.vecprod!(dv, L.f, L.u, v, L.cache...) end function (L::RevModeAutoDiffVecProd{ad,true,false})(dv, v, p, t) where{ad} - L.vecprod!(dv, (_du, _u) -> L.f(_du, _u, p, t), L.u, v, L.cache...) + L.vecprod!(dv, L.f, L.u, v, L.cache...) end function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFiniteDiff(), @@ -100,11 +102,11 @@ function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFi cache = (similar(u), similar(u),) - outofplace = static_hasmethod(f, typeof((u, p, t))) - isinplace = static_hasmethod(f, typeof((u, u, p, t))) + outofplace = static_hasmethod(f, typeof((u,))) + isinplace = static_hasmethod(f, typeof((u, u,))) if !(isinplace) & !(outofplace) - error("$f must have signature f(u, p, t), or f(du, u, p, t)") + error("$f must have signature f(u), or f(du, u)") end L = RevModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!; autodiff = autodiff, diff --git a/test/test_jaches_products.jl b/test/test_jaches_products.jl index f693df2f..200b2584 100644 --- a/test/test_jaches_products.jl +++ b/test/test_jaches_products.jl @@ -4,21 +4,40 @@ using LinearAlgebra, Test using Random Random.seed!(123) N = 300 -const A = rand(N, N) -f(y, x) = mul!(y, A, x) -f(x) = A * x + x = rand(N) v = rand(N) + +# Save original values of x and v to make sure they are not ever mutated +x0 = copy(x) +v0 = copy(v) + a, b = rand(2) dy = similar(x) -g(x) = sum(abs2, x) -function h(x) - FiniteDiff.finite_difference_gradient(g, x) + +# Define functions for testing + +A = rand(N, N) +_f(y, x) = mul!(y, A, x.^2) +_f(x) = A * (x.^2) + +_g(x) = sum(abs2, x.^2) +function _h(x) + FiniteDiff.finite_difference_gradient(_g, x) end -function h(dy, x) - FiniteDiff.finite_difference_gradient!(dy, g, x) +function _h(dy, x) + FiniteDiff.finite_difference_gradient!(dy, _g, x) end +# Make functions state-dependent for operator tests + +include("update_coeffs_testutils.jl") +f = WrapFunc(_f, 1.0, 1.0) +g = WrapFunc(_g, 1.0, 1.0) +h = WrapFunc(_h, 1.0, 1.0) + +### + cache1 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))), eltype(x), 1}.(x, ForwardDiff.Partials.(tuple.(v))) cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))), eltype(x), 1}.(x, ForwardDiff.Partials.(tuple.(v))) @@ -36,122 +55,147 @@ cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), e similar(v))≈ForwardDiff.hessian(g, x) * v rtol=1e-2 @test num_hesvec(g, x, v)≈ForwardDiff.hessian(g, x) * v rtol=1e-2 -@test numauto_hesvec!(dy, g, x, v)≈ForwardDiff.hessian(g, x) * v rtol=1e-8 +@test numauto_hesvec!(dy, g, x, v)≈ForwardDiff.hessian(g, x) * v @test numauto_hesvec!(dy, g, x, v, ForwardDiff.GradientConfig(g, x), similar(v), - similar(v))≈ForwardDiff.hessian(g, x) * v rtol=1e-8 -@test numauto_hesvec(g, x, v)≈ForwardDiff.hessian(g, x) * v rtol=1e-8 + similar(v))≈ForwardDiff.hessian(g, x) * v +@test numauto_hesvec(g, x, v)≈ForwardDiff.hessian(g, x) * v -@test autonum_hesvec!(dy, g, x, v)≈ForwardDiff.hessian(g, x) * v rtol=1e-2 +@test autonum_hesvec!(dy, g, x, v)≈ForwardDiff.hessian(g, x) * v @test autonum_hesvec!(dy, g, x, v, cache1, cache2)≈ForwardDiff.hessian(g, x) * v rtol=1e-2 -@test autonum_hesvec(g, x, v)≈ForwardDiff.hessian(g, x) * v rtol=1e-8 +@test autonum_hesvec(g, x, v)≈ForwardDiff.hessian(g, x) * v -@test numback_hesvec!(dy, g, x, v)≈ForwardDiff.hessian(g, x) * v rtol=1e-8 -@test numback_hesvec!(dy, g, x, v, similar(v), similar(v))≈ForwardDiff.hessian(g, x) * v rtol=1e-8 -@test numback_hesvec(g, x, v)≈ForwardDiff.hessian(g, x) * v rtol=1e-8 +@test numback_hesvec!(dy, g, x, v)≈ForwardDiff.hessian(g, x) * v +@test numback_hesvec!(dy, g, x, v, similar(v), similar(v))≈ForwardDiff.hessian(g, x) * v +@test numback_hesvec(g, x, v)≈ForwardDiff.hessian(g, x) * v cache3 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing, eltype(x))), eltype(x), 1 }.(x, ForwardDiff.Partials.(tuple.(v))) cache4 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing, eltype(x))), eltype(x), 1 }.(x, ForwardDiff.Partials.(tuple.(v))) -@test autoback_hesvec!(dy, g, x, v)≈ForwardDiff.hessian(g, x) * v rtol=1e-8 -@test autoback_hesvec!(dy, g, x, v, cache3, cache4)≈ForwardDiff.hessian(g, x) * v rtol=1e-8 -@test autoback_hesvec(g, x, v)≈ForwardDiff.hessian(g, x) * v rtol=1e-8 +@test autoback_hesvec!(dy, g, x, v)≈ForwardDiff.hessian(g, x) * v +@test autoback_hesvec!(dy, g, x, v, cache3, cache4)≈ForwardDiff.hessian(g, x) * v +@test autoback_hesvec(g, x, v)≈ForwardDiff.hessian(g, x) * v @test num_hesvecgrad!(dy, h, x, v)≈ForwardDiff.hessian(g, x) * v rtol=1e-2 @test num_hesvecgrad!(dy, h, x, v, similar(v), similar(v))≈ForwardDiff.hessian(g, x) * v rtol=1e-2 @test num_hesvecgrad(h, x, v)≈ForwardDiff.hessian(g, x) * v rtol=1e-2 -@test auto_hesvecgrad!(dy, h, x, v)≈ForwardDiff.hessian(g, x) * v rtol=1e-2 -@test auto_hesvecgrad!(dy, h, x, v, cache1, cache2)≈ForwardDiff.hessian(g, x) * v rtol=1e-2 -@test auto_hesvecgrad(h, x, v)≈ForwardDiff.hessian(g, x) * v rtol=1e-2 +@test auto_hesvecgrad!(dy, h, x, v)≈ForwardDiff.hessian(g, x) * v +@test auto_hesvecgrad!(dy, h, x, v, cache1, cache2)≈ForwardDiff.hessian(g, x) * v +@test auto_hesvecgrad(h, x, v)≈ForwardDiff.hessian(g, x) * v @info "JacVec" -L = JacVec(f, x) +L = JacVec(f, copy(x), 1.0, 1.0) +update_coefficients!(f, x, 1.0, 1.0) @test L * x ≈ auto_jacvec(f, x, x) @test L * v ≈ auto_jacvec(f, x, v) @test mul!(dy, L, v) ≈ auto_jacvec(f, x, v) dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*auto_jacvec(f,x,v) + b*_dy -update_coefficients!(L, v, nothing, 0.0) -@test mul!(dy, L, v) ≈ auto_jacvec(f, v, v) -dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*auto_jacvec(f,x,v) + b*_dy - -L = JacVec(f, x, autodiff = AutoFiniteDiff()) +update_coefficients!(L, v, 3.0, 4.0) +update_coefficients!(f, v, 3.0, 4.0) +@test mul!(dy, L, x) ≈ auto_jacvec(f, v, x) +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b) ≈ a*auto_jacvec(f,v,x) + b*_dy +update_coefficients!(f, v, 5.0, 6.0) +@test L(dy, v, 5.0, 6.0) ≈ auto_jacvec(f,v,v) + +L = JacVec(f, copy(x), 1.0, 1.0; autodiff = AutoFiniteDiff()) +update_coefficients!(f, x, 1.0, 1.0) @test L * x ≈ num_jacvec(f, x, x) @test L * v ≈ num_jacvec(f, x, v) @test mul!(dy, L, v)≈num_jacvec(f, x, v) rtol=1e-6 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*num_jacvec(f,x,v) + b*_dy rtol=1e-6 -update_coefficients!(L, v, nothing, 0.0) -@test mul!(dy, L, v)≈num_jacvec(f, v, v) rtol=1e-6 -dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*num_jacvec(f,x,v) + b*_dy rtol=1e-6 +update_coefficients!(L, v, 3.0, 4.0) +update_coefficients!(f, v, 3.0, 4.0) +@test mul!(dy, L, x)≈num_jacvec(f, v, x) rtol=1e-6 +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b) ≈ a*num_jacvec(f,v,x) + b*_dy rtol=1e-6 +update_coefficients!(f, v, 5.0, 6.0) +@test L(dy, v, 5.0, 6.0) ≈ num_jacvec(f,v,v) rtol=1e-6 out = similar(v) -gmres!(out, L, v) +@test_nowarn gmres!(out, L, v) @info "HesVec" -x = rand(N) -v = rand(N) -L = HesVec(g, x, autodiff = AutoFiniteDiff()) -@test L * x ≈ num_hesvec(g, x, x) -@test L * v ≈ num_hesvec(g, x, v) +L = HesVec(g, copy(x), 1.0, 1.0, autodiff = AutoFiniteDiff()) +update_coefficients!(g, x, 1.0, 1.0) +@test L * x ≈ num_hesvec(g, x, x) rtol=1e-2 +@test L * v ≈ num_hesvec(g, x, v) rtol=1e-2 @test mul!(dy, L, v)≈num_hesvec(g, x, v) rtol=1e-2 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*num_hesvec(g,x,v) + b*_dy rtol=1e-2 -update_coefficients!(L, v, nothing, 0.0) -@test mul!(dy, L, v)≈num_hesvec(g, v, v) rtol=1e-2 -dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*num_hesvec(g,x,v) + b*_dy rtol=1e-2 - -L = HesVec(g, x) +update_coefficients!(L, v, 3.0, 4.0) +update_coefficients!(g, v, 3.0, 4.0) +@test mul!(dy, L, x)≈num_hesvec(g, v, x) rtol=1e-2 +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b) ≈ a*num_hesvec(g,v,x) + b*_dy rtol=1e-2 +update_coefficients!(g, v, 5.0, 6.0) +@test L(dy, v, 5.0, 6.0) ≈ num_hesvec(g,v,v) rtol=1e-2 + +L = HesVec(g, copy(x), 1.0, 1.0) @test L * x ≈ numauto_hesvec(g, x, x) @test L * v ≈ numauto_hesvec(g, x, v) -@test mul!(dy, L, v)≈numauto_hesvec(g, x, v) rtol=1e-8 -dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8 -update_coefficients!(L, v, nothing, 0.0) -@test mul!(dy, L, v)≈numauto_hesvec(g, v, v) rtol=1e-8 -dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8 +@test mul!(dy, L, v)≈numauto_hesvec(g, x, v) +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,0)≈a*numauto_hesvec(g,x,v)+0*_dy +update_coefficients!(L, v, 3.0, 4.0) +update_coefficients!(g, v, 3.0, 4.0) +@test mul!(dy, L, x)≈numauto_hesvec(g, v, x) +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b)≈a*numauto_hesvec(g,v,x)+b*_dy +update_coefficients!(g, v, 5.0, 6.0) +@test L(dy, v, 5.0, 6.0) ≈ numauto_hesvec(g,v,v) out = similar(v) gmres!(out, L, v) -using Zygote - -x = rand(N) -v = rand(N) - -L = HesVec(g, x, autodiff = AutoZygote()) +L = HesVec(g, copy(x), 1.0, 1.0; autodiff = AutoZygote()) +update_coefficients!(g, x, 1.0, 1.0) @test L * x ≈ autoback_hesvec(g, x, x) @test L * v ≈ autoback_hesvec(g, x, v) -@test mul!(dy, L, v)≈autoback_hesvec(g, x, v) rtol=1e-8 -dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*autoback_hesvec(g,x,v)+b*_dy rtol=1e-8 -update_coefficients!(L, v, nothing, 0.0) -@test mul!(dy, L, v)≈autoback_hesvec(g, v, v) rtol=1e-8 -dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*autoback_hesvec(g,x,v)+b*_dy rtol=1e-8 +@test mul!(dy, L, v)≈autoback_hesvec(g, x, v) +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*autoback_hesvec(g,x,v)+b*_dy +update_coefficients!(L, v, 3.0, 4.0) +update_coefficients!(g, v, 3.0, 4.0) +@test mul!(dy, L, x)≈autoback_hesvec(g, v, x) +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b)≈a*autoback_hesvec(g,v,x)+b*_dy +update_coefficients!(g, v, 5.0, 6.0) +@test L(dy, v, 5.0, 6.0) ≈ autoback_hesvec(g,v,v) out = similar(v) gmres!(out, L, v) @info "HesVecGrad" -x = rand(N) -v = rand(N) -L = HesVecGrad(h, x, autodiff = AutoFiniteDiff()) +L = HesVecGrad(h, copy(x), 1.0, 1.0; autodiff = AutoFiniteDiff()) +update_coefficients!(h, x, 1.0, 1.0) +update_coefficients!(g, x, 1.0, 1.0) @test L * x ≈ num_hesvec(g, x, x) rtol=1e-2 @test L * v ≈ num_hesvec(g, x, v) rtol=1e-2 @test mul!(dy, L, v)≈num_hesvec(g, x, v) rtol=1e-2 dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*num_hesvec(g,x,v)+b*_dy rtol=1e-2 -update_coefficients!(L, v, nothing, 0.0) -@test mul!(dy, L, v)≈num_hesvec(g, v, v) rtol=1e-2 -dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*num_hesvec(g,x,v)+b*_dy rtol=1e-2 - -L = HesVecGrad(h, x) +for op in (L, g, h) update_coefficients!(op, v, 3.0, 4.0) end +@test mul!(dy, L, x)≈num_hesvec(g, v, x) rtol=1e-2 +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b)≈a*num_hesvec(g,v,x)+b*_dy rtol=1e-2 +update_coefficients!(g, v, 5.0, 6.0) +@test L(dy, v, 5.0, 6.0) ≈ num_hesvec(g,v,v) rtol=1e-2 + +L = HesVecGrad(h, copy(x), 1.0, 1.0) +update_coefficients!(g, x, 1.0, 1.0) +update_coefficients!(h, x, 1.0, 1.0) @test L * x ≈ autonum_hesvec(g, x, x) @test L * v ≈ numauto_hesvec(g, x, v) -@test mul!(dy, L, v)≈numauto_hesvec(g, x, v) rtol=1e-8 -dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8 -update_coefficients!(L, v, nothing, 0.0) -@test mul!(dy, L, v)≈numauto_hesvec(g, v, v) rtol=1e-8 -dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8 +@test mul!(dy, L, v)≈numauto_hesvec(g, x, v) +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*numauto_hesvec(g,x,v)+b*_dy +for op in (L, g, h) update_coefficients!(op, v, 3.0, 4.0) end +@test mul!(dy, L, x)≈numauto_hesvec(g, v, x) +dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b)≈a*numauto_hesvec(g,v,x)+b*_dy +update_coefficients!(g, v, 5.0, 6.0) +update_coefficients!(h, v, 5.0, 6.0) +@test L(dy, v, 5.0, 6.0) ≈ numauto_hesvec(g,v,v) out = similar(v) gmres!(out, L, v) + +# Test that x and v were not mutated +# x's rtol can't be too large since it is mutated and then restored in some algorithms +@test x ≈ x0 +@test v ≈ v0 + # diff --git a/test/test_vecjac_products.jl b/test/test_vecjac_products.jl index a13a8b57..2978d8fc 100644 --- a/test/test_vecjac_products.jl +++ b/test/test_vecjac_products.jl @@ -4,21 +4,61 @@ using LinearAlgebra, Test using Random Random.seed!(123) N = 300 -const A = rand(N, N) +# Use Float32 since Zygote defaults to Float32 x = rand(Float32, N) v = rand(Float32, N) -f(du,u,p,t) = mul!(du, A, u) -f(u,p,t) = A * u +# Save original values of x and v to make sure they are not ever mutated +x0 = copy(x) +v0 = copy(v) + +a, b = rand(2) +dy = similar(x) + +A = rand(Float32, N, N) +_f(du,u) = mul!(du, A, u) +_f(u) = A * u + +# Define state-dependent functions for operator tests +include("update_coeffs_testutils.jl") +f = WrapFunc(_f, 1.0f0, 1.0f0) + +# Compute Jacobian via Zygote @info "VecJac" -L = VecJac(f, x) -actual_vjp = Zygote.jacobian(x -> f(x, nothing, 0.0), x)[1]' * v -update_coefficients!(L, v, nothing, 0.0) -@test L * v ≈ actual_vjp -L = VecJac(f, x; autodiff = AutoFiniteDiff()) -update_coefficients!(L, v, nothing, 0.0) -@test L * v ≈ actual_vjp -# \ No newline at end of file +L = VecJac(f, copy(x), 1.0f0, 1.0f0; autodiff = AutoZygote()) +update_coefficients!(f, x, 1.0, 1.0) +actual_jac = Zygote.jacobian(f, x)[1] +@test L * x ≈ actual_jac' * x +@test L * v ≈ actual_jac' * v +@test mul!(dy, L, v) ≈ actual_jac' * v +update_coefficients!(L, v, 3.0, 4.0) +update_coefficients!(f, v, 3.0, 4.0) +actual_jac = Zygote.jacobian(f, v)[1] +@test mul!(dy, L, x) ≈ actual_jac' * x +_dy=copy(dy); @test mul!(dy,L,x,a,b) ≈ a*actual_jac'*x + b*_dy +update_coefficients!(f, v, 5.0, 6.0) +actual_jac = Zygote.jacobian(f, v)[1] +@test L(dy, v, 5.0, 6.0) ≈ actual_jac' * v + +L = VecJac(f, copy(x), 1.0f0, 1.0f0; autodiff = AutoFiniteDiff()) +update_coefficients!(f, x, 1.0, 1.0) +actual_jac = Zygote.jacobian(f, x)[1] +@test L * x ≈ actual_jac' * x +@test L * v ≈ actual_jac' * v +@test mul!(dy, L, v) ≈ actual_jac' * v +update_coefficients!(L, v, 3.0, 4.0) +update_coefficients!(f, v, 3.0, 4.0) +actual_jac = Zygote.jacobian(f, v)[1] +@test mul!(dy, L, x) ≈ actual_jac' * x +_dy=copy(dy); @test mul!(dy,L,x,a,b) ≈ a*actual_jac'*x + b*_dy +update_coefficients!(f, v, 5.0, 6.0) +actual_jac = Zygote.jacobian(f, v)[1] +@test L(dy, v, 5.0, 6.0) ≈ actual_jac' * v + +# Test that x and v were not mutated +@test x ≈ x0 +@test v ≈ v0 +# diff --git a/test/update_coeffs_testutils.jl b/test/update_coeffs_testutils.jl new file mode 100644 index 00000000..8e82c394 --- /dev/null +++ b/test/update_coeffs_testutils.jl @@ -0,0 +1,21 @@ +import SciMLOperators: update_coefficients, update_coefficients! + +# Wrapper function for testing update coefficient behaviour with state-dependent (i.e. dependent on u/p/t) functions + +mutable struct WrapFunc{F,P,T} + func::F + p::P + t::T +end + +(w::WrapFunc)(u) = w.p * w.t * w.func(u) +function (w::WrapFunc)(v, u) + w.func(v, u) + lmul!(w.p * w.t, v) +end + +update_coefficients(w::WrapFunc, u, p, t) = WrapFunc(w.func, p, t) +function update_coefficients!(w::WrapFunc, u, p, t) + w.p = p + w.t = t +end \ No newline at end of file