Skip to content

Commit

Permalink
Merge pull request #235 from gaurav-arya/ag-ci
Browse files Browse the repository at this point in the history
Fix CI
  • Loading branch information
ChrisRackauckas committed Mar 28, 2023
2 parents f191085 + 7e41089 commit 53755d8
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
fail-fast: false
matrix:
group:
- Core
- All
version:
- '1' # Latest Release
- '~1.6' # Current LTS
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f"
[compat]
ADTypes = "0.1"
Adapt = "3.0"
ArrayInterface = "7"
ArrayInterface = "7.4.2"
Compat = "4"
DataStructures = "0.18"
FiniteDiff = "2.8.1"
Expand Down
14 changes: 7 additions & 7 deletions ext/SparseDiffToolsZygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,27 +41,27 @@ function SparseDiffTools.numback_hesvec(f, x, v)
end

function SparseDiffTools.autoback_hesvec!(dy, f, x, v,
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))),
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag(), eltype(x))),
eltype(x), 1
}.(x,
ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))),
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))),
ForwardDiff.Partials.(tuple.(reshape(v, size(x))))),
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag(), eltype(x))),
eltype(x), 1
}.(x,
ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))))
ForwardDiff.Partials.(tuple.(reshape(v, size(x))))))
g = let f = f
(dx, x) -> dx .= first(Zygote.gradient(f, x))
end
cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), eltype(x), 1
}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
# Reset each dual number in cache1 to primal = dual = 1.
cache1 .= eltype(cache1).(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))
g(cache2, cache1)
dy .= partials.(cache2, 1)
end

function SparseDiffTools.autoback_hesvec(f, x, v)
g = x -> first(Zygote.gradient(f, x))
y = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), eltype(x), 1
}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))
ForwardDiff.partials.(g(y), 1)
end

Expand Down
2 changes: 1 addition & 1 deletion src/differentiation/jaches_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFo
cache1 = Dual{
typeof(ForwardDiff.Tag(DeivVecTag(),eltype(u))), eltype(u), 1
}.(u, ForwardDiff.Partials.(tuple.(u)))
cache2 = copy(u)
cache2 = copy(cache1)

(cache1, cache2), autoback_hesvec, autoback_hesvec!
else
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ if GROUP == "All"
@time @safetestset "Greedy star coloring" begin include("test_greedy_star.jl") end
@time @safetestset "Acyclic coloring" begin include("test_acyclic.jl") end
@time @safetestset "Matrix to graph conversion" begin include("test_matrix2graph.jl") end
@time @safetestset "AD using colorvec vector" begin include("test_ad.jl") end
@time @safetestset "Hessian colorvecs" begin include("test_sparse_hessian.jl") end
@time @safetestset "Integration test" begin include("test_integration.jl") end
@time @safetestset "Special matrices" begin include("test_specialmatrices.jl") end
@time @safetestset "Jac Vecs and Hes Vecs" begin include("test_jaches_products.jl") end
@time @safetestset "Vec Jac Products" begin include("test_vecjac_products.jl") end
@time @safetestset "AD using colorvec vector" begin include("test_ad.jl") end
end

if GROUP == "GPU"
Expand Down
12 changes: 6 additions & 6 deletions test/test_jaches_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ function h(dy, x)
end

cache1 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))),
eltype(x), 1}.(x, ForwardDiff.Partials.(Tuple.(v)))
cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))), eltype(x), 1}.(x, ForwardDiff.Partials.(Tuple.(v)))
eltype(x), 1}.(x, ForwardDiff.Partials.(tuple.(v)))
cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))), eltype(x), 1}.(x, ForwardDiff.Partials.(tuple.(v)))
@test num_jacvec!(dy, f, x, v)ForwardDiff.jacobian(f, similar(x), x) * v rtol=1e-6
@test num_jacvec!(dy, f, x, v, similar(v),
similar(v))ForwardDiff.jacobian(f, similar(x), x) * v rtol=1e-6
Expand Down Expand Up @@ -50,9 +50,9 @@ cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), e
@test numback_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8

cache3 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing, eltype(x))), eltype(x), 1
}.(x, ForwardDiff.Partials.(Tuple.(v)))
}.(x, ForwardDiff.Partials.(tuple.(v)))
cache4 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing, eltype(x))), eltype(x), 1
}.(x, ForwardDiff.Partials.(Tuple.(v)))
}.(x, ForwardDiff.Partials.(tuple.(v)))
@test autoback_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
@test autoback_hesvec!(dy, g, x, v, cache3, cache4)ForwardDiff.hessian(g, x) * v rtol=1e-8
@test autoback_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
Expand Down Expand Up @@ -135,8 +135,8 @@ gmres!(out, L, v)
x = rand(N)
v = rand(N)
L = HesVecGrad(h, x, autodiff = AutoFiniteDiff())
@test L * x num_hesvec(g, x, x)
@test L * v num_hesvec(g, x, v)
@test L * x num_hesvec(g, x, x) rtol=1e-2
@test L * v num_hesvec(g, x, v) rtol=1e-2
@test mul!(dy, L, v)num_hesvec(g, x, v) rtol=1e-2
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*num_hesvec(g,x,v)+b*_dy rtol=1e-2
update_coefficients!(L, v, nothing, 0.0)
Expand Down
12 changes: 1 addition & 11 deletions test/test_vecjac_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,4 @@ update_coefficients!(L, v, nothing, 0.0)
L = VecJac(f, x; autodiff = AutoFiniteDiff())
update_coefficients!(L, v, nothing, 0.0)
@test L * v actual_vjp

@info "ZygoteVecJac"

L = ZygoteVecJac(f, x)
actual_vjp = Zygote.jacobian(x -> f(x, nothing, 0.0), x)[1]' * v
update_coefficients!(L, v, nothing, 0.0)
@test L * v actual_vjp
L = ZygoteVecJac(f, x; autodiff = AutoFiniteDiff())
update_coefficients!(L, v, nothing, 0.0)
@test L * v actual_vjp
#
#

0 comments on commit 53755d8

Please sign in to comment.