Skip to content

Commit

Permalink
add resize pe
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 20, 2023
1 parent 59d487c commit 592d41b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tests/test_cait.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ def test_forward():
m(torch.randn(1, 3, 224, 224))


def test_resize_pe():
m = CaiT.from_config("xxs_24", 224)
m(torch.randn(1, 3, 224, 224))
m.resize_pe(256)
m(torch.randn(1, 3, 256, 256))


def test_from_pretrained():
m = CaiT.from_config("xxs_24", 224, True).eval()
x = torch.randn(1, 3, 224, 224)
Expand Down
5 changes: 5 additions & 0 deletions tests/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
from vision_toolbox.backbones import ViT


def test_forward():
m = ViT.from_config("Ti", 16, 224)
m(torch.randn(1, 3, 224, 224))


def test_resize_pe():
m = ViT.from_config("Ti", 16, 224)
m(torch.randn(1, 3, 224, 224))
Expand Down
9 changes: 9 additions & 0 deletions vision_toolbox/backbones/cait.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,15 @@ def forward(self, imgs: Tensor) -> Tensor:
cls_token = block(patches, cls_token)
return self.norm(cls_token.squeeze(1))

@torch.no_grad()
def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None:
old_size = int(self.pe.shape[1] ** 0.5)
new_size = size // self.patch_embed.weight.shape[2]
pe = self.pe.unflatten(1, (old_size, old_size)).permute(0, 3, 1, 2)
pe = F.interpolate(pe, (new_size, new_size), mode=interpolation_mode)
pe = pe.permute(0, 2, 3, 1).flatten(1, 2)
self.pe = nn.Parameter(pe)

@staticmethod
def from_config(variant: str, img_size: int, pretrained: bool = False) -> CaiT:
variant, sa_depth = variant.split("_")
Expand Down

0 comments on commit 592d41b

Please sign in to comment.