Skip to content

Commit 68584e5

Browse files
committed
Add grid_size check and fix typing
Signed-off-by: NabJa <[email protected]>
1 parent 3ef2188 commit 68584e5

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

monai/networks/blocks/pos_embed_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,11 @@ def build_fourier_position_embedding(
5353
"""
5454

5555
to_tuple = _ntuple(spatial_dims)
56-
grid_size = to_tuple(grid_size)
56+
grid_size_t = to_tuple(grid_size)
57+
if len(grid_size_t) != spatial_dims:
58+
raise ValueError(
59+
f"Length of grid_size must be the same as spatial_dims. Got len(grid_size)={len(grid_size_t)}, should be {spatial_dims}."
60+
)
5761

5862
if embed_dim % (2 * spatial_dims) != 0:
5963
raise AssertionError(
@@ -73,7 +77,7 @@ def build_fourier_position_embedding(
7377
gaussians = torch.normal(0.0, 1.0, (embed_dim // 2, spatial_dims))
7478
gaussians = gaussians * scales_tensor
7579

76-
position_indeces = [torch.linspace(0, 1, x) for x in grid_size]
80+
position_indeces = [torch.linspace(0, 1, x) for x in grid_size_t]
7781
positions = torch.stack(torch.meshgrid(*position_indeces, indexing="ij"), dim=-1)
7882
positions = positions.flatten(end_dim=-2)
7983

0 commit comments

Comments
 (0)