Skip to content

Commit 33a1cbb

Browse files
authored
Fix: correctly apply fftshift to real-valued data inputs
Correctly apply fftshift to real-valued data inputs Signed-off-by: Puyang Wang <[email protected]>
1 parent bfcb318 commit 33a1cbb

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

monai/networks/blocks/fft_utils_t.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,20 +187,19 @@ def fftn_centered_t(im: Tensor, spatial_dims: int, is_complex: bool = True) -> T
187187
output2 = fftn_centered(im, spatial_dims=2, is_complex=True)
188188
"""
189189
# define spatial dims to perform ifftshift, fftshift, and fft
190-
shift = list(range(-spatial_dims, 0))
190+
dims = list(range(-spatial_dims, 0))
191191
if is_complex:
192192
if im.shape[-1] != 2:
193193
raise ValueError(f"img.shape[-1] is not 2 ({im.shape[-1]}).")
194-
shift = list(range(-spatial_dims - 1, -1))
195-
dims = list(range(-spatial_dims, 0))
196-
197-
x = ifftshift(im, shift)
194+
x = ifftshift(im, [d - 1 for d in dims])
195+
else:
196+
x = ifftshift(im, dims)
198197

199198
if is_complex:
200199
x = torch.view_as_real(torch.fft.fftn(torch.view_as_complex(x), dim=dims, norm="ortho"))
201200
else:
202201
x = torch.view_as_real(torch.fft.fftn(x, dim=dims, norm="ortho"))
203202

204-
out: Tensor = fftshift(x, shift)
203+
out: Tensor = fftshift(x, [d - 1 for d in dims])
205204

206205
return out

0 commit comments

Comments
 (0)