Skip to content

Commit

Permalink
simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Jul 24, 2023
1 parent 8bcf593 commit ccbd0aa
Showing 1 changed file with 54 additions and 92 deletions.
146 changes: 54 additions & 92 deletions vision_toolbox/backbones/patchconvnet.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,31 @@
# https://arxiv.org/abs/2112.13692
# https://github.com/facebookresearch/deit/blob/main/patchconvnet_models.py
import warnings

import torch
from torch import Tensor, nn
from __future__ import annotations

from functools import partial

try:
from torchvision.ops import StochasticDepth
from torchvision.ops.misc import SqueezeExcitation
except ImportError:
warnings.warn("torchvision.ops.misc.SqueezeExcitation is not available. Please update your torchvision")
SqueezeExcitation = None
import torch
from torch import Tensor, nn
from torchvision.ops import StochasticDepth
from torchvision.ops.misc import SqueezeExcitation

from .base import BaseBackbone


__all__ = [
"AttentionPooling",
"PatchConvNet",
"patchconvnet_s60",
"patchconvnet_s120",
"patchconvnet_b60",
"patchconvnet_b120",
"patchconvnet_l60",
"patchconvnet_l120",
]


_base = {"mlp_ratio": 3, "drop_path": 0.3, "layer_scale_init": 1e-6}
_S_embed_dim = 384
_B_embed_dim = 768
_L_embed_dim = 1024
configs = {
"PatchConvNet-S60": dict(**_base, embed_dim=_S_embed_dim, depth=60),
"PatchConvNet-S120": dict(**_base, embed_dim=_S_embed_dim, depth=120),
"PatchConvNet-B60": dict(**_base, embed_dim=_B_embed_dim, depth=60),
"PatchConvNet-B120": dict(**_base, embed_dim=_B_embed_dim, depth=120),
"PatchConvNet-L60": dict(**_base, embed_dim=_L_embed_dim, depth=60),
"PatchConvNet-L120": dict(**_base, embed_dim=_L_embed_dim, depth=120),
}


class Permute(nn.Module):
def __init__(self, *dims):
super().__init__()
self.dims = dims

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
return torch.permute(x, self.dims)


class PatchConvBlock(nn.Module):
def __init__(self, embed_dim, drop_path=0.3, layer_scale_init=1e-6, norm_type="bn"):
def __init__(
self, embed_dim: int, drop_path: float = 0.3, layer_scale_init: float = 1e-6, norm_type: str = "bn"
) -> None:
assert norm_type in ("bn", "ln")
super().__init__()
if norm_type == "ln":
Expand All @@ -69,7 +42,7 @@ def __init__(self, embed_dim, drop_path=0.3, layer_scale_init=1e-6, norm_type="b
Permute(0, 2, 3, 1), # (N, C, H, W) -> (N, H, W, C)
nn.Linear(embed_dim, embed_dim),
)
self.layer_scale = nn.Parameter(torch.ones(embed_dim) * layer_scale_init)
self.layer_scale = nn.Parameter(torch.full(embed_dim, layer_scale_init))

else:
# BatchNorm version. Primary format is (N, C, H, W)
Expand All @@ -82,72 +55,86 @@ def __init__(self, embed_dim, drop_path=0.3, layer_scale_init=1e-6, norm_type="b
SqueezeExcitation(embed_dim, embed_dim // 4),
nn.Conv2d(embed_dim, embed_dim, 1),
)
self.layer_scale = nn.Parameter(torch.ones(embed_dim, 1, 1) * layer_scale_init)
self.layer_scale = nn.Parameter(torch.full((embed_dim, 1, 1), layer_scale_init))

self.drop_path = StochasticDepth(drop_path, "row") if drop_path > 0 else nn.Identity()

def forward(self, x: Tensor):
def forward(self, x: Tensor) -> Tensor:
return x + self.drop_path(self.layers(x) * self.layer_scale)


class AttentionPooling(nn.Module):
def __init__(self, embed_dim, mlp_ratio, drop_path=0.3, layer_scale_init=1e-6):
def __init__(
self, embed_dim: int, mlp_ratio: int = 3, drop_path: float = 0.3, layer_scale_init: float = 1e-6
) -> Tensor:
super().__init__()
self.drop_path = StochasticDepth(drop_path, "row") if drop_path > 0 else nn.Identity()
self.cls_token = nn.Parameter(torch.zeros(embed_dim))
self.norm_1 = nn.LayerNorm(embed_dim)

self.norm_1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, 1, batch_first=True)
self.layer_scale_1 = nn.Parameter(torch.ones(embed_dim) * layer_scale_init)
self.norm_2 = nn.LayerNorm(embed_dim)
self.drop_path1 = StochasticDepth(drop_path, "row") if drop_path > 0 else nn.Identity()

self.norm_2 = nn.LayerNorm(embed_dim)
mlp_dim = int(embed_dim * mlp_ratio)
self.mlp = nn.Sequential(nn.Linear(embed_dim, mlp_dim), nn.GELU(), nn.Linear(mlp_dim, embed_dim))
self.layer_scale_2 = nn.Parameter(torch.ones(embed_dim) * layer_scale_init)
self.drop_path2 = StochasticDepth(drop_path, "row") if drop_path > 0 else nn.Identity()

self.norm_3 = nn.LayerNorm(embed_dim)

