Skip to content

Commit d1b13f2

Browse files
authored
fix fftn_centered_t and ifftn_centered_t
Signed-off-by: Puyang Wang <[email protected]>
1 parent 2779f2b commit d1b13f2

File tree

1 file changed

+10
-14
lines changed

1 file changed

+10
-14
lines changed

monai/networks/blocks/fft_utils_t.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -139,21 +139,17 @@ def ifftn_centered_t(ksp: Tensor, spatial_dims: int, is_complex: bool = True) ->
139139
output2 = ifftn_centered(ksp, spatial_dims=2, is_complex=True)
140140
"""
141141
# define spatial dims to perform ifftshift, fftshift, and ifft
142-
shift = list(range(-spatial_dims, 0))
142+
dims = list(range(-spatial_dims, 0))
143143
if is_complex:
144144
if ksp.shape[-1] != 2:
145145
raise ValueError(f"ksp.shape[-1] is not 2 ({ksp.shape[-1]}).")
146-
shift = list(range(-spatial_dims - 1, -1))
147-
dims = list(range(-spatial_dims, 0))
146+
x = torch.view_as_complex(ksp)
148147

149-
x = ifftshift(ksp, shift)
148+
x = ifftshift(ksp, dims)
150149

151-
if is_complex:
152-
x = torch.view_as_real(torch.fft.ifftn(torch.view_as_complex(x), dim=dims, norm="ortho"))
153-
else:
154-
x = torch.view_as_real(torch.fft.ifftn(x, dim=dims, norm="ortho"))
150+
x = fftshift(torch.fft.ifftn(x, dim=dims, norm="ortho"), dims)
155151

156-
out: Tensor = fftshift(x, shift)
152+
out: Tensor = torch.view_as_real(x)
157153

158154
return out
159155

@@ -191,12 +187,12 @@ def fftn_centered_t(im: Tensor, spatial_dims: int, is_complex: bool = True) -> T
191187
if is_complex:
192188
if im.shape[-1] != 2:
193189
raise ValueError(f"img.shape[-1] is not 2 ({im.shape[-1]}).")
194-
x = torch.view_as_complex(ifftshift(im, [d - 1 for d in dims]))
195-
else:
196-
x = ifftshift(im, dims)
190+
x = torch.view_as_complex(im)
191+
192+
x = ifftshift(im, dims)
197193

198-
x = torch.view_as_real(torch.fft.fftn(x, dim=dims, norm="ortho"))
194+
x = fftshift(torch.fft.fftn(x, dim=dims, norm="ortho"), dims)
199195

200-
out: Tensor = fftshift(x, [d - 1 for d in dims])
196+
out: Tensor = torch.view_as_real(x)
201197

202198
return out

0 commit comments

Comments
 (0)