From 7c4d93fcb26bb79cc56e7f340bd708e75fe087c4 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 7 Aug 2023 23:16:54 +0800 Subject: [PATCH] fix norm eps --- vision_toolbox/backbones/base.py | 6 +++- vision_toolbox/backbones/mlp_mixer.py | 44 +++++++++++++++++++-------- 2 files changed, 37 insertions(+), 13 deletions(-) diff --git a/vision_toolbox/backbones/base.py b/vision_toolbox/backbones/base.py index 14e5f3a..eea0850 100644 --- a/vision_toolbox/backbones/base.py +++ b/vision_toolbox/backbones/base.py @@ -1,12 +1,16 @@ from __future__ import annotations from abc import ABCMeta, abstractmethod -from typing import Any +from typing import Callable import torch from torch import Tensor, nn +_norm = Callable[[int], nn.Module] +_act = Callable[[], nn.Module] + + class BaseBackbone(nn.Module, metaclass=ABCMeta): # subclass only needs to implement this method @abstractmethod diff --git a/vision_toolbox/backbones/mlp_mixer.py b/vision_toolbox/backbones/mlp_mixer.py index ebc6411..d69689e 100644 --- a/vision_toolbox/backbones/mlp_mixer.py +++ b/vision_toolbox/backbones/mlp_mixer.py @@ -1,7 +1,9 @@ +# https://arxiv.org/abs/2105.01601 # https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_mixer.py from __future__ import annotations +from functools import partial from typing import Mapping import numpy as np @@ -9,23 +11,32 @@ from torch import Tensor, nn from ..utils import torch_hub_download +from .base import _act, _norm class MLP(nn.Sequential): - def __init__(self, d_model: int, mlp_dim: float) -> None: + def __init__(self, d_model: int, mlp_dim: float, act: _act = nn.GELU) -> None: super().__init__() self.linear1 = nn.Linear(d_model, mlp_dim) - self.act = nn.GELU() + self.act = act() self.linear2 = nn.Linear(mlp_dim, d_model) class MixerBlock(nn.Module): - def __init__(self, n_tokens: int, d_model: int, tokens_mlp_dim: int, channels_mlp_dim: int) -> None: + def __init__( + self, + n_tokens: int, + d_model: int, + tokens_mlp_dim: int, + channels_mlp_dim: int, + norm: _norm = partial(nn.LayerNorm, eps=1e-6), + act: _act = nn.GELU, + ) -> None: super().__init__() - self.norm1 = nn.LayerNorm(d_model) - self.token_mixing = MLP(n_tokens, tokens_mlp_dim) - self.norm2 = nn.LayerNorm(d_model) - self.channel_mixing = MLP(d_model, channels_mlp_dim) + self.norm1 = norm(d_model) + self.token_mixing = MLP(n_tokens, tokens_mlp_dim, act) + self.norm2 = norm(d_model) + self.channel_mixing = MLP(d_model, channels_mlp_dim, act) def forward(self, x: Tensor) -> Tensor: # x -> (B, n_tokens, d_model) @@ -36,26 +47,35 @@ def forward(self, x: Tensor) -> Tensor: class MLPMixer(nn.Module): def __init__( - self, n_layers: int, d_model: int, patch_size: int, img_size: int, tokens_mlp_dim: int, channels_mlp_dim: int + self, + n_layers: int, + d_model: int, + patch_size: int, + img_size: int, + tokens_mlp_dim: int, + channels_mlp_dim: int, + norm: _norm = partial(nn.LayerNorm, eps=1e-6), + act: _act = nn.GELU, ) -> None: assert img_size % patch_size == 0 super().__init__() self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size) n_tokens = (img_size // patch_size) ** 2 self.layers = nn.Sequential( - *[MixerBlock(n_tokens, d_model, tokens_mlp_dim, channels_mlp_dim) for _ in range(n_layers)] + *[MixerBlock(n_tokens, d_model, tokens_mlp_dim, channels_mlp_dim, norm, act) for _ in range(n_layers)] ) - self.norm = nn.LayerNorm(d_model) + self.norm = norm(d_model) def forward(self, x: Tensor) -> Tensor: - x = self.patch_embed(x).flatten(-2).transpose(-1, -2) + x = self.patch_embed(x).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C) x = self.layers(x) x = self.norm(x) - x = x.mean(-2) + x = x.mean(1) return x @staticmethod def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool = False) -> MLPMixer: + # Table 1 in https://arxiv.org/pdf/2105.01601.pdf n_layers, d_model, tokens_mlp_dim, channels_mlp_dim = dict( S=(8, 512, 256, 2048), B=(12, 768, 384, 3072), L=(24, 1024, 512, 4096), H=(32, 1280, 640, 5120) )[variant]