From 7d99c31c76d0f726a364befd6b8de5b963a673d0 Mon Sep 17 00:00:00 2001 From: Aman Sharma <76823502+arcAman07@users.noreply.github.com> Date: Fri, 15 Apr 2022 17:45:42 +0530 Subject: [PATCH 01/11] Update adjoints.jl --- test/geometry/adjoints.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/geometry/adjoints.jl b/test/geometry/adjoints.jl index ef94be7..5224764 100644 --- a/test/geometry/adjoints.jl +++ b/test/geometry/adjoints.jl @@ -107,10 +107,10 @@ end for ind in ((2.5, 2.5), (5, 5)) if t <: Colorant zy = Zygote.gradient((x,y)->_sep(ImageTransformations._getindex(x,y)), itp, ind) - @test zy[2] ≈ fieldsum.(Interpolations.gradient(itp, ind...)) + @test all(zy[2] .≈ Tuple(_sep.(Interpolations.gradient(itp, ind...)))) else zy = Zygote.gradient((x,y)->ImageTransformations._getindex(x,y), itp, ind) - @test zy[2] ≈ Interpolations.gradient(itp, ind...) + @test all(zy[2] .≈ Tuple(Interpolations.gradient(itp, ind...))) end end end From e345063303a75b5b047911e0d50fd59c12ce8fa2 Mon Sep 17 00:00:00 2001 From: Aman Date: Thu, 5 May 2022 13:38:36 +0530 Subject: [PATCH 02/11] Errored tests fixed --- src/DiffImages.jl | 2 +- test/colors/conversions.jl | 46 +++++++++++++++++++------------------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/DiffImages.jl b/src/DiffImages.jl index 4d2f15d..187d695 100644 --- a/src/DiffImages.jl +++ b/src/DiffImages.jl @@ -6,7 +6,7 @@ using ImageCore, CoordinateTransformations, Interpolations, ChainRulesCore, - LinearAlgebra, + LinearAlgebra, Rotations using Zygote: @adjoint diff --git a/test/colors/conversions.jl b/test/colors/conversions.jl index 9164f16..5a7f9f4 100644 --- a/test/colors/conversions.jl +++ b/test/colors/conversions.jl @@ -23,14 +23,14 @@ ds3 = (7, 7, 3, 5) ds2 = (7, 7, 2, 5) ds1 = (7, 7, 1, 5) - + cspaces_with_random = (YIQ, LCHab, - Lab, - BGRA, ABGR, BGR, - RGBA, ARGB, RGB, - HSL, - AGray, GrayA, Gray, + Lab, + BGRA, ABGR, BGR, + RGBA, ARGB, RGB, + HSL, + AGray, GrayA, Gray, HSV) # colorspaces those have samplers defined for Base.Random cspaces_all = (HSV, AHSV, HSVA, @@ -47,7 +47,7 @@ DIN99, ADIN99, DIN99A, LMS, ALMS, LMSA, YIQ, AYIQ, YIQA) # tuple of all of Colorspaces and their transparent variants - + cspaces_4 = (BGRA, ABGR, RGBA, ARGB, AHSL, HSLA, @@ -120,13 +120,13 @@ for cs in cspaces_with_random i = rand(cs, ds...) if cs ∈ (Gray,) - @test channelview(gradient(x -> sum(channelview(x)), i)[1]) == ones(Float64, ds...) + @test channelview(Zygote.gradient(x -> sum(channelview(x)), i)[1]) == ones(Float64, ds...) elseif cs ∈ (AGray, GrayA) - @test channelview(gradient(x -> sum(channelview(x)), i)[1]) == ones(Float64, 2, ds...) + @test channelview(Zygote.gradient(x -> sum(channelview(x)), i)[1]) == ones(Float64, 2, ds...) elseif cs ∈ (BGRA, ABGR, RGBA, ARGB) - @test channelview(gradient(x -> sum(channelview(x)), i)[1]) == ones(Float64, 4, ds...) + @test channelview(Zygote.gradient(x -> sum(channelview(x)), i)[1]) == ones(Float64, 4, ds...) else - @test channelview(gradient(x -> sum(channelview(x)), i)[1]) == ones(Float64, 3, ds...) + @test channelview(Zygote.gradient(x -> sum(channelview(x)), i)[1]) == ones(Float64, 3, ds...) end end end @@ -136,16 +136,16 @@ for cs in cspaces_all if cs ∈ cspaces_4 i = rand(ds4[end - 1], ds4[1:end - 2]..., ds4[end]) - @test gradient(x->sum(channelview(colorview(cs,x))),i)[1]==ones(size(i)) + @test Zygote.gradient(x->sum(channelview(colorview(cs,x))),i)[1]==ones(size(i)) elseif cs ∈ (AGray, GrayA) i = rand(ds2[end - 1], ds2[1:end - 2]..., ds2[end]) - @test gradient(x->sum(channelview(colorview(cs,x))),i)[1]==ones(size(i)) + @test Zygote.gradient(x->sum(channelview(colorview(cs,x))),i)[1]==ones(size(i)) elseif cs ∈ (Gray,) i = rand(ds1[1:end - 2]..., ds1[end]) - @test gradient(x->sum(channelview(colorview(cs,x))),i)[1]==ones(size(i)) + @test Zygote.gradient(x->sum(channelview(colorview(cs,x))),i)[1]==ones(size(i)) else i = rand(ds3[end - 1], ds3[1:end - 2]..., ds3[end]) - @test gradient(x->sum(channelview(colorview(cs,x))),i)[1]==ones(size(i)) + @test Zygote.gradient(x->sum(channelview(colorview(cs,x))),i)[1]==ones(size(i)) end end end @@ -154,13 +154,13 @@ for cs in cspaces_with_random i = rand(cs, ds...) if cs ∈ (Gray,) - @test channelify(gradient(x -> sum(channelify(x)), i)[1]) == ones(Float64, ds[1:end-1]..., 1, ds[end]) + @test channelify(Zygote.gradient(x -> sum(channelify(x)), i)[1]) == ones(Float64, ds[1:end-1]..., 1, ds[end]) elseif cs ∈ (AGray, GrayA) - @test channelify(gradient(x -> sum(channelify(x)), i)[1]) == ones(Float64, ds[1:end-1]..., 2, ds[end]) + @test channelify(Zygote.gradient(x -> sum(channelify(x)), i)[1]) == ones(Float64, ds[1:end-1]..., 2, ds[end]) elseif cs ∈ (BGRA, ABGR, RGBA, ARGB) - @test channelify(gradient(x -> sum(channelify(x)), i)[1]) == ones(Float64, ds[1:end-1]..., 4, ds[end]) + @test channelify(Zygote.gradient(x -> sum(channelify(x)), i)[1]) == ones(Float64, ds[1:end-1]..., 4, ds[end]) else - @test channelify(gradient(x -> sum(channelify(x)), i)[1]) == ones(Float64, ds[1:end-1]..., 3, ds[end]) + @test channelify(Zygote.gradient(x -> sum(channelify(x)), i)[1]) == ones(Float64, ds[1:end-1]..., 3, ds[end]) end end end @@ -169,16 +169,16 @@ for cs in cspaces_all if cs ∈ cspaces_4 i = rand(ds4...) - @test gradient(x->sum(channelify(colorify(cs,x))),i)[1]==ones(size(i)) + @test Zygote.gradient(x->sum(channelify(colorify(cs,x))),i)[1]==ones(size(i)) elseif cs ∈ (AGray, GrayA) i = rand(ds2...) - @test gradient(x->sum(channelify(colorify(cs,x))),i)[1]==ones(size(i)) + @test Zygote.gradient(x->sum(channelify(colorify(cs,x))),i)[1]==ones(size(i)) elseif cs ∈ (Gray,) i = rand(ds1...) - @test gradient(x->sum(channelify(colorify(cs,x))),i)[1]==ones(size(i)) + @test Zygote.gradient(x->sum(channelify(colorify(cs,x))),i)[1]==ones(size(i)) else i = rand(ds3...) - @test gradient(x->sum(channelify(colorify(cs,x))),i)[1]==ones(size(i)) + @test Zygote.gradient(x->sum(channelify(colorify(cs,x))),i)[1]==ones(size(i)) end end end From 3a090d871766fbac0aa02a1818dfdaf9a18a1dd7 Mon Sep 17 00:00:00 2001 From: Aman Date: Thu, 5 May 2022 14:23:08 +0530 Subject: [PATCH 03/11] Fix whitespace --- src/DiffImages.jl | 2 +- test/colors/conversions.jl | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/DiffImages.jl b/src/DiffImages.jl index 187d695..38f3861 100644 --- a/src/DiffImages.jl +++ b/src/DiffImages.jl @@ -6,7 +6,7 @@ using ImageCore, CoordinateTransformations, Interpolations, ChainRulesCore, - LinearAlgebra, + LinearAlgebra, Rotations using Zygote: @adjoint diff --git a/test/colors/conversions.jl b/test/colors/conversions.jl index 5a7f9f4..8f7d417 100644 --- a/test/colors/conversions.jl +++ b/test/colors/conversions.jl @@ -23,14 +23,14 @@ ds3 = (7, 7, 3, 5) ds2 = (7, 7, 2, 5) ds1 = (7, 7, 1, 5) - - cspaces_with_random = (YIQ, - LCHab, - Lab, - BGRA, ABGR, BGR, - RGBA, ARGB, RGB, - HSL, - AGray, GrayA, Gray, + + cspaces_with_random = (YIQ, + LCHab, + Lab, + BGRA, ABGR, BGR, + RGBA, ARGB, RGB, + HSL, + AGray, GrayA, Gray, HSV) # colorspaces those have samplers defined for Base.Random cspaces_all = (HSV, AHSV, HSVA, @@ -47,7 +47,7 @@ DIN99, ADIN99, DIN99A, LMS, ALMS, LMSA, YIQ, AYIQ, YIQA) # tuple of all of Colorspaces and their transparent variants - + cspaces_4 = (BGRA, ABGR, RGBA, ARGB, AHSL, HSLA, From 55b9b4df42543eae1e5035f415899ca5595837c3 Mon Sep 17 00:00:00 2001 From: Aman Sharma <76823502+arcAman07@users.noreply.github.com> Date: Thu, 5 May 2022 14:30:49 +0530 Subject: [PATCH 04/11] Fix whitespaces --- test/colors/conversions.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/colors/conversions.jl b/test/colors/conversions.jl index 8f7d417..69a58f6 100644 --- a/test/colors/conversions.jl +++ b/test/colors/conversions.jl @@ -24,13 +24,13 @@ ds2 = (7, 7, 2, 5) ds1 = (7, 7, 1, 5) - cspaces_with_random = (YIQ, - LCHab, - Lab, - BGRA, ABGR, BGR, - RGBA, ARGB, RGB, - HSL, - AGray, GrayA, Gray, + cspaces_with_random = (YIQ, + LCHab, + Lab, + BGRA, ABGR, BGR, + RGBA, ARGB, RGB, + HSL, + AGray, GrayA, Gray, HSV) # colorspaces those have samplers defined for Base.Random cspaces_all = (HSV, AHSV, HSVA, From f5a852b978979796f6ee51396cac7a5a76745b20 Mon Sep 17 00:00:00 2001 From: Aman Sharma <76823502+arcAman07@users.noreply.github.com> Date: Thu, 5 May 2022 14:31:30 +0530 Subject: [PATCH 05/11] Fix whitespaces --- src/DiffImages.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DiffImages.jl b/src/DiffImages.jl index 38f3861..4d2f15d 100644 --- a/src/DiffImages.jl +++ b/src/DiffImages.jl @@ -6,7 +6,7 @@ using ImageCore, CoordinateTransformations, Interpolations, ChainRulesCore, - LinearAlgebra, + LinearAlgebra, Rotations using Zygote: @adjoint From a3798232ce19922239fb8e8bb2933b559ac26da2 Mon Sep 17 00:00:00 2001 From: Aman Date: Tue, 10 May 2022 17:50:20 +0530 Subject: [PATCH 06/11] fdiff adjoint added --- Project.toml | 1 + src/DiffImages.jl | 6 ++++-- src/ImageBase.jl/fdiff.jl | 11 ++++++++++ test/ImageBase.jl/fdiff.jl | 44 ++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 7 +++++- 5 files changed, 66 insertions(+), 3 deletions(-) create mode 100644 src/ImageBase.jl/fdiff.jl create mode 100644 test/ImageBase.jl/fdiff.jl diff --git a/Project.toml b/Project.toml index df4ed07..5b34bbc 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.1.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" CoordinateTransformations = "150eb455-5306-5404-9cee-2592286d6298" +ImageBase = "c817782e-172a-44cc-b673-b171935fbb9e" ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534" ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" diff --git a/src/DiffImages.jl b/src/DiffImages.jl index 4d2f15d..614b150 100644 --- a/src/DiffImages.jl +++ b/src/DiffImages.jl @@ -7,14 +7,16 @@ using ImageCore, Interpolations, ChainRulesCore, LinearAlgebra, - Rotations + Rotations, + ImageBase using Zygote: @adjoint using ChainRulesCore: NoTangent -export colorify, channelify +export colorify, channelify, fdiff include("colors/conversions.jl") include("geometry/warp.jl") include("geometry/adjoints.jl") +include("ImageBase.jl/fdiff.jl") end diff --git a/src/ImageBase.jl/fdiff.jl b/src/ImageBase.jl/fdiff.jl new file mode 100644 index 0000000..d2a59f5 --- /dev/null +++ b/src/ImageBase.jl/fdiff.jl @@ -0,0 +1,11 @@ +# TODO(arcAman07): support RGB inputs, currently works only for GrayScale Images +# TODO(arcAman07): support N dimensional case, currently works only for 2 dimensional case +@adjoint function fdiff(A::AbstractArray; kwargs...) + y = fdiff!(similar(A, maybe_floattype(eltype(A))), A; kwargs...) + final = similar(A, eltype(A)) + function pullback(Δ) + fill!(final, Δ) + return (final,) + end + return (y, pullback) +end \ No newline at end of file diff --git a/test/ImageBase.jl/fdiff.jl b/test/ImageBase.jl/fdiff.jl new file mode 100644 index 0000000..6208b99 --- /dev/null +++ b/test/ImageBase.jl/fdiff.jl @@ -0,0 +1,44 @@ +using ImageBase.FiniteDiff: fdiff, fdiff! +@testset "fdiff" begin + # Base.diff doesn't promote integer to float + @test ImageBase.FiniteDiff.maybe_floattype(Int) == Int + @test ImageBase.FiniteDiff.maybe_floattype(N0f8) == Float32 + @test ImageBase.FiniteDiff.maybe_floattype(RGB{N0f8}) == RGB{Float32} + @testset "NumericalTests" begin + a = reshape(collect(1:9), 3, 3) + b_fd_1 = [1 1 1; 1 1 1; -2 -2 -2] + b_fd_2 = [3 3 -6; 3 3 -6; 3 3 -6] + b_bd_1 = [-2 -2 -2; 1 1 1; 1 1 1] + b_bd_2 = [-6 3 3; -6 3 3; -6 3 3] + out = similar(a) + + @test fdiff(a, dims=1) == b_fd_1 + @test fdiff(a, dims=2) == b_fd_2 + @test fdiff(a, dims=1, rev=true) == b_bd_1 + @test fdiff(a, dims=2, rev=true) == b_bd_2 + fdiff!(out, a, dims=1) + @test out == b_fd_1 + fdiff!(out, a, dims=2) + @test out == b_fd_2 + fdiff!(out, a, dims=1, rev=true) + @test out == b_bd_1 + fdiff!(out, a, dims=2, rev=true) + @test out == b_bd_2 + end + @testset "Differentiability" begin + a_fd_1 = [2 4 8; 3 9 27; 4 16 64] + a_fd_2 = [3 6 9 12; 6 18 27 36; 9 27 54 81; 12 36 81 144] + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true, boundary=:zero))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1, boundary=:zero))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true, boundary=:zero))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2, boundary=:zero))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true))[1] == ones(Float64,size(a_fd_2)) + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 75e60b8..b116311 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,7 +8,8 @@ using Test, FiniteDifferences, ChainRulesCore, CoordinateTransformations, - Rotations + Rotations, + ImageBase @testset "DiffImages" begin @info "Testing Colorspace modules" @@ -23,4 +24,8 @@ using Test, @testset "Warps" begin include("geometry/warp.jl") end + @info "Testing ImageBase modules" + @testset "FiniteDifferences" begin + include("ImageBase.jl/fdiff.jl") + end end From 01e033dca6ff782fb292c1a6d6e53cf0bf64ad71 Mon Sep 17 00:00:00 2001 From: Aman Date: Tue, 10 May 2022 19:45:24 +0530 Subject: [PATCH 07/11] More tests added --- test/ImageBase.jl/fdiff.jl | 44 +++++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/test/ImageBase.jl/fdiff.jl b/test/ImageBase.jl/fdiff.jl index 6208b99..229fa68 100644 --- a/test/ImageBase.jl/fdiff.jl +++ b/test/ImageBase.jl/fdiff.jl @@ -27,18 +27,36 @@ using ImageBase.FiniteDiff: fdiff, fdiff! end @testset "Differentiability" begin a_fd_1 = [2 4 8; 3 9 27; 4 16 64] - a_fd_2 = [3 6 9 12; 6 18 27 36; 9 27 54 81; 12 36 81 144] - @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2))[1] == ones(Float64,size(a_fd_1)) - @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1))[1] == ones(Float64,size(a_fd_1)) - @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true))[1] == ones(Float64,size(a_fd_1)) - @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true))[1] == ones(Float64,size(a_fd_1)) - @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true, boundary=:zero))[1] == ones(Float64,size(a_fd_1)) - @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1, boundary=:zero))[1] == ones(Float64,size(a_fd_1)) - @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true, boundary=:zero))[1] == ones(Float64,size(a_fd_1)) - @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2, boundary=:zero))[1] == ones(Float64,size(a_fd_1)) - @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2))[1] == ones(Float64,size(a_fd_2)) - @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1))[1] == ones(Float64,size(a_fd_2)) - @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true))[1] == ones(Float64,size(a_fd_2)) - @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true))[1] == ones(Float64,size(a_fd_2)) + a_fd_2 = [3 6 9; 6 18 27; 9 27 54; 12 36 81] + @testset "Testing basic fdiff" begin + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2))[1] == ones(Float64,size(a_fd_2)) + end + @testset "Testing fdiff with rev" begin + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true))[1] == ones(Float64,size(a_fd_2)) + end + @testset "Testing fdiff with boundary condition" begin + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,boundary=:periodic))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,boundary=:periodic))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,boundary=:zero))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,boundary=:zero))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,boundary=:periodic))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,boundary=:periodic))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,boundary=:zero))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,boundary=:zero))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_2)) + end end end \ No newline at end of file From bf70f5f76b5853779f7625b30b630b3b6c96fb43 Mon Sep 17 00:00:00 2001 From: Aman Date: Tue, 10 May 2022 20:14:59 +0530 Subject: [PATCH 08/11] Add compat field --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 5b34bbc..9bcb97d 100644 --- a/Project.toml +++ b/Project.toml @@ -25,6 +25,7 @@ Interpolations = "0.13.4" Rotations = "1.0.2" StaticArrays = "1.2" Zygote = "0.6.17" +ImageBase = "0.1.5" [extras] FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" From ee7834ca02e2a7d5f7f5354a3415d93b0f62e032 Mon Sep 17 00:00:00 2001 From: Aman Date: Wed, 11 May 2022 18:25:48 +0530 Subject: [PATCH 09/11] Updated adjoint by making it shorter --- src/ImageBase.jl/fdiff.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/ImageBase.jl/fdiff.jl b/src/ImageBase.jl/fdiff.jl index d2a59f5..d89c90a 100644 --- a/src/ImageBase.jl/fdiff.jl +++ b/src/ImageBase.jl/fdiff.jl @@ -2,10 +2,8 @@ # TODO(arcAman07): support N dimensional case, currently works only for 2 dimensional case @adjoint function fdiff(A::AbstractArray; kwargs...) y = fdiff!(similar(A, maybe_floattype(eltype(A))), A; kwargs...) - final = similar(A, eltype(A)) function pullback(Δ) - fill!(final, Δ) - return (final,) + return (fill(Δ, size(A)),) end return (y, pullback) end \ No newline at end of file From 6ef1cb7ad86dc35756c03def1da70e35677bf929 Mon Sep 17 00:00:00 2001 From: Aman Date: Sun, 15 May 2022 13:54:44 +0530 Subject: [PATCH 10/11] Remove fdiff to be exported --- src/DiffImages.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DiffImages.jl b/src/DiffImages.jl index 614b150..f81eb9f 100644 --- a/src/DiffImages.jl +++ b/src/DiffImages.jl @@ -13,7 +13,7 @@ using ImageCore, using Zygote: @adjoint using ChainRulesCore: NoTangent -export colorify, channelify, fdiff +export colorify, channelify include("colors/conversions.jl") include("geometry/warp.jl") include("geometry/adjoints.jl") From 52639471ec781cb4f042b4e78b1e01466ab5bb5f Mon Sep 17 00:00:00 2001 From: Aman Date: Wed, 18 May 2022 12:40:50 +0530 Subject: [PATCH 11/11] function exported --- src/DiffImages.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DiffImages.jl b/src/DiffImages.jl index f81eb9f..614b150 100644 --- a/src/DiffImages.jl +++ b/src/DiffImages.jl @@ -13,7 +13,7 @@ using ImageCore, using Zygote: @adjoint using ChainRulesCore: NoTangent -export colorify, channelify +export colorify, channelify, fdiff include("colors/conversions.jl") include("geometry/warp.jl") include("geometry/adjoints.jl")