Skip to content

Commit 9c49afa

Browse files
Add 3D Radial Fourier Transform for medical imaging
- Implements RadialFourier3D for anisotropic resolution normalization - Adds RadialFourierFeatures3D for multi-scale frequency analysis - Includes comprehensive test suite (20/20 passing) - Adds version compatibility for older PyTorch/Python versions - Follows MONAI transform conventions - Exclude transforms/__init__.py from pycln to avoid import removal
1 parent ead0815 commit 9c49afa

File tree

2 files changed

+25
-16
lines changed

2 files changed

+25
-16
lines changed

monai/transforms/signal/radial_fourier.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from __future__ import annotations
1616

17-
import math
1817
from collections.abc import Sequence
1918
from typing import Optional, Union
2019

@@ -33,32 +32,26 @@
3332
class RadialFourier3D(Transform):
3433
"""
3534
Computes the 3D Radial Fourier Transform of medical imaging data.
36-
3735
This transform converts 3D medical images into radial frequency domain representations,
3836
which is particularly useful for handling anisotropic resolution common in medical scans
3937
(e.g., different resolution in axial vs coronal planes).
40-
4138
The radial transform provides rotation-invariant frequency analysis and can help
4239
normalize frequency representations across datasets with different acquisition parameters.
43-
4440
Args:
4541
normalize: if True, normalize the output by the number of voxels.
4642
return_magnitude: if True, return magnitude of the complex result.
4743
return_phase: if True, return phase of the complex result.
4844
radial_bins: number of radial bins for frequency aggregation. If None, returns full 3D spectrum.
4945
max_frequency: maximum normalized frequency to include (0.0 to 1.0).
5046
spatial_dims: spatial dimensions to apply transform to. Default is last three dimensions.
51-
5247
Returns:
5348
Radial Fourier transform of input data. Shape depends on parameters:
5449
- If radial_bins is None: complex tensor of same spatial shape as input
5550
- If radial_bins is set: real tensor of shape (radial_bins,) for magnitude/phase
56-
5751
Example:
5852
>>> transform = RadialFourier3D(radial_bins=64, return_magnitude=True)
5953
>>> image = torch.randn(1, 128, 128, 96) # Batch, Height, Width, Depth
6054
>>> result = transform(image) # Shape: (1, 64)
61-
6255
Raises:
6356
ValueError: If max_frequency not in (0.0, 1.0], radial_bins < 1, or both
6457
return_magnitude and return_phase are False.
@@ -107,11 +100,26 @@ def _compute_radial_coordinates(self, shape: tuple[int, ...], device: torch.devi
107100
coords = []
108101
for dim_size in shape:
109102
# Create frequency range from -0.5 to 0.5
110-
freq = torch.fft.fftfreq(dim_size, device=device)
103+
# Compatible with older PyTorch versions
104+
if hasattr(torch.fft, 'fftfreq'):
105+
freq = torch.fft.fftfreq(dim_size, device=device)
106+
else:
107+
# Fallback for older PyTorch versions (pre-1.8)
108+
n = dim_size
109+
val = 1.0 / n
110+
freq = torch.arange(-(n//2), (n+1)//2, device=device) * val
111+
freq = torch.roll(freq, n//2)
111112
coords.append(freq)
112113

113114
# Create meshgrid and compute radial distance
114-
mesh = torch.meshgrid(coords, indexing="ij")
115+
# Compatible with older PyTorch versions (pre-1.10)
116+
try:
117+
mesh = torch.meshgrid(coords, indexing="ij")
118+
except TypeError:
119+
# Older PyTorch doesn't support indexing parameter
120+
mesh = torch.meshgrid(coords)
121+
# Note: older meshgrid uses ij indexing by default in PyTorch
122+
115123
radial = torch.sqrt(sum(c**2 for c in mesh))
116124

117125
return radial
@@ -176,7 +184,9 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
176184

177185
# Normalize if requested
178186
if self.normalize:
179-
norm_factor = math.prod(spatial_shape)
187+
norm_factor = 1
188+
for dim in spatial_shape:
189+
norm_factor *= dim
180190
spectrum = spectrum / norm_factor
181191

182192
# Compute radial coordinates
@@ -272,7 +282,10 @@ def inverse(self, radial_data: NdarrayOrTensor, original_shape: tuple[int, ...])
272282
result = fftshift(result, dim=self.spatial_dims)
273283

274284
if self.normalize:
275-
result = result * math.prod(original_shape)
285+
shape_product = 1
286+
for dim in original_shape:
287+
shape_product *= dim
288+
result = result * shape_product
276289

277290
result, *_ = convert_data_type(result.real, type(radial_data))
278291
return result
@@ -287,18 +300,14 @@ def inverse(self, radial_data: NdarrayOrTensor, original_shape: tuple[int, ...])
287300
class RadialFourierFeatures3D(Transform):
288301
"""
289302
Extract radial Fourier features for medical image analysis.
290-
291303
Computes multiple radial Fourier transforms with different parameters
292304
to create a comprehensive frequency feature representation.
293-
294305
Args:
295306
n_bins_list: list of radial bin counts to compute.
296307
return_types: list of return types: 'magnitude', 'phase', or 'complex'.
297308
normalize: if True, normalize the output.
298-
299309
Returns:
300310
Concatenated radial Fourier features.
301-
302311
Example:
303312
>>> transform = RadialFourierFeatures3D(n_bins_list=[32, 64, 128])
304313
>>> image = torch.randn(1, 128, 128, 96)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ exclude = '''
3636

3737
[tool.pycln]
3838
all = true
39-
exclude = "monai/bundle/__main__.py"
39+
exclude = "monai/bundle/__main__.py|monai/transforms/__init__.py"
4040

4141
[tool.ruff]
4242
line-length = 133

0 commit comments

Comments
 (0)