1818import torch
1919import torch .nn as nn
2020
21- __all__ = ["build_sincos_position_embedding " , "build_fourier_position_embedding " ]
21+ __all__ = ["build_fourier_position_embedding " , "build_sincos_position_embedding " ]
2222
2323
2424# From PyTorch internals
@@ -36,16 +36,17 @@ def build_fourier_position_embedding(
3636 grid_size : Union [int , List [int ]], embed_dim : int , spatial_dims : int = 3 , scales : Union [float , List [float ]] = 1.0
3737) -> torch .nn .Parameter :
3838 """
39- Builds a (Anistropic) Fourier Feature based positional encoding based on the given grid size, embed dimension,
39+ Builds a (Anistropic) Fourier feature position embedding based on the given grid size, embed dimension,
4040 spatial dimensions, and scales. The scales control the variance of the Fourier features, higher values make distant
4141 points more distinguishable.
42+ Position embedding is made anistropic by allowing setting different scales for each spatial dimension.
4243 Reference: https://arxiv.org/abs/2509.02488
4344
4445 Args:
45- grid_size (List[int]): The size of the grid in each spatial dimension.
46+ grid_size (int | List[int]): The size of the grid in each spatial dimension.
4647 embed_dim (int): The dimension of the embedding.
4748 spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D).
48- scales (List[float]): The scale for every spatial dimension. If a single float is provided,
49+ scales (float | List[float]): The scale for every spatial dimension. If a single float is provided,
4950 the same scale is used for all dimensions.
5051
5152 Returns:
@@ -57,10 +58,8 @@ def build_fourier_position_embedding(
5758 if len (grid_size_t ) != spatial_dims :
5859 raise ValueError (f"Length of grid_size ({ len (grid_size_t )} ) must be the same as spatial_dims." )
5960
60- if embed_dim % (2 * spatial_dims ) != 0 :
61- raise AssertionError (
62- f"Embed dimension must be divisible by { 2 * spatial_dims } for { spatial_dims } D Fourier feature position embedding"
63- )
61+ if embed_dim % 2 != 0 :
62+ raise ValueError ("embed_dim must be even for Fourier position embedding" )
6463
6564 # Ensure scales is a tensor of shape (spatial_dims,)
6665 if isinstance (scales , float ):
@@ -72,11 +71,10 @@ def build_fourier_position_embedding(
7271 else :
7372 raise TypeError (f"scales must be float or list of floats, got { type (scales )} " )
7473
75- gaussians = torch .normal (0.0 , 1.0 , (embed_dim // 2 , spatial_dims ))
76- gaussians = gaussians * scales_tensor
74+ gaussians = torch .randn (embed_dim // 2 , spatial_dims , dtype = torch .float32 ) * scales_tensor
7775
78- position_indeces = [torch .linspace (0 , 1 , x ) for x in grid_size_t ]
79- positions = torch .stack (torch .meshgrid (* position_indeces , indexing = "ij" ), dim = - 1 )
76+ position_indices = [torch .linspace (0 , 1 , x , dtype = torch . float32 ) for x in grid_size_t ]
77+ positions = torch .stack (torch .meshgrid (* position_indices , indexing = "ij" ), dim = - 1 )
8078 positions = positions .flatten (end_dim = - 2 )
8179
8280 x_proj = (2.0 * torch .pi * positions ) @ gaussians .T
0 commit comments