Skip to content

Commit

Permalink
add window partition and unpartition
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 9, 2023
1 parent a41efbe commit 443f3c1
Showing 1 changed file with 71 additions and 3 deletions.
74 changes: 71 additions & 3 deletions vision_toolbox/backbones/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,33 @@

from ..utils import torch_hub_download
from .base import _act, _norm
from .vit import MHA
from .vit import MHA, MLP


def window_partition(x: Tensor, window_size: int) -> tuple[Tensor, int, int]:
B, H, W, C = x.shape
nH, nW = H // window_size, W // window_size
x = x.view(B, nH, window_size, nW, window_size, C)
x = x.transpose(2, 3).reshape(B * nH * nW, window_size * window_size, C)
return x, nH, nW


def window_unpartition(x: Tensor, window_size: int, nH: int, nW: int) -> Tensor:
B = x.shape[0] // (nH * nW)
C = x.shape[2]
x = x.view(B, nH, nW, window_size, window_size, C)
x = x.transpose(2, 3).reshape(B, nH * window_size, nW * window_size, C)
return x


class WindowAttention(MHA):
def __init__(self, d_model: int, window_size: int, n_heads: int, bias: bool = True, dropout: float = 0.0) -> None:
def __init__(
self, d_model: int, window_size: int, shift: int, n_heads: int, bias: bool = True, dropout: float = 0.0
) -> None:
super().__init__(d_model, n_heads, bias, dropout)
self.window_size = window_size
self.shift = shift

self.relative_pe_table = nn.Parameter(torch.empty(n_heads, (2 * window_size - 1) ** 2))
nn.init.trunc_normal_(self.relative_pe_table, 0, 0.02)

Expand All @@ -29,5 +50,52 @@ def __init__(self, d_model: int, window_size: int, n_heads: int, bias: bool = Tr
self.relative_pe_index: Tensor

def forward(self, x: Tensor) -> Tensor:
if self.shift > 0:
x = x.roll((self.shift, self.shift), (1, 2))
x, nH, nW = window_partition(x, self.window_size) # (B * nH * nW, win_size * win_size, C)

attn_bias = self.relative_pe_table[:, self.relative_pe_index].unsqueeze(0)
return super().forward(x, attn_bias)
x = super().forward(x, attn_bias)

x = window_unpartition(x, self.window_size, nH, nW)
return x


class SwinBlock(nn.Module):
def __init__(
self,
d_model: int,
n_heads: int,
window_size: int,
shift: int,
mlp_ratio: float = 4.0,
bias: bool = True,
dropout: float = 0.0,
norm: _norm = nn.LayerNorm,
act: _act = nn.GELU,
) -> None:
super().__init__()
self.norm1 = norm(d_model)
self.mha = WindowAttention(d_model, window_size, shift, n_heads, bias, dropout)
self.norm2 = norm(d_model)
self.mlp = MLP(d_model, int(d_model * mlp_ratio), act)

def forward(self, x: Tensor) -> Tensor:
x = self.mha(self.norm1(x))
x = self.mlp(self.norm2(x))
return x


class SwinTransformer(nn.Module):
def __init__(self, d_model: int, n_layers: tuple[int, int, int, int]) -> None:
super().__init__()

@staticmethod
def from_config(variant: str, pretrained: bool = False) -> SwinTransformer:
d_model, n_layers = dict(
T=(96, (2, 2, 6, 2)),
S=(96, (2, 2, 18, 2)),
B=(128, (2, 2, 18, 2)),
L=(192, (2, 2, 18, 2)),
)[variant]
m = SwinTransformer(d_model, n_layers)

0 comments on commit 443f3c1

Please sign in to comment.