Skip to content

Commit

Permalink
add dropout to MLP
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 9, 2023
1 parent 6413aa7 commit a41efbe
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
10 changes: 7 additions & 3 deletions vision_toolbox/backbones/mlp_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@ def __init__(
n_tokens: int,
d_model: int,
mlp_ratio: tuple[int, int] = (0.5, 4.0),
dropout: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
tokens_mlp_dim, channels_mlp_dim = [int(d_model * ratio) for ratio in mlp_ratio]
super().__init__()
self.norm1 = norm(d_model)
self.token_mixing = MLP(n_tokens, tokens_mlp_dim, act)
self.token_mixing = MLP(n_tokens, tokens_mlp_dim, dropout, act)
self.norm2 = norm(d_model)
self.channel_mixing = MLP(d_model, channels_mlp_dim, act)
self.channel_mixing = MLP(d_model, channels_mlp_dim, dropout, act)

def forward(self, x: Tensor) -> Tensor:
# x -> (B, n_tokens, d_model)
Expand All @@ -46,14 +47,17 @@ def __init__(
patch_size: int,
img_size: int,
mlp_ratio: tuple[float, float] = (0.5, 4.0),
dropout: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
assert img_size % patch_size == 0
super().__init__()
self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size)
n_tokens = (img_size // patch_size) ** 2
self.layers = nn.Sequential(*[MixerBlock(n_tokens, d_model, mlp_ratio, norm, act) for _ in range(n_layers)])
self.layers = nn.Sequential(
*[MixerBlock(n_tokens, d_model, mlp_ratio, dropout, norm, act) for _ in range(n_layers)]
)
self.norm = norm(d_model)

def forward(self, x: Tensor) -> Tensor:
Expand Down
5 changes: 3 additions & 2 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ def forward(self, x: Tensor, attn_bias: Tensor | None = None) -> Tensor:


class MLP(nn.Sequential):
def __init__(self, in_dim: int, hidden_dim: float, act: _act = nn.GELU) -> None:
def __init__(self, in_dim: int, hidden_dim: float, dropout: float = 0.0, act: _act = nn.GELU) -> None:
super().__init__()
self.linear1 = nn.Linear(in_dim, hidden_dim)
self.act = act()
self.linear2 = nn.Linear(hidden_dim, in_dim)
self.dropout = nn.Dropout(dropout)


class ViTBlock(nn.Module):
Expand All @@ -65,7 +66,7 @@ def __init__(
self.norm1 = norm(d_model)
self.mha = MHA(d_model, n_heads, bias, dropout)
self.norm2 = norm(d_model)
self.mlp = MLP(d_model, int(d_model * mlp_ratio), act)
self.mlp = MLP(d_model, int(d_model * mlp_ratio), dropout, act)

def forward(self, x: Tensor) -> Tensor:
x = x + self.mha(self.norm1(x))
Expand Down

0 comments on commit a41efbe

Please sign in to comment.