From 7d99c31c76d0f726a364befd6b8de5b963a673d0 Mon Sep 17 00:00:00 2001 From: Aman Sharma <76823502+arcAman07@users.noreply.github.com> Date: Fri, 15 Apr 2022 17:45:42 +0530 Subject: [PATCH 01/10] Update adjoints.jl --- test/geometry/adjoints.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/geometry/adjoints.jl b/test/geometry/adjoints.jl index ef94be7..5224764 100644 --- a/test/geometry/adjoints.jl +++ b/test/geometry/adjoints.jl @@ -107,10 +107,10 @@ end for ind in ((2.5, 2.5), (5, 5)) if t <: Colorant zy = Zygote.gradient((x,y)->_sep(ImageTransformations._getindex(x,y)), itp, ind) - @test zy[2] ≈ fieldsum.(Interpolations.gradient(itp, ind...)) + @test all(zy[2] .≈ Tuple(_sep.(Interpolations.gradient(itp, ind...)))) else zy = Zygote.gradient((x,y)->ImageTransformations._getindex(x,y), itp, ind) - @test zy[2] ≈ Interpolations.gradient(itp, ind...) + @test all(zy[2] .≈ Tuple(Interpolations.gradient(itp, ind...))) end end end From e345063303a75b5b047911e0d50fd59c12ce8fa2 Mon Sep 17 00:00:00 2001 From: Aman Date: Thu, 5 May 2022 13:38:36 +0530 Subject: [PATCH 02/10] Errored tests fixed --- src/DiffImages.jl | 2 +- test/colors/conversions.jl | 46 +++++++++++++++++++------------------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/DiffImages.jl b/src/DiffImages.jl index 4d2f15d..187d695 100644 --- a/src/DiffImages.jl +++ b/src/DiffImages.jl @@ -6,7 +6,7 @@ using ImageCore, CoordinateTransformations, Interpolations, ChainRulesCore, - LinearAlgebra, + LinearAlgebra, Rotations using Zygote: @adjoint diff --git a/test/colors/conversions.jl b/test/colors/conversions.jl index 9164f16..5a7f9f4 100644 --- a/test/colors/conversions.jl +++ b/test/colors/conversions.jl @@ -23,14 +23,14 @@ ds3 = (7, 7, 3, 5) ds2 = (7, 7, 2, 5) ds1 = (7, 7, 1, 5) - + cspaces_with_random = (YIQ, LCHab, - Lab, - BGRA, ABGR, BGR, - RGBA, ARGB, RGB, - HSL, - AGray, GrayA, Gray, + Lab, + BGRA, ABGR, BGR, + RGBA, ARGB, RGB, + HSL, + AGray, GrayA, Gray, HSV) # colorspaces those have samplers defined for Base.Random cspaces_all = (HSV, AHSV, HSVA, @@ -47,7 +47,7 @@ DIN99, ADIN99, DIN99A, LMS, ALMS, LMSA, YIQ, AYIQ, YIQA) # tuple of all of Colorspaces and their transparent variants - + cspaces_4 = (BGRA, ABGR, RGBA, ARGB, AHSL, HSLA, @@ -120,13 +120,13 @@ for cs in cspaces_with_random i = rand(cs, ds...) if cs ∈ (Gray,) - @test channelview(gradient(x -> sum(channelview(x)), i)[1]) == ones(Float64, ds...) + @test channelview(Zygote.gradient(x -> sum(channelview(x)), i)[1]) == ones(Float64, ds...) elseif cs ∈ (AGray, GrayA) - @test channelview(gradient(x -> sum(channelview(x)), i)[1]) == ones(Float64, 2, ds...) + @test channelview(Zygote.gradient(x -> sum(channelview(x)), i)[1]) == ones(Float64, 2, ds...) elseif cs ∈ (BGRA, ABGR, RGBA, ARGB) - @test channelview(gradient(x -> sum(channelview(x)), i)[1]) == ones(Float64, 4, ds...) + @test channelview(Zygote.gradient(x -> sum(channelview(x)), i)[1]) == ones(Float64, 4, ds...) else - @test channelview(gradient(x -> sum(channelview(x)), i)[1]) == ones(Float64, 3, ds...) + @test channelview(Zygote.gradient(x -> sum(channelview(x)), i)[1]) == ones(Float64, 3, ds...) end end end @@ -136,16 +136,16 @@ for cs in cspaces_all if cs ∈ cspaces_4 i = rand(ds4[end - 1], ds4[1:end - 2]..., ds4[end]) - @test gradient(x->sum(channelview(colorview(cs,x))),i)[1]==ones(size(i)) + @test Zygote.gradient(x->sum(channelview(colorview(cs,x))),i)[1]==ones(size(i)) elseif cs ∈ (AGray, GrayA) i = rand(ds2[end - 1], ds2[1:end - 2]..., ds2[end]) - @test gradient(x->sum(channelview(colorview(cs,x))),i)[1]==ones(size(i)) + @test Zygote.gradient(x->sum(channelview(colorview(cs,x))),i)[1]==ones(size(i)) elseif cs ∈ (Gray,) i = rand(ds1[1:end - 2]..., ds1[end]) - @test gradient(x->sum(channelview(colorview(cs,x))),i)[1]==ones(size(i)) + @test Zygote.gradient(x->sum(channelview(colorview(cs,x))),i)[1]==ones(size(i)) else i = rand(ds3[end - 1], ds3[1:end - 2]..., ds3[end]) - @test gradient(x->sum(channelview(colorview(cs,x))),i)[1]==ones(size(i)) + @test Zygote.gradient(x->sum(channelview(colorview(cs,x))),i)[1]==ones(size(i)) end end end @@ -154,13 +154,13 @@ for cs in cspaces_with_random i = rand(cs, ds...) if cs ∈ (Gray,) - @test channelify(gradient(x -> sum(channelify(x)), i)[1]) == ones(Float64, ds[1:end-1]..., 1, ds[end]) + @test channelify(Zygote.gradient(x -> sum(channelify(x)), i)[1]) == ones(Float64, ds[1:end-1]..., 1, ds[end]) elseif cs ∈ (AGray, GrayA) - @test channelify(gradient(x -> sum(channelify(x)), i)[1]) == ones(Float64, ds[1:end-1]..., 2, ds[end]) + @test channelify(Zygote.gradient(x -> sum(channelify(x)), i)[1]) == ones(Float64, ds[1:end-1]..., 2, ds[end]) elseif cs ∈ (BGRA, ABGR, RGBA, ARGB) - @test channelify(gradient(x -> sum(channelify(x)), i)[1]) == ones(Float64, ds[1:end-1]..., 4, ds[end]) + @test channelify(Zygote.gradient(x -> sum(channelify(x)), i)[1]) == ones(Float64, ds[1:end-1]..., 4, ds[end]) else - @test channelify(gradient(x -> sum(channelify(x)), i)[1]) == ones(Float64, ds[1:end-1]..., 3, ds[end]) + @test channelify(Zygote.gradient(x -> sum(channelify(x)), i)[1]) == ones(Float64, ds[1:end-1]..., 3, ds[end]) end end end @@ -169,16 +169,16 @@ for cs in cspaces_all if cs ∈ cspaces_4 i = rand(ds4...) - @test gradient(x->sum(channelify(colorify(cs,x))),i)[1]==ones(size(i)) + @test Zygote.gradient(x->sum(channelify(colorify(cs,x))),i)[1]==ones(size(i)) elseif cs ∈ (AGray, GrayA) i = rand(ds2...) - @test gradient(x->sum(channelify(colorify(cs,x))),i)[1]==ones(size(i)) + @test Zygote.gradient(x->sum(channelify(colorify(cs,x))),i)[1]==ones(size(i)) elseif cs ∈ (Gray,) i = rand(ds1...) - @test gradient(x->sum(channelify(colorify(cs,x))),i)[1]==ones(size(i)) + @test Zygote.gradient(x->sum(channelify(colorify(cs,x))),i)[1]==ones(size(i)) else i = rand(ds3...) - @test gradient(x->sum(channelify(colorify(cs,x))),i)[1]==ones(size(i)) + @test Zygote.gradient(x->sum(channelify(colorify(cs,x))),i)[1]==ones(size(i)) end end end From 3a090d871766fbac0aa02a1818dfdaf9a18a1dd7 Mon Sep 17 00:00:00 2001 From: Aman Date: Thu, 5 May 2022 14:23:08 +0530 Subject: [PATCH 03/10] Fix whitespace --- src/DiffImages.jl | 2 +- test/colors/conversions.jl | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/DiffImages.jl b/src/DiffImages.jl index 187d695..38f3861 100644 --- a/src/DiffImages.jl +++ b/src/DiffImages.jl @@ -6,7 +6,7 @@ using ImageCore, CoordinateTransformations, Interpolations, ChainRulesCore, - LinearAlgebra, + LinearAlgebra, Rotations using Zygote: @adjoint diff --git a/test/colors/conversions.jl b/test/colors/conversions.jl index 5a7f9f4..8f7d417 100644 --- a/test/colors/conversions.jl +++ b/test/colors/conversions.jl @@ -23,14 +23,14 @@ ds3 = (7, 7, 3, 5) ds2 = (7, 7, 2, 5) ds1 = (7, 7, 1, 5) - - cspaces_with_random = (YIQ, - LCHab, - Lab, - BGRA, ABGR, BGR, - RGBA, ARGB, RGB, - HSL, - AGray, GrayA, Gray, + + cspaces_with_random = (YIQ, + LCHab, + Lab, + BGRA, ABGR, BGR, + RGBA, ARGB, RGB, + HSL, + AGray, GrayA, Gray, HSV) # colorspaces those have samplers defined for Base.Random cspaces_all = (HSV, AHSV, HSVA, @@ -47,7 +47,7 @@ DIN99, ADIN99, DIN99A, LMS, ALMS, LMSA, YIQ, AYIQ, YIQA) # tuple of all of Colorspaces and their transparent variants - + cspaces_4 = (BGRA, ABGR, RGBA, ARGB, AHSL, HSLA, From 55b9b4df42543eae1e5035f415899ca5595837c3 Mon Sep 17 00:00:00 2001 From: Aman Sharma <76823502+arcAman07@users.noreply.github.com> Date: Thu, 5 May 2022 14:30:49 +0530 Subject: [PATCH 04/10] Fix whitespaces --- test/colors/conversions.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/colors/conversions.jl b/test/colors/conversions.jl index 8f7d417..69a58f6 100644 --- a/test/colors/conversions.jl +++ b/test/colors/conversions.jl @@ -24,13 +24,13 @@ ds2 = (7, 7, 2, 5) ds1 = (7, 7, 1, 5) - cspaces_with_random = (YIQ, - LCHab, - Lab, - BGRA, ABGR, BGR, - RGBA, ARGB, RGB, - HSL, - AGray, GrayA, Gray, + cspaces_with_random = (YIQ, + LCHab, + Lab, + BGRA, ABGR, BGR, + RGBA, ARGB, RGB, + HSL, + AGray, GrayA, Gray, HSV) # colorspaces those have samplers defined for Base.Random cspaces_all = (HSV, AHSV, HSVA, From f5a852b978979796f6ee51396cac7a5a76745b20 Mon Sep 17 00:00:00 2001 From: Aman Sharma <76823502+arcAman07@users.noreply.github.com> Date: Thu, 5 May 2022 14:31:30 +0530 Subject: [PATCH 05/10] Fix whitespaces --- src/DiffImages.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DiffImages.jl b/src/DiffImages.jl index 38f3861..4d2f15d 100644 --- a/src/DiffImages.jl +++ b/src/DiffImages.jl @@ -6,7 +6,7 @@ using ImageCore, CoordinateTransformations, Interpolations, ChainRulesCore, - LinearAlgebra, + LinearAlgebra, Rotations using Zygote: @adjoint From a3798232ce19922239fb8e8bb2933b559ac26da2 Mon Sep 17 00:00:00 2001 From: Aman Date: Tue, 10 May 2022 17:50:20 +0530 Subject: [PATCH 06/10] fdiff adjoint added --- Project.toml | 1 + src/DiffImages.jl | 6 ++++-- src/ImageBase.jl/fdiff.jl | 11 ++++++++++ test/ImageBase.jl/fdiff.jl | 44 ++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 7 +++++- 5 files changed, 66 insertions(+), 3 deletions(-) create mode 100644 src/ImageBase.jl/fdiff.jl create mode 100644 test/ImageBase.jl/fdiff.jl diff --git a/Project.toml b/Project.toml index df4ed07..5b34bbc 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.1.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" CoordinateTransformations = "150eb455-5306-5404-9cee-2592286d6298" +ImageBase = "c817782e-172a-44cc-b673-b171935fbb9e" ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534" ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" diff --git a/src/DiffImages.jl b/src/DiffImages.jl index 4d2f15d..614b150 100644 --- a/src/DiffImages.jl +++ b/src/DiffImages.jl @@ -7,14 +7,16 @@ using ImageCore, Interpolations, ChainRulesCore, LinearAlgebra, - Rotations + Rotations, + ImageBase using Zygote: @adjoint using ChainRulesCore: NoTangent -export colorify, channelify +export colorify, channelify, fdiff include("colors/conversions.jl") include("geometry/warp.jl") include("geometry/adjoints.jl") +include("ImageBase.jl/fdiff.jl") end diff --git a/src/ImageBase.jl/fdiff.jl b/src/ImageBase.jl/fdiff.jl new file mode 100644 index 0000000..d2a59f5 --- /dev/null +++ b/src/ImageBase.jl/fdiff.jl @@ -0,0 +1,11 @@ +# TODO(arcAman07): support RGB inputs, currently works only for GrayScale Images +# TODO(arcAman07): support N dimensional case, currently works only for 2 dimensional case +@adjoint function fdiff(A::AbstractArray; kwargs...) + y = fdiff!(similar(A, maybe_floattype(eltype(A))), A; kwargs...) + final = similar(A, eltype(A)) + function pullback(Δ) + fill!(final, Δ) + return (final,) + end + return (y, pullback) +end \ No newline at end of file diff --git a/test/ImageBase.jl/fdiff.jl b/test/ImageBase.jl/fdiff.jl new file mode 100644 index 0000000..6208b99 --- /dev/null +++ b/test/ImageBase.jl/fdiff.jl @@ -0,0 +1,44 @@ +using ImageBase.FiniteDiff: fdiff, fdiff! +@testset "fdiff" begin + # Base.diff doesn't promote integer to float + @test ImageBase.FiniteDiff.maybe_floattype(Int) == Int + @test ImageBase.FiniteDiff.maybe_floattype(N0f8) == Float32 + @test ImageBase.FiniteDiff.maybe_floattype(RGB{N0f8}) == RGB{Float32} + @testset "NumericalTests" begin + a = reshape(collect(1:9), 3, 3) + b_fd_1 = [1 1 1; 1 1 1; -2 -2 -2] + b_fd_2 = [3 3 -6; 3 3 -6; 3 3 -6] + b_bd_1 = [-2 -2 -2; 1 1 1; 1 1 1] + b_bd_2 = [-6 3 3; -6 3 3; -6 3 3] + out = similar(a) + + @test fdiff(a, dims=1) == b_fd_1 + @test fdiff(a, dims=2) == b_fd_2 + @test fdiff(a, dims=1, rev=true) == b_bd_1 + @test fdiff(a, dims=2, rev=true) == b_bd_2 + fdiff!(out, a, dims=1) + @test out == b_fd_1 + fdiff!(out, a, dims=2) + @test out == b_fd_2 + fdiff!(out, a, dims=1, rev=true) + @test out == b_bd_1 + fdiff!(out, a, dims=2, rev=true) + @test out == b_bd_2 + end + @testset "Differentiability" begin + a_fd_1 = [2 4 8; 3 9 27; 4 16 64] + a_fd_2 = [3 6 9 12; 6 18 27 36; 9 27 54 81; 12 36 81 144] + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true, boundary=:zero))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1, boundary=:zero))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true, boundary=:zero))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2, boundary=:zero))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true))[1] == ones(Float64,size(a_fd_2)) + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 75e60b8..b116311 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,7 +8,8 @@ using Test, FiniteDifferences, ChainRulesCore, CoordinateTransformations, - Rotations + Rotations, + ImageBase @testset "DiffImages" begin @info "Testing Colorspace modules" @@ -23,4 +24,8 @@ using Test, @testset "Warps" begin include("geometry/warp.jl") end + @info "Testing ImageBase modules" + @testset "FiniteDifferences" begin + include("ImageBase.jl/fdiff.jl") + end end From 01e033dca6ff782fb292c1a6d6e53cf0bf64ad71 Mon Sep 17 00:00:00 2001 From: Aman Date: Tue, 10 May 2022 19:45:24 +0530 Subject: [PATCH 07/10] More tests added --- test/ImageBase.jl/fdiff.jl | 44 +++++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/test/ImageBase.jl/fdiff.jl b/test/ImageBase.jl/fdiff.jl index 6208b99..229fa68 100644 --- a/test/ImageBase.jl/fdiff.jl +++ b/test/ImageBase.jl/fdiff.jl @@ -27,18 +27,36 @@ using ImageBase.FiniteDiff: fdiff, fdiff! end @testset "Differentiability" begin a_fd_1 = [2 4 8; 3 9 27; 4 16 64] - a_fd_2 = [3 6 9 12; 6 18 27 36; 9 27 54 81; 12 36 81 144] - @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2))[1] == ones(Float64,size(a_fd_1)) - @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1))[1] == ones(Float64,size(a_fd_1)) - @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true))[1] == ones(Float64,size(a_fd_1)) - @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true))[1] == ones(Float64,size(a_fd_1)) - @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true, boundary=:zero))[1] == ones(Float64,size(a_fd_1)) - @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1, boundary=:zero))[1] == ones(Float64,size(a_fd_1)) - @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true, boundary=:zero))[1] == ones(Float64,size(a_fd_1)) - @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2, boundary=:zero))[1] == ones(Float64,size(a_fd_1)) - @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2))[1] == ones(Float64,size(a_fd_2)) - @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1))[1] == ones(Float64,size(a_fd_2)) - @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true))[1] == ones(Float64,size(a_fd_2)) - @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true))[1] == ones(Float64,size(a_fd_2)) + a_fd_2 = [3 6 9; 6 18 27; 9 27 54; 12 36 81] + @testset "Testing basic fdiff" begin + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2))[1] == ones(Float64,size(a_fd_2)) + end + @testset "Testing fdiff with rev" begin + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true))[1] == ones(Float64,size(a_fd_2)) + end + @testset "Testing fdiff with boundary condition" begin + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,boundary=:periodic))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,boundary=:periodic))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,boundary=:zero))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,boundary=:zero))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,boundary=:periodic))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,boundary=:periodic))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,boundary=:zero))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,boundary=:zero))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_2)) + end end end \ No newline at end of file From bf70f5f76b5853779f7625b30b630b3b6c96fb43 Mon Sep 17 00:00:00 2001 From: Aman Date: Tue, 10 May 2022 20:14:59 +0530 Subject: [PATCH 08/10] Add compat field --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 5b34bbc..9bcb97d 100644 --- a/Project.toml +++ b/Project.toml @@ -25,6 +25,7 @@ Interpolations = "0.13.4" Rotations = "1.0.2" StaticArrays = "1.2" Zygote = "0.6.17" +ImageBase = "0.1.5" [extras] FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" From aa6b56e3b5de7273549dd1f412b524ad918c51db Mon Sep 17 00:00:00 2001 From: Aman Date: Wed, 11 May 2022 18:24:35 +0530 Subject: [PATCH 09/10] statistics adjoints added --- Project.toml | 3 +- src/DiffImages.jl | 1 + src/ImageBase.jl/fdiff.jl | 4 +- src/ImageBase.jl/statistics.jl | 35 +++++++++++++ test/ImageBase.jl/statistics.jl | 93 +++++++++++++++++++++++++++++++++ test/runtests.jl | 6 ++- 6 files changed, 137 insertions(+), 5 deletions(-) create mode 100644 src/ImageBase.jl/statistics.jl create mode 100644 test/ImageBase.jl/statistics.jl diff --git a/Project.toml b/Project.toml index 9bcb97d..a26bb62 100644 --- a/Project.toml +++ b/Project.toml @@ -13,19 +13,20 @@ Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ChainRulesCore = "1.3.0" CoordinateTransformations = "0.6.1" +ImageBase = "0.1.5" ImageCore = "0.9" ImageTransformations = "0.8, 0.9" Interpolations = "0.13.4" Rotations = "1.0.2" StaticArrays = "1.2" Zygote = "0.6.17" -ImageBase = "0.1.5" [extras] FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" diff --git a/src/DiffImages.jl b/src/DiffImages.jl index 614b150..b07ecd4 100644 --- a/src/DiffImages.jl +++ b/src/DiffImages.jl @@ -18,5 +18,6 @@ include("colors/conversions.jl") include("geometry/warp.jl") include("geometry/adjoints.jl") include("ImageBase.jl/fdiff.jl") +include("ImageBase.jl/statistics.jl") end diff --git a/src/ImageBase.jl/fdiff.jl b/src/ImageBase.jl/fdiff.jl index d2a59f5..d89c90a 100644 --- a/src/ImageBase.jl/fdiff.jl +++ b/src/ImageBase.jl/fdiff.jl @@ -2,10 +2,8 @@ # TODO(arcAman07): support N dimensional case, currently works only for 2 dimensional case @adjoint function fdiff(A::AbstractArray; kwargs...) y = fdiff!(similar(A, maybe_floattype(eltype(A))), A; kwargs...) - final = similar(A, eltype(A)) function pullback(Δ) - fill!(final, Δ) - return (final,) + return (fill(Δ, size(A)),) end return (y, pullback) end \ No newline at end of file diff --git a/src/ImageBase.jl/statistics.jl b/src/ImageBase.jl/statistics.jl new file mode 100644 index 0000000..10cc61d --- /dev/null +++ b/src/ImageBase.jl/statistics.jl @@ -0,0 +1,35 @@ +@adjoint function sumfinite(A::AbstractArray{T,N}; kwargs...) where {T,N} + y = ImageBase.sumfinite(identity, A; kwargs...) + function pullback(Δ) + return (fill(Δ,size(A)),) + end + return (y, pullback) +end + +@adjoint function meanfinite(A::AbstractArray{T,N}; kwargs...) where {T,N} + y = ImageBase.meanfinite(identity, A; kwargs...) + function pullback(Δ) + return (fill(Δ / length(A),size(A)),) + end + return (y, pullback) +end + +@adjoint function maximum_finite(A::AbstractArray{T,N}; kwargs...) where {T,N} + y = ImageBase.maximum_finite(identity, A; kwargs...) + final = zeros(Float64, size(A)) + function pullback(Δ) + final[last(findall(x -> x == y, A))] = Δ + return (final,) + end + return (y, pullback) +end + +@adjoint function minimum_finite(A::AbstractArray{T,N}; kwargs...) where {T,N} + y = ImageBase.minimum_finite(identity, A; kwargs...) + final = zeros(Float64, size(A)) + function pullback(Δ) + final[first(findall(x -> x == y, A))] = Δ + return (final,) + end + return (y, pullback) +end \ No newline at end of file diff --git a/test/ImageBase.jl/statistics.jl b/test/ImageBase.jl/statistics.jl new file mode 100644 index 0000000..551cc7c --- /dev/null +++ b/test/ImageBase.jl/statistics.jl @@ -0,0 +1,93 @@ +using ImageBase +using ImageBase: varmult +using Statistics +@testset "Statistics" begin + a_fd_1 = [2 4 8; 3 9 27; 4 16 64] + a_fd_2 = [3 6 9; 6 18 27; 9 27 54; 12 36 81] + a_fd_3 = rand(10, 10) + a_fd_4 = randn(6, 4) + b_fd_1 = zeros(Float64, size(a_fd_1)) + b_fd_2 = zeros(Float64, size(a_fd_2)) + b_fd_3 = zeros(Float64, size(a_fd_3)) + b_fd_4 = zeros(Float64, size(a_fd_4)) + e_fd_1 = zeros(Float64, size(a_fd_1)) + e_fd_2 = zeros(Float64, size(a_fd_2)) + e_fd_3 = zeros(Float64, size(a_fd_3)) + e_fd_4 = zeros(Float64, size(a_fd_4)) + c_fd_1 = minimum_finite(a_fd_1) + c_fd_2 = minimum_finite(a_fd_2) + c_fd_3 = minimum_finite(a_fd_3) + c_fd_4 = minimum_finite(a_fd_4) + d_fd_1 = maximum_finite(a_fd_1) + d_fd_2 = maximum_finite(a_fd_2) + d_fd_3 = maximum_finite(a_fd_3) + d_fd_4 = maximum_finite(a_fd_4) + @testset "NumericalTests" begin + @testset "Testing sumfinite" begin + @test sumfinite(a_fd_1) == sum(a_fd_1) + @test sumfinite(a_fd_2) == sum(a_fd_2) + @test sumfinite(a_fd_3) == sum(a_fd_3) + @test sumfinite(a_fd_4) == sum(a_fd_4) + @test sumfinite(a_fd_1) == 137 + @test sumfinite(a_fd_2) == 288 + end + @testset "Testing meanfinite" begin + @test meanfinite(a_fd_1) ≈ mean(a_fd_1) + @test meanfinite(a_fd_2) ≈ mean(a_fd_2) + @test meanfinite(a_fd_3) ≈ mean(a_fd_3) + @test meanfinite(a_fd_4) ≈ mean(a_fd_4) + @test meanfinite(a_fd_1) ≈ 15.222222222222221 + @test meanfinite(a_fd_2) ≈ 24.0 + end + @testset "Testing minimum_finite" begin + @test minimum_finite(a_fd_1) == minimum(a_fd_1) + @test minimum_finite(a_fd_2) == minimum(a_fd_2) + @test minimum_finite(a_fd_3) == minimum(a_fd_3) + @test minimum_finite(a_fd_4) == minimum(a_fd_4) + @test minimum_finite(a_fd_1) == 2 + @test minimum_finite(a_fd_2) == 3 + end + @testset "Testing maximum_finite" begin + @test maximum_finite(a_fd_1) == maximum(a_fd_1) + @test maximum_finite(a_fd_2) == maximum(a_fd_2) + @test maximum_finite(a_fd_3) == maximum(a_fd_3) + @test maximum_finite(a_fd_4) == maximum(a_fd_4) + @test maximum_finite(a_fd_1) == 64 + @test maximum_finite(a_fd_2) == 81 + end + end + @testset "Testing Differentiability" begin + @testset "Testing sumfinite" begin + @test Zygote.gradient(sumfinite, a_fd_1)[1] == ones(Float64, size(a_fd_1)) + @test Zygote.gradient(sumfinite, a_fd_2)[1] == ones(Float64, size(a_fd_2)) + @test Zygote.gradient(sumfinite, a_fd_3)[1] == ones(Float64, size(a_fd_3)) + @test Zygote.gradient(sumfinite, a_fd_4)[1] == ones(Float64, size(a_fd_4)) + end + @testset "Testing meanfinite" begin + @test Zygote.gradient(meanfinite, a_fd_1)[1] == fill((1 / length(a_fd_1)), size(a_fd_1)) + @test Zygote.gradient(meanfinite, a_fd_2)[1] == fill((1 / length(a_fd_2)), size(a_fd_2)) + @test Zygote.gradient(meanfinite, a_fd_3)[1] == fill(1 / length(a_fd_3), size(a_fd_3)) + @test Zygote.gradient(meanfinite, a_fd_4)[1] == fill(1 / length(a_fd_4), size(a_fd_4)) + end + @testset "Testing minimum_finite" begin + b_fd_1[first(findall(x -> x == c_fd_1, a_fd_1))] = 1 + b_fd_2[first(findall(x -> x == c_fd_2, a_fd_2))] = 1 + b_fd_3[first(findall(x -> x == c_fd_3, a_fd_3))] = 1 + b_fd_4[first(findall(x -> x == c_fd_4, a_fd_4))] = 1 + @test Zygote.gradient(minimum_finite, a_fd_1)[1] == b_fd_1 + @test Zygote.gradient(minimum_finite, a_fd_2)[1] == b_fd_2 + @test Zygote.gradient(minimum_finite, a_fd_3)[1] == b_fd_3 + @test Zygote.gradient(minimum_finite, a_fd_4)[1] == b_fd_4 + end + @testset "Testing maximum_finite" begin + e_fd_1[last(findall(x -> x == d_fd_1, a_fd_1))] = 1 + e_fd_2[last(findall(x -> x == d_fd_2, a_fd_2))] = 1 + e_fd_3[last(findall(x -> x == d_fd_3, a_fd_3))] = 1 + e_fd_4[last(findall(x -> x == d_fd_4, a_fd_4))] = 1 + @test Zygote.gradient(maximum_finite, a_fd_1)[1] == e_fd_1 + @test Zygote.gradient(maximum_finite, a_fd_2)[1] == e_fd_2 + @test Zygote.gradient(maximum_finite, a_fd_3)[1] == e_fd_3 + @test Zygote.gradient(maximum_finite, a_fd_4)[1] == e_fd_4 + end + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index b116311..34173dd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,7 +9,8 @@ using Test, ChainRulesCore, CoordinateTransformations, Rotations, - ImageBase + ImageBase, + Statistics @testset "DiffImages" begin @info "Testing Colorspace modules" @@ -28,4 +29,7 @@ using Test, @testset "FiniteDifferences" begin include("ImageBase.jl/fdiff.jl") end + @testset "Statistics" begin + include("ImageBase.jl/statistics.jl") + end end From 7e7a830f8756c44b7c7aa3ce2c27071380ac8b02 Mon Sep 17 00:00:00 2001 From: Aman Date: Fri, 13 May 2022 18:12:05 +0530 Subject: [PATCH 10/10] code refined --- test/ImageBase.jl/statistics.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/ImageBase.jl/statistics.jl b/test/ImageBase.jl/statistics.jl index 551cc7c..06f14e9 100644 --- a/test/ImageBase.jl/statistics.jl +++ b/test/ImageBase.jl/statistics.jl @@ -1,6 +1,3 @@ -using ImageBase -using ImageBase: varmult -using Statistics @testset "Statistics" begin a_fd_1 = [2 4 8; 3 9 27; 4 16 64] a_fd_2 = [3 6 9; 6 18 27; 9 27 54; 12 36 81]