Skip to content

Commit

Permalink
optimize multi threading
Browse files Browse the repository at this point in the history
  • Loading branch information
JakobAsslaender committed May 3, 2022
1 parent 9ea37df commit 3cfeb13
Showing 1 changed file with 29 additions and 35 deletions.
64 changes: 29 additions & 35 deletions src/NFFTNormalOpBasisFunc.jl
Original file line number Diff line number Diff line change
@@ -1,34 +1,31 @@
function calculateToeplitzKernelBasis(
img_shape_os,
trj::Vector{Matrix{T}},
U::Matrix{Complex{T}};
fftplan = plan_fft(Array{Complex{T}}(undef, img_shape_os); flags = FFTW.MEASURE),
nfftflag = FFTW.ESTIMATE,
) where {T}
function calculateToeplitzKernelBasis(img_shape_os, trj::Vector{Matrix{T}}, U::Matrix{Complex{T}}; verbose = false) where {T}

Ncoeff = size(U, 2)

FFTW.set_num_threads(Threads.nthreads())
nfftplan = plan_nfft(trj[1], img_shape_os; precompute = LINEAR, blocking = false, fftflags = nfftflag, σ = T(2))
fftplan = plan_fft(Array{Complex{T}}(undef, img_shape_os); flags = FFTW.MEASURE)
nfftplan = plan_nfft(trj[1], img_shape_os; precompute = LINEAR, blocking = false, fftflags = FFTW.ESTIMATE)

λ = Array{Complex{T}}(undef, img_shape_os)
Λ = Array{Complex{T}}(undef, Ncoeff, Ncoeff, prod(img_shape_os))
Λ .= 0
@info "Planned FFTs and Λ initialized, iterating throught the time frames: "
flush(stdout)

for i eachindex(trj)
t_kernel = @elapsed calculateToeplitzKernel!(λ, nfftplan, trj[i], fftplan)

@views U2 = conj.(U[i, :]) * transpose(U[i, :])
t_multiplcation = @elapsed begin
@batch for j eachindex(λ)
Threads.@threads for j eachindex(λ)
@simd for iu CartesianIndices(U2)
@inbounds Λ[iu, j] += U2[iu] * λ[j]
end
end
end
@info "Time frame $i" t_kernel t_multiplcation
flush(stdout)

if verbose
println("Time frame $i: t_kernel = $t_kernel; t_multiplcation = $t_multiplcation")
flush(stdout)
end
end

return Λ
Expand All @@ -53,13 +50,13 @@ function NFFTNormalOpBasisFunc(
trj::Vector{Matrix{T}},
U::Matrix{Complex{T}};
cmaps = (1,),
fftplan = plan_fft(Array{Complex{T}}(undef, 2 .* img_shape); flags = FFTW.MEASURE),
nfftflag = FFTW.ESTIMATE,
Λ = calculateToeplitzKernelBasis(2 .* img_shape, trj, U; fftplan = fftplan, nfftflag = nfftflag),
verbose = false,
Λ = calculateToeplitzKernelBasis(2 .* img_shape, trj, U; verbose = verbose),
) where {T}
# FFTW.set_num_threads(Threads.nthreads())

ifftplan = plan_ifft(Array{Complex{T}}(undef, 2 .* img_shape); flags = FFTW.MEASURE)
FFTW.set_num_threads(1)
fftplan = plan_fft!(Array{Complex{T}}(undef, 2 .* img_shape); flags = FFTW.MEASURE)
ifftplan = plan_ifft!(Array{Complex{T}}(undef, 2 .* img_shape); flags = FFTW.MEASURE)
Ncoeff = size(U, 2)
kL1 = Array{Complex{T}}(undef, (2 .* img_shape)..., Ncoeff)
kL2 = similar(kL1)
Expand All @@ -74,7 +71,6 @@ function LinearAlgebra.mul!(x::Vector{T}, S::NFFTNormalOpBasisFunc, b, α, β) w
Ncoils = length(S.cmaps)

b = reshape(b, S.shape..., S.Ncoeff)
xL = @view S.kL2[idxos, 1]
if β == 0
fill!(x, zero(T)) # to avoid 0 * NaN == NaN
else
Expand All @@ -85,24 +81,23 @@ function LinearAlgebra.mul!(x::Vector{T}, S::NFFTNormalOpBasisFunc, b, α, β) w
try
BLAS.set_num_threads(1)
for icoil 1:Ncoils
fill!(xL, zero(T))

@inbounds for i = 1:S.Ncoeff
@views xL[idx] .= S.cmaps[icoil] .* b[idx, i]
@views mul!(S.kL1[idxos, i], S.fftplan, xL)
Threads.@threads for i = 1:S.Ncoeff
S.kL1[idxos, i] .= 0
@views S.kL1[idx, i] .= S.cmaps[icoil] .* b[idx, i]
@views S.fftplan * S.kL1[idxos, i]
end

kin = reshape(S.kL1, :, S.Ncoeff)
kout = reshape(S.kL2, :, S.Ncoeff)
@batch for i eachindex(view(S.Λ, 1, 1, :))
@views @inbounds mul!(kout[i, :], S.Λ[:, :, i], kin[i, :])
kL1_rs = reshape(S.kL1, :, S.Ncoeff)
kL2_rs = reshape(S.kL2, :, S.Ncoeff)
Threads.@threads for i eachindex(view(S.Λ, 1, 1, :))
@views @inbounds mul!(kL2_rs[i, :], S.Λ[:, :, i], kL1_rs[i, :])
end

@batch for i = 1:S.Ncoeff
@views mul!(S.kL1[idxos, i], S.ifftplan, S.kL2[idxos, i])
Threads.@threads for i = 1:S.Ncoeff
@views S.ifftplan * S.kL2[idxos, i]
end

@views x .+= α .* vec(conj.(S.cmaps[icoil]) .* S.kL1[idx, :])
@views x .+= α .* vec(conj.(S.cmaps[icoil]) .* S.kL2[idx, :])
end
finally
BLAS.set_num_threads(bthreads)
Expand Down Expand Up @@ -138,11 +133,10 @@ function NFFTNormalOpBasisFuncLO(
trj::Vector{Matrix{T}},
U::Matrix{Complex{T}};
cmaps = (1,),
fftplan = plan_fft(Array{Complex{T}}(undef, 2 .* img_shape); flags = FFTW.MEASURE),
nfftflag = FFTW.ESTIMATE,
Λ = calculateToeplitzKernelBasis(2 .* img_shape, trj, U; fftplan = fftplan, nfftflag = nfftflag),
verbose = false,
Λ = calculateToeplitzKernelBasis(2 .* img_shape, trj, U; verbose = verbose),
) where {T}

S = NFFTNormalOpBasisFunc(img_shape, trj, U; cmaps = cmaps, fftplan = fftplan, Λ = Λ)
S = NFFTNormalOpBasisFunc(img_shape, trj, U; cmaps = cmaps, Λ = Λ)
return NFFTNormalOpBasisFuncLO(S)
end

0 comments on commit 3cfeb13

Please sign in to comment.