Skip to content

Commit

Permalink
Merge pull request #232 from vpuri3/update_coeffs
Browse files Browse the repository at this point in the history
have update_coeffs(L::ADVecProd,) recursively update L.f
  • Loading branch information
ChrisRackauckas committed Apr 8, 2023
2 parents 4e4fc7b + 57f4a55 commit c337dd5
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 90 deletions.
1 change: 1 addition & 0 deletions ext/SparseDiffToolsZygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ function SparseDiffTools.numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache
g(cache1, x)
@. x -= 2ϵ * v
g(cache2, x)
@. x += ϵ * v
@. dy = (cache1 - cache2) / (2ϵ)
end

Expand Down
7 changes: 6 additions & 1 deletion src/differentiation/jaches_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ function num_hesvec!(dy,
g(cache2, x)
@. x -= 2ϵ * v
g(cache3, x)
@. x += ϵ * v
@. dy = (cache2 - cache3) / (2ϵ)
end

Expand Down Expand Up @@ -110,6 +111,7 @@ function numauto_hesvec!(dy,
g(cache1, x)
@. x -= 2ϵ * v
g(cache2, x)
@. x += ϵ * v
@. dy = (cache1 - cache2) / (2ϵ)
end

Expand Down Expand Up @@ -158,6 +160,7 @@ function num_hesvecgrad!(dy, g, x, v, cache2 = similar(v), cache3 = similar(v))
g(cache2, x)
@. x -= 2ϵ * v
g(cache3, x)
@. x += ϵ * v
@. dy = (cache2 - cache3) / (2ϵ)
end

Expand Down Expand Up @@ -207,10 +210,12 @@ struct FwdModeAutoDiffVecProd{F,U,C,V,V!} <: AbstractAutoDiffVecProd
end

function update_coefficients(L::FwdModeAutoDiffVecProd, u, p, t)
FwdModeAutoDiffVecProd(L.f, u, L.vecprod, L.vecprod!, L.cache)
f = update_coefficients(L.f, u, p, t)
FwdModeAutoDiffVecProd(f, u, L.cache, L.vecprod, L.vecprod!)
end

function update_coefficients!(L::FwdModeAutoDiffVecProd, u, p, t)
update_coefficients!(L.f, u, p, t)
copy!(L.u, u)
L
end
Expand Down
16 changes: 9 additions & 7 deletions src/differentiation/vecjac_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,26 +65,28 @@ struct RevModeAutoDiffVecProd{ad,iip,oop,F,U,C,V,V!} <: AbstractAutoDiffVecProd
end

function update_coefficients(L::RevModeAutoDiffVecProd, u, p, t)
RevModeAutoDiffVecProd(L.f, u, L.vecprod, L.vecprod!, L.cache)
f = update_coefficients(L.f, u, p, t)
RevModeAutoDiffVecProd(f, u, L.vecprod, L.vecprod!, L.cache)
end

function update_coefficients!(L::RevModeAutoDiffVecProd, u, p, t)
update_coefficients!(L.f, u, p, t)
copy!(L.u, u)
L
end

# Interpret the call as df/du' * u
function (L::RevModeAutoDiffVecProd)(v, p, t)
L.vecprod(_u -> L.f(_u, p, t), L.u, v)
L.vecprod(L.f, L.u, v)
end

# prefer non in-place method
function (L::RevModeAutoDiffVecProd{ad,iip,true})(dv, v, p, t) where{ad,iip}
L.vecprod!(dv, _u -> L.f(_u, p, t), L.u, v, L.cache...)
L.vecprod!(dv, L.f, L.u, v, L.cache...)
end

function (L::RevModeAutoDiffVecProd{ad,true,false})(dv, v, p, t) where{ad}
L.vecprod!(dv, (_du, _u) -> L.f(_du, _u, p, t), L.u, v, L.cache...)
L.vecprod!(dv, L.f, L.u, v, L.cache...)
end

function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFiniteDiff(),
Expand All @@ -100,11 +102,11 @@ function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFi

cache = (similar(u), similar(u),)

outofplace = static_hasmethod(f, typeof((u, p, t)))
isinplace = static_hasmethod(f, typeof((u, u, p, t)))
outofplace = static_hasmethod(f, typeof((u,)))
isinplace = static_hasmethod(f, typeof((u, u,)))

if !(isinplace) & !(outofplace)
error("$f must have signature f(u, p, t), or f(du, u, p, t)")
error("$f must have signature f(u), or f(du, u)")
end

L = RevModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!; autodiff = autodiff,
Expand Down
186 changes: 115 additions & 71 deletions test/test_jaches_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,40 @@ using LinearAlgebra, Test
using Random
Random.seed!(123)
N = 300
const A = rand(N, N)
f(y, x) = mul!(y, A, x)
f(x) = A * x

x = rand(N)
v = rand(N)

# Save original values of x and v to make sure they are not ever mutated
x0 = copy(x)
v0 = copy(v)

a, b = rand(2)
dy = similar(x)
g(x) = sum(abs2, x)
function h(x)
FiniteDiff.finite_difference_gradient(g, x)

# Define functions for testing

A = rand(N, N)
_f(y, x) = mul!(y, A, x.^2)
_f(x) = A * (x.^2)

_g(x) = sum(abs2, x.^2)
function _h(x)
FiniteDiff.finite_difference_gradient(_g, x)
end
function h(dy, x)
FiniteDiff.finite_difference_gradient!(dy, g, x)
function _h(dy, x)
FiniteDiff.finite_difference_gradient!(dy, _g, x)
end

# Make functions state-dependent for operator tests

include("update_coeffs_testutils.jl")
f = WrapFunc(_f, 1.0, 1.0)
g = WrapFunc(_g, 1.0, 1.0)
h = WrapFunc(_h, 1.0, 1.0)

###

cache1 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))),
eltype(x), 1}.(x, ForwardDiff.Partials.(tuple.(v)))
cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))), eltype(x), 1}.(x, ForwardDiff.Partials.(tuple.(v)))
Expand All @@ -36,122 +55,147 @@ cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), e
similar(v))ForwardDiff.hessian(g, x) * v rtol=1e-2
@test num_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-2

