Skip to content

Commit

Permalink
add S3 config
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 10, 2023
1 parent bfc91e5 commit 7948762
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions vision_toolbox/backbones/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,15 @@ def __init__(
d_model: int,
n_heads: int,
depths: tuple[int, ...],
window_sizes: tuple[int, ...],
patch_size: int = 4,
window_size: int = 7,
mlp_ratio: float = 4.0,
bias: bool = True,
dropout: float = 0.0,
norm: _norm = nn.LayerNorm,
act: _act = nn.GELU,
) -> None:
assert img_size % patch_size == 0
assert img_size % window_size == 0
assert d_model % n_heads == 0
super().__init__()
self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size)
Expand All @@ -152,7 +151,7 @@ def __init__(

input_size = img_size // patch_size
self.stages = nn.Sequential()
for i, depth in enumerate(depths):
for i, (depth, window_size) in enumerate(zip(depths, window_sizes)):
stage = nn.Sequential()
for i in range(depth):
blk = SwinBlock(input_size, d_model, n_heads, window_size, i % 2, mlp_ratio, bias, dropout, norm, act)
Expand All @@ -176,12 +175,17 @@ def forward(self, x: Tensor) -> Tensor:

@staticmethod
def from_config(variant: str, img_size: int, pretrained: bool = False) -> SwinTransformer:
d_model, n_heads, depths = dict(
T=(96, 3, (2, 2, 6, 2)),
S=(96, 3, (2, 2, 18, 2)),
B=(128, 4, (2, 2, 18, 2)),
L=(192, 6, (2, 2, 18, 2)),
)[variant]
m = SwinTransformer(img_size, d_model, n_heads, depths)
d_model, n_heads, depths, window_sizes = {
# Sub-section 3.3 in https://arxiv.org/pdf/2103.14030.pdf
"T": (96, 3, (2, 2, 6, 2), (7, 7, 7, 7)),
"S": (96, 3, (2, 2, 18, 2), (7, 7, 7, 7)),
"B": (128, 4, (2, 2, 18, 2), (7, 7, 7, 7)),
"L": (192, 6, (2, 2, 18, 2), (7, 7, 7, 7)),
# https://github.com/microsoft/Cream/blob/main/AutoFormerV2/configs
"S3-T": (96, 3, (2, 2, 6, 2), (7, 7, 14, 7)),
"S3-S": (96, 3, (2, 2, 18, 2), (14, 14, 14, 14)),
"S3-B": (96, 3, (2, 2, 30, 2), (7, 7, 14, 7)),
}[variant]
m = SwinTransformer(img_size, d_model, n_heads, depths, window_sizes)

return m

0 comments on commit 7948762

Please sign in to comment.