diff --git a/vision_toolbox/backbones/convnext.py b/vision_toolbox/backbones/convnext.py index 67bb132..d9b04f2 100644 --- a/vision_toolbox/backbones/convnext.py +++ b/vision_toolbox/backbones/convnext.py @@ -10,7 +10,7 @@ import torch from torch import Tensor, nn -from ..components import Permute, StochasticDepth +from ..components import LayerScale, Permute, StochasticDepth from .base import BaseBackbone, _act, _norm @@ -34,12 +34,14 @@ def __init__( d_model: int, expansion_ratio: float = 4.0, bias: bool = True, - layer_scale_init: float = 1e-6, + layer_scale_init: float | None = 1e-6, stochastic_depth: float = 0.0, norm: _norm = partial(nn.LayerNorm, eps=1e-6), act: _act = nn.GELU, v2: bool = False, ) -> None: + if v2: + layer_scale_init = None super().__init__() hidden_dim = int(d_model * expansion_ratio) self.layers = nn.Sequential( @@ -51,17 +53,12 @@ def __init__( act(), GlobalResponseNorm(hidden_dim) if v2 else nn.Identity(), nn.Linear(hidden_dim, d_model, bias=bias), + LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(), + StochasticDepth(stochastic_depth), ) - self.layer_scale = ( - nn.Parameter(torch.full((d_model,), layer_scale_init)) if layer_scale_init > 0 and not v2 else None - ) - self.drop = StochasticDepth(stochastic_depth) def forward(self, x: Tensor) -> Tensor: - out = self.layers(x) - if self.layer_scale is not None: - out = out * self.layer_scale - return x + self.drop(out) + return x + self.layers(x) class ConvNeXt(BaseBackbone): @@ -71,7 +68,7 @@ def __init__( depths: tuple[int, ...], expansion_ratio: float = 4.0, bias: bool = True, - layer_scale_init: float = 1e-6, + layer_scale_init: float | None = 1e-6, stochastic_depth: float = 0.0, norm: _norm = partial(nn.LayerNorm, eps=1e-6), act: _act = nn.GELU, @@ -187,8 +184,8 @@ def copy_(m: nn.Conv2d | nn.Linear | nn.LayerNorm, prefix: str): block.layers[6].beta.copy_(state_dict.pop(prefix + "grn.beta").squeeze()) copy_(block.layers[7], prefix + "pwconv2") - if block.layer_scale is not None: - block.layer_scale.copy_(state_dict.pop(prefix + "gamma")) + if isinstance(block.layers[8], LayerScale): + block.layers[8].gamma.copy_(state_dict.pop(prefix + "gamma")) # FCMAE checkpoints don't contain head norm if "norm.weight" in state_dict: diff --git a/vision_toolbox/backbones/vit.py b/vision_toolbox/backbones/vit.py index c1e4ed8..5ca4049 100644 --- a/vision_toolbox/backbones/vit.py +++ b/vision_toolbox/backbones/vit.py @@ -12,6 +12,7 @@ import torch.nn.functional as F from torch import Tensor, nn +from ..components import LayerScale, StochasticDepth from ..utils import torch_hub_download from .base import _act, _norm @@ -58,18 +59,28 @@ def __init__( bias: bool = True, mlp_ratio: float = 4.0, dropout: float = 0.0, + layer_scale_init: float | None = None, + stochastic_depth: float = 0.0, norm: _norm = partial(nn.LayerNorm, eps=1e-6), act: _act = nn.GELU, ) -> None: super().__init__() - self.norm1 = norm(d_model) - self.mha = MHA(d_model, n_heads, bias, dropout) - self.norm2 = norm(d_model) - self.mlp = MLP(d_model, int(d_model * mlp_ratio), dropout, act) + self.mha = nn.Sequential( + norm(d_model), + MHA(d_model, n_heads, bias, dropout), + LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(), + StochasticDepth(stochastic_depth), + ) + self.mlp = nn.Sequential( + norm(d_model), + MLP(d_model, int(d_model * mlp_ratio), dropout, act), + LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(), + StochasticDepth(stochastic_depth), + ) def forward(self, x: Tensor) -> Tensor: - x = x + self.mha(self.norm1(x)) - x = x + self.mlp(self.norm2(x)) + x = x + self.mha(x) + x = x + self.mlp(x) return x @@ -85,6 +96,8 @@ def __init__( bias: bool = True, mlp_ratio: float = 4.0, dropout: float = 0.0, + layer_scale_init: float | None = None, + stochastic_depth: float = 0.0, norm: _norm = partial(nn.LayerNorm, eps=1e-6), act: _act = nn.GELU, ) -> None: @@ -99,9 +112,11 @@ def __init__( self.pe = nn.Parameter(torch.empty(1, pe_size, d_model)) nn.init.normal_(self.pe, 0, 0.02) - self.layers = nn.Sequential( - *[ViTBlock(d_model, n_heads, bias, mlp_ratio, dropout, norm, act) for _ in range(n_layers)] - ) + self.layers = nn.Sequential() + for _ in range(n_layers): + block = ViTBlock(d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm, act) + self.layers.append(block) + self.norm = norm(d_model) def forward(self, imgs: Tensor) -> Tensor: @@ -173,21 +188,21 @@ def get_w(key: str) -> Tensor: prefix = f"Transformer/encoderblock_{idx}/" mha_prefix = prefix + "MultiHeadDotProductAttention_1/" - layer.norm1.weight.copy_(get_w(prefix + "LayerNorm_0/scale")) - layer.norm1.bias.copy_(get_w(prefix + "LayerNorm_0/bias")) + layer.mha[0].weight.copy_(get_w(prefix + "LayerNorm_0/scale")) + layer.mha[0].bias.copy_(get_w(prefix + "LayerNorm_0/bias")) w = torch.stack([get_w(mha_prefix + x + "/kernel") for x in ["query", "key", "value"]], 1) b = torch.stack([get_w(mha_prefix + x + "/bias") for x in ["query", "key", "value"]], 0) - layer.mha.in_proj.weight.copy_(w.flatten(1).T) - layer.mha.in_proj.bias.copy_(b.flatten()) - layer.mha.out_proj.weight.copy_(get_w(mha_prefix + "out/kernel").flatten(0, 1).T) - layer.mha.out_proj.bias.copy_(get_w(mha_prefix + "out/bias")) - - layer.norm2.weight.copy_(get_w(prefix + "LayerNorm_2/scale")) - layer.norm2.bias.copy_(get_w(prefix + "LayerNorm_2/bias")) - layer.mlp.linear1.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_0/kernel").T) - layer.mlp.linear1.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_0/bias")) - layer.mlp.linear2.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_1/kernel").T) - layer.mlp.linear2.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_1/bias")) + layer.mha[1].in_proj.weight.copy_(w.flatten(1).T) + layer.mha[1].in_proj.bias.copy_(b.flatten()) + layer.mha[1].out_proj.weight.copy_(get_w(mha_prefix + "out/kernel").flatten(0, 1).T) + layer.mha[1].out_proj.bias.copy_(get_w(mha_prefix + "out/bias")) + + layer.mlp[0].weight.copy_(get_w(prefix + "LayerNorm_2/scale")) + layer.mlp[0].bias.copy_(get_w(prefix + "LayerNorm_2/bias")) + layer.mlp[1].linear1.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_0/kernel").T) + layer.mlp[1].linear1.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_0/bias")) + layer.mlp[1].linear2.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_1/kernel").T) + layer.mlp[1].linear2.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_1/bias")) self.norm.weight.copy_(get_w("Transformer/encoder_norm/scale")) self.norm.bias.copy_(get_w("Transformer/encoder_norm/bias")) diff --git a/vision_toolbox/components.py b/vision_toolbox/components.py index d813166..455cacd 100644 --- a/vision_toolbox/components.py +++ b/vision_toolbox/components.py @@ -178,3 +178,15 @@ def forward(self, x: Tensor) -> Tensor: def extra_repr(self) -> str: return f"p={self.p}" + + +class LayerScale(nn.Module): + def __init__(self, dim: int, init: float) -> None: + super().__init__() + self.gamma = nn.Parameter(torch.full((dim,), init)) + + def forward(self, x: Tensor) -> Tensor: + return x * self.gamma + + def extra_repr(self) -> str: + return f"gamma={self.gamma}"