Skip to content

Commit

Permalink
Change Tuple -> tuple for consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav-arya committed Mar 28, 2023
1 parent b34e28b commit 9ec8666
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
8 changes: 4 additions & 4 deletions ext/SparseDiffToolsZygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,24 +44,24 @@ function SparseDiffTools.autoback_hesvec!(dy, f, x, v,
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag(), eltype(x))),
eltype(x), 1
}.(x,
ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))),
ForwardDiff.Partials.(tuple.(reshape(v, size(x))))),
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag(), eltype(x))),
eltype(x), 1
}.(x,
ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))))
ForwardDiff.Partials.(tuple.(reshape(v, size(x))))))
g = let f = f
(dx, x) -> dx .= first(Zygote.gradient(f, x))
end
# Reset each dual number in cache1 to primal = dual = 1.
cache1 .= eltype(cache1).(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
cache1 .= eltype(cache1).(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))
g(cache2, cache1)
dy .= partials.(cache2, 1)
end

function SparseDiffTools.autoback_hesvec(f, x, v)
g = x -> first(Zygote.gradient(f, x))
y = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), eltype(x), 1
}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))
ForwardDiff.partials.(g(y), 1)
end

Expand Down
8 changes: 4 additions & 4 deletions test/test_jaches_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ function h(dy, x)
end

cache1 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))),
eltype(x), 1}.(x, ForwardDiff.Partials.(Tuple.(v)))
cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))), eltype(x), 1}.(x, ForwardDiff.Partials.(Tuple.(v)))
eltype(x), 1}.(x, ForwardDiff.Partials.(tuple.(v)))
cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))), eltype(x), 1}.(x, ForwardDiff.Partials.(tuple.(v)))
@test num_jacvec!(dy, f, x, v)ForwardDiff.jacobian(f, similar(x), x) * v rtol=1e-6
@test num_jacvec!(dy, f, x, v, similar(v),
similar(v))ForwardDiff.jacobian(f, similar(x), x) * v rtol=1e-6
Expand Down Expand Up @@ -50,9 +50,9 @@ cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), e
@test numback_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8

cache3 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing, eltype(x))), eltype(x), 1
}.(x, ForwardDiff.Partials.(Tuple.(v)))
}.(x, ForwardDiff.Partials.(tuple.(v)))
cache4 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing, eltype(x))), eltype(x), 1
}.(x, ForwardDiff.Partials.(Tuple.(v)))
}.(x, ForwardDiff.Partials.(tuple.(v)))
@test autoback_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
@test autoback_hesvec!(dy, g, x, v, cache3, cache4)ForwardDiff.hessian(g, x) * v rtol=1e-8
@test autoback_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
Expand Down

0 comments on commit 9ec8666

Please sign in to comment.