Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding adjoints for statistics functions in ImageBase.jl #27

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,21 @@ version = "0.1.0"
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CoordinateTransformations = "150eb455-5306-5404-9cee-2592286d6298"
ImageBase = "c817782e-172a-44cc-b673-b171935fbb9e"
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ChainRulesCore = "1.3.0"
CoordinateTransformations = "0.6.1"
ImageBase = "0.1.5"
ImageCore = "0.9"
ImageTransformations = "0.8, 0.9"
Interpolations = "0.13.4"
Expand Down
7 changes: 5 additions & 2 deletions src/DiffImages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@ using ImageCore,
Interpolations,
ChainRulesCore,
LinearAlgebra,
Rotations
Rotations,
ImageBase

using Zygote: @adjoint
using ChainRulesCore: NoTangent

export colorify, channelify
export colorify, channelify, fdiff
include("colors/conversions.jl")
include("geometry/warp.jl")
include("geometry/adjoints.jl")
include("ImageBase.jl/fdiff.jl")
include("ImageBase.jl/statistics.jl")

end
9 changes: 9 additions & 0 deletions src/ImageBase.jl/fdiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# TODO(arcAman07): support RGB inputs, currently works only for GrayScale Images
# TODO(arcAman07): support N dimensional case, currently works only for 2 dimensional case
@adjoint function fdiff(A::AbstractArray; kwargs...)
y = fdiff!(similar(A, maybe_floattype(eltype(A))), A; kwargs...)
function pullback(Δ)
return (fill(Δ, size(A)),)
end
return (y, pullback)
end
35 changes: 35 additions & 0 deletions src/ImageBase.jl/statistics.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
@adjoint function sumfinite(A::AbstractArray{T,N}; kwargs...) where {T,N}
y = ImageBase.sumfinite(identity, A; kwargs...)
function pullback(Δ)
return (fill(Δ,size(A)),)
end
return (y, pullback)
end

@adjoint function meanfinite(A::AbstractArray{T,N}; kwargs...) where {T,N}
y = ImageBase.meanfinite(identity, A; kwargs...)
function pullback(Δ)
return (fill(Δ / length(A),size(A)),)
end
return (y, pullback)
end

@adjoint function maximum_finite(A::AbstractArray{T,N}; kwargs...) where {T,N}
y = ImageBase.maximum_finite(identity, A; kwargs...)
final = zeros(Float64, size(A))
function pullback(Δ)
final[last(findall(x -> x == y, A))] = Δ
return (final,)
end
return (y, pullback)
end

@adjoint function minimum_finite(A::AbstractArray{T,N}; kwargs...) where {T,N}
y = ImageBase.minimum_finite(identity, A; kwargs...)
final = zeros(Float64, size(A))
function pullback(Δ)
final[first(findall(x -> x == y, A))] = Δ
return (final,)
end
return (y, pullback)
end
62 changes: 62 additions & 0 deletions test/ImageBase.jl/fdiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using ImageBase.FiniteDiff: fdiff, fdiff!
@testset "fdiff" begin
# Base.diff doesn't promote integer to float
@test ImageBase.FiniteDiff.maybe_floattype(Int) == Int
@test ImageBase.FiniteDiff.maybe_floattype(N0f8) == Float32
@test ImageBase.FiniteDiff.maybe_floattype(RGB{N0f8}) == RGB{Float32}
@testset "NumericalTests" begin
a = reshape(collect(1:9), 3, 3)
b_fd_1 = [1 1 1; 1 1 1; -2 -2 -2]
b_fd_2 = [3 3 -6; 3 3 -6; 3 3 -6]
b_bd_1 = [-2 -2 -2; 1 1 1; 1 1 1]
b_bd_2 = [-6 3 3; -6 3 3; -6 3 3]
out = similar(a)

