From 057709fb535dea037c6555f73f3605684f5cf380 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 20 Aug 2023 15:02:25 +0800 Subject: [PATCH] minor changes --- vision_toolbox/backbones/cait.py | 9 ++------- vision_toolbox/backbones/deit.py | 6 +++--- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/vision_toolbox/backbones/cait.py b/vision_toolbox/backbones/cait.py index 88346a0..c7e18f4 100644 --- a/vision_toolbox/backbones/cait.py +++ b/vision_toolbox/backbones/cait.py @@ -10,7 +10,7 @@ from torch import Tensor, nn from .base import _act, _norm -from .vit import MHA, ViTBlock +from .vit import MHA, ViT, ViTBlock # basically attention pooling @@ -152,12 +152,7 @@ def forward(self, imgs: Tensor) -> Tensor: @torch.no_grad() def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None: - old_size = int(self.pe.shape[1] ** 0.5) - new_size = size // self.patch_embed.weight.shape[2] - pe = self.pe.unflatten(1, (old_size, old_size)).permute(0, 3, 1, 2) - pe = F.interpolate(pe, (new_size, new_size), mode=interpolation_mode) - pe = pe.permute(0, 2, 3, 1).flatten(1, 2) - self.pe = nn.Parameter(pe) + ViT.resize_pe(self, size, interpolation_mode) @staticmethod def from_config(variant: str, img_size: int, pretrained: bool = False) -> CaiT: diff --git a/vision_toolbox/backbones/deit.py b/vision_toolbox/backbones/deit.py index 27fe7f1..e70157c 100644 --- a/vision_toolbox/backbones/deit.py +++ b/vision_toolbox/backbones/deit.py @@ -40,13 +40,13 @@ def __init__( self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model)) def forward(self, imgs: Tensor) -> Tensor: - out = self.patch_embed(imgs).flatten(2).transpose(1, 2) + self.pe # (N, C, H, W) -> (N, H*W, C) - out = torch.cat([self.cls_token, self.dist_token, out], 1) + out = self.patch_embed(imgs).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C) + out = torch.cat([self.cls_token, self.dist_token, out + self.pe], 1) out = self.layers(out) return self.norm(out[:, :2]).mean(1) @staticmethod - def from_config(variant: str, img_size: int, version: bool = False, pretrained: bool = False) -> DeiT: + def from_config(variant: str, img_size: int, pretrained: bool = False) -> DeiT: variant, patch_size = variant.split("_") d_model, depth, n_heads = dict(