diff --git a/tests/test_cait.py b/tests/test_cait.py index 79d8c3f..41fca49 100644 --- a/tests/test_cait.py +++ b/tests/test_cait.py @@ -9,6 +9,13 @@ def test_forward(): m(torch.randn(1, 3, 224, 224)) +def test_resize_pe(): + m = CaiT.from_config("xxs_24", 224) + m(torch.randn(1, 3, 224, 224)) + m.resize_pe(256) + m(torch.randn(1, 3, 256, 256)) + + def test_from_pretrained(): m = CaiT.from_config("xxs_24", 224, True).eval() x = torch.randn(1, 3, 224, 224) diff --git a/tests/test_vit.py b/tests/test_vit.py index add5ae8..47a2257 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -4,6 +4,11 @@ from vision_toolbox.backbones import ViT +def test_forward(): + m = ViT.from_config("Ti", 16, 224) + m(torch.randn(1, 3, 224, 224)) + + def test_resize_pe(): m = ViT.from_config("Ti", 16, 224) m(torch.randn(1, 3, 224, 224)) diff --git a/vision_toolbox/backbones/cait.py b/vision_toolbox/backbones/cait.py index 2e68119..88346a0 100644 --- a/vision_toolbox/backbones/cait.py +++ b/vision_toolbox/backbones/cait.py @@ -150,6 +150,15 @@ def forward(self, imgs: Tensor) -> Tensor: cls_token = block(patches, cls_token) return self.norm(cls_token.squeeze(1)) + @torch.no_grad() + def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None: + old_size = int(self.pe.shape[1] ** 0.5) + new_size = size // self.patch_embed.weight.shape[2] + pe = self.pe.unflatten(1, (old_size, old_size)).permute(0, 3, 1, 2) + pe = F.interpolate(pe, (new_size, new_size), mode=interpolation_mode) + pe = pe.permute(0, 2, 3, 1).flatten(1, 2) + self.pe = nn.Parameter(pe) + @staticmethod def from_config(variant: str, img_size: int, pretrained: bool = False) -> CaiT: variant, sa_depth = variant.split("_")