@test fdiff(a, dims=1) == b_fd_1
@test fdiff(a, dims=2) == b_fd_2
@test fdiff(a, dims=1, rev=true) == b_bd_1
@test fdiff(a, dims=2, rev=true) == b_bd_2
fdiff!(out, a, dims=1)
@test out == b_fd_1
fdiff!(out, a, dims=2)
@test out == b_fd_2
fdiff!(out, a, dims=1, rev=true)
@test out == b_bd_1
fdiff!(out, a, dims=2, rev=true)
@test out == b_bd_2
end
@testset "Differentiability" begin
a_fd_1 = [2 4 8; 3 9 27; 4 16 64]
a_fd_2 = [3 6 9; 6 18 27; 9 27 54; 12 36 81]
@testset "Testing basic fdiff" begin
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2))[1] == ones(Float64,size(a_fd_2))
end
@testset "Testing fdiff with rev" begin
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true))[1] == ones(Float64,size(a_fd_2))
end
@testset "Testing fdiff with boundary condition" begin
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,boundary=:periodic))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,boundary=:periodic))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,boundary=:zero))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,boundary=:zero))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,boundary=:periodic))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,boundary=:periodic))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,boundary=:zero))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,boundary=:zero))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_2))
end
end
end
90 changes: 90 additions & 0 deletions test/ImageBase.jl/statistics.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
@testset "Statistics" begin
a_fd_1 = [2 4 8; 3 9 27; 4 16 64]
a_fd_2 = [3 6 9; 6 18 27; 9 27 54; 12 36 81]
a_fd_3 = rand(10, 10)
a_fd_4 = randn(6, 4)
b_fd_1 = zeros(Float64, size(a_fd_1))
b_fd_2 = zeros(Float64, size(a_fd_2))
b_fd_3 = zeros(Float64, size(a_fd_3))
b_fd_4 = zeros(Float64, size(a_fd_4))
e_fd_1 = zeros(Float64, size(a_fd_1))
e_fd_2 = zeros(Float64, size(a_fd_2))
e_fd_3 = zeros(Float64, size(a_fd_3))
e_fd_4 = zeros(Float64, size(a_fd_4))
c_fd_1 = minimum_finite(a_fd_1)
c_fd_2 = minimum_finite(a_fd_2)
c_fd_3 = minimum_finite(a_fd_3)
c_fd_4 = minimum_finite(a_fd_4)
d_fd_1 = maximum_finite(a_fd_1)
d_fd_2 = maximum_finite(a_fd_2)
d_fd_3 = maximum_finite(a_fd_3)
d_fd_4 = maximum_finite(a_fd_4)
@testset "NumericalTests" begin
@testset "Testing sumfinite" begin
@test sumfinite(a_fd_1) == sum(a_fd_1)
@test sumfinite(a_fd_2) == sum(a_fd_2)
@test sumfinite(a_fd_3) == sum(a_fd_3)
@test sumfinite(a_fd_4) == sum(a_fd_4)
@test sumfinite(a_fd_1) == 137
@test sumfinite(a_fd_2) == 288
end
@testset "Testing meanfinite" begin
@test meanfinite(a_fd_1) ≈ mean(a_fd_1)
@test meanfinite(a_fd_2) ≈ mean(a_fd_2)
@test meanfinite(a_fd_3) ≈ mean(a_fd_3)
@test meanfinite(a_fd_4) ≈ mean(a_fd_4)
@test meanfinite(a_fd_1) ≈ 15.222222222222221
@test meanfinite(a_fd_2) ≈ 24.0
end
@testset "Testing minimum_finite" begin
@test minimum_finite(a_fd_1) == minimum(a_fd_1)
@test minimum_finite(a_fd_2) == minimum(a_fd_2)
@test minimum_finite(a_fd_3) == minimum(a_fd_3)
@test minimum_finite(a_fd_4) == minimum(a_fd_4)
@test minimum_finite(a_fd_1) == 2
@test minimum_finite(a_fd_2) == 3
end
@testset "Testing maximum_finite" begin
@test maximum_finite(a_fd_1) == maximum(a_fd_1)
@test maximum_finite(a_fd_2) == maximum(a_fd_2)
@test maximum_finite(a_fd_3) == maximum(a_fd_3)
@test maximum_finite(a_fd_4) == maximum(a_fd_4)
@test maximum_finite(a_fd_1) == 64
@test maximum_finite(a_fd_2) == 81
end
end
@testset "Testing Differentiability" begin
@testset "Testing sumfinite" begin
@test Zygote.gradient(sumfinite, a_fd_1)[1] == ones(Float64, size(a_fd_1))
@test Zygote.gradient(sumfinite, a_fd_2)[1] == ones(Float64, size(a_fd_2))
@test Zygote.gradient(sumfinite, a_fd_3)[1] == ones(Float64, size(a_fd_3))
@test Zygote.gradient(sumfinite, a_fd_4)[1] == ones(Float64, size(a_fd_4))
end
@testset "Testing meanfinite" begin
@test Zygote.gradient(meanfinite, a_fd_1)[1] == fill((1 / length(a_fd_1)), size(a_fd_1))
@test Zygote.gradient(meanfinite, a_fd_2)[1] == fill((1 / length(a_fd_2)), size(a_fd_2))
@test Zygote.gradient(meanfinite, a_fd_3)[1] == fill(1 / length(a_fd_3), size(a_fd_3))
@test Zygote.gradient(meanfinite, a_fd_4)[1] == fill(1 / length(a_fd_4), size(a_fd_4))
end
@testset "Testing minimum_finite" begin
b_fd_1[first(findall(x -> x == c_fd_1, a_fd_1))] = 1
b_fd_2[first(findall(x -> x == c_fd_2, a_fd_2))] = 1
b_fd_3[first(findall(x -> x == c_fd_3, a_fd_3))] = 1
b_fd_4[first(findall(x -> x == c_fd_4, a_fd_4))] = 1
@test Zygote.gradient(minimum_finite, a_fd_1)[1] == b_fd_1
@test Zygote.gradient(minimum_finite, a_fd_2)[1] == b_fd_2
@test Zygote.gradient(minimum_finite, a_fd_3)[1] == b_fd_3
@test Zygote.gradient(minimum_finite, a_fd_4)[1] == b_fd_4
end
@testset "Testing maximum_finite" begin
e_fd_1[last(findall(x -> x == d_fd_1, a_fd_1))] = 1
e_fd_2[last(findall(x -> x == d_fd_2, a_fd_2))] = 1
e_fd_3[last(findall(x -> x == d_fd_3, a_fd_3))] = 1
e_fd_4[last(findall(x -> x == d_fd_4, a_fd_4))] = 1
@test Zygote.gradient(maximum_finite, a_fd_1)[1] == e_fd_1
@test Zygote.gradient(maximum_finite, a_fd_2)[1] == e_fd_2
@test Zygote.gradient(maximum_finite, a_fd_3)[1] == e_fd_3
@test Zygote.gradient(maximum_finite, a_fd_4)[1] == e_fd_4
end
end
end
11 changes: 10 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ using Test,
FiniteDifferences,
ChainRulesCore,
CoordinateTransformations,
Rotations
Rotations,
ImageBase,
Statistics

@testset "DiffImages" begin
@info "Testing Colorspace modules"
Expand All @@ -23,4 +25,11 @@ using Test,
@testset "Warps" begin
include("geometry/warp.jl")
end
@info "Testing ImageBase modules"
@testset "FiniteDifferences" begin
include("ImageBase.jl/fdiff.jl")
end
@testset "Statistics" begin
include("ImageBase.jl/statistics.jl")
end
end