Skip to content

Commit 28aefc2

Browse files
committed
Fix type and flake8 errors
Signed-off-by: NabJa <[email protected]>
1 parent 6c885a1 commit 28aefc2

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

monai/networks/blocks/pos_embed_utils.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,23 @@ def build_fourier_position_embedding(
6060
f"Embed dimension must be divisible by {2 * spatial_dims} for {spatial_dims}D Fourier feature position embedding"
6161
)
6262

63-
scales: torch.Tensor = torch.as_tensor(scales, dtype=torch.float)
64-
if scales.ndim > 1 and scales.ndim != spatial_dims:
65-
raise ValueError("Scales must be either a float or a list of floats with length equal to spatial_dims")
66-
if scales.ndim == 0:
67-
scales = scales.repeat(spatial_dims)
63+
# Ensure scales is a tensor of shape (spatial_dims,)
64+
if isinstance(scales, float):
65+
scales_tensor = torch.full((spatial_dims,), scales, dtype=torch.float)
66+
elif isinstance(scales, (list, tuple)):
67+
if len(scales) != spatial_dims:
68+
raise ValueError(
69+
f"Length of scales {len(scales)} does not match spatial_dims {spatial_dims}"
70+
)
71+
scales_tensor = torch.tensor(scales, dtype=torch.float)
72+
else:
73+
raise TypeError(f"scales must be float or list of floats, got {type(scales)}")
6874

6975
gaussians = torch.normal(0.0, 1.0, (embed_dim // 2, spatial_dims))
70-
gaussians = gaussians * scales
76+
gaussians = gaussians * scales_tensor
7177

72-
positions = [torch.linspace(0, 1, x) for x in grid_size]
73-
positions = torch.stack(torch.meshgrid(*positions, indexing="ij"), dim=-1)
78+
position_indeces = [torch.linspace(0, 1, x) for x in grid_size]
79+
positions = torch.stack(torch.meshgrid(*position_indeces, indexing="ij"), dim=-1)
7480
positions = positions.flatten(end_dim=-2)
7581

7682
x_proj = (2.0 * torch.pi * positions) @ gaussians.T

0 commit comments

Comments
 (0)