Skip to content

Commit

Permalink
Implement Polyester Colored AD for oop functions
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 16, 2024
1 parent 49b5ab9 commit 4d0c091
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 14 deletions.
68 changes: 67 additions & 1 deletion ext/SparseDiffToolsPolyesterExt.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,69 @@
module SparseDiffToolsPolyesterExt


using Adapt, ArrayInterface, ForwardDiff, FiniteDiff, Polyester, SparseDiffTools,
SparseArrays
import SparseDiffTools: polyesterforwarddiff_color_jacobian, ForwardColorJacCache,
__parameterless_type

function cld_fast(a::A, b::B) where {A, B}
T = promote_type(A, B)
return cld_fast(a % T, b % T)
end
function cld_fast(n::T, d::T) where {T}
x = Base.udiv_int(n, d)
x += n != d * x
return x
end

function polyesterforwarddiff_color_jacobian(J::AbstractMatrix{<:Number}, f::F,
x::AbstractArray{<:Number}, jac_cache::ForwardColorJacCache) where {F}
t = jac_cache.t
dx = jac_cache.dx
p = jac_cache.p
colorvec = jac_cache.colorvec
sparsity = jac_cache.sparsity
chunksize = jac_cache.chunksize
maxcolor = maximum(colorvec)

vecx = vec(x)

nrows, ncols = size(J)

if !(sparsity isa Nothing)
rows_index, cols_index = ArrayInterface.findstructralnz(sparsity)
rows_index = [rows_index[i] for i in 1:length(rows_index)]
cols_index = [cols_index[i] for i in 1:length(cols_index)]
else
rows_index = 1:nrows
cols_index = 1:ncols
end

if J isa AbstractSparseMatrix
fill!(nonzeros(J), zero(eltype(J)))
else
fill!(J, zero(eltype(J)))
end

batch((length(p), min(length(p), Threads.nthreads()))) do _, start, stop
for i in start:stop
partial_i = p[i]
color_i = i
t_ = reshape(eltype(t).(vecx, ForwardDiff.Partials.(partial_i)), size(t))
fx = f(t_)
for j in 1:chunksize
dx = vec(ForwardDiff.partials.(fx, j))
pick_inds = [idx
for idx in 1:length(rows_index)
if colorvec[cols_index[idx]] == color_i]
rows_index_c = rows_index[pick_inds]
cols_index_c = cols_index[pick_inds]
@inbounds @simd for i in 1:length(rows_index_c)
J[rows_index_c[i], cols_index_c[i]] = dx[rows_index_c[i]]
end
end
end
end
return J
end

end
13 changes: 7 additions & 6 deletions ext/SparseDiffToolsPolyesterForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import ForwardDiff
import SparseDiffTools: AbstractMaybeSparseJacobianCache, AbstractMaybeSparsityDetection,
ForwardColorJacCache, NoMatrixColoring, sparse_jacobian_cache,
sparse_jacobian!,
sparse_jacobian_static_array, __standard_tag, __chunksize
sparse_jacobian_static_array, __standard_tag, __chunksize,
polyesterforwarddiff_color_jacobian,
polyesterforwarddiff_color_jacobian!

struct PolyesterForwardDiffJacobianCache{CO, CA, J, FX, X} <:
AbstractMaybeSparseJacobianCache
Expand All @@ -25,8 +27,6 @@ function sparse_jacobian_cache(
cache = __chunksize(ad, x)
jac_prototype = nothing
else
@warn """Currently PolyesterForwardDiff does not support sparsity detection
natively. Falling back to using ForwardDiff.jl""" maxlog=1
tag = __standard_tag(nothing, x)
# Colored ForwardDiff passes `tag` directly into Dual so we need the `typeof`
cache = ForwardColorJacCache(f, x, __chunksize(ad); coloring_result.colorvec,
Expand All @@ -45,7 +45,8 @@ function sparse_jacobian_cache(
jac_prototype = nothing
else
@warn """Currently PolyesterForwardDiff does not support sparsity detection
natively. Falling back to using ForwardDiff.jl""" maxlog=1
natively for inplace functions. Falling back to using
ForwardDiff.jl""" maxlog=1
tag = __standard_tag(nothing, x)
# Colored ForwardDiff passes `tag` directly into Dual so we need the `typeof`
cache = ForwardColorJacCache(f!, x, __chunksize(ad); coloring_result.colorvec,
Expand All @@ -58,7 +59,7 @@ end
function sparse_jacobian!(J::AbstractMatrix, _, cache::PolyesterForwardDiffJacobianCache,
f::F, x) where {F}
if cache.cache isa ForwardColorJacCache
forwarddiff_color_jacobian(J, f, x, cache.cache) # Use Sparse ForwardDiff
polyesterforwarddiff_color_jacobian(J, f, x, cache.cache)
else
PolyesterForwardDiff.threaded_jacobian!(f, J, x, cache.cache) # Don't try to exploit sparsity
end
Expand All @@ -68,7 +69,7 @@ end
function sparse_jacobian!(J::AbstractMatrix, _, cache::PolyesterForwardDiffJacobianCache,
f!::F, fx, x) where {F}
if cache.cache isa ForwardColorJacCache
forwarddiff_color_jacobian!(J, f!, x, cache.cache) # Use Sparse ForwardDiff
forwarddiff_color_jacobian!(J, f!, x, cache.cache)
else
PolyesterForwardDiff.threaded_jacobian!(f!, fx, J, x, cache.cache) # Don't try to exploit sparsity
end
Expand Down
16 changes: 9 additions & 7 deletions src/differentiation/compute_jacobian_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ function forwarddiff_color_jacobian(f::F, x::AbstractArray{<:Number},
end
end

# Defined in extension. Polyester version of `forwarddiff_color_jacobian`
function polyesterforwarddiff_color_jacobian end

# When J is mutable, this version of forwarddiff_color_jacobian will mutate J to avoid allocations
function forwarddiff_color_jacobian(J::AbstractMatrix{<:Number}, f::F,
x::AbstractArray{<:Number},
Expand Down Expand Up @@ -249,9 +252,8 @@ function forwarddiff_color_jacobian(J::AbstractMatrix{<:Number}, f::F,
end

# When J is immutable, this version of forwarddiff_color_jacobian will avoid mutating J
function forwarddiff_color_jacobian_immutable(f, x::AbstractArray{<:Number},
jac_cache::ForwardColorJacCache,
jac_prototype = nothing)
function forwarddiff_color_jacobian_immutable(f::F, x::AbstractArray{<:Number},
jac_cache::ForwardColorJacCache, jac_prototype = nothing) where {F}
t = jac_cache.t
dx = jac_cache.dx
p = jac_cache.p
Expand Down Expand Up @@ -315,16 +317,16 @@ function forwarddiff_color_jacobian_immutable(f, x::AbstractArray{<:Number},
return J
end

function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number}, f,
function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number}, f::F,
x::AbstractArray{<:Number}; dx = similar(x, size(J, 1)), colorvec = 1:length(x),
sparsity = ArrayInterface.has_sparsestruct(J) ? J : nothing)
sparsity = ArrayInterface.has_sparsestruct(J) ? J : nothing) where {F}
forwarddiff_color_jacobian!(J, f, x, ForwardColorJacCache(f, x; dx, colorvec, sparsity))
end

function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
f,
f::F,
x::AbstractArray{<:Number},
jac_cache::ForwardColorJacCache)
jac_cache::ForwardColorJacCache) where {F}
t = jac_cache.t
fx = jac_cache.fx
dx = jac_cache.dx
Expand Down

0 comments on commit 4d0c091

Please sign in to comment.