From fa6460235c0c38aa3238abbba82f4ae552c12964 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 8 Aug 2023 23:24:14 +0800 Subject: [PATCH] add initial window attention --- vision_toolbox/backbones/swin_transformer.py | 33 ++++++++++++++++++++ vision_toolbox/backbones/vit.py | 6 ++-- 2 files changed, 37 insertions(+), 2 deletions(-) create mode 100644 vision_toolbox/backbones/swin_transformer.py diff --git a/vision_toolbox/backbones/swin_transformer.py b/vision_toolbox/backbones/swin_transformer.py new file mode 100644 index 0000000..a9adb43 --- /dev/null +++ b/vision_toolbox/backbones/swin_transformer.py @@ -0,0 +1,33 @@ +# https://arxiv.org/abs/2103.14030 +# https://github.com/microsoft/Swin-Transformer + +from __future__ import annotations + +from functools import partial +from typing import Mapping + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from ..utils import torch_hub_download +from .base import _act, _norm +from .vit import MHA + + +class WindowAttention(MHA): + def __init__(self, d_model: int, window_size: int, n_heads: int, bias: bool = True, dropout: float = 0.0) -> None: + super().__init__(d_model, n_heads, bias, dropout) + self.relative_pe_table = nn.Parameter(torch.empty(n_heads, (2 * window_size - 1) ** 2)) + nn.init.trunc_normal_(self.relative_pe_table, 0, 0.02) + + xy = torch.cartesian_prod(torch.arange(window_size), torch.arange(window_size)) # all possible (x,y) pairs + diff = xy.unsqueeze(1) - xy.unsqueeze(0) # difference between all (x,y) pairs + index = (diff[:, :, 0] + window_size - 1) * (2 * window_size - 1) + diff[:, :, 1] + window_size - 1 + self.register_buffer("relative_pe_index", index.flatten(), False) + self.relative_pe_index: Tensor + + def forward(self, x: Tensor) -> Tensor: + attn_bias = self.relative_pe_table[:, self.relative_pe_index].unsqueeze(0) + return super().forward(x, attn_bias) diff --git a/vision_toolbox/backbones/vit.py b/vision_toolbox/backbones/vit.py index becda7d..0929f66 100644 --- a/vision_toolbox/backbones/vit.py +++ b/vision_toolbox/backbones/vit.py @@ -25,14 +25,16 @@ def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float self.dropout = dropout self.scale = (d_model // n_heads) ** (-0.5) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: Tensor, attn_mask: Tensor | None = None) -> Tensor: qkv = self.in_proj(x) q, k, v = qkv.unflatten(-1, (3, self.n_heads, -1)).transpose(-2, -4).unbind(-3) if hasattr(F, "scaled_dot_product_attention"): - out = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout) + out = F.scaled_dot_product_attention(q, k, v, attn_mask, self.dropout) else: attn = torch.softmax(q @ (k * self.scale).transpose(-1, -2), -1) + if attn_mask is not None: + attn = attn + attn_mask out = F.dropout(attn, self.dropout, self.training) @ v out = out.transpose(-2, -3).flatten(-2)