Skip to content

Commit 8a375b6

Browse files
committed
Fixed overrestrictive embed_dim check, improved code style
Signed-off-by: NabJa <[email protected]>
1 parent f0ff88f commit 8a375b6

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

monai/networks/blocks/pos_embed_utils.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919
import torch.nn as nn
2020

21-
__all__ = ["build_sincos_position_embedding", "build_fourier_position_embedding"]
21+
__all__ = ["build_fourier_position_embedding", "build_sincos_position_embedding"]
2222

2323

2424
# From PyTorch internals
@@ -36,16 +36,17 @@ def build_fourier_position_embedding(
3636
grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, scales: Union[float, List[float]] = 1.0
3737
) -> torch.nn.Parameter:
3838
"""
39-
Builds a (Anistropic) Fourier Feature based positional encoding based on the given grid size, embed dimension,
39+
Builds a (Anistropic) Fourier feature position embedding based on the given grid size, embed dimension,
4040
spatial dimensions, and scales. The scales control the variance of the Fourier features, higher values make distant
4141
points more distinguishable.
42+
Position embedding is made anistropic by allowing setting different scales for each spatial dimension.
4243
Reference: https://arxiv.org/abs/2509.02488
4344
4445
Args:
45-
grid_size (List[int]): The size of the grid in each spatial dimension.
46+
grid_size (int | List[int]): The size of the grid in each spatial dimension.
4647
embed_dim (int): The dimension of the embedding.
4748
spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D).
48-
scales (List[float]): The scale for every spatial dimension. If a single float is provided,
49+
scales (float | List[float]): The scale for every spatial dimension. If a single float is provided,
4950
the same scale is used for all dimensions.
5051
5152
Returns:
@@ -57,10 +58,8 @@ def build_fourier_position_embedding(
5758
if len(grid_size_t) != spatial_dims:
5859
raise ValueError(f"Length of grid_size ({len(grid_size_t)}) must be the same as spatial_dims.")
5960

60-
if embed_dim % (2 * spatial_dims) != 0:
61-
raise AssertionError(
62-
f"Embed dimension must be divisible by {2 * spatial_dims} for {spatial_dims}D Fourier feature position embedding"
63-
)
61+
if embed_dim % 2 != 0:
62+
raise ValueError("embed_dim must be even for Fourier position embedding")
6463

6564
# Ensure scales is a tensor of shape (spatial_dims,)
6665
if isinstance(scales, float):
@@ -72,11 +71,10 @@ def build_fourier_position_embedding(
7271
else:
7372
raise TypeError(f"scales must be float or list of floats, got {type(scales)}")
7473

75-
gaussians = torch.normal(0.0, 1.0, (embed_dim // 2, spatial_dims))
76-
gaussians = gaussians * scales_tensor
74+
gaussians = torch.randn(embed_dim // 2, spatial_dims, dtype=torch.float32) * scales_tensor
7775

78-
position_indeces = [torch.linspace(0, 1, x) for x in grid_size_t]
79-
positions = torch.stack(torch.meshgrid(*position_indeces, indexing="ij"), dim=-1)
76+
position_indices = [torch.linspace(0, 1, x, dtype=torch.float32) for x in grid_size_t]
77+
positions = torch.stack(torch.meshgrid(*position_indices, indexing="ij"), dim=-1)
8078
positions = positions.flatten(end_dim=-2)
8179

8280
x_proj = (2.0 * torch.pi * positions) @ gaussians.T

0 commit comments

Comments
 (0)