Skip to content

Commit

Permalink
update patchconvnet
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Jul 24, 2023
1 parent ccbd0aa commit 7ce7b4e
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 41 deletions.
2 changes: 2 additions & 0 deletions tests/test_backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
DarknetYOLOv5,
EfficientNetExtractor,
MobileNetExtractor,
PatchConvNet,
RegNetExtractor,
ResNetExtractor,
VoVNet,
Expand All @@ -27,6 +28,7 @@ def inputs():
partial(VoVNet.from_config, x, y, z)
for x, y, z in ((27, True, False), (39, False, False), (19, True, True), (57, False, True))
],
# partial(PatchConvNet.from_config, "S", 60),
partial(ResNetExtractor, "resnet18"),
partial(RegNetExtractor, "regnet_x_400mf"),
partial(MobileNetExtractor, "mobilenet_v2"),
Expand Down
2 changes: 1 addition & 1 deletion vision_toolbox/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .darknet import Darknet, DarknetYOLOv5
from .patchconvnet import *
from .patchconvnet import PatchConvNet
from .torchvision_models import EfficientNetExtractor, MobileNetExtractor, RegNetExtractor, ResNetExtractor
from .vit import ViT
from .vovnet import VoVNet
84 changes: 44 additions & 40 deletions vision_toolbox/backbones/patchconvnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,41 +22,43 @@ def forward(self, x: Tensor) -> Tensor:
return torch.permute(x, self.dims)


class PatchConvBlock(nn.Module):
def __init__(
self, embed_dim: int, drop_path: float = 0.3, layer_scale_init: float = 1e-6, norm_type: str = "bn"
) -> None:
assert norm_type in ("bn", "ln")
class PatchConvBlockLN(nn.Module):
def __init__(self, embed_dim: int, drop_path: float = 0.3, layer_scale_init: float = 1e-6) -> None:
super().__init__()
if norm_type == "ln":
# LayerNorm version. Primary format is (N, H, W, C)
# follow this approach https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py
self.layers = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, embed_dim),
nn.GELU(),
Permute(0, 3, 1, 2), # (N, H, W, C) -> (N, C, H, W)
nn.Conv2d(embed_dim, embed_dim, 3, padding=1, groups=embed_dim), # dw-conv
nn.GELU(),
SqueezeExcitation(embed_dim, embed_dim // 4),
Permute(0, 2, 3, 1), # (N, C, H, W) -> (N, H, W, C)
nn.Linear(embed_dim, embed_dim),
)
self.layer_scale = nn.Parameter(torch.full(embed_dim, layer_scale_init))

else:
# BatchNorm version. Primary format is (N, C, H, W)
self.layers = nn.Sequential(
nn.BatchNorm2d(embed_dim),
nn.Conv2d(embed_dim, embed_dim, 1),
nn.GELU(),
nn.Conv2d(embed_dim, embed_dim, 3, padding=1, groups=embed_dim),
nn.GELU(),
SqueezeExcitation(embed_dim, embed_dim // 4),
nn.Conv2d(embed_dim, embed_dim, 1),
)
self.layer_scale = nn.Parameter(torch.full((embed_dim, 1, 1), layer_scale_init))
# LayerNorm version. Primary format is (N, H, W, C)
# follow this approach https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py
self.layers = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, embed_dim),
nn.GELU(),
Permute(0, 3, 1, 2), # (N, H, W, C) -> (N, C, H, W)
nn.Conv2d(embed_dim, embed_dim, 3, padding=1, groups=embed_dim), # dw-conv
nn.GELU(),
SqueezeExcitation(embed_dim, embed_dim // 4),
Permute(0, 2, 3, 1), # (N, C, H, W) -> (N, H, W, C)
nn.Linear(embed_dim, embed_dim),
)
self.layer_scale = nn.Parameter(torch.full((embed_dim,), layer_scale_init))
self.drop_path = StochasticDepth(drop_path, "row") if drop_path > 0 else nn.Identity()

