-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix rrule
for rfft
and ifft
for CuArray
#96
Conversation
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## master #96 +/- ##
==========================================
+ Coverage 87.08% 88.05% +0.97%
==========================================
Files 3 3
Lines 209 226 +17
==========================================
+ Hits 182 199 +17
Misses 27 27
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report in Codecov by Sentry. |
55d2262
to
2a724d0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding testing, #78 wants to make AbstractFFTsTestUtils
a separate package for downstream packages to use. Once its a separate dependency, we could probably lump in a ChainRulesTestUtils dependency there so test these chain rules, and CUDA can then use AbstractFFTsTestUtils
in its tests. cc @devmotion
For now, I think we can only test this change in this package if there's a way to make thie code error without using a GPU array?
ext/AbstractFFTsChainRulesCoreExt.jl
Outdated
@@ -37,7 +37,7 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims) | |||
|
|||
project_x = ChainRulesCore.ProjectTo(x) | |||
function rfft_pullback(ȳ) | |||
x̄ = project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ scale, d, dims)) | |||
x̄ = project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ typeof(x)(scale), d, dims)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it better to make scale
the appropriate array type at point of construction, rather than converting it here, to avoid unnecessary allocations?
Thanks for your suggestion. I addressed the |
How about an
Seems like you might just be able to add an offset array e.g. (Unfortunately, one can imagine a situation where your fix here actually fails, e.g. if |
ext/AbstractFFTsChainRulesCoreExt.jl
Outdated
@@ -30,10 +30,10 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims) | |||
halfdim = first(dims) | |||
d = size(x, halfdim) | |||
n = size(y, halfdim) | |||
scale = reshape( | |||
scale = typeof(x)(reshape( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, would typeof(y)
be better here? Since scale
is broadcasted against y
's adjoint.
(Since fft
and friends don't preserve OffsetArrays in their output, this unfortunately means that the test I proposed won't work anymore: we'd need a case where fft
does not produce a Vector
, and I don't know of any other than GPU arrays. Maybe it's OK to satisfy ourselves with the existing tests passing for now?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's not too much trouble, an OffsetArrays
test could still be good at add though: it would at least catch that it should be typeof(y)
rather than typeof(x)
here.
I did x = OffsetArray(randn(3), 2:4)
test_rrule(rfft, x, 1) # errors
test_rrule: rfft on OffsetVector{Int64, Vector{Int64}},Int64: Test Failed at /Users/ziyiyin/.julia/packages/ChainRulesTestUtils/lERVj/src/testers.jl:314
Expression: ad_cotangent isa NoTangent
Evaluated: [-8.2, -6.698780366869516, 5.598780366869516] isa NoTangent
Stacktrace:
[1] macro expansion
@ /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/Test/src/Test.jl:464 [inlined]
[2] _test_cotangent(::NoTangent, ad_cotangent::Any, ::NoTangent; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/lERVj/src/testers.jl:314
Test Summary: | Pass Fail Total Time
test_rrule: rfft on OffsetVector{Int64, Vector{Int64}},Int64 | 6 1 7 0.0s
ERROR: Some tests did not pass: 6 passed, 1 failed, 0 errored, 0 broken. Seems nasty and I am not quite sure how the chainrules testing package works in this case. Suggestions? |
It might be fixable with some type piracy to Should also note that the If this fixes the CUDA behaviour locally, it looks good to me to merge as is and improve later as needed; hopefully we would get CUDA tests of this behaviour in the near future. Someone with merge rights would need to sign off though. |
Yes it works on my end locally. _
_ _ _(_)_ | Documentation: https://docs.julialang.org
(_) | (_) (_) |
_ _ _| |_ __ _ | Type "?" for help, "]?" for Pkg help.
| | | | | | |/ _` | |
| | |_| | | | (_| | | Version 1.8.5 (2023-01-08)
_/ |\__'_|_|_|\__'_| | Official https://julialang.org/ release
|__/ |
(@v1.8) pkg> activate --temp
Activating new project at `/tmp/jl_S6w38G`
(jl_S6w38G) pkg> add https://github.com/ziyiyin97/AbstractFFTs.jl.git
Updating git-repo `https://github.com/ziyiyin97/AbstractFFTs.jl.git`
Updating registry at `~/.julia/registries/General.toml`
Resolving package versions...
Updating `/tmp/jl_S6w38G/Project.toml`
[621f4979] + AbstractFFTs v1.3.1 `https://github.com/ziyiyin97/AbstractFFTs.jl.git#master`
Updating `/tmp/jl_S6w38G/Manifest.toml`
[621f4979] + AbstractFFTs v1.3.1 `https://github.com/ziyiyin97/AbstractFFTs.jl.git#master`
[d360d2e6] + ChainRulesCore v1.15.7
[34da2185] + Compat v4.6.1
[56f22d72] + Artifacts
[ade2ca70] + Dates
[8f399da3] + Libdl
[37e2e46d] + LinearAlgebra
[de0858da] + Printf
[9a3f8284] + Random
[ea8e919c] + SHA v0.7.0
[9e88b42a] + Serialization
[2f01184e] + SparseArrays
[cf7118a7] + UUIDs
[4ec0a83e] + Unicode
[e66e0078] + CompilerSupportLibraries_jll v1.0.1+0
[4536629a] + OpenBLAS_jll v0.3.20+0
[8e850b90] + libblastrampoline_jll v5.1.1+0
Precompiling project...
1 dependency successfully precompiled in 1 seconds. 5 already precompiled.
(jl_S6w38G) pkg> add CUDA, Flux, FFTW
Resolving package versions...
Installed CUDA_Driver_jll ──────── v0.5.0+0
Installed CUDA_Runtime_Discovery ─ v0.2.0
Installed CUDA_Runtime_jll ─────── v0.5.0+0
Installed OrderedCollections ───── v1.6.0
Installed CUDA ─────────────────── v4.1.2
Updating `/tmp/jl_S6w38G/Project.toml`
[052768ef] + CUDA v4.1.2
[7a1cc6ca] + FFTW v1.6.0
[587475ba] + Flux v0.13.14
Updating `/tmp/jl_S6w38G/Manifest.toml`
[7d9f7c33] + Accessors v0.1.28
[79e6a3ab] + Adapt v3.6.1
[dce04be8] + ArgCheck v2.3.0
[a9b6321e] + Atomix v0.1.0
[ab4f0b2a] + BFloat16s v0.4.2
[198e06fe] + BangBang v0.3.37
[9718e550] + Baselet v0.1.1
[fa961155] + CEnum v0.4.2
[052768ef] + CUDA v4.1.2
[1af6417a] + CUDA_Runtime_Discovery v0.2.0
[082447d4] + ChainRules v1.48.0
[9e997f8a] + ChangesOfVariables v0.1.6
[bbf7d656] + CommonSubexpressions v0.3.0
[a33af91c] + CompositionsBase v0.1.1
[187b0558] + ConstructionBase v1.5.1
[6add18c4] + ContextVariablesX v0.1.3
[9a962f9c] + DataAPI v1.14.0
[864edb3b] + DataStructures v0.18.13
[e2d170a0] + DataValueInterfaces v1.0.0
[244e2a9f] + DefineSingletons v0.1.2
[163ba53b] + DiffResults v1.1.0
[b552c78f] + DiffRules v1.13.0
[ffbed154] + DocStringExtensions v0.9.3
[e2ba6199] + ExprTools v0.1.9
[7a1cc6ca] + FFTW v1.6.0
[cc61a311] + FLoops v0.2.1
[b9860ae5] + FLoopsBase v0.1.1
[1a297f60] + FillArrays v0.13.11
[587475ba] + Flux v0.13.14
[9c68100b] + FoldsThreads v0.1.1
[f6369f11] + ForwardDiff v0.10.35
[069b7b12] + FunctionWrappers v1.1.3
[d9f16b24] + Functors v0.4.4
[0c68f7d7] + GPUArrays v8.6.5
[46192b85] + GPUArraysCore v0.1.4
[61eb1bfa] + GPUCompiler v0.18.0
[7869d1d1] + IRTools v0.4.9
[22cec73e] + InitialValues v0.3.1
[3587e190] + InverseFunctions v0.1.8
[92d709cd] + IrrationalConstants v0.2.2
[82899510] + IteratorInterfaceExtensions v1.0.0
[692b3bcd] + JLLWrappers v1.4.1
[b14d175d] + JuliaVariables v0.2.4
[63c18a36] + KernelAbstractions v0.9.1
⌅ [929cbde3] + LLVM v4.17.1
[2ab3a3ac] + LogExpFunctions v0.3.23
[d8e11817] + MLStyle v0.4.17
[f1d291b0] + MLUtils v0.4.1
[1914dd2f] + MacroTools v0.5.10
[128add7d] + MicroCollections v0.1.4
[e1d29d7a] + Missings v1.1.0
[872c559c] + NNlib v0.8.19
[a00861dc] + NNlibCUDA v0.2.7
[77ba4419] + NaNMath v1.0.2
[71a1bf82] + NameResolution v0.1.5
[0b1bfda6] + OneHotArrays v0.2.3
[3bd65402] + Optimisers v0.2.17
[bac558e1] + OrderedCollections v1.6.0
[21216c6a] + Preferences v1.3.0
[8162dcfd] + PrettyPrint v0.2.0
[33c8b6b6] + ProgressLogging v0.1.4
[74087812] + Random123 v1.6.0
[e6cf234a] + RandomNumbers v1.5.3
[c1ae055f] + RealDot v0.1.0
[189a3867] + Reexport v1.2.2
[ae029012] + Requires v1.3.0
[efcf1570] + Setfield v1.1.1
[605ecd9f] + ShowCases v0.1.0
[699a6c99] + SimpleTraits v0.9.4
[66db9d55] + SnoopPrecompile v1.0.3
[a2af1166] + SortingAlgorithms v1.1.0
[276daf66] + SpecialFunctions v2.2.0
[171d559e] + SplittablesBase v0.1.15
[90137ffa] + StaticArrays v1.5.19
[1e83bf80] + StaticArraysCore v1.4.0
[82ae8749] + StatsAPI v1.6.0
[2913bbd2] + StatsBase v0.33.21
[09ab397b] + StructArrays v0.6.15
[3783bdb8] + TableTraits v1.0.1
[bd369af6] + Tables v1.10.1
[a759f4b9] + TimerOutputs v0.5.22
[28d57a85] + Transducers v0.4.75
[013be700] + UnsafeAtomics v0.2.1
⌃ [d80eeb9a] + UnsafeAtomicsLLVM v0.1.1
[e88e6eb3] + Zygote v0.6.59
[700de1a5] + ZygoteRules v0.2.3
[02a925ec] + cuDNN v1.0.2
[4ee394cb] + CUDA_Driver_jll v0.5.0+0
[76a88914] + CUDA_Runtime_jll v0.5.0+0
[62b44479] + CUDNN_jll v8.8.1+0
[f5851436] + FFTW_jll v3.3.10+0
[1d5cc7b8] + IntelOpenMP_jll v2018.0.3+2
⌅ [dad2f222] + LLVMExtra_jll v0.0.18+0
[856f044c] + MKL_jll v2022.2.0+0
[efe28fd5] + OpenSpecFun_jll v0.5.5+0
[0dad84c5] + ArgTools v1.1.1
[2a0f44e3] + Base64
[8bb1440f] + DelimitedFiles
[8ba89e20] + Distributed
[f43a241f] + Downloads v1.6.0
[7b1f6079] + FileWatching
[9fa8497b] + Future
[b77e0a4c] + InteractiveUtils
[4af54fe1] + LazyArtifacts
[b27032c2] + LibCURL v0.6.3
[76f85450] + LibGit2
[56ddb016] + Logging
[d6f4376e] + Markdown
[a63ad114] + Mmap
[ca575930] + NetworkOptions v1.2.0
[44cfe95a] + Pkg v1.8.0
[3fa0cd96] + REPL
[6462fe0b] + Sockets
[10745b16] + Statistics
[fa267f1f] + TOML v1.0.0
[a4e569a6] + Tar v1.10.1
[8dfed614] + Test
[deac9b47] + LibCURL_jll v7.84.0+0
[29816b5a] + LibSSH2_jll v1.10.2+0
[c8ffd9c3] + MbedTLS_jll v2.28.0+0
[14a3606d] + MozillaCACerts_jll v2022.2.1
[05823500] + OpenLibm_jll v0.8.1+0
[83775a58] + Zlib_jll v1.2.12+3
[8e850ede] + nghttp2_jll v1.48.0+0
[3f19e933] + p7zip_jll v17.4.0+0
Info Packages marked with ⌃ and ⌅ have new versions available, but those with ⌅ are restricted by compatibility constraints from upgrading. To see why use `status --outdated -m`
Precompiling project...
19 dependencies successfully precompiled in 56 seconds. 86 already precompiled.
julia> using CUDA, FFTW, Flux
[ Info: Precompiling FFTW [7a1cc6ca-52ef-59f5-83cd-3a7055c09341]
julia> x = CUDA.randn(3)
3-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
-0.08069154
0.49658614
-0.9882421
julia> gradient(()->sum(abs.(rfft(x))), Flux.params(x))
Grads(...)
julia> y = rfft(x)
2-element CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}:
-0.57234746f0 + 0.0f0im
0.16513643f0 - 1.2858989f0im
julia> gradient(()->sum(abs.(irfft(y, 3))), Flux.params(x))
Grads(...)
julia> gradient(()->sum(abs.(brfft(y, 3))), Flux.params(x))
Grads(...) |
Could you please help me notify the reviewers with merge rights? Thanks again |
One worry I have is whether the way using CUDA
N = 100000000
x = CuArray(rand(N))
# f1 is similar to our current approach
function f1(x)
y = typeof(x)([i == 7 ? 2 : 1 for i in 1:N])
return x ./ y
end
# f2 does not explicitly construct a scale array
function f2(x)
y = copy(x)
y[7] *= 2
return y
end
@time f1(x) # 0.633765 seconds (105.59 k allocations: 1.496 GiB, 3.71% gc time, 6.02% compilation time)
@time f2(x) # 0.007896 seconds (4.60 k allocations: 224.639 KiB, 73.60% compilation time) An alternative might be to write out the division in the pullback without broadcasting, similar to function rfft_pullback(ȳ)
dY = ChainRulesCore.unthunk(ȳ)
dY_scaled = similar(dY)
dY_scaled .= dY
dY_scaled ./= 2
selectdim(dY_scaled, halfdim, 1) .*= 2
if 2 * (n - 1) == d
selectdim(dY_scaled, halfdim, n) .*= 2
end
x̄ = project_x(brfft(dY_scaled, d, dims))
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
end I'm not at all experienced with CUDA array programming in Julia to be sure about this. It would also be nice if there were a way of keeping the "style" of the original broadcasting solution without having to allocate a CPU array. Pinging @maleadt for any more info:) |
Here's an approach that keeps the original broadcasting style, but makes the # Make scaling array in a GPU friendly way
scale = similar(y, ntuple(i -> i == halfdim ? n : 1, Val(ndims(x))))
scale .= 2
selectdim(scale, halfdim, 1) .= 1
if 2 * (n - 1) == d
selectdim(scale, halfdim, n) .= 1
end My worry is that the current approach in this PR would be slow on the GPU since it makes a large CPU array allocation within the rule -- @ziyiyin97 do you think the same, and if so could you use replace the construction of the cc'ing @devmotion as the author of the original scaling code. |
Sorry for the late response. I think the solution you proposed sounds great! I've replicated your approach to |
It seems that Julia 1.0 does not like it |
ext/AbstractFFTsChainRulesCoreExt.jl
Outdated
dY = ChainRulesCore.unthunk(ȳ) ./ 2 | ||
selectdim(dY, halfdim, 1) .*= 2 | ||
if 2 * (n - 1) == d | ||
selectdim(dY, halfdim, n) .*= 2 | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems both inefficient (many unnecessary computations) and it breaks non-mutating arrays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Any suggestion? I used to do the code block below in the previous commit but @gaurav-arya also made a good point that this type conversion might be slow in certain cases.
scale = typeof(y)(reshape(
[i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n],
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to write this code to be GPU compatible, so avoiding broadcasting with CPU arrays. That's why I divide the whole array by 2, then multiply some slices of the array by 2. I wasn't sure if what I did was the right approach, so I'd appreciate any feedback on how to write it better:)
Regarding the speed issue, I benchmarked the following code before and after this PR:
using FFTW
using ChainRulesCore
function tobenchmark(x, dims)
y, pb = rrule(rfft, x, dims)
return pb(y)
end
julia> @btime tobenchmark(rand(1000, 1000), 1:2);
13.913 ms (71 allocations: 45.85 MiB) [BEFORE]
13.897 ms (78 allocations: 45.84 MiB) [AFTER]
Regarding the mutable array issue, that's why I used similar
in the code I originally suggested, which is semantically guaranteed to return a mutable array. I agree it's not a perfect solution for the immutable array case (perhaps using Adapt.jl
or ArrayInterface.jl
could help with that). But also, note that this about the type of the output array rather than the input, and afaik there is no existing case in the ecosystem where the output array is immutable: FFTW converts all CPU arrays to vectors. So it's not perfect, but it did seem to fix the CUDA case which the previous approach didn't support (and with similar
it would even be correct for a hypothetical static array, although admittedly not an ideal approach) -- hopefully that helps explain my reasoning :)
@ziyiyin97 regarding the invalid assigment location, the following workaround seemed to work for me on Julia 1.0: julia> x = rand(2,2)
2×2 Array{Float64,2}:
0.000300982 0.405891
0.903893 0.814312
julia> v = selectdim(x, 1, 2); # place view in a separate variable (workaround for Julia <1.2)
julia> v .+= 1
2-element view(::Array{Float64,2}, 2, :) with eltype Float64:
1.9038934352514685
1.8143118255443202 It looks like it's a bug that was fixed only on Julia 1.2: https://discourse.julialang.org/t/invalid-assignment-location-on-function-call-returning-view/23346/5. It only seems to appear for the |
dX_scaled = similar(dX) | ||
dX_scaled .= dX | ||
dX_scaled .*= 2 | ||
v = selectdim(dX_scaled, halfdim, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe a comment above this line saying something like like # assign view to a separate variable before assignment, to support Julia <1.2
?
ext/AbstractFFTsChainRulesCoreExt.jl
Outdated
dX = rfft(real.(ChainRulesCore.unthunk(ȳ)), dims) .* 2 | ||
selectdim(dX, halfdim, 1) ./= 2 | ||
dX = rfft(real.(ChainRulesCore.unthunk(ȳ)), dims) | ||
# apply scaling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe also add a comment saying something like # below approach is for ensuring GPU compatibility, see PR #96
?
ext/AbstractFFTsChainRulesCoreExt.jl
Outdated
@@ -33,10 +33,12 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims) | |||
project_x = ChainRulesCore.ProjectTo(x) | |||
function rfft_pullback(ȳ) | |||
dY = ChainRulesCore.unthunk(ȳ) | |||
# apply scaling | |||
# apply scaling; below approach is for GPU CuArray compatibility, see PR #96 | |||
dY_scaled = similar(dY) | |||
dY_scaled .= dY | |||
dY_scaled ./= 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you could actually fuse this line with the previous one I think, e.g. dY_scaled .= dY ./ 2
and similarly for the other scalings
Any other comments? |
I don't have any other comments, but this would need an approval from someone with merge rights, if @stevengj or @devmotion have time for a quick look? |
Any update? |
#105 seems simpler. |
I think this conversion approach was discussed (and opposed) earlier in here #96 (comment) ... but I'm fine with it. I can close this PR if #105 is merged. Feel free to let me know. |
Closed by #105 |
@stevengj @devmotion, could we revisit this? It does not suffer from either of the issues observed in #112, that were introduced by #105. I'd be happy to reopen a new PR for review if you think this is the right path. If there's a better solution, that would of course be wonderful too 🙂 |
I wonder if we could just use |
I think that should fix the Zygote issue, which is the more important issue, so I'd be happy with that as a stopgap. I don't think it would handle the subarray issue though. |
Actually, it looks like a reasonable solution, since it just relies on the output of the FFT function being convertable to. I've implemented it in #114. Edit: This would fix reverse rules for non-plans (issue #115), but not reverse rules for plans (issue #112), because in the current design we delegate to |
I think a solution without new dependencies would be preferable (hence I didn't suggest Adapt) but maybe it's not possible. |
Ok. I think #114 is the right way forward as a first step, to first fix downstream. #112 at least is not going to be a regression in any case, because the rules for real plans in Zygote are currently incorrect Edit: As for a solution without dependencies, the |
Fix #95
Right now there is no test on
CuArray
so this fix cannot be tested easily. Any suggestion?