Skip to content

Commit

Permalink
fix norm eps
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 7, 2023
1 parent d6b9340 commit 7c4d93f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 13 deletions.
6 changes: 5 additions & 1 deletion vision_toolbox/backbones/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations

from abc import ABCMeta, abstractmethod
from typing import Any
from typing import Callable

import torch
from torch import Tensor, nn


_norm = Callable[[int], nn.Module]
_act = Callable[[], nn.Module]


class BaseBackbone(nn.Module, metaclass=ABCMeta):
# subclass only needs to implement this method
@abstractmethod
Expand Down
44 changes: 32 additions & 12 deletions vision_toolbox/backbones/mlp_mixer.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,42 @@
# https://arxiv.org/abs/2105.01601
# https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_mixer.py

from __future__ import annotations

from functools import partial
from typing import Mapping

import numpy as np
import torch
from torch import Tensor, nn

from ..utils import torch_hub_download
from .base import _act, _norm


class MLP(nn.Sequential):
def __init__(self, d_model: int, mlp_dim: float) -> None:
def __init__(self, d_model: int, mlp_dim: float, act: _act = nn.GELU) -> None:
super().__init__()
self.linear1 = nn.Linear(d_model, mlp_dim)
self.act = nn.GELU()
self.act = act()
self.linear2 = nn.Linear(mlp_dim, d_model)


class MixerBlock(nn.Module):
def __init__(self, n_tokens: int, d_model: int, tokens_mlp_dim: int, channels_mlp_dim: int) -> None:
def __init__(
self,
n_tokens: int,
d_model: int,
tokens_mlp_dim: int,
channels_mlp_dim: int,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.token_mixing = MLP(n_tokens, tokens_mlp_dim)
self.norm2 = nn.LayerNorm(d_model)
self.channel_mixing = MLP(d_model, channels_mlp_dim)
self.norm1 = norm(d_model)
self.token_mixing = MLP(n_tokens, tokens_mlp_dim, act)
self.norm2 = norm(d_model)
self.channel_mixing = MLP(d_model, channels_mlp_dim, act)

def forward(self, x: Tensor) -> Tensor:
# x -> (B, n_tokens, d_model)
Expand All @@ -36,26 +47,35 @@ def forward(self, x: Tensor) -> Tensor:

class MLPMixer(nn.Module):
def __init__(
self, n_layers: int, d_model: int, patch_size: int, img_size: int, tokens_mlp_dim: int, channels_mlp_dim: int
self,
n_layers: int,
d_model: int,
patch_size: int,
img_size: int,
tokens_mlp_dim: int,
channels_mlp_dim: int,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
assert img_size % patch_size == 0
super().__init__()
self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size)
n_tokens = (img_size // patch_size) ** 2
self.layers = nn.Sequential(
*[MixerBlock(n_tokens, d_model, tokens_mlp_dim, channels_mlp_dim) for _ in range(n_layers)]
*[MixerBlock(n_tokens, d_model, tokens_mlp_dim, channels_mlp_dim, norm, act) for _ in range(n_layers)]
)
self.norm = nn.LayerNorm(d_model)
self.norm = norm(d_model)

def forward(self, x: Tensor) -> Tensor:
x = self.patch_embed(x).flatten(-2).transpose(-1, -2)
x = self.patch_embed(x).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
x = self.layers(x)
x = self.norm(x)
x = x.mean(-2)
x = x.mean(1)
return x

@staticmethod
def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool = False) -> MLPMixer:
# Table 1 in https://arxiv.org/pdf/2105.01601.pdf
n_layers, d_model, tokens_mlp_dim, channels_mlp_dim = dict(
S=(8, 512, 256, 2048), B=(12, 768, 384, 3072), L=(24, 1024, 512, 4096), H=(32, 1280, 640, 5120)
)[variant]
Expand Down

0 comments on commit 7c4d93f

Please sign in to comment.