def forward(self, x: Tensor):
def forward(self, x: Tensor) -> Tensor:
# (N, HW, C)
cls_token = self.cls_token.expand(x.shape[0], 1, -1)
out = torch.cat((cls_token, x), dim=1)

# attention pooling. q = cls_token. k = v = (cls_token, x)
out = self.norm_1(out)
out = self.attn(out[:, :1], out, out, need_weights=False)[0]
cls_token = cls_token + self.drop_path(out * self.layer_scale_1) # residual + layer scale + dropout
cls_token = cls_token + self.drop_path1(out * self.layer_scale_1) # residual + layer scale + dropout

# mlp
out = self.norm_2(cls_token)
out = self.mlp(out)
cls_token = cls_token + self.drop_path(out * self.layer_scale_2)
out = self.mlp(self.norm_2(cls_token))
cls_token = cls_token + self.drop_path2(out * self.layer_scale_2)

out = self.norm_3(cls_token).squeeze(1) # (N, 1, C) -> (N, C)
return out


class PatchConvNet(BaseBackbone):
def __init__(self, embed_dim, depth, mlp_ratio, drop_path, layer_scale_init, norm_type="bn"):
def __init__(
self,
embed_dim: int,
depth: int,
mlp_ratio: int = 3,
drop_path: float = 0.3,
layer_scale_init: float = 1e-6,
norm_type: str = "bn",
) -> None:
assert norm_type in ("bn", "ln")
super().__init__()
self.norm_type = norm_type
self.out_channels = (embed_dim,)

# stem has no bias and no last activation layer
# https://github.com/facebookresearch/deit/issues/151
kwargs = dict(kernel_size=3, stride=2, padding=1, bias=False)
conv3x3_s2 = partial(nn.Conv2d, kernel_size=3, stride=2, padding=1, bias=False)
self.stem = nn.Sequential(
nn.Conv2d(3, embed_dim // 8, **kwargs),
conv3x3_s2(3, embed_dim // 8),
nn.GELU(),
nn.Conv2d(embed_dim // 8, embed_dim // 4, **kwargs),
conv3x3_s2(embed_dim // 8, embed_dim // 4),
nn.GELU(),
nn.Conv2d(embed_dim // 4, embed_dim // 2, **kwargs),
conv3x3_s2(embed_dim // 4, embed_dim // 2),
nn.GELU(),
nn.Conv2d(embed_dim // 2, embed_dim, **kwargs),
conv3x3_s2(embed_dim // 2, embed_dim),
)

kwargs = dict(drop_path=drop_path, layer_scale_init=layer_scale_init)
self.trunk = nn.Sequential(*[PatchConvBlock(embed_dim, norm_type=norm_type, **kwargs) for _ in range(depth)])
self.pool = AttentionPooling(embed_dim, mlp_ratio, **kwargs)
self.trunk = nn.Sequential(
Permute(0, 2, 3, 1) if norm_type == "ln" else nn.Identity(),
*[PatchConvBlock(embed_dim, drop_path, layer_scale_init, norm_type) for _ in range(depth)],
Permute(0, 2, 3, 1) if norm_type == "bn" else nn.Identity(),
)
self.pool = AttentionPooling(embed_dim, mlp_ratio, drop_path, layer_scale_init)

# weight initialization
nn.init.trunc_normal_(self.pool.cls_token, std=0.02)
Expand All @@ -161,40 +148,15 @@ def __init__(self, embed_dim, depth, mlp_ratio, drop_path, layer_scale_init, nor

def forward(self, x: Tensor):
out = self.stem(x)

if self.norm_type == "ln":
# layer norm
out = torch.permute(out, (0, 2, 3, 1)) # (N, C, H, W) -> (N, H, W, C)
out = self.trunk(out)
out = torch.flatten(out, 1, 2) # (N, H, W, C) -> (N, HW, C)
else:
# batch norm
out = self.trunk(out)
out = out.flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, HW, C)

out = self.trunk(x)
out = out.flatten(1, 2)
out = self.pool(out)
return out


def patchconvnet_s60(pretrained=False, **kwargs):
return PatchConvNet.from_config(configs["PatchConvNet-S60"], pretrained=pretrained, **kwargs)


def patchconvnet_s120(pretrained=False, **kwargs):
return PatchConvNet.from_config(configs["PatchConvNet-S120"], pretrained=pretrained, **kwargs)


def patchconvnet_b60(pretrained=False, **kwargs):
return PatchConvNet.from_config(configs["PatchConvNet-B60"], pretrained=pretrained, **kwargs)


def patchconvnet_b120(pretrained=False, **kwargs):
return PatchConvNet.from_config(configs["PatchConvNet-B120"], pretrained=pretrained, **kwargs)


def patchconvnet_l60(pretrained=False, **kwargs):
return PatchConvNet.from_config(configs["PatchConvNet-L60"], pretrained=pretrained, **kwargs)


def patchconvnet_l120(pretrained=False, **kwargs):
return PatchConvNet.from_config(configs["PatchConvNet-L120"], pretrained=pretrained, **kwargs)
@staticmethod
def from_config(variant: str, depth: int, pretrained: bool = False) -> PatchConvNet:
embed_dim = dict(S=384, B=768, L=1024)[variant]
m = PatchConvNet(embed_dim, depth)
if pretrained:
raise ValueError
return m

0 comments on commit ccbd0aa

Please sign in to comment.