From 1e4557a69163651f5e6054af59c911dbca97b554 Mon Sep 17 00:00:00 2001 From: Som Tambe Date: Tue, 24 Aug 2021 06:24:31 +0000 Subject: [PATCH 1/4] Add ImageFitering to deps --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index df4ed07..289cded 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "0.1.0" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" CoordinateTransformations = "150eb455-5306-5404-9cee-2592286d6298" ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534" +ImageFiltering = "6a3955dd-da59-5b1f-98d4-e7296123deb5" ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" From a47911a6984b925157f93d7970ec1a0e9a47e5c1 Mon Sep 17 00:00:00 2001 From: Som Tambe Date: Wed, 25 Aug 2021 12:00:26 +0000 Subject: [PATCH 2/4] Add deps version --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 289cded..9811f3e 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ChainRulesCore = "1.3.0" CoordinateTransformations = "0.6.1" ImageCore = "0.9" +ImageFiltering = "0.6.22" ImageTransformations = "0.8, 0.9" Interpolations = "0.13.4" Rotations = "1.0.2" From 270cb4b37d49f959a9512de665c0b05272a59b81 Mon Sep 17 00:00:00 2001 From: Som Tambe Date: Wed, 25 Aug 2021 12:04:19 +0000 Subject: [PATCH 3/4] Add adjoints to make ImageFiltering.imfilter differentiable --- src/filters/adjoints.jl | 150 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 src/filters/adjoints.jl diff --git a/src/filters/adjoints.jl b/src/filters/adjoints.jl new file mode 100644 index 0000000..7dbb1ea --- /dev/null +++ b/src/filters/adjoints.jl @@ -0,0 +1,150 @@ +using ImageFiltering +using ImageFiltering.TiledIteration +using ImageFiltering: imfilter, + imfilter!, + default_resource, + alg_defaults, + Alg, + ProcessedKernel, + AbstractBorder, + __imfilter_inbounds!, + safe_for_prod, + copydata!, + factorkernel, + factorstridedkernel, + padarray, + filter_algorithm + +function ChainRulesCore.rrule(::typeof(imfilter!), out::AbstractArray, + img::AbstractArray, + kernel::ProcessedKernel, + border::AbstractBorder, + alg::Alg) + # imfilter! places a snag because of a try catch block. + y = imfilter!(out, img, kernel, border, alg) + function ∇imfilter!_try(Δy) + @show typeof(Δy) + println("entered first") + k = default_resource(alg_defaults(alg, out, kernel)) + ret = imfilter!(k, out, img, kernel, border) + _, ∇ret = rrule_via_ad(Zygote.ZygoteRuleConfig(), imfilter!, k, out, img, kernel, border) + ∇k, ∇out, ∇img, ∇kernel, ∇border = ∇ret(Δy) + return NoTangent(), ∇out, ∇img, ∇kernel, ∇border, ∇k + end + return y, ∇imfilter!_try +end # needed and works + +## writing adjoints for `__imfilter_inbounds!` -> where the mutation takes place +function ChainRulesCore.rrule(::typeof(__imfilter_inbounds!), r, + out, + A::OffsetArray, + kern::OffsetArray, + border, + R, + z) + y = __imfilter_inbounds!(r, out, A, kern, border, R, z) + function ∇__imfilter_inbounds!(Δy) + # ∇out should not have any gradients + # since it is just being alloted the values + # after processing. ∇border also should not have + # gradients since it does not make sense (for now). + @show typeof(Δy) + ∇out = NoTangent() + ∇border = NoTangent() + + # Don't exactly know what r, R and z are actually. + + off, k = CartesianIndex(kern.offsets), parent(kern) + o, O = safehead(off), safetail(off) + Rnew = CartesianIndices(map((x,y)->x.+y, R.indices, Tuple(off))) + Rk = CartesianIndices(axes(k)) + offA, pA = CartesianIndex(A.offsets), parent(A) + oA, OA = safehead(offA), safetail(offA) + # ∇A, ∇kern should have some values. + ∇A = 0 + ∇kern = 0 # since k is not an OffsetArray + + for I in safetail(Rnew) + IA = I-OA + for i in safehead(Rnew) + tmp = z + iA = i-oA + dk = zeros(eltype(k), size(k)) + dA = zeros(eltype(pA), size(pA)) + @inbounds for J in safetail(Rk), j in safehead(Rk) + _, ∇prod = rrule_via_ad(Zygote.ZygoteRuleConfig(), (a, b, c) -> safe_for_prod(a, b) * c, + pA[iA+j, IA+J], + tmp, + k[j, J]) + dA_j_J, _, dk_j_J = ∇prod(Δy[iA+j, IA+J]) + dA[iA+j, IA+J] += dA_j_J + dk[j+J] += dk_j_J + end + ∇A += dA + ∇kern += dk + end + end + ∇z = NoTangent() + ∇R = NoTangent() + ∇r = NoTangent() + + return NoTangent(), ∇r, ∇out, ∇A, ∇kern, ∇border, ∇R, ∇z + end + return y, ∇__imfilter_inbounds! +end + +Zygote.@nograd TiledIteration.TileBuffer # needed, works +# Zygote.@nograd ImageFiltering.padindices # not needed +Zygote.@nograd ImageFiltering.filter_algorithm # ~~should be correct~~ is correct +Zygote.@nograd ImageFiltering.Pad{N} where N + +# what should the gradient of copyto! be? It is being used in various places throughout the filters + +function ChainRulesCore.rrule(::typeof(padarray), t::Type{T}, img::AbstractArray, border) where T + y = padarray(t, img, border) + function padarray_pb(Δy) + ba, ba_pb = rrule_via_ad(Zygote.ZygoteRuleConfig(), BorderArray, img, border) + out = similar(ba, T, axes(ba)) + copy!(out, ba) + ∇img, ∇border = ba_pb(Δy) + return NoTangent(), NoTangent(), ∇img, ∇border + end + return y, padarray_pb +end + +function ChainRulesCore.rrule(::typeof(factorkernel), kernel::AbstractMatrix{T}) where T + y = factorkernel(kernel) + function factorkernel_pb(Δy) + ## + inds = axes(kernel) + m, n = map(length, inds) + kern = Array{T}(undef, m, n) + copyto!(kern, 1:m, 1:n, kernel, inds[1], inds[2]) + ## + _, kernel_pb = rrule_via_ad(Zygote.ZygoteRuleConfig(), factorstridedkernel, inds, kern) + + return NoTangent(), kernel_pb(Δy) + end + return y, factorkernel_pb +end + +# function ChainRulesCore.rrule(::typeof(copydata!), dest::OffsetArray, img, inds::Tuple{Vararg{OffsetArray}}) +# y = copydata!(dest, img, inds) +# function copydata!_pb(Δy) +# @show typeof(Δy) +# println("copydata! here") +# # dest = parent(dest) +# # inds = map(parent, inds) +# # if isempty(img) +# # ∇img = canonicalize(Tangent{typeof(img)}()) +# # else +# # ∇img = Tangent{typeof(img)}(;ones(eltype(img), size(img))) +# # end +# return NoTangent(), NoTangent(), Δy, NoTangent() +# end +# return y, copydata!_pb +# end + +## ~~make copyto! gradients correct~~ final task + +## it is still not getting inside the final mutation loop adjoint, figure that out asap. From cf6fbf8cc459134ce0f83028cfe9f0c24d1185c7 Mon Sep 17 00:00:00 2001 From: Som Tambe Date: Wed, 25 Aug 2021 12:10:04 +0000 Subject: [PATCH 4/4] remove unnecessary show lines --- src/filters/adjoints.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/filters/adjoints.jl b/src/filters/adjoints.jl index 7dbb1ea..cedf7dd 100644 --- a/src/filters/adjoints.jl +++ b/src/filters/adjoints.jl @@ -23,8 +23,6 @@ function ChainRulesCore.rrule(::typeof(imfilter!), out::AbstractArray, # imfilter! places a snag because of a try catch block. y = imfilter!(out, img, kernel, border, alg) function ∇imfilter!_try(Δy) - @show typeof(Δy) - println("entered first") k = default_resource(alg_defaults(alg, out, kernel)) ret = imfilter!(k, out, img, kernel, border) _, ∇ret = rrule_via_ad(Zygote.ZygoteRuleConfig(), imfilter!, k, out, img, kernel, border) @@ -48,7 +46,6 @@ function ChainRulesCore.rrule(::typeof(__imfilter_inbounds!), r, # since it is just being alloted the values # after processing. ∇border also should not have # gradients since it does not make sense (for now). - @show typeof(Δy) ∇out = NoTangent() ∇border = NoTangent()