Speed up back-projection + cleanup
JakobAsslaender committed Dec 21, 2022
36 changes: 17 additions & 19 deletions src/BackProjection.jl
function calculateBackProjection(data::AbstractArray{T}, trj, U, cmaps) where {T}
function calculateBackProjection(data::AbstractArray{T}, trj, U, cmaps; verbose = false) where {T}
test_dimension(data, trj, U, cmaps)

Nt, Ncoef = size(U)
_, Ncoef = size(U)
img_shape = size(cmaps[1])
Ncoils = length(cmaps)


p = NFFT.NFFTPlan(trj[1], img_shape; precompute=POLYNOMIAL, blocking = false, fftflags = FFTW.MEASURE)
pv = [copy(p) for _ = 1:Threads.nthreads()]
xbp = [zeros(T, img_shape..., Ncoef) for _ = 1:Threads.nthreads()]
xtmp = [Array{T}(undef, img_shape) for _ = 1:Threads.nthreads()]

@batch for it 1:Nt
tid = Threads.threadid()
Ui = reshape(conj.(U[it, :]), one.(img_shape)..., Ncoef)
NFFT.nodes!(pv[tid], trj[it])

for icoil 1:Ncoils
@views mul!(xtmp[tid], adjoint(pv[tid]), data[:, it, icoil])
@views xbp[tid] .+= conj.(cmaps[icoil]) .* xtmp[tid] .* Ui
p = plan_nfft(reduce(hcat,trj), img_shape; precompute=TENSOR, blocking = true, fftflags = FFTW.MEASURE)
xbp = zeros(T, img_shape..., Ncoef)
xtmp = Array{T}(undef, img_shape)

dataU = similar(@view data[:,:,1]) # size = Ncycles*Nr x Nt
img_idx = CartesianIndices(img_shape)
for icoef axes(U,2)
t = @elapsed for icoil eachindex(cmaps)
@simd for i CartesianIndices(dataU)
@inbounds dataU[i] = data[i,icoil] * conj(U[i[2],icoef])
mul!(xtmp, adjoint(p), vec(dataU))
xbp[img_idx,icoef] .+= conj.(cmaps[icoil]) .* xtmp
verbose && println("coefficient = $icoef: t = $t s"); flush(stdout)
return sum(xbp)
return xbp

function test_dimension(data, trj, U, cmaps)
30 changes: 12 additions & 18 deletions src/CoilMaps.jl
function calcCoilMaps(data::AbstractArray{Complex{T},3}, trj::AbstractVector{<:AbstractMatrix{T}}, U::AbstractMatrix{Complex{T}}, img_shape::NTuple{N,Int}; kernel_size = ntuple(_->6, N), calib_size = ntuple(_->24, N), eigThresh_1=0.04, eigThresh_2=0.0, nmaps=1) where {N,T}
function calcCoilMaps(data::AbstractArray{Complex{T},3}, trj::AbstractVector{<:AbstractMatrix{T}}, U::AbstractMatrix{Complex{T}}, img_shape::NTuple{N,Int}; kernel_size = ntuple(_->6, N), calib_size = ntuple(_->24, N), eigThresh_1=0.04, eigThresh_2=0.0, nmaps=1, verbose = false) where {N,T}
Ncoils = size(data,3)
Ndims = length(img_shape)
imdims = ntuple(i->i, Ndims)

# dataU = data .* U[:,1]'
# dataU .*= dropdims(sum(abs2, combinedimsview(trj), dims=1), dims=1)
dataU = similar(data) # size = Ncycles*Nr x Nt x Ncoils
cU1 = conj(U[:,1])
@batch for i CartesianIndices(dataU)
dataU[i] = data[i] * cU1[i[2]] * sum(abs2, @view trj[i[2]][:,i[1]])
dataU = reshape(dataU, :, size(dataU, 3))

