Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of the adjoint of the fdiff function from ImageBase.jl #26

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.1.0"
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CoordinateTransformations = "150eb455-5306-5404-9cee-2592286d6298"
ImageBase = "c817782e-172a-44cc-b673-b171935fbb9e"
arcAman07 marked this conversation as resolved.
Show resolved Hide resolved
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
Expand All @@ -24,6 +25,7 @@ Interpolations = "0.13.4"
Rotations = "1.0.2"
StaticArrays = "1.2"
Zygote = "0.6.17"
ImageBase = "0.1.5"

[extras]
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Expand Down
6 changes: 4 additions & 2 deletions src/DiffImages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ using ImageCore,
Interpolations,
ChainRulesCore,
LinearAlgebra,
Rotations
Rotations,
ImageBase

using Zygote: @adjoint
using ChainRulesCore: NoTangent

export colorify, channelify
export colorify, channelify, fdiff
include("colors/conversions.jl")
include("geometry/warp.jl")
include("geometry/adjoints.jl")
include("ImageBase.jl/fdiff.jl")

end
9 changes: 9 additions & 0 deletions src/ImageBase.jl/fdiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# TODO(arcAman07): support RGB inputs, currently works only for GrayScale Images
# TODO(arcAman07): support N dimensional case, currently works only for 2 dimensional case
@adjoint function fdiff(A::AbstractArray; kwargs...)
y = fdiff!(similar(A, maybe_floattype(eltype(A))), A; kwargs...)
function pullback(Δ)
return (fill(Δ, size(A)),)
end
return (y, pullback)
end
62 changes: 62 additions & 0 deletions test/ImageBase.jl/fdiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using ImageBase.FiniteDiff: fdiff, fdiff!
@testset "fdiff" begin
# Base.diff doesn't promote integer to float
@test ImageBase.FiniteDiff.maybe_floattype(Int) == Int
@test ImageBase.FiniteDiff.maybe_floattype(N0f8) == Float32
@test ImageBase.FiniteDiff.maybe_floattype(RGB{N0f8}) == RGB{Float32}
@testset "NumericalTests" begin
a = reshape(collect(1:9), 3, 3)
b_fd_1 = [1 1 1; 1 1 1; -2 -2 -2]
b_fd_2 = [3 3 -6; 3 3 -6; 3 3 -6]
b_bd_1 = [-2 -2 -2; 1 1 1; 1 1 1]
b_bd_2 = [-6 3 3; -6 3 3; -6 3 3]
out = similar(a)

@test fdiff(a, dims=1) == b_fd_1
@test fdiff(a, dims=2) == b_fd_2
@test fdiff(a, dims=1, rev=true) == b_bd_1
@test fdiff(a, dims=2, rev=true) == b_bd_2
fdiff!(out, a, dims=1)
@test out == b_fd_1
fdiff!(out, a, dims=2)
@test out == b_fd_2
fdiff!(out, a, dims=1, rev=true)
@test out == b_bd_1
fdiff!(out, a, dims=2, rev=true)
@test out == b_bd_2
end
@testset "Differentiability" begin
a_fd_1 = [2 4 8; 3 9 27; 4 16 64]
a_fd_2 = [3 6 9; 6 18 27; 9 27 54; 12 36 81]
@testset "Testing basic fdiff" begin
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2))[1] == ones(Float64,size(a_fd_2))
end
@testset "Testing fdiff with rev" begin
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true))[1] == ones(Float64,size(a_fd_2))
end
@testset "Testing fdiff with boundary condition" begin
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,boundary=:periodic))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,boundary=:periodic))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,boundary=:zero))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,boundary=:zero))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,boundary=:periodic))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,boundary=:periodic))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,boundary=:zero))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,boundary=:zero))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_2))
end
end
end
7 changes: 6 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ using Test,
FiniteDifferences,
ChainRulesCore,
CoordinateTransformations,
Rotations
Rotations,
ImageBase

@testset "DiffImages" begin
@info "Testing Colorspace modules"
Expand All @@ -23,4 +24,8 @@ using Test,
@testset "Warps" begin
include("geometry/warp.jl")
end
@info "Testing ImageBase modules"
@testset "FiniteDifferences" begin
include("ImageBase.jl/fdiff.jl")
end
end