Skip to content

Commit

Permalink
move MLP to vit.py. fix dropout during inference
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 9, 2023
1 parent fa64602 commit 6413aa7
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 23 deletions.
9 changes: 1 addition & 8 deletions vision_toolbox/backbones/mlp_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,7 @@

from ..utils import torch_hub_download
from .base import _act, _norm


class MLP(nn.Sequential):
def __init__(self, in_dim: int, hidden_dim: float, act: _act = nn.GELU) -> None:
super().__init__()
self.linear1 = nn.Linear(in_dim, hidden_dim)
self.act = act()
self.linear2 = nn.Linear(hidden_dim, in_dim)
from .vit import MLP


class MixerBlock(nn.Module):
Expand Down
37 changes: 22 additions & 15 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,38 @@ def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float
self.dropout = dropout
self.scale = (d_model // n_heads) ** (-0.5)

def forward(self, x: Tensor, attn_mask: Tensor | None = None) -> Tensor:
def forward(self, x: Tensor, attn_bias: Tensor | None = None) -> Tensor:
qkv = self.in_proj(x)
q, k, v = qkv.unflatten(-1, (3, self.n_heads, -1)).transpose(-2, -4).unbind(-3)

if hasattr(F, "scaled_dot_product_attention"):
out = F.scaled_dot_product_attention(q, k, v, attn_mask, self.dropout)
out = F.scaled_dot_product_attention(q, k, v, attn_bias, self.dropout if self.training else 0.0)
else:
attn = torch.softmax(q @ (k * self.scale).transpose(-1, -2), -1)
if attn_mask is not None:
attn = attn + attn_mask
if attn_bias is not None:
attn = attn + attn_bias
out = F.dropout(attn, self.dropout, self.training) @ v

out = out.transpose(-2, -3).flatten(-2)
out = self.out_proj(out)
return out


class MLP(nn.Sequential):
def __init__(self, in_dim: int, hidden_dim: float, act: _act = nn.GELU) -> None:
super().__init__()
self.linear1 = nn.Linear(in_dim, hidden_dim)
self.act = act()
self.linear2 = nn.Linear(hidden_dim, in_dim)


class ViTBlock(nn.Module):
def __init__(
self,
d_model: int,
n_heads: int,
bias: bool = True,
mlp_ratio: float = 4.0,
dropout: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
Expand All @@ -56,12 +65,7 @@ def __init__(
self.norm1 = norm(d_model)
self.mha = MHA(d_model, n_heads, bias, dropout)
self.norm2 = norm(d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model, d_model * 4, bias),
act(),
nn.Linear(d_model * 4, d_model, bias),
nn.Dropout(dropout),
)
self.mlp = MLP(d_model, int(d_model * mlp_ratio), act)

def forward(self, x: Tensor) -> Tensor:
x = x + self.mha(self.norm1(x))
Expand All @@ -79,6 +83,7 @@ def __init__(
img_size: int,
cls_token: bool = True,
bias: bool = True,
mlp_ratio: float = 4.0,
dropout: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
Expand All @@ -94,7 +99,9 @@ def __init__(
self.pe = nn.Parameter(torch.empty(1, pe_size, d_model))
nn.init.normal_(self.pe, 0, 0.02)

self.layers = nn.Sequential(*[ViTBlock(d_model, n_heads, bias, dropout, norm, act) for _ in range(n_layers)])
self.layers = nn.Sequential(
*[ViTBlock(d_model, n_heads, bias, mlp_ratio, dropout, norm, act) for _ in range(n_layers)]
)
self.norm = norm(d_model)

def forward(self, imgs: Tensor) -> Tensor:
Expand Down Expand Up @@ -177,10 +184,10 @@ def get_w(key: str) -> Tensor:

layer.norm2.weight.copy_(get_w(prefix + "LayerNorm_2/scale"))
layer.norm2.bias.copy_(get_w(prefix + "LayerNorm_2/bias"))
layer.mlp[0].weight.copy_(get_w(prefix + "MlpBlock_3/Dense_0/kernel").T)
layer.mlp[0].bias.copy_(get_w(prefix + "MlpBlock_3/Dense_0/bias"))
layer.mlp[2].weight.copy_(get_w(prefix + "MlpBlock_3/Dense_1/kernel").T)
layer.mlp[2].bias.copy_(get_w(prefix + "MlpBlock_3/Dense_1/bias"))
layer.mlp.linear1.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_0/kernel").T)
layer.mlp.linear1.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_0/bias"))
layer.mlp.linear2.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_1/kernel").T)
layer.mlp.linear2.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_1/bias"))

self.norm.weight.copy_(get_w("Transformer/encoder_norm/scale"))
self.norm.bias.copy_(get_w("Transformer/encoder_norm/bias"))
Expand Down

0 comments on commit 6413aa7

Please sign in to comment.