2525
2626from monai .config import NdarrayOrTensor
2727from monai .transforms .transform import Transform
28- from monai .utils import convert_data_type , optional_import
28+ from monai .utils import convert_data_type
2929
3030# Optional imports for type checking
31- spatial , _ = optional_import ("monai.utils" , name = "spatial" )
31+ # spatial, _ = optional_import("monai.utils", name="spatial") # Commented out unused import
3232
3333
3434class RadialFourier3D (Transform ):
@@ -59,6 +59,10 @@ class RadialFourier3D(Transform):
5959 >>> transform = RadialFourier3D(radial_bins=64, return_magnitude=True)
6060 >>> image = torch.randn(1, 128, 128, 96) # Batch, Height, Width, Depth
6161 >>> result = transform(image) # Shape: (1, 64)
62+
63+ Raises:
64+ ValueError: If max_frequency not in (0.0, 1.0], radial_bins < 1, or both
65+ return_magnitude and return_phase are False.
6266 """
6367
6468 def __init__ (
@@ -89,12 +93,13 @@ def __init__(
8993 if not return_magnitude and not return_phase :
9094 raise ValueError ("At least one of return_magnitude or return_phase must be True" )
9195
92- def _compute_radial_coordinates (self , shape : tuple [int , ...]) -> torch .Tensor :
96+ def _compute_radial_coordinates (self , shape : tuple [int , ...], device : torch . device = None ) -> torch .Tensor :
9397 """
9498 Compute radial distance from frequency domain center.
9599
96100 Args:
97101 shape: spatial dimensions (D, H, W) or (H, W, D) depending on dims order.
102+ device: device to create tensor on.
98103
99104 Returns:
100105 Tensor of same spatial shape with radial distances.
@@ -103,7 +108,7 @@ def _compute_radial_coordinates(self, shape: tuple[int, ...]) -> torch.Tensor:
103108 coords = []
104109 for dim_size in shape :
105110 # Create frequency range from -0.5 to 0.5
106- freq = torch .fft .fftfreq (dim_size )
111+ freq = torch .fft .fftfreq (dim_size , device = device )
107112 coords .append (freq )
108113
109114 # Create meshgrid and compute radial distance
@@ -176,7 +181,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
176181 spectrum = spectrum / norm_factor
177182
178183 # Compute radial coordinates
179- radial_coords = self ._compute_radial_coordinates (spatial_shape )
184+ radial_coords = self ._compute_radial_coordinates (spatial_shape , device = spectrum . device )
180185
181186 # Apply radial binning if requested
182187 if self .radial_bins is not None :
@@ -217,7 +222,8 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
217222 if self .max_frequency < 1.0 :
218223 freq_mask = radial_coords <= (self .max_frequency * 0.5 )
219224 # Expand mask to match spectrum dimensions
220- for _ in range (len (self .spatial_dims )):
225+ n_non_spatial = len (spectrum .shape ) - len (spatial_shape )
226+ for _ in range (n_non_spatial ):
221227 freq_mask = freq_mask .unsqueeze (0 )
222228 spectrum = spectrum * freq_mask
223229
0 commit comments