1414
1515from __future__ import annotations
1616
17- import math
1817from collections .abc import Sequence
1918from typing import Optional , Union
2019
3332class RadialFourier3D (Transform ):
3433 """
3534 Computes the 3D Radial Fourier Transform of medical imaging data.
36-
3735 This transform converts 3D medical images into radial frequency domain representations,
3836 which is particularly useful for handling anisotropic resolution common in medical scans
3937 (e.g., different resolution in axial vs coronal planes).
40-
4138 The radial transform provides rotation-invariant frequency analysis and can help
4239 normalize frequency representations across datasets with different acquisition parameters.
43-
4440 Args:
4541 normalize: if True, normalize the output by the number of voxels.
4642 return_magnitude: if True, return magnitude of the complex result.
4743 return_phase: if True, return phase of the complex result.
4844 radial_bins: number of radial bins for frequency aggregation. If None, returns full 3D spectrum.
4945 max_frequency: maximum normalized frequency to include (0.0 to 1.0).
5046 spatial_dims: spatial dimensions to apply transform to. Default is last three dimensions.
51-
5247 Returns:
5348 Radial Fourier transform of input data. Shape depends on parameters:
5449 - If radial_bins is None: complex tensor of same spatial shape as input
5550 - If radial_bins is set: real tensor of shape (radial_bins,) for magnitude/phase
56-
5751 Example:
5852 >>> transform = RadialFourier3D(radial_bins=64, return_magnitude=True)
5953 >>> image = torch.randn(1, 128, 128, 96) # Batch, Height, Width, Depth
6054 >>> result = transform(image) # Shape: (1, 64)
61-
6255 Raises:
6356 ValueError: If max_frequency not in (0.0, 1.0], radial_bins < 1, or both
6457 return_magnitude and return_phase are False.
@@ -107,11 +100,26 @@ def _compute_radial_coordinates(self, shape: tuple[int, ...], device: torch.devi
107100 coords = []
108101 for dim_size in shape :
109102 # Create frequency range from -0.5 to 0.5
110- freq = torch .fft .fftfreq (dim_size , device = device )
103+ # Compatible with older PyTorch versions
104+ if hasattr (torch .fft , 'fftfreq' ):
105+ freq = torch .fft .fftfreq (dim_size , device = device )
106+ else :
107+ # Fallback for older PyTorch versions (pre-1.8)
108+ n = dim_size
109+ val = 1.0 / n
110+ freq = torch .arange (- (n // 2 ), (n + 1 )// 2 , device = device ) * val
111+ freq = torch .roll (freq , n // 2 )
111112 coords .append (freq )
112113
113114 # Create meshgrid and compute radial distance
114- mesh = torch .meshgrid (coords , indexing = "ij" )
115+ # Compatible with older PyTorch versions (pre-1.10)
116+ try :
117+ mesh = torch .meshgrid (coords , indexing = "ij" )
118+ except TypeError :
119+ # Older PyTorch doesn't support indexing parameter
120+ mesh = torch .meshgrid (coords )
121+ # Note: older meshgrid uses ij indexing by default in PyTorch
122+
115123 radial = torch .sqrt (sum (c ** 2 for c in mesh ))
116124
117125 return radial
@@ -176,7 +184,9 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
176184
177185 # Normalize if requested
178186 if self .normalize :
179- norm_factor = math .prod (spatial_shape )
187+ norm_factor = 1
188+ for dim in spatial_shape :
189+ norm_factor *= dim
180190 spectrum = spectrum / norm_factor
181191
182192 # Compute radial coordinates
@@ -272,7 +282,10 @@ def inverse(self, radial_data: NdarrayOrTensor, original_shape: tuple[int, ...])
272282 result = fftshift (result , dim = self .spatial_dims )
273283
274284 if self .normalize :
275- result = result * math .prod (original_shape )
285+ shape_product = 1
286+ for dim in original_shape :
287+ shape_product *= dim
288+ result = result * shape_product
276289
277290 result , * _ = convert_data_type (result .real , type (radial_data ))
278291 return result
@@ -287,18 +300,14 @@ def inverse(self, radial_data: NdarrayOrTensor, original_shape: tuple[int, ...])
287300class RadialFourierFeatures3D (Transform ):
288301 """
289302 Extract radial Fourier features for medical image analysis.
290-
291303 Computes multiple radial Fourier transforms with different parameters
292304 to create a comprehensive frequency feature representation.
293-
294305 Args:
295306 n_bins_list: list of radial bin counts to compute.
296307 return_types: list of return types: 'magnitude', 'phase', or 'complex'.
297308 normalize: if True, normalize the output.
298-
299309 Returns:
300310 Concatenated radial Fourier features.
301-
302311 Example:
303312 >>> transform = RadialFourierFeatures3D(n_bins_list=[32, 64, 128])
304313 >>> image = torch.randn(1, 128, 128, 96)
0 commit comments