@@ -60,17 +60,23 @@ def build_fourier_position_embedding(
6060 f"Embed dimension must be divisible by { 2 * spatial_dims } for { spatial_dims } D Fourier feature position embedding"
6161 )
6262
63- scales : torch .Tensor = torch .as_tensor (scales , dtype = torch .float )
64- if scales .ndim > 1 and scales .ndim != spatial_dims :
65- raise ValueError ("Scales must be either a float or a list of floats with length equal to spatial_dims" )
66- if scales .ndim == 0 :
67- scales = scales .repeat (spatial_dims )
63+ # Ensure scales is a tensor of shape (spatial_dims,)
64+ if isinstance (scales , float ):
65+ scales_tensor = torch .full ((spatial_dims ,), scales , dtype = torch .float )
66+ elif isinstance (scales , (list , tuple )):
67+ if len (scales ) != spatial_dims :
68+ raise ValueError (
69+ f"Length of scales { len (scales )} does not match spatial_dims { spatial_dims } "
70+ )
71+ scales_tensor = torch .tensor (scales , dtype = torch .float )
72+ else :
73+ raise TypeError (f"scales must be float or list of floats, got { type (scales )} " )
6874
6975 gaussians = torch .normal (0.0 , 1.0 , (embed_dim // 2 , spatial_dims ))
70- gaussians = gaussians * scales
76+ gaussians = gaussians * scales_tensor
7177
72- positions = [torch .linspace (0 , 1 , x ) for x in grid_size ]
73- positions = torch .stack (torch .meshgrid (* positions , indexing = "ij" ), dim = - 1 )
78+ position_indeces = [torch .linspace (0 , 1 , x ) for x in grid_size ]
79+ positions = torch .stack (torch .meshgrid (* position_indeces , indexing = "ij" ), dim = - 1 )
7480 positions = positions .flatten (end_dim = - 2 )
7581
7682 x_proj = (2.0 * torch .pi * positions ) @ gaussians .T
0 commit comments