Skip to content

Commit

Permalink
add initial window attention
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 8, 2023
1 parent d1acdc3 commit fa64602
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
33 changes: 33 additions & 0 deletions vision_toolbox/backbones/swin_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# https://arxiv.org/abs/2103.14030
# https://github.com/microsoft/Swin-Transformer

from __future__ import annotations

from functools import partial
from typing import Mapping

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn

from ..utils import torch_hub_download
from .base import _act, _norm
from .vit import MHA


class WindowAttention(MHA):
def __init__(self, d_model: int, window_size: int, n_heads: int, bias: bool = True, dropout: float = 0.0) -> None:
super().__init__(d_model, n_heads, bias, dropout)
self.relative_pe_table = nn.Parameter(torch.empty(n_heads, (2 * window_size - 1) ** 2))
nn.init.trunc_normal_(self.relative_pe_table, 0, 0.02)

xy = torch.cartesian_prod(torch.arange(window_size), torch.arange(window_size)) # all possible (x,y) pairs
diff = xy.unsqueeze(1) - xy.unsqueeze(0) # difference between all (x,y) pairs
index = (diff[:, :, 0] + window_size - 1) * (2 * window_size - 1) + diff[:, :, 1] + window_size - 1
self.register_buffer("relative_pe_index", index.flatten(), False)
self.relative_pe_index: Tensor

def forward(self, x: Tensor) -> Tensor:
attn_bias = self.relative_pe_table[:, self.relative_pe_index].unsqueeze(0)
return super().forward(x, attn_bias)
6 changes: 4 additions & 2 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float
self.dropout = dropout
self.scale = (d_model // n_heads) ** (-0.5)

def forward(self, x: Tensor) -> Tensor:
def forward(self, x: Tensor, attn_mask: Tensor | None = None) -> Tensor:
qkv = self.in_proj(x)
q, k, v = qkv.unflatten(-1, (3, self.n_heads, -1)).transpose(-2, -4).unbind(-3)

if hasattr(F, "scaled_dot_product_attention"):
out = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout)
out = F.scaled_dot_product_attention(q, k, v, attn_mask, self.dropout)
else:
attn = torch.softmax(q @ (k * self.scale).transpose(-1, -2), -1)
if attn_mask is not None:
attn = attn + attn_mask
out = F.dropout(attn, self.dropout, self.training) @ v

out = out.transpose(-2, -3).flatten(-2)
Expand Down

0 comments on commit fa64602

Please sign in to comment.