Skip to content

Commit

Permalink
Add Polyester approximate sparsity detection
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 16, 2024
1 parent 01a5e5f commit 49b5ab9
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 12 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SparseDiffTools"
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
authors = ["Pankaj Mishra <[email protected]>", "Chris Rackauckas <[email protected]>"]
version = "2.16.0"
version = "2.17.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -27,12 +27,14 @@ VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f"

[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
SparseDiffToolsEnzymeExt = "Enzyme"
SparseDiffToolsPolyesterExt = "Polyester"
SparseDiffToolsPolyesterForwardDiffExt = "PolyesterForwardDiff"
SparseDiffToolsSymbolicsExt = "Symbolics"
SparseDiffToolsZygoteExt = "Zygote"
Expand All @@ -49,6 +51,7 @@ ForwardDiff = "0.10"
Graphs = "1"
LinearAlgebra = "<0.0.1, 1"
PackageExtensionCompat = "1"
Polyester = "0.7.9"
PolyesterForwardDiff = "0.1.1"
Random = "1.6"
Reexport = "1"
Expand All @@ -70,6 +73,7 @@ BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Expand Down
3 changes: 3 additions & 0 deletions ext/SparseDiffToolsPolyesterExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module SparseDiffToolsPolyesterExt

end
48 changes: 39 additions & 9 deletions ext/SparseDiffToolsPolyesterForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module SparseDiffToolsPolyesterForwardDiffExt

using ADTypes, SparseDiffTools, PolyesterForwardDiff
using ADTypes, SparseDiffTools, PolyesterForwardDiff, UnPack, Random, SparseArrays
import ForwardDiff
import SparseDiffTools: AbstractMaybeSparseJacobianCache, AbstractMaybeSparsityDetection,
ForwardColorJacCache, NoMatrixColoring, sparse_jacobian_cache,
Expand All @@ -17,10 +17,8 @@ struct PolyesterForwardDiffJacobianCache{CO, CA, J, FX, X} <:
end

function sparse_jacobian_cache(
ad::Union{AutoSparsePolyesterForwardDiff,
AutoPolyesterForwardDiff},
sd::AbstractMaybeSparsityDetection, f::F, x;
fx = nothing) where {F}
ad::Union{AutoSparsePolyesterForwardDiff, AutoPolyesterForwardDiff},
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
coloring_result = sd(ad, f, x)
fx = fx === nothing ? similar(f(x)) : fx
if coloring_result isa NoMatrixColoring
Expand All @@ -39,10 +37,8 @@ function sparse_jacobian_cache(
end

function sparse_jacobian_cache(
ad::Union{AutoSparsePolyesterForwardDiff,
AutoPolyesterForwardDiff},
sd::AbstractMaybeSparsityDetection, f!::F, fx,
x) where {F}
ad::Union{AutoSparsePolyesterForwardDiff, AutoPolyesterForwardDiff},
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F}
coloring_result = sd(ad, f!, fx, x)
if coloring_result isa NoMatrixColoring
cache = __chunksize(ad, x)
Expand Down Expand Up @@ -79,4 +75,38 @@ function sparse_jacobian!(J::AbstractMatrix, _, cache::PolyesterForwardDiffJacob
return J
end

## Approximate Sparsity Detection
function (alg::ApproximateJacobianSparsity)(
ad::AutoSparsePolyesterForwardDiff, f::F, x; fx = nothing, kwargs...) where {F}
@unpack ntrials, rng = alg
fx = fx === nothing ? f(x) : fx
ck = __chunksize(ad, x)
J = fill!(similar(fx, length(fx), length(x)), 0)
J_cache = similar(J)
x_ = similar(x)
for _ in 1:ntrials
randn!(rng, x_)
PolyesterForwardDiff.threaded_jacobian!(f, J_cache, x_, ck)
@. J += abs(J_cache)
end
return (JacPrototypeSparsityDetection(; jac_prototype = sparse(J), alg.alg))(ad, f, x;
fx, kwargs...)
end

function (alg::ApproximateJacobianSparsity)(ad::AutoSparsePolyesterForwardDiff, f::F, fx, x;
kwargs...) where {F}
@unpack ntrials, rng = alg
ck = __chunksize(ad, x)
J = fill!(similar(fx, length(fx), length(x)), 0)
J_cache = similar(J)
x_ = similar(x)
for _ in 1:ntrials
randn!(rng, x_)
PolyesterForwardDiff.threaded_jacobian!(f, fx, J_cache, x_, ck)
@. J += abs(J_cache)
end
return (JacPrototypeSparsityDetection(; jac_prototype = sparse(J), alg.alg))(ad, f, x;
fx, kwargs...)
end

end
12 changes: 10 additions & 2 deletions src/highlevel/coloring.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ function (alg::ApproximateJacobianSparsity)(
ad::AbstractSparseADType, f::F, x; fx = nothing,
kwargs...) where {F}
if !(ad isa AutoSparseForwardDiff)
@warn "$(ad) support for approximate jacobian not implemented. Using ForwardDiff instead." maxlog=1
if ad isa AutoSparsePolyesterForwardDiff
@warn "$(ad) is only supported if `PolyesterForwardDiff` is explicitly loaded. Using ForwardDiff instead." maxlog=1
else
@warn "$(ad) support for approximate jacobian not implemented. Using ForwardDiff instead." maxlog=1
end
end
@unpack ntrials, rng = alg
fx = fx === nothing ? f(x) : fx
Expand All @@ -56,7 +60,11 @@ end
function (alg::ApproximateJacobianSparsity)(ad::AbstractSparseADType, f::F, fx, x;
kwargs...) where {F}
if !(ad isa AutoSparseForwardDiff)
@warn "$(ad) support for approximate jacobian not implemented. Using ForwardDiff instead." maxlog=1
if ad isa AutoSparsePolyesterForwardDiff
@warn "$(ad) is only supported if `PolyesterForwardDiff` is explicitly loaded. Using ForwardDiff instead." maxlog=1
else
@warn "$(ad) support for approximate jacobian not implemented. Using ForwardDiff instead." maxlog=1
end
end
@unpack ntrials, rng = alg
cfg = ForwardDiff.JacobianConfig(f, fx, x)
Expand Down

0 comments on commit 49b5ab9

Please sign in to comment.