Skip to content

Commit

Permalink
make LayerScale a separate component. add LayerScale and StochasticDe…
Browse files Browse the repository at this point in the history
…pth to ViT
  • Loading branch information
gau-nernst committed Aug 20, 2023
1 parent b3bfb54 commit efe4c9e
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 35 deletions.
23 changes: 10 additions & 13 deletions vision_toolbox/backbones/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from torch import Tensor, nn

from ..components import Permute, StochasticDepth
from ..components import LayerScale, Permute, StochasticDepth
from .base import BaseBackbone, _act, _norm


Expand All @@ -34,12 +34,14 @@ def __init__(
d_model: int,
expansion_ratio: float = 4.0,
bias: bool = True,
layer_scale_init: float = 1e-6,
layer_scale_init: float | None = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
v2: bool = False,
) -> None:
if v2:
layer_scale_init = None
super().__init__()
hidden_dim = int(d_model * expansion_ratio)
self.layers = nn.Sequential(
Expand All @@ -51,17 +53,12 @@ def __init__(
act(),
GlobalResponseNorm(hidden_dim) if v2 else nn.Identity(),
nn.Linear(hidden_dim, d_model, bias=bias),
LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(),
StochasticDepth(stochastic_depth),
)
self.layer_scale = (
nn.Parameter(torch.full((d_model,), layer_scale_init)) if layer_scale_init > 0 and not v2 else None
)
self.drop = StochasticDepth(stochastic_depth)

def forward(self, x: Tensor) -> Tensor:
out = self.layers(x)
if self.layer_scale is not None:
out = out * self.layer_scale
return x + self.drop(out)
return x + self.layers(x)


class ConvNeXt(BaseBackbone):
Expand All @@ -71,7 +68,7 @@ def __init__(
depths: tuple[int, ...],
expansion_ratio: float = 4.0,
bias: bool = True,
layer_scale_init: float = 1e-6,
layer_scale_init: float | None = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
Expand Down Expand Up @@ -187,8 +184,8 @@ def copy_(m: nn.Conv2d | nn.Linear | nn.LayerNorm, prefix: str):
block.layers[6].beta.copy_(state_dict.pop(prefix + "grn.beta").squeeze())

copy_(block.layers[7], prefix + "pwconv2")
if block.layer_scale is not None:
block.layer_scale.copy_(state_dict.pop(prefix + "gamma"))
if isinstance(block.layers[8], LayerScale):
block.layers[8].gamma.copy_(state_dict.pop(prefix + "gamma"))

# FCMAE checkpoints don't contain head norm
if "norm.weight" in state_dict:
Expand Down
59 changes: 37 additions & 22 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.nn.functional as F
from torch import Tensor, nn

from ..components import LayerScale, StochasticDepth
from ..utils import torch_hub_download
from .base import _act, _norm

Expand Down Expand Up @@ -58,18 +59,28 @@ def __init__(
bias: bool = True,
mlp_ratio: float = 4.0,
dropout: float = 0.0,
layer_scale_init: float | None = None,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
super().__init__()
self.norm1 = norm(d_model)
self.mha = MHA(d_model, n_heads, bias, dropout)
self.norm2 = norm(d_model)
self.mlp = MLP(d_model, int(d_model * mlp_ratio), dropout, act)
self.mha = nn.Sequential(
norm(d_model),
MHA(d_model, n_heads, bias, dropout),
LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(),
StochasticDepth(stochastic_depth),
)
self.mlp = nn.Sequential(
norm(d_model),
MLP(d_model, int(d_model * mlp_ratio), dropout, act),
LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(),
StochasticDepth(stochastic_depth),
)

def forward(self, x: Tensor) -> Tensor:
x = x + self.mha(self.norm1(x))
x = x + self.mlp(self.norm2(x))
x = x + self.mha(x)
x = x + self.mlp(x)
return x


Expand All @@ -85,6 +96,8 @@ def __init__(
bias: bool = True,
mlp_ratio: float = 4.0,
dropout: float = 0.0,
layer_scale_init: float | None = None,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
Expand All @@ -99,9 +112,11 @@ def __init__(
self.pe = nn.Parameter(torch.empty(1, pe_size, d_model))
nn.init.normal_(self.pe, 0, 0.02)

self.layers = nn.Sequential(
*[ViTBlock(d_model, n_heads, bias, mlp_ratio, dropout, norm, act) for _ in range(n_layers)]
)
self.layers = nn.Sequential()
for _ in range(n_layers):
block = ViTBlock(d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm, act)
self.layers.append(block)

self.norm = norm(d_model)

def forward(self, imgs: Tensor) -> Tensor:
Expand Down Expand Up @@ -173,21 +188,21 @@ def get_w(key: str) -> Tensor:
prefix = f"Transformer/encoderblock_{idx}/"
mha_prefix = prefix + "MultiHeadDotProductAttention_1/"

layer.norm1.weight.copy_(get_w(prefix + "LayerNorm_0/scale"))
layer.norm1.bias.copy_(get_w(prefix + "LayerNorm_0/bias"))
layer.mha[0].weight.copy_(get_w(prefix + "LayerNorm_0/scale"))
layer.mha[0].bias.copy_(get_w(prefix + "LayerNorm_0/bias"))
w = torch.stack([get_w(mha_prefix + x + "/kernel") for x in ["query", "key", "value"]], 1)
b = torch.stack([get_w(mha_prefix + x + "/bias") for x in ["query", "key", "value"]], 0)
layer.mha.in_proj.weight.copy_(w.flatten(1).T)
layer.mha.in_proj.bias.copy_(b.flatten())
layer.mha.out_proj.weight.copy_(get_w(mha_prefix + "out/kernel").flatten(0, 1).T)
layer.mha.out_proj.bias.copy_(get_w(mha_prefix + "out/bias"))

layer.norm2.weight.copy_(get_w(prefix + "LayerNorm_2/scale"))
layer.norm2.bias.copy_(get_w(prefix + "LayerNorm_2/bias"))
layer.mlp.linear1.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_0/kernel").T)
layer.mlp.linear1.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_0/bias"))
layer.mlp.linear2.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_1/kernel").T)
layer.mlp.linear2.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_1/bias"))
layer.mha[1].in_proj.weight.copy_(w.flatten(1).T)
layer.mha[1].in_proj.bias.copy_(b.flatten())
layer.mha[1].out_proj.weight.copy_(get_w(mha_prefix + "out/kernel").flatten(0, 1).T)
layer.mha[1].out_proj.bias.copy_(get_w(mha_prefix + "out/bias"))

layer.mlp[0].weight.copy_(get_w(prefix + "LayerNorm_2/scale"))
layer.mlp[0].bias.copy_(get_w(prefix + "LayerNorm_2/bias"))
layer.mlp[1].linear1.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_0/kernel").T)
layer.mlp[1].linear1.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_0/bias"))
layer.mlp[1].linear2.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_1/kernel").T)
layer.mlp[1].linear2.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_1/bias"))

self.norm.weight.copy_(get_w("Transformer/encoder_norm/scale"))
self.norm.bias.copy_(get_w("Transformer/encoder_norm/bias"))
Expand Down
12 changes: 12 additions & 0 deletions vision_toolbox/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,15 @@ def forward(self, x: Tensor) -> Tensor:

def extra_repr(self) -> str:
return f"p={self.p}"


class LayerScale(nn.Module):
def __init__(self, dim: int, init: float) -> None:
super().__init__()
self.gamma = nn.Parameter(torch.full((dim,), init))

def forward(self, x: Tensor) -> Tensor:
return x * self.gamma

def extra_repr(self) -> str:
return f"gamma={self.gamma}"

0 comments on commit efe4c9e

Please sign in to comment.