diff --git a/vision_toolbox/backbones/mlp_mixer.py b/vision_toolbox/backbones/mlp_mixer.py index f109d7c..b16800d 100644 --- a/vision_toolbox/backbones/mlp_mixer.py +++ b/vision_toolbox/backbones/mlp_mixer.py @@ -27,11 +27,11 @@ def __init__( self, n_tokens: int, d_model: int, - tokens_mlp_dim: int, - channels_mlp_dim: int, + mlp_ratio: tuple[int, int] = (0.5, 4.0), norm: _norm = partial(nn.LayerNorm, eps=1e-6), act: _act = nn.GELU, ) -> None: + tokens_mlp_dim, channels_mlp_dim = [int(d_model * ratio) for ratio in mlp_ratio] super().__init__() self.norm1 = norm(d_model) self.token_mixing = MLP(n_tokens, tokens_mlp_dim, act) @@ -52,8 +52,7 @@ def __init__( d_model: int, patch_size: int, img_size: int, - tokens_mlp_dim: int, - channels_mlp_dim: int, + mlp_ratio: tuple[float, float] = (0.5, 4.0), norm: _norm = partial(nn.LayerNorm, eps=1e-6), act: _act = nn.GELU, ) -> None: @@ -61,9 +60,7 @@ def __init__( super().__init__() self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size) n_tokens = (img_size // patch_size) ** 2 - self.layers = nn.Sequential( - *[MixerBlock(n_tokens, d_model, tokens_mlp_dim, channels_mlp_dim, norm, act) for _ in range(n_layers)] - ) + self.layers = nn.Sequential(*[MixerBlock(n_tokens, d_model, mlp_ratio, norm, act) for _ in range(n_layers)]) self.norm = norm(d_model) def forward(self, x: Tensor) -> Tensor: @@ -76,13 +73,8 @@ def forward(self, x: Tensor) -> Tensor: @staticmethod def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool = False) -> MLPMixer: # Table 1 in https://arxiv.org/pdf/2105.01601.pdf - n_layers, d_model, tokens_mlp_dim, channels_mlp_dim = dict( - S=(8, 512, 256, 2048), - B=(12, 768, 384, 3072), - L=(24, 1024, 512, 4096), - H=(32, 1280, 640, 5120), - )[variant] - m = MLPMixer(n_layers, d_model, patch_size, img_size, tokens_mlp_dim, channels_mlp_dim) + n_layers, d_model = dict(S=(8, 512), B=(12, 768), L=(24, 1024), H=(32, 1280))[variant] + m = MLPMixer(n_layers, d_model, patch_size, img_size) if pretrained: ckpt = { ("S", 8): "gsam/Mixer-S_8.npz",