diff --git a/ext/SparseDiffToolsZygote.jl b/ext/SparseDiffToolsZygote.jl index 6af37347..b91e7d31 100644 --- a/ext/SparseDiffToolsZygote.jl +++ b/ext/SparseDiffToolsZygote.jl @@ -44,16 +44,16 @@ function SparseDiffTools.autoback_hesvec!(dy, f, x, v, cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag(), eltype(x))), eltype(x), 1 }.(x, - ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))), + ForwardDiff.Partials.(tuple.(reshape(v, size(x))))), cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag(), eltype(x))), eltype(x), 1 }.(x, - ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))) + ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))) g = let f = f (dx, x) -> dx .= first(Zygote.gradient(f, x)) end # Reset each dual number in cache1 to primal = dual = 1. - cache1 .= eltype(cache1).(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))) + cache1 .= eltype(cache1).(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x))))) g(cache2, cache1) dy .= partials.(cache2, 1) end @@ -61,7 +61,7 @@ end function SparseDiffTools.autoback_hesvec(f, x, v) g = x -> first(Zygote.gradient(f, x)) y = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), eltype(x), 1 - }.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))) + }.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x))))) ForwardDiff.partials.(g(y), 1) end diff --git a/test/test_jaches_products.jl b/test/test_jaches_products.jl index bf4eca9d..f693df2f 100644 --- a/test/test_jaches_products.jl +++ b/test/test_jaches_products.jl @@ -20,8 +20,8 @@ function h(dy, x) end cache1 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))), - eltype(x), 1}.(x, ForwardDiff.Partials.(Tuple.(v))) -cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))), eltype(x), 1}.(x, ForwardDiff.Partials.(Tuple.(v))) + eltype(x), 1}.(x, ForwardDiff.Partials.(tuple.(v))) +cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))), eltype(x), 1}.(x, ForwardDiff.Partials.(tuple.(v))) @test num_jacvec!(dy, f, x, v)≈ForwardDiff.jacobian(f, similar(x), x) * v rtol=1e-6 @test num_jacvec!(dy, f, x, v, similar(v), similar(v))≈ForwardDiff.jacobian(f, similar(x), x) * v rtol=1e-6 @@ -50,9 +50,9 @@ cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), e @test numback_hesvec(g, x, v)≈ForwardDiff.hessian(g, x) * v rtol=1e-8 cache3 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing, eltype(x))), eltype(x), 1 - }.(x, ForwardDiff.Partials.(Tuple.(v))) + }.(x, ForwardDiff.Partials.(tuple.(v))) cache4 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing, eltype(x))), eltype(x), 1 - }.(x, ForwardDiff.Partials.(Tuple.(v))) + }.(x, ForwardDiff.Partials.(tuple.(v))) @test autoback_hesvec!(dy, g, x, v)≈ForwardDiff.hessian(g, x) * v rtol=1e-8 @test autoback_hesvec!(dy, g, x, v, cache3, cache4)≈ForwardDiff.hessian(g, x) * v rtol=1e-8 @test autoback_hesvec(g, x, v)≈ForwardDiff.hessian(g, x) * v rtol=1e-8