Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster czt #43

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "FourierTools"
uuid = "b18b359b-aebc-45ac-a139-9c0ccbb2871e"
authors = ["Felix Wechsler (roflmaostc) <[email protected]>", "Rainer Heintzmann (rheintzmann) <[email protected]>"]
version = "0.4.3"
version = "0.4.4"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -15,17 +15,16 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
ShiftedArrays = "1277b4bf-5013-50f5-be3d-901d8477a67a"

[compat]
ChainRulesCore = "1, 1.0, 1.1"
FFTW = "1.5"
ChainRulesCore = "1, 1.0, 1.1, 1.24"
FFTW = "1.5, 1.6, 1.7, 1.8"
ImageTransformations = "0.9"
IndexFunArrays = "0.2"
NDTools = "0.5.1, 0.6"
NDTools = "0.5.1, 0.6, 0.7"
NFFT = "0.11, 0.12, 0.13"
PaddedViews = "0.5"
Reexport = "1"
ShiftedArrays = "2"
Zygote = "0.6"
julia = "1, 1.6, 1.7, 1.8"
julia = "1, 1.6, 1.7, 1.8, 1.9, 1.10"

[extras]
FractionalTransforms = "e50ca838-b4f0-4a10-ad18-4b920bf1ae5c"
Expand Down
70 changes: 41 additions & 29 deletions src/czt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@
`aw`: factor to multiply input with
`fft_fv`: fourier-transform (FFTW) of the convolutio kernel
`wd`: factor to multiply the result of the convolution by
`fftw_plan`: plan for the forward FFTW of the convolution kernel
`ifftw_plan`: plan for the inverse FFTW of the convolution kernel
`fftw_plan!`: plan for the in-place forward FFTW of the convolution kernel
`ifftw_plan!`: plan for the in-place inverse FFTW of the convolution kernel
"""
struct CZTPlan_1D{CT, PT, D} # <: AbstractArray{T,D}
d :: Int
Expand All @@ -80,8 +80,8 @@
aw :: Array{CT, D}
fft_fv :: Array{CT, D}
wd :: Array{CT, D}
fftw_plan :: FFTW.cFFTWPlan
ifftw_plan :: AbstractFFTs.ScaledPlan
fftw_plan! :: FFTW.cFFTWPlan
ifftw_plan! :: AbstractFFTs.ScaledPlan
# dimension of this transformation
# as :: Array{T, D} # not needed since it is just the conjugate of ws
end
Expand Down Expand Up @@ -161,10 +161,12 @@
nsz = ntuple((dd) -> (d==dd) ? size(fft_fv, 1) : size(xin, dd), Val(ndims(xin)))
y = Array{eltype(aw), ndims(xin)}(undef, nsz)

fft_p = plan_fft(y, (d,); flags=fft_flags)
ifft_p = plan_ifft(y, (d,); flags=fft_flags) # inv(fft_p)
fft_p! = plan_fft!(y, (d,); flags=fft_flags)
ifft_p! = plan_ifft!(y, (d,); flags=fft_flags) # inv(fft_p)

plan = CZTPlan_1D(d, pad_value, (start_range, end_range), reorient(aw, d, Val(ndims(xin))), reorient(fft_fv, d, Val(ndims(xin))), reorient(wd, d, Val(ndims(xin))), fft_p, ifft_p)
plan = CZTPlan_1D(d, pad_value, (start_range, end_range),
reorient(aw, d, Val(ndims(xin))), reorient(fft_fv, d, Val(ndims(xin))),
reorient(wd, d, Val(ndims(xin))), fft_p!, ifft_p!)
return plan
end

Expand All @@ -175,9 +177,9 @@
creates a plan for an N-dimensional chirp z-transformation (CZT). The generated plan is then applied via
muliplication. For details about the arguments, see `czt()`.
"""
function plan_czt(xin, scale, dims, dsize=size(xin); a=nothing, w=nothing, damp=ones(ndims(xin)), src_center=size(xin).÷2 .+1, dst_center=dsize.÷2 .+1, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE)
function plan_czt(xin::AbstractArray{U,D}, scale, dims, dsize=size(xin); a=nothing, w=nothing, damp=ones(ndims(xin)), src_center=size(xin).÷2 .+1, dst_center=dsize.÷2 .+1, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE) where {U,D}
CT = (eltype(xin) <: Real) ? Complex{eltype(xin)} : eltype(xin)
D = ndims(xin)
# D = ndims(xin)
plans = [] # Vector{CZT1DPlan{CT,D}}
sz = size(xin)
for d in dims
Expand All @@ -189,7 +191,7 @@
return CZTPlan_ND{CT, typeof(pad_value),D}(plans)
end

