@@ -139,21 +139,17 @@ def ifftn_centered_t(ksp: Tensor, spatial_dims: int, is_complex: bool = True) ->
139139 output2 = ifftn_centered(ksp, spatial_dims=2, is_complex=True)
140140 """
141141 # define spatial dims to perform ifftshift, fftshift, and ifft
142- shift = list (range (- spatial_dims , 0 ))
142+ dims = list (range (- spatial_dims , 0 ))
143143 if is_complex :
144144 if ksp .shape [- 1 ] != 2 :
145145 raise ValueError (f"ksp.shape[-1] is not 2 ({ ksp .shape [- 1 ]} )." )
146- shift = list (range (- spatial_dims - 1 , - 1 ))
147- dims = list (range (- spatial_dims , 0 ))
146+ x = torch .view_as_complex (ksp )
148147
149- x = ifftshift (ksp , shift )
148+ x = ifftshift (ksp , dims )
150149
151- if is_complex :
152- x = torch .view_as_real (torch .fft .ifftn (torch .view_as_complex (x ), dim = dims , norm = "ortho" ))
153- else :
154- x = torch .view_as_real (torch .fft .ifftn (x , dim = dims , norm = "ortho" ))
150+ x = fftshift (torch .fft .ifftn (x , dim = dims , norm = "ortho" ), dims )
155151
156- out : Tensor = fftshift ( x , shift )
152+ out : Tensor = torch . view_as_real ( x )
157153
158154 return out
159155
@@ -191,12 +187,12 @@ def fftn_centered_t(im: Tensor, spatial_dims: int, is_complex: bool = True) -> T
191187 if is_complex :
192188 if im .shape [- 1 ] != 2 :
193189 raise ValueError (f"img.shape[-1] is not 2 ({ im .shape [- 1 ]} )." )
194- x = torch .view_as_complex (ifftshift ( im , [ d - 1 for d in dims ]) )
195- else :
196- x = ifftshift (im , dims )
190+ x = torch .view_as_complex (im )
191+
192+ x = ifftshift (im , dims )
197193
198- x = torch . view_as_real (torch .fft .fftn (x , dim = dims , norm = "ortho" ))
194+ x = fftshift (torch .fft .fftn (x , dim = dims , norm = "ortho" ), dims )
199195
200- out : Tensor = fftshift ( x , [ d - 1 for d in dims ] )
196+ out : Tensor = torch . view_as_real ( x )
201197
202198 return out
0 commit comments