Skip to content

Commit

Permalink
use mlp_ratio
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 8, 2023
1 parent d029bde commit ddd8f3e
Showing 1 changed file with 6 additions and 14 deletions.
20 changes: 6 additions & 14 deletions vision_toolbox/backbones/mlp_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ def __init__(
self,
n_tokens: int,
d_model: int,
tokens_mlp_dim: int,
channels_mlp_dim: int,
mlp_ratio: tuple[int, int] = (0.5, 4.0),
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
tokens_mlp_dim, channels_mlp_dim = [int(d_model * ratio) for ratio in mlp_ratio]
super().__init__()
self.norm1 = norm(d_model)
self.token_mixing = MLP(n_tokens, tokens_mlp_dim, act)
Expand All @@ -52,18 +52,15 @@ def __init__(
d_model: int,
patch_size: int,
img_size: int,
tokens_mlp_dim: int,
channels_mlp_dim: int,
mlp_ratio: tuple[float, float] = (0.5, 4.0),
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
assert img_size % patch_size == 0
super().__init__()
self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size)
n_tokens = (img_size // patch_size) ** 2
self.layers = nn.Sequential(
*[MixerBlock(n_tokens, d_model, tokens_mlp_dim, channels_mlp_dim, norm, act) for _ in range(n_layers)]
)
self.layers = nn.Sequential(*[MixerBlock(n_tokens, d_model, mlp_ratio, norm, act) for _ in range(n_layers)])
self.norm = norm(d_model)

def forward(self, x: Tensor) -> Tensor:
Expand All @@ -76,13 +73,8 @@ def forward(self, x: Tensor) -> Tensor:
@staticmethod
def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool = False) -> MLPMixer:
# Table 1 in https://arxiv.org/pdf/2105.01601.pdf
n_layers, d_model, tokens_mlp_dim, channels_mlp_dim = dict(
S=(8, 512, 256, 2048),
B=(12, 768, 384, 3072),
L=(24, 1024, 512, 4096),
H=(32, 1280, 640, 5120),
)[variant]
m = MLPMixer(n_layers, d_model, patch_size, img_size, tokens_mlp_dim, channels_mlp_dim)
n_layers, d_model = dict(S=(8, 512), B=(12, 768), L=(24, 1024), H=(32, 1280))[variant]
m = MLPMixer(n_layers, d_model, patch_size, img_size)
if pretrained:
ckpt = {
("S", 8): "gsam/Mixer-S_8.npz",
Expand Down

0 comments on commit ddd8f3e

Please sign in to comment.