@@ -78,6 +78,12 @@ def build_sincos2d_pos_embed(
7878 return pos_emb .to (dtype = dtype )
7979
8080
81+ def swap_shape_xy (seq : List [int ]) -> List [int ]:
82+ if len (seq ) < 2 :
83+ return seq
84+ return [seq [1 ], seq [0 ]] + seq [2 :]
85+
86+
8187def build_fourier_pos_embed (
8288 feat_shape : List [int ],
8389 bands : Optional [torch .Tensor ] = None ,
@@ -134,6 +140,11 @@ def build_fourier_pos_embed(
134140 if dtype is None :
135141 dtype = bands .dtype
136142
143+ if grid_indexing == 'xy' :
144+ feat_shape = swap_shape_xy (feat_shape )
145+ if ref_feat_shape is not None :
146+ ref_feat_shape = swap_shape_xy (ref_feat_shape )
147+
137148 if in_pixels :
138149 t = [
139150 torch .linspace (- 1. , 1. , steps = s , device = device , dtype = torch .float32 )
@@ -516,15 +527,16 @@ def init_random_2d_freqs(
516527@torch .fx .wrap
517528@register_notrace_function
518529def get_mixed_grid (
519- height : int ,
520- width : int ,
530+ shape : List [int ],
521531 grid_indexing : str = 'ij' ,
522532 device : Optional [torch .device ] = None ,
523533 dtype : torch .dtype = torch .float32 ,
524534) -> Tuple [torch .Tensor , torch .Tensor ]:
535+ if grid_indexing == 'xy' :
536+ shape = swap_shape_xy (shape )
525537 x_pos , y_pos = torch .meshgrid (
526- torch .arange (height , dtype = dtype , device = device ),
527- torch .arange (width , dtype = dtype , device = device ),
538+ torch .arange (shape [ 0 ] , dtype = dtype , device = device ),
539+ torch .arange (shape [ 1 ] , dtype = dtype , device = device ),
528540 indexing = grid_indexing ,
529541 )
530542 t_x = x_pos .flatten ()
@@ -599,8 +611,7 @@ def __init__(
599611 if feat_shape is not None :
600612 # cache pre-computed grid
601613 t_x , t_y = get_mixed_grid (
602- feat_shape [0 ],
603- feat_shape [1 ],
614+ feat_shape ,
604615 grid_indexing = grid_indexing ,
605616 device = self .freqs .device
606617 )
@@ -620,8 +631,7 @@ def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor:
620631 """
621632 if shape is not None :
622633 t_x , t_y = get_mixed_grid (
623- shape [0 ],
624- shape [1 ],
634+ shape ,
625635 grid_indexing = self .grid_indexing ,
626636 device = self .freqs .device
627637 )
0 commit comments