@test numauto_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
@test numauto_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v
@test numauto_hesvec!(dy, g, x, v, ForwardDiff.GradientConfig(g, x), similar(v),
similar(v))ForwardDiff.hessian(g, x) * v rtol=1e-8
@test numauto_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
similar(v))ForwardDiff.hessian(g, x) * v
@test numauto_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v

@test autonum_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-2
@test autonum_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v
@test autonum_hesvec!(dy, g, x, v, cache1, cache2)ForwardDiff.hessian(g, x) * v rtol=1e-2
@test autonum_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
@test autonum_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v

@test numback_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
@test numback_hesvec!(dy, g, x, v, similar(v), similar(v))ForwardDiff.hessian(g, x) * v rtol=1e-8
@test numback_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
@test numback_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v
@test numback_hesvec!(dy, g, x, v, similar(v), similar(v))ForwardDiff.hessian(g, x) * v
@test numback_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v

cache3 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing, eltype(x))), eltype(x), 1
}.(x, ForwardDiff.Partials.(tuple.(v)))
cache4 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing, eltype(x))), eltype(x), 1
}.(x, ForwardDiff.Partials.(tuple.(v)))
@test autoback_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
@test autoback_hesvec!(dy, g, x, v, cache3, cache4)ForwardDiff.hessian(g, x) * v rtol=1e-8
@test autoback_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
@test autoback_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v
@test autoback_hesvec!(dy, g, x, v, cache3, cache4)ForwardDiff.hessian(g, x) * v
@test autoback_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v

