diff --git a/vision_toolbox/backbones/cait.py b/vision_toolbox/backbones/cait.py index c7e18f4..644603a 100644 --- a/vision_toolbox/backbones/cait.py +++ b/vision_toolbox/backbones/cait.py @@ -9,7 +9,6 @@ import torch.nn.functional as F from torch import Tensor, nn -from .base import _act, _norm from .vit import MHA, ViT, ViTBlock @@ -62,13 +61,12 @@ def __init__( dropout: float = 0.0, layer_scale_init: float | None = 1e-6, stochastic_depth: float = 0.0, - norm: _norm = partial(nn.LayerNorm, eps=1e-6), - act: _act = nn.GELU, + norm_eps: float = 1e-6, ) -> None: # fmt: off super().__init__( d_model, n_heads, bias, mlp_ratio, dropout, - layer_scale_init, stochastic_depth, norm, act, + layer_scale_init, stochastic_depth, norm_eps, partial(ClassAttention, d_model, n_heads, bias, dropout), ) # fmt: on @@ -89,13 +87,12 @@ def __init__( dropout: float = 0.0, layer_scale_init: float | None = 1e-6, stochastic_depth: float = 0.0, - norm: _norm = partial(nn.LayerNorm, eps=1e-6), - act: _act = nn.GELU, + norm_eps: float = 1e-6, ) -> None: # fmt: off super().__init__( d_model, n_heads, bias, mlp_ratio, dropout, - layer_scale_init, stochastic_depth, norm, act, + layer_scale_init, stochastic_depth, norm_eps, partial(TalkingHeadAttention, d_model, n_heads, bias, dropout), ) # fmt: on @@ -115,8 +112,7 @@ def __init__( dropout: float = 0.0, layer_scale_init: float | None = 1e-6, stochastic_depth: float = 0.0, - norm: _norm = partial(nn.LayerNorm, eps=1e-6), - act: _act = nn.GELU, + norm_eps: float = 1e-6, ) -> None: assert img_size % patch_size == 0 super().__init__() @@ -127,19 +123,15 @@ def __init__( self.sa_layers = nn.Sequential() for _ in range(sa_depth): - block = CaiTSABlock( - d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm, act - ) - self.sa_layers.append(block) + blk = CaiTSABlock(d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm_eps) + self.sa_layers.append(blk) self.ca_layers = nn.ModuleList() for _ in range(ca_depth): - block = CaiTCABlock( - d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm, act - ) - self.ca_layers.append(block) + blk = CaiTCABlock(d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm_eps) + self.ca_layers.append(blk) - self.norm = norm(d_model) + self.norm = nn.LayerNorm(d_model, norm_eps) def forward(self, imgs: Tensor) -> Tensor: patches = self.patch_embed(imgs).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C) diff --git a/vision_toolbox/backbones/convnext.py b/vision_toolbox/backbones/convnext.py index d9b04f2..40e5dd3 100644 --- a/vision_toolbox/backbones/convnext.py +++ b/vision_toolbox/backbones/convnext.py @@ -5,13 +5,11 @@ from __future__ import annotations -from functools import partial - import torch from torch import Tensor, nn from ..components import LayerScale, Permute, StochasticDepth -from .base import BaseBackbone, _act, _norm +from .base import BaseBackbone class GlobalResponseNorm(nn.Module): @@ -36,8 +34,7 @@ def __init__( bias: bool = True, layer_scale_init: float | None = 1e-6, stochastic_depth: float = 0.0, - norm: _norm = partial(nn.LayerNorm, eps=1e-6), - act: _act = nn.GELU, + norm_eps: float = 1e-6, v2: bool = False, ) -> None: if v2: @@ -48,9 +45,9 @@ def __init__( Permute(0, 3, 1, 2), nn.Conv2d(d_model, d_model, 7, padding=3, groups=d_model, bias=bias), Permute(0, 2, 3, 1), - norm(d_model), + nn.LayerNorm(d_model, norm_eps), nn.Linear(d_model, hidden_dim, bias=bias), - act(), + nn.GELU(), GlobalResponseNorm(hidden_dim) if v2 else nn.Identity(), nn.Linear(hidden_dim, d_model, bias=bias), LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(), @@ -70,12 +67,11 @@ def __init__( bias: bool = True, layer_scale_init: float | None = 1e-6, stochastic_depth: float = 0.0, - norm: _norm = partial(nn.LayerNorm, eps=1e-6), - act: _act = nn.GELU, + norm_eps: float = 1e-6, v2: bool = False, ) -> None: super().__init__() - self.stem = nn.Sequential(nn.Conv2d(3, d_model, 4, 4), Permute(0, 2, 3, 1), norm(d_model)) + self.stem = nn.Sequential(nn.Conv2d(3, d_model, 4, 4), Permute(0, 2, 3, 1), nn.LayerNorm(d_model, norm_eps)) stochastic_depth_rates = torch.linspace(0, stochastic_depth, sum(depths)) self.stages = nn.Sequential() @@ -85,7 +81,7 @@ def __init__( if stage_idx > 0: # equivalent to PatchMerging in SwinTransformer downsample = nn.Sequential( - norm(d_model), + nn.LayerNorm(d_model, norm_eps), Permute(0, 3, 1, 2), nn.Conv2d(d_model, d_model * 2, 2, 2), Permute(0, 2, 3, 1), @@ -97,12 +93,12 @@ def __init__( for block_idx in range(depth): rate = stochastic_depth_rates[sum(depths[:stage_idx]) + block_idx] - block = ConvNeXtBlock(d_model, expansion_ratio, bias, layer_scale_init, rate, norm, act, v2) + block = ConvNeXtBlock(d_model, expansion_ratio, bias, layer_scale_init, rate, norm_eps, v2) stage.append(block) self.stages.append(stage) - self.head_norm = norm(d_model) + self.norm = nn.LayerNorm(d_model, norm_eps) def get_feature_maps(self, x: Tensor) -> list[Tensor]: out = [self.stem(x)] @@ -111,7 +107,7 @@ def get_feature_maps(self, x: Tensor) -> list[Tensor]: return out[-1:] def forward(self, x: Tensor) -> Tensor: - return self.head_norm(self.get_feature_maps(x)[-1].mean((1, 2))) + return self.norm(self.get_feature_maps(x)[-1].mean((1, 2))) @staticmethod def from_config(variant: str, v2: bool = False, pretrained: bool = False) -> ConvNeXt: @@ -189,7 +185,7 @@ def copy_(m: nn.Conv2d | nn.Linear | nn.LayerNorm, prefix: str): # FCMAE checkpoints don't contain head norm if "norm.weight" in state_dict: - copy_(self.head_norm, "norm") + copy_(self.norm, "norm") assert len(state_dict) == 2 else: assert len(state_dict) == 0 diff --git a/vision_toolbox/backbones/deit.py b/vision_toolbox/backbones/deit.py index 102f60f..a5252c8 100644 --- a/vision_toolbox/backbones/deit.py +++ b/vision_toolbox/backbones/deit.py @@ -4,13 +4,10 @@ from __future__ import annotations -from functools import partial - import torch from torch import Tensor, nn from ..components import LayerScale -from .base import _act, _norm from .vit import ViT, ViTBlock @@ -27,13 +24,12 @@ def __init__( dropout: float = 0.0, layer_scale_init: float | None = None, stochastic_depth: float = 0.0, - norm: _norm = partial(nn.LayerNorm, eps=1e-6), - act: _act = nn.GELU, + norm_eps: float = 1e-6, ) -> None: # fmt: off super().__init__( d_model, depth, n_heads, patch_size, img_size, True, bias, mlp_ratio, - dropout, layer_scale_init, stochastic_depth, norm, act + dropout, layer_scale_init, stochastic_depth, norm_eps ) # fmt: on self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model)) @@ -133,13 +129,12 @@ def __init__( dropout: float = 0.0, layer_scale_init: float | None = 1e-6, stochastic_depth: float = 0.0, - norm: _norm = partial(nn.LayerNorm, eps=1e-6), - act: _act = nn.GELU, + norm_eps: float = 1e-6, ): # fmt: off super().__init__( d_model, depth, n_heads, patch_size, img_size, cls_token, bias, - mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm, act + mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm_eps, ) # fmt: on diff --git a/vision_toolbox/backbones/mlp_mixer.py b/vision_toolbox/backbones/mlp_mixer.py index 3e07665..85d0081 100644 --- a/vision_toolbox/backbones/mlp_mixer.py +++ b/vision_toolbox/backbones/mlp_mixer.py @@ -3,7 +3,6 @@ from __future__ import annotations -from functools import partial from typing import Mapping import numpy as np @@ -11,7 +10,6 @@ from torch import Tensor, nn from ..utils import torch_hub_download -from .base import _act, _norm from .vit import MLP @@ -22,15 +20,14 @@ def __init__( d_model: int, mlp_ratio: tuple[int, int] = (0.5, 4.0), dropout: float = 0.0, - norm: _norm = partial(nn.LayerNorm, eps=1e-6), - act: _act = nn.GELU, + norm_eps: float = 1e-6, ) -> None: tokens_mlp_dim, channels_mlp_dim = [int(d_model * ratio) for ratio in mlp_ratio] super().__init__() - self.norm1 = norm(d_model) - self.token_mixing = MLP(n_tokens, tokens_mlp_dim, dropout, act) - self.norm2 = norm(d_model) - self.channel_mixing = MLP(d_model, channels_mlp_dim, dropout, act) + self.norm1 = nn.LayerNorm(d_model, norm_eps) + self.token_mixing = MLP(n_tokens, tokens_mlp_dim, dropout) + self.norm2 = nn.LayerNorm(d_model, norm_eps) + self.channel_mixing = MLP(d_model, channels_mlp_dim, dropout) def forward(self, x: Tensor) -> Tensor: # x -> (B, n_tokens, d_model) @@ -48,17 +45,16 @@ def __init__( img_size: int, mlp_ratio: tuple[float, float] = (0.5, 4.0), dropout: float = 0.0, - norm: _norm = partial(nn.LayerNorm, eps=1e-6), - act: _act = nn.GELU, + norm_eps: float = 1e-6, ) -> None: assert img_size % patch_size == 0 super().__init__() self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size) n_tokens = (img_size // patch_size) ** 2 self.layers = nn.Sequential( - *[MixerBlock(n_tokens, d_model, mlp_ratio, dropout, norm, act) for _ in range(n_layers)] + *[MixerBlock(n_tokens, d_model, mlp_ratio, dropout, norm_eps) for _ in range(n_layers)] ) - self.norm = norm(d_model) + self.norm = nn.LayerNorm(d_model, norm_eps) def forward(self, x: Tensor) -> Tensor: x = self.patch_embed(x).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C) diff --git a/vision_toolbox/backbones/swin.py b/vision_toolbox/backbones/swin.py index 1bd12d4..da0fe35 100644 --- a/vision_toolbox/backbones/swin.py +++ b/vision_toolbox/backbones/swin.py @@ -9,7 +9,7 @@ import torch from torch import Tensor, nn -from .base import BaseBackbone, _act, _norm +from .base import BaseBackbone from .vit import MHA, ViTBlock @@ -99,22 +99,21 @@ def __init__( dropout: float = 0.0, layer_scale_init: float | None = None, stochastic_depth: float = 0.0, - norm: _norm = nn.LayerNorm, - act: _act = nn.GELU, + norm_eps: float = 1e-5, ) -> None: # fmt: off super().__init__( d_model, n_heads, bias, mlp_ratio, dropout, - layer_scale_init, stochastic_depth, norm, act, + layer_scale_init, stochastic_depth, norm_eps, partial(WindowAttention, input_size, d_model, n_heads, window_size, shift, bias, dropout), ) # fmt: on class PatchMerging(nn.Module): - def __init__(self, d_model: int, norm: _norm = nn.LayerNorm) -> None: + def __init__(self, d_model: int, norm_eps: float = 1e-5) -> None: super().__init__() - self.norm = norm(d_model * 4) + self.norm = nn.LayerNorm(d_model * 4, norm_eps) self.reduction = nn.Linear(d_model * 4, d_model * 2, False) def forward(self, x: Tensor) -> Tensor: @@ -139,14 +138,13 @@ def __init__( dropout: float = 0.0, layer_scale_init: float | None = None, stochastic_depth: float = 0.0, - norm: _norm = nn.LayerNorm, - act: _act = nn.GELU, + norm_eps: float = 1e-5, ) -> None: assert img_size % patch_size == 0 assert d_model % n_heads == 0 super().__init__() self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size) - self.norm = norm(d_model) + self.patch_norm = nn.LayerNorm(d_model, norm_eps) self.dropout = nn.Dropout(dropout) input_size = img_size // patch_size @@ -154,7 +152,7 @@ def __init__( for i, (depth, window_size) in enumerate(zip(depths, window_sizes)): stage = nn.Sequential() if i > 0: - downsample = PatchMerging(d_model, norm) + downsample = PatchMerging(d_model, norm_eps) input_size //= 2 d_model *= 2 n_heads *= 2 @@ -167,23 +165,23 @@ def __init__( # fmt: off block = SwinBlock( input_size, d_model, n_heads, window_size, shift, mlp_ratio, - bias, dropout, layer_scale_init, stochastic_depth, norm, act, + bias, dropout, layer_scale_init, stochastic_depth, norm_eps, ) # fmt: on stage.append(block) self.stages.append(stage) - self.head_norm = norm(d_model) + self.norm = nn.LayerNorm(d_model, norm_eps) def get_feature_maps(self, x: Tensor) -> list[Tensor]: - out = [self.dropout(self.norm(self.patch_embed(x).permute(0, 2, 3, 1)))] + out = [self.dropout(self.patch_norm(self.patch_embed(x).permute(0, 2, 3, 1)))] for stage in self.stages: out.append(stage(out[-1])) return out[1:] def forward(self, x: Tensor) -> Tensor: - return self.head_norm(self.get_feature_maps(x)[-1]).mean((1, 2)) + return self.norm(self.get_feature_maps(x)[-1]).mean((1, 2)) def resize_pe(self, img_size: int) -> None: raise NotImplementedError() @@ -222,7 +220,7 @@ def copy_(m: nn.Linear | nn.LayerNorm, prefix: str) -> None: m.bias.copy_(state_dict.pop(prefix + ".bias")) copy_(self.patch_embed, "patch_embed.proj") - copy_(self.norm, "patch_embed.norm") + copy_(self.patch_norm, "patch_embed.norm") for stage_idx, stage in enumerate(self.stages): if stage_idx > 0: @@ -261,5 +259,5 @@ def rearrange(p): copy_(block.mlp[1].linear1, prefix + "mlp.fc1") copy_(block.mlp[1].linear2, prefix + "mlp.fc2") - copy_(self.head_norm, "norm") + copy_(self.norm, "norm") assert len(state_dict) == 2 # head.weight and head.bias diff --git a/vision_toolbox/backbones/vit.py b/vision_toolbox/backbones/vit.py index bd14c2d..5ded0b0 100644 --- a/vision_toolbox/backbones/vit.py +++ b/vision_toolbox/backbones/vit.py @@ -14,7 +14,6 @@ from ..components import LayerScale, StochasticDepth from ..utils import torch_hub_download -from .base import _act, _norm class MHA(nn.Module): @@ -47,10 +46,10 @@ def forward(self, x: Tensor, attn_bias: Tensor | None = None) -> Tensor: class MLP(nn.Sequential): - def __init__(self, in_dim: int, hidden_dim: float, dropout: float = 0.0, act: _act = nn.GELU) -> None: + def __init__(self, in_dim: int, hidden_dim: float, dropout: float = 0.0) -> None: super().__init__() self.linear1 = nn.Linear(in_dim, hidden_dim) - self.act = act() + self.act = nn.GELU() self.linear2 = nn.Linear(hidden_dim, in_dim) self.dropout = nn.Dropout(dropout) @@ -65,22 +64,21 @@ def __init__( dropout: float = 0.0, layer_scale_init: float | None = None, stochastic_depth: float = 0.0, - norm: _norm = partial(nn.LayerNorm, eps=1e-6), - act: _act = nn.GELU, + norm_eps: float = 1e-6, attention: type[nn.Module] | None = None, ) -> None: if attention is None: attention = partial(MHA, d_model, n_heads, bias, dropout) super().__init__() self.mha = nn.Sequential( - norm(d_model), + nn.LayerNorm(d_model, norm_eps), attention(), LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(), StochasticDepth(stochastic_depth), ) self.mlp = nn.Sequential( - norm(d_model), - MLP(d_model, int(d_model * mlp_ratio), dropout, act), + nn.LayerNorm(d_model, norm_eps), + MLP(d_model, int(d_model * mlp_ratio), dropout), LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(), StochasticDepth(stochastic_depth), ) @@ -105,8 +103,7 @@ def __init__( dropout: float = 0.0, layer_scale_init: float | None = None, stochastic_depth: float = 0.0, - norm: _norm = partial(nn.LayerNorm, eps=1e-6), - act: _act = nn.GELU, + norm_eps: float = 1e-6, ) -> None: assert img_size % patch_size == 0 super().__init__() @@ -117,10 +114,10 @@ def __init__( self.layers = nn.Sequential() for _ in range(depth): - block = ViTBlock(d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm, act) + block = ViTBlock(d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm_eps) self.layers.append(block) - self.norm = norm(d_model) + self.norm = nn.LayerNorm(d_model, norm_eps) def forward(self, imgs: Tensor) -> Tensor: out = self.patch_embed(imgs).flatten(2).transpose(1, 2) + self.pe # (N, C, H, W) -> (N, H*W, C)