p = plan_nfft(reduce(hcat,trj), img_shape)
p = plan_nfft(reduce(hcat,trj), img_shape; precompute=TENSOR, blocking = true, fftflags = FFTW.MEASURE)
xbp = Array{Complex{T}}(undef, img_shape..., Ncoils)

dataU = similar(@view data[:,:,1]) # size = Ncycles*Nr x Nt
img_idx = CartesianIndices(img_shape)

@info "BP for coils maps: "
@time begin
@batch for ic 1:Ncoils
@views mul!(xbp[img_idx,ic], adjoint(copy(p)), dataU[:,ic])
t = @elapsed for icoil axes(data,3)
@simd for i CartesianIndices(dataU)
dataU[i] = data[i,icoil] * conj(U[i[2],1]) * sum(abs2, @view trj[i[2]][:,i[1]])
@views mul!(xbp[img_idx,icoil], adjoint(p), vec(dataU))
verbose && println("BP for coils maps: $t s")

kbp = fftshift(xbp, imdims)
fft!(kbp, imdims)
Expand All @@ -31,8 +23,10 @@ function calcCoilMaps(data::AbstractArray{Complex{T},3}, trj::AbstractVector{<:A
m = CartesianIndices(calib_size) .+ CartesianIndex((img_shape .- calib_size) 2)
kbp = kbp[m,:]

@info "espirit: "
cmaps = @time espirit(kbp, img_shape, kernel_size, eigThresh_1=eigThresh_1, eigThresh_2=eigThresh_2, nmaps=nmaps)
t = @elapsed begin
cmaps = espirit(kbp, img_shape, kernel_size, eigThresh_1=eigThresh_1, eigThresh_2=eigThresh_2, nmaps=nmaps)
verbose && println("espirit: $t s")

cmaps = [cmaps[img_idx,ic,1] for ic=1:Ncoils]
xbp = [ xbp[img_idx,ic ] for ic=1:Ncoils]
34 changes: 15 additions & 19 deletions src/NFFTNormalOpBasisFunc.jl
Expand Up @@ -16,26 +16,22 @@ function calculateToeplitzKernelBasis(img_shape_os, trj::Vector{Matrix{T}}, U::M
for ic2 axes(Λ, 2), ic1 axes(Λ, 1)
if ic2 >= ic1 # eval. only upper triangular matrix
t = @elapsed begin
@simd for it axes(U,1)
@inbounds S[:,it] .= conj(U[it,ic1]) * U[it,ic2]

mul!(λ, adjoint(nfftplan), vec(S))
fftshift!(λ2, λ)
mul!(λ, fftplan, λ2)
λ2 .= conj.(λ2)
mul!(λ3, fftplan, λ2)

Threads.@threads for it eachindex(λ)
@inbounds Λ[ic2,ic1,it] = λ3[it]
@inbounds Λ[ic1,ic2,it] = λ[it]

if verbose
println("ic = ($ic1, $ic2): t = $t")
@simd for it axes(U,1)
@inbounds S[:,it] .= conj(U[it,ic1]) * U[it,ic2]

mul!(λ, adjoint(nfftplan), vec(S))
fftshift!(λ2, λ)
mul!(λ, fftplan, λ2)
λ2 .= conj.(λ2)
mul!(λ3, fftplan, λ2)

Threads.@threads for it eachindex(λ)
@inbounds Λ[ic2,ic1,it] = λ3[it]
@inbounds Λ[ic1,ic2,it] = λ[it]
verbose && println("ic = ($ic1, $ic2): t = $t s"); flush(stdout)

2 changes: 1 addition & 1 deletion test/reconstruct.jl
Expand Up @@ -44,7 +44,7 @@ for it ∈ axes(data,2)

## BackProjection
b = vec(calculateBackProjection(data, trj, U, [ones(T, Nx,Nx)]))
b = vec(calculateBackProjection(data, trj, U, [ones(T, Nx,Nx)], verbose = true))

## construct forward operator
A = NFFTNormalOpBasisFuncLO((Nx,Nx), trj, U; verbose = true)
Expand Down