@test num_hesvecgrad!(dy, h, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-2
@test num_hesvecgrad!(dy, h, x, v, similar(v), similar(v))ForwardDiff.hessian(g, x) * v rtol=1e-2
@test num_hesvecgrad(h, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-2

@test auto_hesvecgrad!(dy, h, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-2
@test auto_hesvecgrad!(dy, h, x, v, cache1, cache2)ForwardDiff.hessian(g, x) * v rtol=1e-2
@test auto_hesvecgrad(h, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-2
@test auto_hesvecgrad!(dy, h, x, v)ForwardDiff.hessian(g, x) * v
@test auto_hesvecgrad!(dy, h, x, v, cache1, cache2)ForwardDiff.hessian(g, x) * v
@test auto_hesvecgrad(h, x, v)ForwardDiff.hessian(g, x) * v

@info "JacVec"

L = JacVec(f, x)
L = JacVec(f, copy(x), 1.0, 1.0)
update_coefficients!(f, x, 1.0, 1.0)
@test L * x auto_jacvec(f, x, x)
@test L * v auto_jacvec(f, x, v)
@test mul!(dy, L, v) auto_jacvec(f, x, v)
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*auto_jacvec(f,x,v) + b*_dy
update_coefficients!(L, v, nothing, 0.0)
@test mul!(dy, L, v) auto_jacvec(f, v, v)
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*auto_jacvec(f,x,v) + b*_dy

L = JacVec(f, x, autodiff = AutoFiniteDiff())
update_coefficients!(L, v, 3.0, 4.0)
update_coefficients!(f, v, 3.0, 4.0)
@test mul!(dy, L, x) auto_jacvec(f, v, x)
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b) a*auto_jacvec(f,v,x) + b*_dy
update_coefficients!(f, v, 5.0, 6.0)
@test L(dy, v, 5.0, 6.0) auto_jacvec(f,v,v)

L = JacVec(f, copy(x), 1.0, 1.0; autodiff = AutoFiniteDiff())
update_coefficients!(f, x, 1.0, 1.0)
@test L * x num_jacvec(f, x, x)
@test L * v num_jacvec(f, x, v)
@test mul!(dy, L, v)num_jacvec(f, x, v) rtol=1e-6
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_jacvec(f,x,v) + b*_dy rtol=1e-6
update_coefficients!(L, v, nothing, 0.0)
@test mul!(dy, L, v)num_jacvec(f, v, v) rtol=1e-6
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_jacvec(f,x,v) + b*_dy rtol=1e-6
update_coefficients!(L, v, 3.0, 4.0)
update_coefficients!(f, v, 3.0, 4.0)
@test mul!(dy, L, x)num_jacvec(f, v, x) rtol=1e-6
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b) a*num_jacvec(f,v,x) + b*_dy rtol=1e-6
update_coefficients!(f, v, 5.0, 6.0)
@test L(dy, v, 5.0, 6.0) num_jacvec(f,v,v) rtol=1e-6

out = similar(v)
gmres!(out, L, v)
@test_nowarn gmres!(out, L, v)

@info "HesVec"

x = rand(N)
v = rand(N)
L = HesVec(g, x, autodiff = AutoFiniteDiff())
@test L * x num_hesvec(g, x, x)
@test L * v num_hesvec(g, x, v)
L = HesVec(g, copy(x), 1.0, 1.0, autodiff = AutoFiniteDiff())
update_coefficients!(g, x, 1.0, 1.0)
@test L * x num_hesvec(g, x, x) rtol=1e-2
@test L * v num_hesvec(g, x, v) rtol=1e-2
@test mul!(dy, L, v)num_hesvec(g, x, v) rtol=1e-2
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_hesvec(g,x,v) + b*_dy rtol=1e-2
update_coefficients!(L, v, nothing, 0.0)
@test mul!(dy, L, v)num_hesvec(g, v, v) rtol=1e-2
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_hesvec(g,x,v) + b*_dy rtol=1e-2

L = HesVec(g, x)
update_coefficients!(L, v, 3.0, 4.0)
update_coefficients!(g, v, 3.0, 4.0)
@test mul!(dy, L, x)num_hesvec(g, v, x) rtol=1e-2
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b) a*num_hesvec(g,v,x) + b*_dy rtol=1e-2
update_coefficients!(g, v, 5.0, 6.0)
@test L(dy, v, 5.0, 6.0) num_hesvec(g,v,v) rtol=1e-2

L = HesVec(g, copy(x), 1.0, 1.0)
@test L * x numauto_hesvec(g, x, x)
@test L * v numauto_hesvec(g, x, v)
@test mul!(dy, L, v)numauto_hesvec(g, x, v) rtol=1e-8
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
update_coefficients!(L, v, nothing, 0.0)
@test mul!(dy, L, v)numauto_hesvec(g, v, v) rtol=1e-8
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
@test mul!(dy, L, v)numauto_hesvec(g, x, v)
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,0)a*numauto_hesvec(g,x,v)+0*_dy
update_coefficients!(L, v, 3.0, 4.0)
update_coefficients!(g, v, 3.0, 4.0)
@test mul!(dy, L, x)numauto_hesvec(g, v, x)
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b)a*numauto_hesvec(g,v,x)+b*_dy
update_coefficients!(g, v, 5.0, 6.0)
@test L(dy, v, 5.0, 6.0) numauto_hesvec(g,v,v)

