Skip to content

Commit cb0546d

Browse files
FEAT: Add 3D Radial Fourier Transform for medical image frequency analysis
- Implement RadialFourier3D transform for radial frequency analysis - Add RadialFourierFeatures3D for multi-scale feature extraction - Include comprehensive tests (20/20 passing) - Support for magnitude, phase, and complex outputs - Handle anisotropic resolution in medical imaging - Fix numpy compatibility and spatial dimension handling Signed-off-by: Hitendrasinh Rathod<[email protected]> Signed-off-by: Hitendrasinh Rathod <[email protected]>
1 parent e5417c8 commit cb0546d

File tree

5 files changed

+555
-2
lines changed

5 files changed

+555
-2
lines changed

monai/transforms/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,9 +376,9 @@
376376
SignalRandAddSquarePulsePartial,
377377
SignalRandDrop,
378378
SignalRandScale,
379-
SignalRandShift,
380-
SignalRemoveFrequency,
379+
SignalRemoveFrequency
381380
)
381+
from .signal import RadialFourier3D, RadialFourierFeatures3D
382382
from .signal.dictionary import SignalFillEmptyd, SignalFillEmptyD, SignalFillEmptyDict
383383
from .smooth_field.array import (
384384
RandSmoothDeform,

monai/transforms/signal/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,10 @@
88
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
11+
"""
12+
Signal processing transforms for medical imaging.
13+
"""
14+
15+
from .radial_fourier import RadialFourier3D, RadialFourierFeatures3D
16+
17+
__all__ = ["RadialFourier3D", "RadialFourierFeatures3D"]
Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
"""
12+
3D Radial Fourier Transform for medical imaging data.
13+
"""
14+
15+
from __future__ import annotations
16+
17+
import math
18+
from typing import Optional, Union
19+
20+
from collections.abc import Sequence
21+
22+
import numpy as np
23+
import torch
24+
from torch.fft import fftn, fftshift, ifftn, ifftshift
25+
26+
from monai.config import NdarrayOrTensor
27+
from monai.transforms.transform import Transform
28+
from monai.utils import convert_data_type, optional_import
29+
30+
# Optional imports for type checking
31+
spatial, _ = optional_import("monai.utils", name="spatial")
32+
33+
34+
class RadialFourier3D(Transform):
35+
"""
36+
Computes the 3D Radial Fourier Transform of medical imaging data.
37+
38+
This transform converts 3D medical images into radial frequency domain representations,
39+
which is particularly useful for handling anisotropic resolution common in medical scans
40+
(e.g., different resolution in axial vs coronal planes).
41+
42+
The radial transform provides rotation-invariant frequency analysis and can help
43+
normalize frequency representations across datasets with different acquisition parameters.
44+
45+
Args:
46+
normalize: if True, normalize the output by the number of voxels.
47+
return_magnitude: if True, return magnitude of the complex result.
48+
return_phase: if True, return phase of the complex result.
49+
radial_bins: number of radial bins for frequency aggregation. If None, returns full 3D spectrum.
50+
max_frequency: maximum normalized frequency to include (0.0 to 1.0).
51+
spatial_dims: spatial dimensions to apply transform to. Default is last three dimensions.
52+
53+
Returns:
54+
Radial Fourier transform of input data. Shape depends on parameters:
55+
- If radial_bins is None: complex tensor of same spatial shape as input
56+
- If radial_bins is set: real tensor of shape (radial_bins,) for magnitude/phase
57+
58+
Example:
59+
>>> transform = RadialFourier3D(radial_bins=64, return_magnitude=True)
60+
>>> image = torch.randn(1, 128, 128, 96) # Batch, Height, Width, Depth
61+
>>> result = transform(image) # Shape: (1, 64)
62+
"""
63+
64+
def __init__(
65+
self,
66+
normalize: bool = True,
67+
return_magnitude: bool = True,
68+
return_phase: bool = False,
69+
radial_bins: Optional[int] = None,
70+
max_frequency: float = 1.0,
71+
spatial_dims: Union[int, Sequence[int]] = (-3, -2, -1),
72+
) -> None:
73+
super().__init__()
74+
self.normalize = normalize
75+
self.return_magnitude = return_magnitude
76+
self.return_phase = return_phase
77+
self.radial_bins = radial_bins
78+
self.max_frequency = max_frequency
79+
80+
if isinstance(spatial_dims, int):
81+
spatial_dims = (spatial_dims,)
82+
self.spatial_dims = tuple(spatial_dims)
83+
84+
# Validate parameters
85+
if not 0.0 < max_frequency <= 1.0:
86+
raise ValueError(f"max_frequency must be in (0.0, 1.0], got {max_frequency}")
87+
if radial_bins is not None and radial_bins < 1:
88+
raise ValueError(f"radial_bins must be >= 1, got {radial_bins}")
89+
if not return_magnitude and not return_phase:
90+
raise ValueError("At least one of return_magnitude or return_phase must be True")
91+
92+
def _compute_radial_coordinates(self, shape: tuple[int, ...]) -> torch.Tensor:
93+
"""
94+
Compute radial distance from frequency domain center.
95+
96+
Args:
97+
shape: spatial dimensions (D, H, W) or (H, W, D) depending on dims order.
98+
99+
Returns:
100+
Tensor of same spatial shape with radial distances.
101+
"""
102+
# Create frequency coordinates for each dimension
103+
coords = []
104+
for dim_size in shape:
105+
# Create frequency range from -0.5 to 0.5
106+
freq = torch.fft.fftfreq(dim_size)
107+
coords.append(freq)
108+
109+
# Create meshgrid and compute radial distance
110+
mesh = torch.meshgrid(coords, indexing="ij")
111+
radial = torch.sqrt(sum(c**2 for c in mesh))
112+
113+
return radial
114+
115+
def _compute_radial_spectrum(self, spectrum: torch.Tensor, radial_coords: torch.Tensor) -> torch.Tensor:
116+
"""
117+
Compute radial average of frequency spectrum.
118+
119+
Args:
120+
spectrum: complex frequency spectrum (flattened 1D array).
121+
radial_coords: radial distance for each frequency coordinate (flattened 1D array).
122+
123+
Returns:
124+
Radial average of spectrum (1D array of length radial_bins).
125+
"""
126+
if self.radial_bins is None:
127+
return spectrum
128+
129+
# Bin radial coordinates
130+
max_r = self.max_frequency * 0.5 # Maximum normalized frequency
131+
bin_edges = torch.linspace(0, max_r, self.radial_bins + 1, device=spectrum.device)
132+
133+
# Initialize output
134+
result_real = torch.zeros(self.radial_bins, dtype=spectrum.real.dtype, device=spectrum.device)
135+
result_imag = torch.zeros(self.radial_bins, dtype=spectrum.imag.dtype, device=spectrum.device)
136+
137+
# Bin the frequencies - spectrum and radial_coords are both 1D
138+
for i in range(self.radial_bins):
139+
mask = (radial_coords >= bin_edges[i]) & (radial_coords < bin_edges[i + 1])
140+
if mask.any():
141+
# spectrum is 1D, so we can index it directly
142+
result_real[i] = spectrum.real[mask].mean()
143+
result_imag[i] = spectrum.imag[mask].mean()
144+
145+
# Combine real and imaginary parts
146+
result = torch.complex(result_real, result_imag)
147+
148+
return result
149+
150+
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
151+
"""
152+
Apply 3D Radial Fourier Transform to input data.
153+
154+
Args:
155+
img: input medical image data. Expected shape: (..., D, H, W)
156+
where D, H, W are spatial dimensions.
157+
158+
Returns:
159+
Transformed data in radial frequency domain.
160+
"""
161+
# Convert to tensor if needed
162+
img_tensor, *_ = convert_data_type(img, torch.Tensor)
163+
# Get spatial dimensions
164+
spatial_shape = tuple(img_tensor.shape[d] for d in self.spatial_dims)
165+
if len(spatial_shape) != 3:
166+
raise ValueError(f"Expected 3 spatial dimensions, got {len(spatial_shape)}")
167+
168+
# Compute 3D FFT
169+
# Shift zero frequency to center and compute FFT
170+
spectrum = fftn(ifftshift(img_tensor, dim=self.spatial_dims), dim=self.spatial_dims)
171+
spectrum = fftshift(spectrum, dim=self.spatial_dims)
172+
173+
# Normalize if requested
174+
if self.normalize:
175+
norm_factor = math.prod(spatial_shape)
176+
spectrum = spectrum / norm_factor
177+
178+
# Compute radial coordinates
179+
radial_coords = self._compute_radial_coordinates(spatial_shape)
180+
181+
# Apply radial binning if requested
182+
if self.radial_bins is not None:
183+
# Reshape for radial processing
184+
orig_shape = spectrum.shape
185+
# Move spatial dimensions to end for processing
186+
spatial_indices = [d % len(orig_shape) for d in self.spatial_dims]
187+
non_spatial_indices = [i for i in range(len(orig_shape)) if i not in spatial_indices]
188+
189+
# Reshape to (non_spatial..., spatial_prod)
190+
flat_shape = (*[orig_shape[i] for i in non_spatial_indices], -1)
191+
spectrum_flat = spectrum.moveaxis(spatial_indices, [-3, -2, -1]).reshape(flat_shape)
192+
radial_flat = radial_coords.flatten()
193+
194+
# Get non-spatial dimensions (batch, channel, etc.)
195+
non_spatial_dims = spectrum_flat.shape[:-1]
196+
spatial_size = spectrum_flat.shape[-1]
197+
198+
# Reshape to 2D: (non_spatial_product, spatial_size)
199+
non_spatial_product = 1
200+
for dim in non_spatial_dims:
201+
non_spatial_product *= dim
202+
203+
spectrum_2d = spectrum_flat.reshape(non_spatial_product, spatial_size)
204+
205+
# Process each non-spatial element (batch/channel combination)
206+
results = []
207+
for i in range(non_spatial_product):
208+
elem_spectrum = spectrum_2d[i] # Get spatial frequencies for this batch/channel
209+
radial_result = self._compute_radial_spectrum(elem_spectrum, radial_flat)
210+
results.append(radial_result)
211+
212+
# Combine results and reshape back
213+
spectrum = torch.stack(results, dim=0)
214+
spectrum = spectrum.reshape(*non_spatial_dims, self.radial_bins)
215+
else:
216+
# Apply frequency mask if max_frequency < 1.0
217+
if self.max_frequency < 1.0:
218+
freq_mask = radial_coords <= (self.max_frequency * 0.5)
219+
# Expand mask to match spectrum dimensions
220+
for _ in range(len(self.spatial_dims)):
221+
freq_mask = freq_mask.unsqueeze(0)
222+
spectrum = spectrum * freq_mask
223+
224+
# Extract magnitude and/or phase as requested
225+
output = None
226+
if self.return_magnitude:
227+
magnitude = torch.abs(spectrum)
228+
output = magnitude if output is None else torch.cat([output, magnitude], dim=-1)
229+
230+
if self.return_phase:
231+
phase = torch.angle(spectrum)
232+
output = phase if output is None else torch.cat([output, phase], dim=-1)
233+
234+
# Convert back to original data type
235+
output, *_ = convert_data_type(output, type(img))
236+
237+
return output
238+
239+
def inverse(self, radial_data: NdarrayOrTensor, original_shape: tuple[int, ...]) -> NdarrayOrTensor:
240+
"""
241+
Inverse transform from radial frequency domain to spatial domain.
242+
243+
Args:
244+
radial_data: data in radial frequency domain.
245+
original_shape: original spatial shape (D, H, W).
246+
247+
Returns:
248+
Reconstructed spatial data.
249+
250+
Note:
251+
This is an approximate inverse when radial_bins is used.
252+
"""
253+
if self.radial_bins is None:
254+
# Direct inverse FFT
255+
radial_tensor, *_ = convert_data_type(radial_data, torch.Tensor)
256+
257+
# Separate magnitude and phase if needed
258+
if self.return_magnitude and self.return_phase:
259+
# Assuming they were concatenated along last dimension
260+
split_idx = radial_tensor.shape[-1] // 2
261+
magnitude = radial_tensor[..., :split_idx]
262+
phase = radial_tensor[..., split_idx:]
263+
radial_tensor = torch.complex(magnitude * torch.cos(phase), magnitude * torch.sin(phase))
264+
265+
# Apply inverse FFT
266+
result = ifftn(ifftshift(radial_tensor, dim=self.spatial_dims), dim=self.spatial_dims)
267+
result = fftshift(result, dim=self.spatial_dims)
268+
269+
if self.normalize:
270+
result = result * math.prod(original_shape)
271+
272+
result, *_ = convert_data_type(result.real, type(radial_data))
273+
return result
274+
275+
else:
276+
raise NotImplementedError(
277+
"Exact inverse transform not available for radially binned data. "
278+
"Consider using radial_bins=None for applications requiring inversion."
279+
)
280+
281+
282+
class RadialFourierFeatures3D(Transform):
283+
"""
284+
Extract radial Fourier features for medical image analysis.
285+
286+
Computes multiple radial Fourier transforms with different parameters
287+
to create a comprehensive frequency feature representation.
288+
289+
Args:
290+
n_bins_list: list of radial bin counts to compute.
291+
return_types: list of return types: 'magnitude', 'phase', or 'complex'.
292+
normalize: if True, normalize the output.
293+
294+
Returns:
295+
Concatenated radial Fourier features.
296+
297+
Example:
298+
>>> transform = RadialFourierFeatures3D(n_bins_list=[32, 64, 128])
299+
>>> image = torch.randn(1, 128, 128, 96)
300+
>>> features = transform(image) # Shape: (1, 32+64+128=224)
301+
"""
302+
303+
def __init__(
304+
self,
305+
n_bins_list: Sequence[int] = (32, 64, 128),
306+
return_types: Sequence[str] = ("magnitude",),
307+
normalize: bool = True,
308+
) -> None:
309+
super().__init__()
310+
self.n_bins_list = n_bins_list
311+
self.return_types = return_types
312+
self.normalize = normalize
313+
314+
# Create individual transforms
315+
self.transforms = []
316+
for n_bins in n_bins_list:
317+
for return_type in return_types:
318+
transform = RadialFourier3D(
319+
normalize=normalize,
320+
return_magnitude=(return_type in ["magnitude", "complex"]),
321+
return_phase=(return_type in ["phase", "complex"]),
322+
radial_bins=n_bins,
323+
)
324+
self.transforms.append(transform)
325+
326+
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
327+
"""Extract radial Fourier features."""
328+
features = []
329+
for transform in self.transforms:
330+
feat = transform(img)
331+
features.append(feat)
332+
333+
# Concatenate along last dimension
334+
if features:
335+
# Convert all features to tensors if any are numpy arrays
336+
features_tensors = []
337+
for feat in features:
338+
if isinstance(feat, np.ndarray):
339+
features_tensors.append(torch.from_numpy(feat))
340+
else:
341+
features_tensors.append(feat)
342+
output = torch.cat(features_tensors, dim=-1)
343+
else:
344+
output = img
345+
346+
# Convert to original type if needed
347+
if isinstance(img, np.ndarray):
348+
output = output.cpu().numpy()
349+
350+
return output

0 commit comments

Comments
 (0)