Skip to content

Commit

Permalink
re-use ViT modules
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 20, 2023
1 parent 40241c5 commit d9c9007
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 54 deletions.
84 changes: 31 additions & 53 deletions vision_toolbox/backbones/cait.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,20 @@
# https://arxiv.org/abs/2103.17239
# https://github.com/facebookresearch/deit

from __future__ import annotations

from functools import partial

import torch
import torch.nn.functional as F
from torch import Tensor, nn

from ..components import LayerScale, StochasticDepth
from .base import _act, _norm
from .vit import MLP
from .vit import MHA, ViTBlock


# basically attention pooling
class ClassAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float = 0.0) -> None:
super().__init__()
self.q_proj = nn.Linear(d_model, d_model, bias)
self.k_proj = nn.Linear(d_model, d_model, bias)
self.v_proj = nn.Linear(d_model, d_model, bias)
self.out_proj = nn.Linear(d_model, d_model, bias)

self.n_heads = n_heads
self.dropout = dropout
self.scale = (d_model // n_heads) ** (-0.5)

class ClassAttention(MHA):
def forward(self, x: Tensor) -> None:
q = self.q_proj(x[:, 0]).unflatten(-1, (self.n_heads, -1)).unsqueeze(2) # (B, n_heads, 1, head_dim)
k = self.k_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) # (B, n_heads, L, head_dim)
Expand All @@ -39,21 +30,15 @@ def forward(self, x: Tensor) -> None:


# does not support flash attention
class TalkingHeadAttention(nn.Module):
class TalkingHeadAttention(MHA):
def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float = 0.0) -> None:
super().__init__()
self.q_proj = nn.Linear(d_model, d_model, bias)
self.k_proj = nn.Linear(d_model, d_model, bias)
self.v_proj = nn.Linear(d_model, d_model, bias)
self.out_proj = nn.Linear(d_model, d_model, bias)
super().__init__(d_model, n_heads, bias, dropout)
self.talking_head_proj = nn.Sequential(
nn.Conv2d(n_heads, n_heads, 1), # impl as 1x1 conv to avoid permutating data
nn.Softmax(-1),
nn.Conv2d(n_heads, n_heads, 1),
nn.Dropout(dropout),
)
self.n_heads = n_heads
self.scale = (d_model // n_heads) ** (-0.5)

def forward(self, x: Tensor) -> Tensor:
q = self.q_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) # (B, n_heads, L, head_dim)
Expand All @@ -67,7 +52,7 @@ def forward(self, x: Tensor) -> Tensor:
return out


class CaiTCABlock(nn.Module):
class CaiTCABlock(ViTBlock):
def __init__(
self,
d_model: int,
Expand All @@ -80,18 +65,17 @@ def __init__(
norm: _norm = nn.LayerNorm,
act: _act = nn.GELU,
) -> None:
super().__init__()
self.mha = nn.Sequential(
norm(d_model),
ClassAttention(d_model, n_heads, bias, dropout),
LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(),
StochasticDepth(stochastic_depth),
)
self.mlp = nn.Sequential(
norm(d_model),
MLP(d_model, int(d_model * mlp_ratio), dropout, act),
LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(),
StochasticDepth(stochastic_depth),
super().__init__(
d_model,
n_heads,
bias,
mlp_ratio,
dropout,
layer_scale_init,
stochastic_depth,
norm,
act,
partial(ClassAttention, d_model, n_heads, bias, dropout),
)

def forward(self, x: Tensor, cls_token: Tensor) -> Tensor:
Expand All @@ -100,7 +84,7 @@ def forward(self, x: Tensor, cls_token: Tensor) -> Tensor:
return cls_token


class CaiTSABlock(nn.Module):
class CaiTSABlock(ViTBlock):
def __init__(
self,
d_model: int,
Expand All @@ -113,24 +97,18 @@ def __init__(
norm: _norm = nn.LayerNorm,
act: _act = nn.GELU,
) -> None:
super().__init__()
self.mha = nn.Sequential(
norm(d_model),
TalkingHeadAttention(d_model, n_heads, bias, dropout),
LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(),
StochasticDepth(stochastic_depth),
super().__init__(
d_model,
n_heads,
bias,
mlp_ratio,
dropout,
layer_scale_init,
stochastic_depth,
norm,
act,
partial(TalkingHeadAttention, d_model, n_heads, bias, dropout),
)
self.mlp = nn.Sequential(
norm(d_model),
MLP(d_model, int(d_model * mlp_ratio), dropout, act),
LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(),
StochasticDepth(stochastic_depth),
)

def forward(self, x: Tensor) -> Tensor:
x = x + self.mha(x)
x = x + self.mlp(x)
return x


class CaiT(nn.Module):
Expand Down
5 changes: 4 additions & 1 deletion vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,14 @@ def __init__(
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
attention: type[nn.Module] | None = None,
) -> None:
if attention is None:
attention = partial(MHA, d_model, n_heads, bias, dropout)
super().__init__()
self.mha = nn.Sequential(
norm(d_model),
MHA(d_model, n_heads, bias, dropout),
attention(),
LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(),
StochasticDepth(stochastic_depth),
)
Expand Down

0 comments on commit d9c9007

Please sign in to comment.