From a41efbe60e309e590cb5f58054110af13a91ff29 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 9 Aug 2023 11:42:17 +0800 Subject: [PATCH] add dropout to MLP --- vision_toolbox/backbones/mlp_mixer.py | 10 +++++++--- vision_toolbox/backbones/vit.py | 5 +++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/vision_toolbox/backbones/mlp_mixer.py b/vision_toolbox/backbones/mlp_mixer.py index 6189c9d..3e07665 100644 --- a/vision_toolbox/backbones/mlp_mixer.py +++ b/vision_toolbox/backbones/mlp_mixer.py @@ -21,15 +21,16 @@ def __init__( n_tokens: int, d_model: int, mlp_ratio: tuple[int, int] = (0.5, 4.0), + dropout: float = 0.0, norm: _norm = partial(nn.LayerNorm, eps=1e-6), act: _act = nn.GELU, ) -> None: tokens_mlp_dim, channels_mlp_dim = [int(d_model * ratio) for ratio in mlp_ratio] super().__init__() self.norm1 = norm(d_model) - self.token_mixing = MLP(n_tokens, tokens_mlp_dim, act) + self.token_mixing = MLP(n_tokens, tokens_mlp_dim, dropout, act) self.norm2 = norm(d_model) - self.channel_mixing = MLP(d_model, channels_mlp_dim, act) + self.channel_mixing = MLP(d_model, channels_mlp_dim, dropout, act) def forward(self, x: Tensor) -> Tensor: # x -> (B, n_tokens, d_model) @@ -46,6 +47,7 @@ def __init__( patch_size: int, img_size: int, mlp_ratio: tuple[float, float] = (0.5, 4.0), + dropout: float = 0.0, norm: _norm = partial(nn.LayerNorm, eps=1e-6), act: _act = nn.GELU, ) -> None: @@ -53,7 +55,9 @@ def __init__( super().__init__() self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size) n_tokens = (img_size // patch_size) ** 2 - self.layers = nn.Sequential(*[MixerBlock(n_tokens, d_model, mlp_ratio, norm, act) for _ in range(n_layers)]) + self.layers = nn.Sequential( + *[MixerBlock(n_tokens, d_model, mlp_ratio, dropout, norm, act) for _ in range(n_layers)] + ) self.norm = norm(d_model) def forward(self, x: Tensor) -> Tensor: diff --git a/vision_toolbox/backbones/vit.py b/vision_toolbox/backbones/vit.py index 620c25b..81cb1ff 100644 --- a/vision_toolbox/backbones/vit.py +++ b/vision_toolbox/backbones/vit.py @@ -43,11 +43,12 @@ def forward(self, x: Tensor, attn_bias: Tensor | None = None) -> Tensor: class MLP(nn.Sequential): - def __init__(self, in_dim: int, hidden_dim: float, act: _act = nn.GELU) -> None: + def __init__(self, in_dim: int, hidden_dim: float, dropout: float = 0.0, act: _act = nn.GELU) -> None: super().__init__() self.linear1 = nn.Linear(in_dim, hidden_dim) self.act = act() self.linear2 = nn.Linear(hidden_dim, in_dim) + self.dropout = nn.Dropout(dropout) class ViTBlock(nn.Module): @@ -65,7 +66,7 @@ def __init__( self.norm1 = norm(d_model) self.mha = MHA(d_model, n_heads, bias, dropout) self.norm2 = norm(d_model) - self.mlp = MLP(d_model, int(d_model * mlp_ratio), act) + self.mlp = MLP(d_model, int(d_model * mlp_ratio), dropout, act) def forward(self, x: Tensor) -> Tensor: x = x + self.mha(self.norm1(x))