Skip to content

Commit d12876d

Browse files
Fix CodeRabbit review issues for radial Fourier transform
- Add device parameter to _compute_radial_coordinates to prevent CPU/GPU mismatch - Fix frequency mask expansion for multi-dimensional inputs - Add reconstruction accuracy test assertion (with proper magnitude+phase for inverse) - Add Raises section to docstring - Remove unused import - Address all review comments
1 parent cb0546d commit d12876d

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

monai/transforms/signal/radial_fourier.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525

2626
from monai.config import NdarrayOrTensor
2727
from monai.transforms.transform import Transform
28-
from monai.utils import convert_data_type, optional_import
28+
from monai.utils import convert_data_type
2929

3030
# Optional imports for type checking
31-
spatial, _ = optional_import("monai.utils", name="spatial")
31+
# spatial, _ = optional_import("monai.utils", name="spatial") # Commented out unused import
3232

3333

3434
class RadialFourier3D(Transform):
@@ -59,6 +59,10 @@ class RadialFourier3D(Transform):
5959
>>> transform = RadialFourier3D(radial_bins=64, return_magnitude=True)
6060
>>> image = torch.randn(1, 128, 128, 96) # Batch, Height, Width, Depth
6161
>>> result = transform(image) # Shape: (1, 64)
62+
63+
Raises:
64+
ValueError: If max_frequency not in (0.0, 1.0], radial_bins < 1, or both
65+
return_magnitude and return_phase are False.
6266
"""
6367

6468
def __init__(
@@ -89,12 +93,13 @@ def __init__(
8993
if not return_magnitude and not return_phase:
9094
raise ValueError("At least one of return_magnitude or return_phase must be True")
9195

92-
def _compute_radial_coordinates(self, shape: tuple[int, ...]) -> torch.Tensor:
96+
def _compute_radial_coordinates(self, shape: tuple[int, ...], device: torch.device = None) -> torch.Tensor:
9397
"""
9498
Compute radial distance from frequency domain center.
9599
96100
Args:
97101
shape: spatial dimensions (D, H, W) or (H, W, D) depending on dims order.
102+
device: device to create tensor on.
98103
99104
Returns:
100105
Tensor of same spatial shape with radial distances.
@@ -103,7 +108,7 @@ def _compute_radial_coordinates(self, shape: tuple[int, ...]) -> torch.Tensor:
103108
coords = []
104109
for dim_size in shape:
105110
# Create frequency range from -0.5 to 0.5
106-
freq = torch.fft.fftfreq(dim_size)
111+
freq = torch.fft.fftfreq(dim_size, device=device)
107112
coords.append(freq)
108113

109114
# Create meshgrid and compute radial distance
@@ -176,7 +181,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
176181
spectrum = spectrum / norm_factor
177182

178183
# Compute radial coordinates
179-
radial_coords = self._compute_radial_coordinates(spatial_shape)
184+
radial_coords = self._compute_radial_coordinates(spatial_shape, device=spectrum.device)
180185

181186
# Apply radial binning if requested
182187
if self.radial_bins is not None:
@@ -217,7 +222,8 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
217222
if self.max_frequency < 1.0:
218223
freq_mask = radial_coords <= (self.max_frequency * 0.5)
219224
# Expand mask to match spectrum dimensions
220-
for _ in range(len(self.spatial_dims)):
225+
n_non_spatial = len(spectrum.shape) - len(spatial_shape)
226+
for _ in range(n_non_spatial):
221227
freq_mask = freq_mask.unsqueeze(0)
222228
spectrum = spectrum * freq_mask
223229

tests/test_radial_fourier.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def test_normalization(self):
7676
def test_inverse_transform(self):
7777
"""Test approximate inverse transform."""
7878
# Use full spectrum for invertibility
79-
transform = RadialFourier3D(radial_bins=None, normalize=True)
79+
transform = RadialFourier3D(radial_bins=None, normalize=True, return_magnitude=True, return_phase=True)
8080

8181
# Forward transform
8282
spectrum = transform(self.test_image_3d)
@@ -87,6 +87,9 @@ def test_inverse_transform(self):
8787
# Should have same shape
8888
self.assertEqual(reconstructed.shape, self.test_image_3d.shape)
8989

90+
# Should approximately reconstruct original
91+
self.assertTrue(torch.allclose(reconstructed, self.test_image_3d, atol=1e-5))
92+
9093
def test_deterministic(self):
9194
"""Test that transform is deterministic."""
9295
transform = RadialFourier3D(radial_bins=32)

0 commit comments

Comments
 (0)