out = similar(v)
gmres!(out, L, v)

using Zygote

x = rand(N)
v = rand(N)

L = HesVec(g, x, autodiff = AutoZygote())
L = HesVec(g, copy(x), 1.0, 1.0; autodiff = AutoZygote())
update_coefficients!(g, x, 1.0, 1.0)
@test L * x autoback_hesvec(g, x, x)
@test L * v autoback_hesvec(g, x, v)
@test mul!(dy, L, v)autoback_hesvec(g, x, v) rtol=1e-8
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*autoback_hesvec(g,x,v)+b*_dy rtol=1e-8
update_coefficients!(L, v, nothing, 0.0)
@test mul!(dy, L, v)autoback_hesvec(g, v, v) rtol=1e-8
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*autoback_hesvec(g,x,v)+b*_dy rtol=1e-8
@test mul!(dy, L, v)autoback_hesvec(g, x, v)
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*autoback_hesvec(g,x,v)+b*_dy
update_coefficients!(L, v, 3.0, 4.0)
update_coefficients!(g, v, 3.0, 4.0)
@test mul!(dy, L, x)autoback_hesvec(g, v, x)
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b)a*autoback_hesvec(g,v,x)+b*_dy
update_coefficients!(g, v, 5.0, 6.0)
@test L(dy, v, 5.0, 6.0) autoback_hesvec(g,v,v)

out = similar(v)
gmres!(out, L, v)

@info "HesVecGrad"

x = rand(N)
v = rand(N)
L = HesVecGrad(h, x, autodiff = AutoFiniteDiff())
L = HesVecGrad(h, copy(x), 1.0, 1.0; autodiff = AutoFiniteDiff())
update_coefficients!(h, x, 1.0, 1.0)
update_coefficients!(g, x, 1.0, 1.0)
@test L * x num_hesvec(g, x, x) rtol=1e-2
@test L * v num_hesvec(g, x, v) rtol=1e-2
@test mul!(dy, L, v)num_hesvec(g, x, v) rtol=1e-2
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*num_hesvec(g,x,v)+b*_dy rtol=1e-2
update_coefficients!(L, v, nothing, 0.0)
@test mul!(dy, L, v)num_hesvec(g, v, v) rtol=1e-2
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*num_hesvec(g,x,v)+b*_dy rtol=1e-2

L = HesVecGrad(h, x)
for op in (L, g, h) update_coefficients!(op, v, 3.0, 4.0) end
@test mul!(dy, L, x)num_hesvec(g, v, x) rtol=1e-2
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b)a*num_hesvec(g,v,x)+b*_dy rtol=1e-2
update_coefficients!(g, v, 5.0, 6.0)
@test L(dy, v, 5.0, 6.0) num_hesvec(g,v,v) rtol=1e-2

L = HesVecGrad(h, copy(x), 1.0, 1.0)
update_coefficients!(g, x, 1.0, 1.0)
update_coefficients!(h, x, 1.0, 1.0)
@test L * x autonum_hesvec(g, x, x)
@test L * v numauto_hesvec(g, x, v)
@test mul!(dy, L, v)numauto_hesvec(g, x, v) rtol=1e-8
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
update_coefficients!(L, v, nothing, 0.0)
@test mul!(dy, L, v)numauto_hesvec(g, v, v) rtol=1e-8
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
@test mul!(dy, L, v)numauto_hesvec(g, x, v)
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy
for op in (L, g, h) update_coefficients!(op, v, 3.0, 4.0) end
@test mul!(dy, L, x)numauto_hesvec(g, v, x)
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b)a*numauto_hesvec(g,v,x)+b*_dy
update_coefficients!(g, v, 5.0, 6.0)
update_coefficients!(h, v, 5.0, 6.0)
@test L(dy, v, 5.0, 6.0) numauto_hesvec(g,v,v)

out = similar(v)
gmres!(out, L, v)

# Test that x and v were not mutated
# x's rtol can't be too large since it is mutated and then restored in some algorithms
@test x x0
@test v v0

#
Loading

0 comments on commit c337dd5

Please sign in to comment.