@@ -139,21 +139,17 @@ def ifftn_centered_t(ksp: Tensor, spatial_dims: int, is_complex: bool = True) ->
139139 output2 = ifftn_centered(ksp, spatial_dims=2, is_complex=True)
140140 """
141141 # define spatial dims to perform ifftshift, fftshift, and ifft
142- shift = list (range (- spatial_dims , 0 ))
142+ dims = list (range (- spatial_dims , 0 ))
143143 if is_complex :
144144 if ksp .shape [- 1 ] != 2 :
145145 raise ValueError (f"ksp.shape[-1] is not 2 ({ ksp .shape [- 1 ]} )." )
146- shift = list (range (- spatial_dims - 1 , - 1 ))
147- dims = list (range (- spatial_dims , 0 ))
148-
149- x = ifftshift (ksp , shift )
150-
151- if is_complex :
152- x = torch .view_as_real (torch .fft .ifftn (torch .view_as_complex (x ), dim = dims , norm = "ortho" ))
146+ x = torch .view_as_complex (ifftshift (ksp , [d - 1 for d in dims ]))
153147 else :
154- x = torch .view_as_real (torch .fft .ifftn (x , dim = dims , norm = "ortho" ))
148+ x = ifftshift (ksp , dims )
149+
150+ x = torch .view_as_real (torch .fft .ifftn (x , dim = dims , norm = "ortho" ))
155151
156- out : Tensor = fftshift (x , shift )
152+ out : Tensor = fftshift (x , [ d - 1 for d in dims ] )
157153
158154 return out
159155
@@ -187,20 +183,16 @@ def fftn_centered_t(im: Tensor, spatial_dims: int, is_complex: bool = True) -> T
187183 output2 = fftn_centered(im, spatial_dims=2, is_complex=True)
188184 """
189185 # define spatial dims to perform ifftshift, fftshift, and fft
190- shift = list (range (- spatial_dims , 0 ))
186+ dims = list (range (- spatial_dims , 0 ))
191187 if is_complex :
192188 if im .shape [- 1 ] != 2 :
193189 raise ValueError (f"img.shape[-1] is not 2 ({ im .shape [- 1 ]} )." )
194- shift = list (range (- spatial_dims - 1 , - 1 ))
195- dims = list (range (- spatial_dims , 0 ))
196-
197- x = ifftshift (im , shift )
198-
199- if is_complex :
200- x = torch .view_as_real (torch .fft .fftn (torch .view_as_complex (x ), dim = dims , norm = "ortho" ))
190+ x = torch .view_as_complex (ifftshift (im , [d - 1 for d in dims ]))
201191 else :
202- x = torch .view_as_real (torch .fft .fftn (x , dim = dims , norm = "ortho" ))
192+ x = ifftshift (im , dims )
193+
194+ x = torch .view_as_real (torch .fft .fftn (x , dim = dims , norm = "ortho" ))
203195
204- out : Tensor = fftshift (x , shift )
196+ out : Tensor = fftshift (x , [ d - 1 for d in dims ] )
205197
206198 return out
0 commit comments