2525from monai .transforms .transform import Transform
2626from monai .utils import convert_data_type
2727
28- # Optional imports for type checking
29- # spatial, _ = optional_import("monai.utils", name="spatial") # Commented out unused import
30-
3128
3229class RadialFourier3D (Transform ):
3330 """
@@ -37,24 +34,25 @@ class RadialFourier3D(Transform):
3734 (e.g., different resolution in axial vs coronal planes).
3835 The radial transform provides rotation-invariant frequency analysis and can help
3936 normalize frequency representations across datasets with different acquisition parameters.
37+
4038 Args:
41- normalize: if True, normalize the output by the number of voxels.
42- return_magnitude: if True, return magnitude of the complex result.
43- return_phase: if True, return phase of the complex result.
44- radial_bins: number of radial bins for frequency aggregation. If None, returns full 3D spectrum.
45- max_frequency: maximum normalized frequency to include (0.0 to 1.0).
46- spatial_dims: spatial dimensions to apply transform to. Default is last three dimensions.
39+ normalize (bool): if True, normalize the output by the number of voxels.
40+ return_magnitude (bool): if True, return magnitude of the complex result.
41+ return_phase (bool): if True, return phase of the complex result.
42+ radial_bins (Optional[int]): number of radial bins for frequency aggregation.
43+ If None, returns full 3D spectrum.
44+ max_frequency (float): maximum normalized frequency to include (0.0 to 1.0).
45+ spatial_dims (Union[int, Sequence[int]]): spatial dimensions to apply transform to.
46+ Default is last three dimensions.
47+
4748 Returns:
4849 Radial Fourier transform of input data. Shape depends on parameters:
4950 - If radial_bins is None: complex tensor of same spatial shape as input
5051 - If radial_bins is set: real tensor of shape (radial_bins,) for magnitude/phase
51- Example:
52- >>> transform = RadialFourier3D(radial_bins=64, return_magnitude=True)
53- >>> image = torch.randn(1, 128, 128, 96) # Batch, Height, Width, Depth
54- >>> result = transform(image) # Shape: (1, 64)
52+
5553 Raises:
56- ValueError: If max_frequency not in (0.0, 1.0], radial_bins < 1, or both
57- return_magnitude and return_phase are False.
54+ ValueError: If max_frequency not in (0.0, 1.0], radial_bins < 1,
55+ or both return_magnitude and return_phase are False.
5856 """
5957
6058 def __init__ (
@@ -120,7 +118,7 @@ def _compute_radial_coordinates(self, shape: tuple[int, ...], device: torch.devi
120118 mesh = torch .meshgrid (coords )
121119 # Note: older meshgrid uses ij indexing by default in PyTorch
122120
123- radial = torch .sqrt (sum ( c ** 2 for c in mesh ))
121+ radial = torch .sqrt (torch . stack ([ c ** 2 for c in mesh ]). sum ( dim = 0 ))
124122
125123 return radial
126124
@@ -271,10 +269,19 @@ def inverse(self, radial_data: NdarrayOrTensor, original_shape: tuple[int, ...])
271269
272270 # Separate magnitude and phase if needed
273271 if self .return_magnitude and self .return_phase :
274- # Assuming they were concatenated along last dimension
275- split_idx = radial_tensor .shape [- 1 ] // 2
276- magnitude = radial_tensor [..., :split_idx ]
277- phase = radial_tensor [..., split_idx :]
272+ # When radial_bins is None, magnitude and phase were concatenated along last dimension
273+ # The last dimension was doubled (magnitude + phase)
274+ last_dim = radial_tensor .shape [- 1 ]
275+ if last_dim != original_shape [- 1 ] * 2 :
276+ raise ValueError (
277+ f"For inverse with magnitude+phase and radial_bins=None, "
278+ f"expected last dimension to be doubled. "
279+ f"Got { last_dim } , expected { original_shape [- 1 ] * 2 } "
280+ )
281+
282+ split_size = original_shape [- 1 ]
283+ magnitude = radial_tensor [..., :split_size ]
284+ phase = radial_tensor [..., split_size :]
278285 radial_tensor = torch .complex (magnitude * torch .cos (phase ), magnitude * torch .sin (phase ))
279286
280287 # Apply inverse FFT
@@ -302,12 +309,15 @@ class RadialFourierFeatures3D(Transform):
302309 Extract radial Fourier features for medical image analysis.
303310 Computes multiple radial Fourier transforms with different parameters
304311 to create a comprehensive frequency feature representation.
312+
305313 Args:
306314 n_bins_list: list of radial bin counts to compute.
307315 return_types: list of return types: 'magnitude', 'phase', or 'complex'.
308316 normalize: if True, normalize the output.
317+
309318 Returns:
310319 Concatenated radial Fourier features.
320+
311321 Example:
312322 >>> transform = RadialFourierFeatures3D(n_bins_list=[32, 64, 128])
313323 >>> image = torch.randn(1, 128, 128, 96)
@@ -325,6 +335,12 @@ def __init__(
325335 self .return_types = return_types
326336 self .normalize = normalize
327337
338+ # Validate parameters
339+ if not n_bins_list :
340+ raise ValueError ("n_bins_list must not be empty" )
341+ if not return_types :
342+ raise ValueError ("return_types must not be empty" )
343+
328344 # Create individual transforms
329345 self .transforms = []
330346 for n_bins in n_bins_list :
@@ -355,7 +371,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
355371 features_tensors .append (feat )
356372 output = torch .cat (features_tensors , dim = - 1 )
357373 else :
358- output = img
374+ raise ValueError ( "No features extracted. This should not happen with validated parameters." )
359375
360376 # Convert to original type if needed
361377 if isinstance (img , np .ndarray ):
0 commit comments