@@ -143,13 +143,13 @@ def ifftn_centered_t(ksp: Tensor, spatial_dims: int, is_complex: bool = True) ->
143143 if is_complex :
144144 if ksp .shape [- 1 ] != 2 :
145145 raise ValueError (f"ksp.shape[-1] is not 2 ({ ksp .shape [- 1 ]} )." )
146- x = torch .view_as_complex (ksp )
146+ x = torch .view_as_complex (ifftshift (ksp , [d - 1 for d in dims ]))
147+ else :
148+ x = ifftshift (ksp , dims )
147149
148- x = ifftshift ( x , dims )
150+ x = torch . view_as_real ( torch . fft . fftn ( x , dim = dims , norm = "ortho" ) )
149151
150- x = fftshift (torch .fft .ifftn (x , dim = dims , norm = "ortho" ), dims )
151-
152- out : Tensor = torch .view_as_real (x )
152+ out : Tensor = fftshift (x , [d - 1 for d in dims ])
153153
154154 return out
155155
@@ -187,12 +187,12 @@ def fftn_centered_t(im: Tensor, spatial_dims: int, is_complex: bool = True) -> T
187187 if is_complex :
188188 if im .shape [- 1 ] != 2 :
189189 raise ValueError (f"img.shape[-1] is not 2 ({ im .shape [- 1 ]} )." )
190- x = torch .view_as_complex (im )
191-
192- x = ifftshift (im , dims )
190+ x = torch .view_as_complex (ifftshift ( im , [ d - 1 for d in dims ]) )
191+ else :
192+ x = ifftshift (im , dims )
193193
194- x = fftshift (torch .fft .fftn (x , dim = dims , norm = "ortho" ), dims )
194+ x = torch . view_as_real (torch .fft .fftn (x , dim = dims , norm = "ortho" ))
195195
196- out : Tensor = torch . view_as_real ( x )
196+ out : Tensor = fftshift ( x , [ d - 1 for d in dims ] )
197197
198198 return out
0 commit comments