From a8bf4e83e723484131f6dddfb07dc1ab7dc72237 Mon Sep 17 00:00:00 2001 From: Jakob Asslaender Date: Tue, 20 Dec 2022 23:20:43 -0500 Subject: [PATCH] Speed up back-projection + cleanup --- src/BackProjection.jl | 36 +++++++++++++++++------------------- src/CoilMaps.jl | 30 ++++++++++++------------------ src/NFFTNormalOpBasisFunc.jl | 34 +++++++++++++++------------------- test/reconstruct.jl | 2 +- 4 files changed, 45 insertions(+), 57 deletions(-) diff --git a/src/BackProjection.jl b/src/BackProjection.jl index 964c67e..3c1bc17 100644 --- a/src/BackProjection.jl +++ b/src/BackProjection.jl @@ -1,28 +1,26 @@ -function calculateBackProjection(data::AbstractArray{T}, trj, U, cmaps) where {T} +function calculateBackProjection(data::AbstractArray{T}, trj, U, cmaps; verbose = false) where {T} test_dimension(data, trj, U, cmaps) - Nt, Ncoef = size(U) + _, Ncoef = size(U) img_shape = size(cmaps[1]) - Ncoils = length(cmaps) - - FFTW.set_num_threads(1) - - p = NFFT.NFFTPlan(trj[1], img_shape; precompute=POLYNOMIAL, blocking = false, fftflags = FFTW.MEASURE) - pv = [copy(p) for _ = 1:Threads.nthreads()] - xbp = [zeros(T, img_shape..., Ncoef) for _ = 1:Threads.nthreads()] - xtmp = [Array{T}(undef, img_shape) for _ = 1:Threads.nthreads()] - - @batch for it ∈ 1:Nt - tid = Threads.threadid() - Ui = reshape(conj.(U[it, :]), one.(img_shape)..., Ncoef) - NFFT.nodes!(pv[tid], trj[it]) - for icoil ∈ 1:Ncoils - @views mul!(xtmp[tid], adjoint(pv[tid]), data[:, it, icoil]) - @views xbp[tid] .+= conj.(cmaps[icoil]) .* xtmp[tid] .* Ui + p = plan_nfft(reduce(hcat,trj), img_shape; precompute=TENSOR, blocking = true, fftflags = FFTW.MEASURE) + xbp = zeros(T, img_shape..., Ncoef) + xtmp = Array{T}(undef, img_shape) + + dataU = similar(@view data[:,:,1]) # size = Ncycles*Nr x Nt + img_idx = CartesianIndices(img_shape) + for icoef ∈ axes(U,2) + t = @elapsed for icoil ∈ eachindex(cmaps) + @simd for i ∈ CartesianIndices(dataU) + @inbounds dataU[i] = data[i,icoil] * conj(U[i[2],icoef]) + end + mul!(xtmp, adjoint(p), vec(dataU)) + xbp[img_idx,icoef] .+= conj.(cmaps[icoil]) .* xtmp end + verbose && println("coefficient = $icoef: t = $t s"); flush(stdout) end - return sum(xbp) + return xbp end function test_dimension(data, trj, U, cmaps) diff --git a/src/CoilMaps.jl b/src/CoilMaps.jl index 5d9018e..5b695ff 100644 --- a/src/CoilMaps.jl +++ b/src/CoilMaps.jl @@ -1,28 +1,20 @@ -function calcCoilMaps(data::AbstractArray{Complex{T},3}, trj::AbstractVector{<:AbstractMatrix{T}}, U::AbstractMatrix{Complex{T}}, img_shape::NTuple{N,Int}; kernel_size = ntuple(_->6, N), calib_size = ntuple(_->24, N), eigThresh_1=0.04, eigThresh_2=0.0, nmaps=1) where {N,T} +function calcCoilMaps(data::AbstractArray{Complex{T},3}, trj::AbstractVector{<:AbstractMatrix{T}}, U::AbstractMatrix{Complex{T}}, img_shape::NTuple{N,Int}; kernel_size = ntuple(_->6, N), calib_size = ntuple(_->24, N), eigThresh_1=0.04, eigThresh_2=0.0, nmaps=1, verbose = false) where {N,T} Ncoils = size(data,3) Ndims = length(img_shape) imdims = ntuple(i->i, Ndims) - # dataU = data .* U[:,1]' - # dataU .*= dropdims(sum(abs2, combinedimsview(trj), dims=1), dims=1) - dataU = similar(data) # size = Ncycles*Nr x Nt x Ncoils - cU1 = conj(U[:,1]) - @batch for i ∈ CartesianIndices(dataU) - dataU[i] = data[i] * cU1[i[2]] * sum(abs2, @view trj[i[2]][:,i[1]]) - end - dataU = reshape(dataU, :, size(dataU, 3)) - - p = plan_nfft(reduce(hcat,trj), img_shape) + p = plan_nfft(reduce(hcat,trj), img_shape; precompute=TENSOR, blocking = true, fftflags = FFTW.MEASURE) xbp = Array{Complex{T}}(undef, img_shape..., Ncoils) + dataU = similar(@view data[:,:,1]) # size = Ncycles*Nr x Nt img_idx = CartesianIndices(img_shape) - - @info "BP for coils maps: " - @time begin - @batch for ic ∈ 1:Ncoils - @views mul!(xbp[img_idx,ic], adjoint(copy(p)), dataU[:,ic]) + t = @elapsed for icoil ∈ axes(data,3) + @simd for i ∈ CartesianIndices(dataU) + dataU[i] = data[i,icoil] * conj(U[i[2],1]) * sum(abs2, @view trj[i[2]][:,i[1]]) end + @views mul!(xbp[img_idx,icoil], adjoint(p), vec(dataU)) end + verbose && println("BP for coils maps: $t s") kbp = fftshift(xbp, imdims) fft!(kbp, imdims) @@ -31,8 +23,10 @@ function calcCoilMaps(data::AbstractArray{Complex{T},3}, trj::AbstractVector{<:A m = CartesianIndices(calib_size) .+ CartesianIndex((img_shape .- calib_size) .÷ 2) kbp = kbp[m,:] - @info "espirit: " - cmaps = @time espirit(kbp, img_shape, kernel_size, eigThresh_1=eigThresh_1, eigThresh_2=eigThresh_2, nmaps=nmaps) + t = @elapsed begin + cmaps = espirit(kbp, img_shape, kernel_size, eigThresh_1=eigThresh_1, eigThresh_2=eigThresh_2, nmaps=nmaps) + end + verbose && println("espirit: $t s") cmaps = [cmaps[img_idx,ic,1] for ic=1:Ncoils] xbp = [ xbp[img_idx,ic ] for ic=1:Ncoils] diff --git a/src/NFFTNormalOpBasisFunc.jl b/src/NFFTNormalOpBasisFunc.jl index 703866c..908036e 100644 --- a/src/NFFTNormalOpBasisFunc.jl +++ b/src/NFFTNormalOpBasisFunc.jl @@ -16,26 +16,22 @@ function calculateToeplitzKernelBasis(img_shape_os, trj::Vector{Matrix{T}}, U::M for ic2 ∈ axes(Λ, 2), ic1 ∈ axes(Λ, 1) if ic2 >= ic1 # eval. only upper triangular matrix t = @elapsed begin - @simd for it ∈ axes(U,1) - @inbounds S[:,it] .= conj(U[it,ic1]) * U[it,ic2] - end - - mul!(λ, adjoint(nfftplan), vec(S)) - fftshift!(λ2, λ) - mul!(λ, fftplan, λ2) - λ2 .= conj.(λ2) - mul!(λ3, fftplan, λ2) - - Threads.@threads for it ∈ eachindex(λ) - @inbounds Λ[ic2,ic1,it] = λ3[it] - @inbounds Λ[ic1,ic2,it] = λ[it] - end - end - - if verbose - println("ic = ($ic1, $ic2): t = $t") - flush(stdout) + @simd for it ∈ axes(U,1) + @inbounds S[:,it] .= conj(U[it,ic1]) * U[it,ic2] + end + + mul!(λ, adjoint(nfftplan), vec(S)) + fftshift!(λ2, λ) + mul!(λ, fftplan, λ2) + λ2 .= conj.(λ2) + mul!(λ3, fftplan, λ2) + + Threads.@threads for it ∈ eachindex(λ) + @inbounds Λ[ic2,ic1,it] = λ3[it] + @inbounds Λ[ic1,ic2,it] = λ[it] + end end + verbose && println("ic = ($ic1, $ic2): t = $t s"); flush(stdout) end end diff --git a/test/reconstruct.jl b/test/reconstruct.jl index 3a91ef8..4dab613 100644 --- a/test/reconstruct.jl +++ b/test/reconstruct.jl @@ -44,7 +44,7 @@ for it ∈ axes(data,2) end ## BackProjection -b = vec(calculateBackProjection(data, trj, U, [ones(T, Nx,Nx)])) +b = vec(calculateBackProjection(data, trj, U, [ones(T, Nx,Nx)], verbose = true)) ## construct forward operator A = NFFTNormalOpBasisFuncLO((Nx,Nx), trj, U; verbose = true)