diff --git a/vision_toolbox/backbones/vit.py b/vision_toolbox/backbones/vit.py index 829cf21..ae54506 100644 --- a/vision_toolbox/backbones/vit.py +++ b/vision_toolbox/backbones/vit.py @@ -81,6 +81,7 @@ def __init__( norm: _norm = partial(nn.LayerNorm, eps=1e-6), act: _act = nn.GELU, ) -> None: + assert img_size % patch_size == 0 super().__init__() self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size) self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) if cls_token else None