Skip to content

Commit c8eae2b

Browse files
Fix issues identified in CodeRabbit review
- Remove dead code (commented import) - Add proper type annotations to docstrings - Fix sum() with generator for tensor operations - Fix inverse transform logic when radial_bins=None and both magnitude+phase returned - Add validation for empty feature extraction in RadialFourierFeatures3D - Update test to expect ValueError for empty n_bins_list - All tests passing
1 parent 9c49afa commit c8eae2b

File tree

2 files changed

+40
-26
lines changed

2 files changed

+40
-26
lines changed

monai/transforms/signal/radial_fourier.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,6 @@
2525
from monai.transforms.transform import Transform
2626
from monai.utils import convert_data_type
2727

28-
# Optional imports for type checking
29-
# spatial, _ = optional_import("monai.utils", name="spatial") # Commented out unused import
30-
3128

3229
class RadialFourier3D(Transform):
3330
"""
@@ -37,24 +34,25 @@ class RadialFourier3D(Transform):
3734
(e.g., different resolution in axial vs coronal planes).
3835
The radial transform provides rotation-invariant frequency analysis and can help
3936
normalize frequency representations across datasets with different acquisition parameters.
37+
4038
Args:
41-
normalize: if True, normalize the output by the number of voxels.
42-
return_magnitude: if True, return magnitude of the complex result.
43-
return_phase: if True, return phase of the complex result.
44-
radial_bins: number of radial bins for frequency aggregation. If None, returns full 3D spectrum.
45-
max_frequency: maximum normalized frequency to include (0.0 to 1.0).
46-
spatial_dims: spatial dimensions to apply transform to. Default is last three dimensions.
39+
normalize (bool): if True, normalize the output by the number of voxels.
40+
return_magnitude (bool): if True, return magnitude of the complex result.
41+
return_phase (bool): if True, return phase of the complex result.
42+
radial_bins (Optional[int]): number of radial bins for frequency aggregation.
43+
If None, returns full 3D spectrum.
44+
max_frequency (float): maximum normalized frequency to include (0.0 to 1.0).
45+
spatial_dims (Union[int, Sequence[int]]): spatial dimensions to apply transform to.
46+
Default is last three dimensions.
47+
4748
Returns:
4849
Radial Fourier transform of input data. Shape depends on parameters:
4950
- If radial_bins is None: complex tensor of same spatial shape as input
5051
- If radial_bins is set: real tensor of shape (radial_bins,) for magnitude/phase
51-
Example:
52-
>>> transform = RadialFourier3D(radial_bins=64, return_magnitude=True)
53-
>>> image = torch.randn(1, 128, 128, 96) # Batch, Height, Width, Depth
54-
>>> result = transform(image) # Shape: (1, 64)
52+
5553
Raises:
56-
ValueError: If max_frequency not in (0.0, 1.0], radial_bins < 1, or both
57-
return_magnitude and return_phase are False.
54+
ValueError: If max_frequency not in (0.0, 1.0], radial_bins < 1,
55+
or both return_magnitude and return_phase are False.
5856
"""
5957

6058
def __init__(
@@ -120,7 +118,7 @@ def _compute_radial_coordinates(self, shape: tuple[int, ...], device: torch.devi
120118
mesh = torch.meshgrid(coords)
121119
# Note: older meshgrid uses ij indexing by default in PyTorch
122120

123-
radial = torch.sqrt(sum(c**2 for c in mesh))
121+
radial = torch.sqrt(torch.stack([c**2 for c in mesh]).sum(dim=0))
124122

125123
return radial
126124

@@ -271,10 +269,19 @@ def inverse(self, radial_data: NdarrayOrTensor, original_shape: tuple[int, ...])
271269

272270
# Separate magnitude and phase if needed
273271
if self.return_magnitude and self.return_phase:
274-
# Assuming they were concatenated along last dimension
275-
split_idx = radial_tensor.shape[-1] // 2
276-
magnitude = radial_tensor[..., :split_idx]
277-
phase = radial_tensor[..., split_idx:]
272+
# When radial_bins is None, magnitude and phase were concatenated along last dimension
273+
# The last dimension was doubled (magnitude + phase)
274+
last_dim = radial_tensor.shape[-1]
275+
if last_dim != original_shape[-1] * 2:
276+
raise ValueError(
277+
f"For inverse with magnitude+phase and radial_bins=None, "
278+
f"expected last dimension to be doubled. "
279+
f"Got {last_dim}, expected {original_shape[-1] * 2}"
280+
)
281+
282+
split_size = original_shape[-1]
283+
magnitude = radial_tensor[..., :split_size]
284+
phase = radial_tensor[..., split_size:]
278285
radial_tensor = torch.complex(magnitude * torch.cos(phase), magnitude * torch.sin(phase))
279286

280287
# Apply inverse FFT
@@ -302,12 +309,15 @@ class RadialFourierFeatures3D(Transform):
302309
Extract radial Fourier features for medical image analysis.
303310
Computes multiple radial Fourier transforms with different parameters
304311
to create a comprehensive frequency feature representation.
312+
305313
Args:
306314
n_bins_list: list of radial bin counts to compute.
307315
return_types: list of return types: 'magnitude', 'phase', or 'complex'.
308316
normalize: if True, normalize the output.
317+
309318
Returns:
310319
Concatenated radial Fourier features.
320+
311321
Example:
312322
>>> transform = RadialFourierFeatures3D(n_bins_list=[32, 64, 128])
313323
>>> image = torch.randn(1, 128, 128, 96)
@@ -325,6 +335,12 @@ def __init__(
325335
self.return_types = return_types
326336
self.normalize = normalize
327337

338+
# Validate parameters
339+
if not n_bins_list:
340+
raise ValueError("n_bins_list must not be empty")
341+
if not return_types:
342+
raise ValueError("return_types must not be empty")
343+
328344
# Create individual transforms
329345
self.transforms = []
330346
for n_bins in n_bins_list:
@@ -355,7 +371,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
355371
features_tensors.append(feat)
356372
output = torch.cat(features_tensors, dim=-1)
357373
else:
358-
output = img
374+
raise ValueError("No features extracted. This should not happen with validated parameters.")
359375

360376
# Convert to original type if needed
361377
if isinstance(img, np.ndarray):

tests/test_radial_fourier.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,9 @@ def test_complex_output(self):
180180
self.assertEqual(features.shape, (2, 16 * 2))
181181

182182
def test_empty_bins_list(self):
183-
"""Test with empty bins list."""
184-
transform = RadialFourierFeatures3D(n_bins_list=[], return_types=["magnitude"])
185-
features = transform(self.test_image)
186-
# Should return original image when no transforms
187-
self.assertEqual(features.shape, self.test_image.shape)
183+
"""Test with empty bins list raises ValueError."""
184+
with self.assertRaises(ValueError):
185+
RadialFourierFeatures3D(n_bins_list=[], return_types=["magnitude"])
188186

189187
def test_numpy_compatibility(self):
190188
"""Test with numpy input."""

0 commit comments

Comments
 (0)