diff --git a/vision_toolbox/backbones/mlp_mixer.py b/vision_toolbox/backbones/mlp_mixer.py index 124f7ce..6189c9d 100644 --- a/vision_toolbox/backbones/mlp_mixer.py +++ b/vision_toolbox/backbones/mlp_mixer.py @@ -12,14 +12,7 @@ from ..utils import torch_hub_download from .base import _act, _norm - - -class MLP(nn.Sequential): - def __init__(self, in_dim: int, hidden_dim: float, act: _act = nn.GELU) -> None: - super().__init__() - self.linear1 = nn.Linear(in_dim, hidden_dim) - self.act = act() - self.linear2 = nn.Linear(hidden_dim, in_dim) +from .vit import MLP class MixerBlock(nn.Module): diff --git a/vision_toolbox/backbones/vit.py b/vision_toolbox/backbones/vit.py index 0929f66..620c25b 100644 --- a/vision_toolbox/backbones/vit.py +++ b/vision_toolbox/backbones/vit.py @@ -25,16 +25,16 @@ def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float self.dropout = dropout self.scale = (d_model // n_heads) ** (-0.5) - def forward(self, x: Tensor, attn_mask: Tensor | None = None) -> Tensor: + def forward(self, x: Tensor, attn_bias: Tensor | None = None) -> Tensor: qkv = self.in_proj(x) q, k, v = qkv.unflatten(-1, (3, self.n_heads, -1)).transpose(-2, -4).unbind(-3) if hasattr(F, "scaled_dot_product_attention"): - out = F.scaled_dot_product_attention(q, k, v, attn_mask, self.dropout) + out = F.scaled_dot_product_attention(q, k, v, attn_bias, self.dropout if self.training else 0.0) else: attn = torch.softmax(q @ (k * self.scale).transpose(-1, -2), -1) - if attn_mask is not None: - attn = attn + attn_mask + if attn_bias is not None: + attn = attn + attn_bias out = F.dropout(attn, self.dropout, self.training) @ v out = out.transpose(-2, -3).flatten(-2) @@ -42,12 +42,21 @@ def forward(self, x: Tensor, attn_mask: Tensor | None = None) -> Tensor: return out +class MLP(nn.Sequential): + def __init__(self, in_dim: int, hidden_dim: float, act: _act = nn.GELU) -> None: + super().__init__() + self.linear1 = nn.Linear(in_dim, hidden_dim) + self.act = act() + self.linear2 = nn.Linear(hidden_dim, in_dim) + + class ViTBlock(nn.Module): def __init__( self, d_model: int, n_heads: int, bias: bool = True, + mlp_ratio: float = 4.0, dropout: float = 0.0, norm: _norm = partial(nn.LayerNorm, eps=1e-6), act: _act = nn.GELU, @@ -56,12 +65,7 @@ def __init__( self.norm1 = norm(d_model) self.mha = MHA(d_model, n_heads, bias, dropout) self.norm2 = norm(d_model) - self.mlp = nn.Sequential( - nn.Linear(d_model, d_model * 4, bias), - act(), - nn.Linear(d_model * 4, d_model, bias), - nn.Dropout(dropout), - ) + self.mlp = MLP(d_model, int(d_model * mlp_ratio), act) def forward(self, x: Tensor) -> Tensor: x = x + self.mha(self.norm1(x)) @@ -79,6 +83,7 @@ def __init__( img_size: int, cls_token: bool = True, bias: bool = True, + mlp_ratio: float = 4.0, dropout: float = 0.0, norm: _norm = partial(nn.LayerNorm, eps=1e-6), act: _act = nn.GELU, @@ -94,7 +99,9 @@ def __init__( self.pe = nn.Parameter(torch.empty(1, pe_size, d_model)) nn.init.normal_(self.pe, 0, 0.02) - self.layers = nn.Sequential(*[ViTBlock(d_model, n_heads, bias, dropout, norm, act) for _ in range(n_layers)]) + self.layers = nn.Sequential( + *[ViTBlock(d_model, n_heads, bias, mlp_ratio, dropout, norm, act) for _ in range(n_layers)] + ) self.norm = norm(d_model) def forward(self, imgs: Tensor) -> Tensor: @@ -177,10 +184,10 @@ def get_w(key: str) -> Tensor: layer.norm2.weight.copy_(get_w(prefix + "LayerNorm_2/scale")) layer.norm2.bias.copy_(get_w(prefix + "LayerNorm_2/bias")) - layer.mlp[0].weight.copy_(get_w(prefix + "MlpBlock_3/Dense_0/kernel").T) - layer.mlp[0].bias.copy_(get_w(prefix + "MlpBlock_3/Dense_0/bias")) - layer.mlp[2].weight.copy_(get_w(prefix + "MlpBlock_3/Dense_1/kernel").T) - layer.mlp[2].bias.copy_(get_w(prefix + "MlpBlock_3/Dense_1/bias")) + layer.mlp.linear1.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_0/kernel").T) + layer.mlp.linear1.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_0/bias")) + layer.mlp.linear2.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_1/kernel").T) + layer.mlp.linear2.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_1/bias")) self.norm.weight.copy_(get_w("Transformer/encoder_norm/scale")) self.norm.bias.copy_(get_w("Transformer/encoder_norm/bias"))