diff --git a/vision_toolbox/backbones/deit.py b/vision_toolbox/backbones/deit.py index 3d39d6b..27fe7f1 100644 --- a/vision_toolbox/backbones/deit.py +++ b/vision_toolbox/backbones/deit.py @@ -12,10 +12,10 @@ from ..components import LayerScale from .base import _act, _norm -from .vit import ViTBlock +from .vit import ViT, ViTBlock -class DeiT(nn.Module): +class DeiT(ViT): def __init__( self, d_model: int, @@ -31,37 +31,20 @@ def __init__( norm: _norm = partial(nn.LayerNorm, eps=1e-6), act: _act = nn.GELU, ) -> None: - assert img_size % patch_size == 0 - super().__init__() - self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size) - self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) + # fmt: off + super().__init__( + d_model, depth, n_heads, patch_size, img_size, True, bias, mlp_ratio, + dropout, layer_scale_init, stochastic_depth, norm, act + ) + # fmt: on self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model)) - self.pe = nn.Parameter(torch.empty(1, (img_size // patch_size) ** 2 + 2, d_model)) - nn.init.normal_(self.pe, 0, 0.02) - - self.layers = nn.Sequential() - for _ in range(depth): - block = ViTBlock(d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm, act) - self.layers.append(block) - - self.norm = norm(d_model) def forward(self, imgs: Tensor) -> Tensor: - out = self.patch_embed(imgs).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C) - out = torch.cat([self.cls_token, self.dist_token, out], 1) + self.pe + out = self.patch_embed(imgs).flatten(2).transpose(1, 2) + self.pe # (N, C, H, W) -> (N, H*W, C) + out = torch.cat([self.cls_token, self.dist_token, out], 1) out = self.layers(out) return self.norm(out[:, :2]).mean(1) - @torch.no_grad() - def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None: - pe = self.pe[:, 2:] - old_size = int(pe.shape[1] ** 0.5) - new_size = size // self.patch_embed.weight.shape[2] - pe = pe.unflatten(1, (old_size, old_size)).permute(0, 3, 1, 2) - pe = F.interpolate(pe, (new_size, new_size), mode=interpolation_mode) - pe = pe.permute(0, 2, 3, 1).flatten(1, 2) - self.pe = nn.Parameter(torch.cat((self.pe[:, :2], pe), 1)) - @staticmethod def from_config(variant: str, img_size: int, version: bool = False, pretrained: bool = False) -> DeiT: variant, patch_size = variant.split("_") @@ -97,12 +80,16 @@ def copy_(m: nn.Linear | nn.LayerNorm, prefix: str): m.bias.copy_(state_dict.pop(prefix + ".bias")) copy_(self.patch_embed, "patch_embed.proj") + pe = state_dict.pop("pos_embed") + self.pe.copy_(pe[:, -self.pe.shape[1] :]) + self.cls_token.copy_(state_dict.pop("cls_token")) + self.cls_token.add_(pe[:, 0]) if self.dist_token is not None: self.dist_token.copy_(state_dict.pop("dist_token")) + self.dist_token.add_(pe[:, 1]) state_dict.pop("head_dist.weight") state_dict.pop("head_dist.bias") - self.pe.copy_(state_dict.pop("pos_embed")) for i, block in enumerate(self.layers): block: ViTBlock