function Base.:*(p::CZTPlan_ND, xin::AbstractArray{U,D}; kargs...) where {U,D} # Complex{U}
function Base.:*(p::CZTPlan_ND, xin::AbstractArray{U,D}; kargs...)::AbstractArray{complex(U),D} where {U,D} # Complex{U}
xout = xin
for pd in p.plans
xout = czt_1d(xout, pd)
Expand Down Expand Up @@ -230,13 +232,13 @@
+ `remove_wrap`: if true, the positions that represent a wrap-around will be set to zero
+ `pad_value`: the value to pad wrapped data with.
"""
function czt_1d(xin, scaled, d, dsize=size(xin,d); a=nothing, w=nothing, damp=1.0, src_center=size(xin,d)÷2+1,
dst_center=dsize÷2+1, extra_phase=nothing, global_phase=nothing, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE)
function czt_1d(xin::AbstractArray{U,D}, scaled, d, dsize=size(xin,d); a=nothing, w=nothing, damp=1.0, src_center=size(xin,d)÷2+1,
dst_center=dsize÷2+1, extra_phase=nothing, global_phase=nothing, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE)::AbstractArray{complex(U),D} where {U,D}
plan = plan_czt_1d(xin, scaled, d, dsize; a=a, w=w, extra_phase=extra_phase, global_phase=global_phase, damp, src_center=src_center, dst_center=dst_center, remove_wrap=remove_wrap, pad_value=pad_value, fft_flags=fft_flags);
return plan * xin
end

function Base.:*(p::CZTPlan_1D, xin::AbstractArray{U,D}; kargs...) where {U,D} # Complex{U}
function Base.:*(p::CZTPlan_1D, xin::AbstractArray{U,D}; kargs...)::AbstractArray{complex(U),D} where {U,D} # Complex{U}
return czt_1d(xin, p)
end

Expand All @@ -258,7 +260,7 @@
# Arguments
`plan`: A plan created via plan_czt_1d()
"""
function czt_1d(xin, plan::CZTPlan_1D)
function czt_1d(xin::AbstractArray{U,D}, plan::CZTPlan_1D)::AbstractArray{complex(U),D} where {U,D}
# destination position
# cispi(-1/scaled * half_pix_shift)
#
Expand All @@ -267,30 +269,38 @@
# which (intentionally) leads to non-real results for even-sized arrays at non-unit zoom

L = size(plan.fft_fv, plan.d)
nsz = ntuple((dd) -> (dd==plan.d) ? L : size(xin, dd), Val(ndims(xin)))
nsz = ntuple((dd) -> (dd==plan.d) ? L : size(xin, dd), Val(D)) # sizes may vary per dimension, so a new array has to be generated
# append zeros
y = zeros(eltype(plan.aw), nsz)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want this CUDA runable, do

y = similar(xin, eltype(plan.aw), nsz)
fill!(y, 0)

In my code, I try to avoid zeros or ones.

myrange = ntuple((dd) -> (1:size(xin,dd)), Val(ndims(xin)))
y[myrange...] = xin .* plan.aw
myrange = ntuple((dd) -> (1:size(xin,dd)), Val(D))
# writes values into the top half of the y-array
y[myrange...] .= xin .* plan.aw
# corner = ntuple((x)->1, Val(ndims(xin)))
# select_region(xin .* plan.aw, new_size=nsz, center=corner, dst_center=corner)

# g = ifft(fft(y, d) .* plan.fft_fv, d)
g = plan.ifftw_plan * (plan.fftw_plan * y .* plan.fft_fv)
plan.fftw_plan! * y # in-place application to y
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general avoid comments behind lines, do better above.

y .*= plan.fft_fv
plan.ifftw_plan! * y # in-place application to y

