From d9c90071f52c5dfb9c26a0fc0b122e3d6ebf5ea7 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 20 Aug 2023 12:14:03 +0800 Subject: [PATCH] re-use ViT modules --- vision_toolbox/backbones/cait.py | 84 ++++++++++++-------------------- vision_toolbox/backbones/vit.py | 5 +- 2 files changed, 35 insertions(+), 54 deletions(-) diff --git a/vision_toolbox/backbones/cait.py b/vision_toolbox/backbones/cait.py index 8696e1b..9415597 100644 --- a/vision_toolbox/backbones/cait.py +++ b/vision_toolbox/backbones/cait.py @@ -1,29 +1,20 @@ # https://arxiv.org/abs/2103.17239 # https://github.com/facebookresearch/deit + from __future__ import annotations +from functools import partial + import torch import torch.nn.functional as F from torch import Tensor, nn -from ..components import LayerScale, StochasticDepth from .base import _act, _norm -from .vit import MLP +from .vit import MHA, ViTBlock # basically attention pooling -class ClassAttention(nn.Module): - def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float = 0.0) -> None: - super().__init__() - self.q_proj = nn.Linear(d_model, d_model, bias) - self.k_proj = nn.Linear(d_model, d_model, bias) - self.v_proj = nn.Linear(d_model, d_model, bias) - self.out_proj = nn.Linear(d_model, d_model, bias) - - self.n_heads = n_heads - self.dropout = dropout - self.scale = (d_model // n_heads) ** (-0.5) - +class ClassAttention(MHA): def forward(self, x: Tensor) -> None: q = self.q_proj(x[:, 0]).unflatten(-1, (self.n_heads, -1)).unsqueeze(2) # (B, n_heads, 1, head_dim) k = self.k_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) # (B, n_heads, L, head_dim) @@ -39,21 +30,15 @@ def forward(self, x: Tensor) -> None: # does not support flash attention -class TalkingHeadAttention(nn.Module): +class TalkingHeadAttention(MHA): def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float = 0.0) -> None: - super().__init__() - self.q_proj = nn.Linear(d_model, d_model, bias) - self.k_proj = nn.Linear(d_model, d_model, bias) - self.v_proj = nn.Linear(d_model, d_model, bias) - self.out_proj = nn.Linear(d_model, d_model, bias) + super().__init__(d_model, n_heads, bias, dropout) self.talking_head_proj = nn.Sequential( nn.Conv2d(n_heads, n_heads, 1), # impl as 1x1 conv to avoid permutating data nn.Softmax(-1), nn.Conv2d(n_heads, n_heads, 1), nn.Dropout(dropout), ) - self.n_heads = n_heads - self.scale = (d_model // n_heads) ** (-0.5) def forward(self, x: Tensor) -> Tensor: q = self.q_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) # (B, n_heads, L, head_dim) @@ -67,7 +52,7 @@ def forward(self, x: Tensor) -> Tensor: return out -class CaiTCABlock(nn.Module): +class CaiTCABlock(ViTBlock): def __init__( self, d_model: int, @@ -80,18 +65,17 @@ def __init__( norm: _norm = nn.LayerNorm, act: _act = nn.GELU, ) -> None: - super().__init__() - self.mha = nn.Sequential( - norm(d_model), - ClassAttention(d_model, n_heads, bias, dropout), - LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(), - StochasticDepth(stochastic_depth), - ) - self.mlp = nn.Sequential( - norm(d_model), - MLP(d_model, int(d_model * mlp_ratio), dropout, act), - LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(), - StochasticDepth(stochastic_depth), + super().__init__( + d_model, + n_heads, + bias, + mlp_ratio, + dropout, + layer_scale_init, + stochastic_depth, + norm, + act, + partial(ClassAttention, d_model, n_heads, bias, dropout), ) def forward(self, x: Tensor, cls_token: Tensor) -> Tensor: @@ -100,7 +84,7 @@ def forward(self, x: Tensor, cls_token: Tensor) -> Tensor: return cls_token -class CaiTSABlock(nn.Module): +class CaiTSABlock(ViTBlock): def __init__( self, d_model: int, @@ -113,24 +97,18 @@ def __init__( norm: _norm = nn.LayerNorm, act: _act = nn.GELU, ) -> None: - super().__init__() - self.mha = nn.Sequential( - norm(d_model), - TalkingHeadAttention(d_model, n_heads, bias, dropout), - LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(), - StochasticDepth(stochastic_depth), + super().__init__( + d_model, + n_heads, + bias, + mlp_ratio, + dropout, + layer_scale_init, + stochastic_depth, + norm, + act, + partial(TalkingHeadAttention, d_model, n_heads, bias, dropout), ) - self.mlp = nn.Sequential( - norm(d_model), - MLP(d_model, int(d_model * mlp_ratio), dropout, act), - LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(), - StochasticDepth(stochastic_depth), - ) - - def forward(self, x: Tensor) -> Tensor: - x = x + self.mha(x) - x = x + self.mlp(x) - return x class CaiT(nn.Module): diff --git a/vision_toolbox/backbones/vit.py b/vision_toolbox/backbones/vit.py index 9f8564a..778d39c 100644 --- a/vision_toolbox/backbones/vit.py +++ b/vision_toolbox/backbones/vit.py @@ -67,11 +67,14 @@ def __init__( stochastic_depth: float = 0.0, norm: _norm = partial(nn.LayerNorm, eps=1e-6), act: _act = nn.GELU, + attention: type[nn.Module] | None = None, ) -> None: + if attention is None: + attention = partial(MHA, d_model, n_heads, bias, dropout) super().__init__() self.mha = nn.Sequential( norm(d_model), - MHA(d_model, n_heads, bias, dropout), + attention(), LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(), StochasticDepth(stochastic_depth), )