def forward(self, x: Tensor) -> Tensor:
return x + self.drop_path(self.layers(x) * self.layer_scale)


class PatchConvBlockBN(nn.Module):
def __init__(self, embed_dim: int, drop_path: float = 0.3, layer_scale_init: float = 1e-6) -> None:
super().__init__()
# BatchNorm version. Primary format is (N, C, H, W)
self.layers = nn.Sequential(
nn.BatchNorm2d(embed_dim),
nn.Conv2d(embed_dim, embed_dim, 1),
nn.GELU(),
nn.Conv2d(embed_dim, embed_dim, 3, padding=1, groups=embed_dim),
nn.GELU(),
SqueezeExcitation(embed_dim, embed_dim // 4),
nn.Conv2d(embed_dim, embed_dim, 1),
)
self.layer_scale = nn.Parameter(torch.full((embed_dim, 1, 1), layer_scale_init))
self.drop_path = StochasticDepth(drop_path, "row") if drop_path > 0 else nn.Identity()

def forward(self, x: Tensor) -> Tensor:
Expand All @@ -72,13 +74,13 @@ def __init__(

self.norm_1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, 1, batch_first=True)
self.layer_scale_1 = nn.Parameter(torch.ones(embed_dim) * layer_scale_init)
self.layer_scale_1 = nn.Parameter(torch.full((embed_dim,), layer_scale_init))
self.drop_path1 = StochasticDepth(drop_path, "row") if drop_path > 0 else nn.Identity()

self.norm_2 = nn.LayerNorm(embed_dim)
mlp_dim = int(embed_dim * mlp_ratio)
self.mlp = nn.Sequential(nn.Linear(embed_dim, mlp_dim), nn.GELU(), nn.Linear(mlp_dim, embed_dim))
self.layer_scale_2 = nn.Parameter(torch.ones(embed_dim) * layer_scale_init)
self.layer_scale_2 = nn.Parameter(torch.full((embed_dim,), layer_scale_init))
self.drop_path2 = StochasticDepth(drop_path, "row") if drop_path > 0 else nn.Identity()

self.norm_3 = nn.LayerNorm(embed_dim)
Expand Down Expand Up @@ -114,7 +116,8 @@ def __init__(
assert norm_type in ("bn", "ln")
super().__init__()
self.norm_type = norm_type
self.out_channels = (embed_dim,)
self.out_channels_list = (embed_dim,)
self.stride = 16

# stem has no bias and no last activation layer
# https://github.com/facebookresearch/deit/issues/151
Expand All @@ -129,9 +132,10 @@ def __init__(
conv3x3_s2(embed_dim // 2, embed_dim),
)

blk = PatchConvBlockLN if norm_type == "ln" else PatchConvBlockBN
self.trunk = nn.Sequential(
Permute(0, 2, 3, 1) if norm_type == "ln" else nn.Identity(),
*[PatchConvBlock(embed_dim, drop_path, layer_scale_init, norm_type) for _ in range(depth)],
*[blk(embed_dim, drop_path, layer_scale_init) for _ in range(depth)],
Permute(0, 2, 3, 1) if norm_type == "bn" else nn.Identity(),
)
self.pool = AttentionPooling(embed_dim, mlp_ratio, drop_path, layer_scale_init)
Expand All @@ -146,12 +150,12 @@ def __init__(
if m.bias is not None:
nn.init.zeros_(m.bias)

def forward(self, x: Tensor):
def get_feature_maps(self, x: Tensor) -> list[Tensor]:
out = self.stem(x)
out = self.trunk(x)
out = self.trunk(out)
out = out.flatten(1, 2)
out = self.pool(out)
return out
return [out]

@staticmethod
def from_config(variant: str, depth: int, pretrained: bool = False) -> PatchConvNet:
Expand Down

0 comments on commit 7ce7b4e

Please sign in to comment.