# g = plan.ifftw_plan * (plan.fftw_plan * y .* plan.fft_fv)
# dsz = ntuple((dd) -> (d==dd) ? dsize : size(xin), Val(ndims(xin)))
# return only the wanted (valid) part
myrange = ntuple((dd) -> (dd==plan.d) ? (1:size(plan.wd,plan.d)) : (1:size(xin, dd)), Val(ndims(xin)))
res = g[myrange...] .* plan.wd
myrange = ntuple((dd) -> (dd==plan.d) ? (1:size(plan.wd, plan.d)) : (1:size(xin, dd)), Val(D))
res = y[myrange...] .* plan.wd
# pad_value=0 means that it is either already handled by plan.wd or no padding is wanted.
if plan.pad_value != 0
myrange = ntuple((dd) -> (dd==plan.d) ? plan.pad_ranges[1] : Colon(), Val(ndims(xin)))
# before the start position
myrange = ntuple((dd) -> (dd==plan.d) ? plan.pad_ranges[1] : Colon(), Val(D))
res[myrange...] .= plan.pad_value
myrange = ntuple((dd) -> (dd==plan.d) ? plan.pad_ranges[2] : Colon(), Val(ndims(xin)))
# after the stop position
myrange = ntuple((dd) -> (dd==plan.d) ? plan.pad_ranges[2] : Colon(), Val(D))
res[myrange...] .= plan.pad_value
end
return res
end


"""
czt(xin, scale, dims=1:ndims(xin), dsize=size(xin,d); a=nothing, w=nothing, damp=ones(ndims(xin)),
src_center=size(xin,d)÷2+1, dst_center=dsize÷2+1, remove_wrap=false, fft_flags=FFTW.ESTIMATE)
Expand Down Expand Up @@ -355,18 +365,20 @@
0.0239759 -0.028264 0.0541186 -0.0116475 -0.261294 0.312719 -0.261294 -0.0116475 0.0541186 -0.028264
```
"""
function czt(xin::AbstractArray{T,N}, scale, dims=1:ndims(xin), dsize=size(xin);
a=nothing, w=nothing, damp=ones(ndims(xin)), src_center=size(xin).÷2 .+1, dst_center=dsize.÷2 .+1, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE)::AbstractArray{complex(T),N} where {T,N}
xout = xin
if length(scale) != ndims(xin)
error("Every of the $(ndims(xin)) dimension needs exactly one corresponding scale (zoom) factor, which should be equal to 1.0 for dimensions not contained in the dims argument.")
function czt(xin::AbstractArray{T,D}, scale, dims=1:D, dsize=size(xin);
a=nothing, w=nothing, damp=ones(ndims(xin)), src_center=size(xin).÷2 .+1,
dst_center=dsize.÷2 .+1, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE)::AbstractArray{complex(T),D} where {T,D}
# xout = xin; # similar(xin, complex(T))
if length(scale) != D
error("Every of the $(D) dimension needs exactly one corresponding scale (zoom) factor, which should be equal to 1.0 for dimensions not contained in the dims argument.")

Check warning on line 373 in src/czt.jl

View check run for this annotation

Codecov / codecov/patch

src/czt.jl#L373

Added line #L373 was not covered by tests
end
for d = 1:ndims(xin)
for d = 1:D # check all the dims
if !(d in dims) && scale[d] != 1.0 && !isnothing(scale[d])
error("The scale factor $(scale[d]) needs to be nothing or 1.0, if this dimension is not in the list of dimensions to transform.")
end
end
for d in dims
xout = czt_1d(xin, scale[dims[1]], dims[1], dsize[dims[1]]; a=a, w=w, damp=damp[dims[1]], src_center=src_center[dims[1]], dst_center=dst_center[dims[1]], remove_wrap=remove_wrap, pad_value=pad_value, fft_flags=fft_flags)
for d in dims[2:end]
xout = czt_1d(xout, scale[d], d, dsize[d]; a=a, w=w, damp=damp[d], src_center=src_center[d], dst_center=dst_center[d], remove_wrap=remove_wrap, pad_value=pad_value, fft_flags=fft_flags)
end
return xout
Expand Down
Loading