forked from ai-dawang/PlugNPlay-Modules
-
Notifications
You must be signed in to change notification settings - Fork 0
/
(tmm2023)多尺度膨胀注意力机制.py
75 lines (65 loc) · 3.37 KB
/
(tmm2023)多尺度膨胀注意力机制.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import torch.nn as nn
# Github地址:https://github.com/JIAOJIAYUASD/dilateformer
# 论文地址:https://arxiv.org/abs/2302.01791
class DilateAttention(nn.Module):
"Implementation of Dilate-attention"
def __init__(self, head_dim, qk_scale=None, attn_drop=0, kernel_size=3, dilation=1):
super().__init__()
self.head_dim = head_dim
self.scale = qk_scale or head_dim ** -0.5
self.kernel_size = kernel_size
self.unfold = nn.Unfold(kernel_size, dilation, dilation * (kernel_size - 1) // 2, 1)
self.attn_drop = nn.Dropout(attn_drop)
def forward(self, q, k, v):
# B, C//3, H, W
B, d, H, W = q.shape
q = q.reshape([B, d // self.head_dim, self.head_dim, 1, H * W]).permute(0, 1, 4, 3, 2) # B,h,N,1,d
k = self.unfold(k).reshape(
[B, d // self.head_dim, self.head_dim, self.kernel_size * self.kernel_size, H * W]).permute(0, 1, 4, 2,
3) # B,h,N,d,k*k
attn = (q @ k) * self.scale # B,h,N,1,k*k
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
v = self.unfold(v).reshape(
[B, d // self.head_dim, self.head_dim, self.kernel_size * self.kernel_size, H * W]).permute(0, 1, 4, 3,
2) # B,h,N,k*k,d
x = (attn @ v).transpose(1, 2).reshape(B, H, W, d)
return x
class MultiDilatelocalAttention(nn.Module):
"Implementation of Dilate-attention"
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None,
attn_drop=0., proj_drop=0., kernel_size=3, dilation=[2, 3]):
super().__init__()
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.dilation = dilation
self.kernel_size = kernel_size
self.scale = qk_scale or head_dim ** -0.5
self.num_dilation = len(dilation)
assert num_heads % self.num_dilation == 0, f"num_heads{num_heads} must be the times of num_dilation{self.num_dilation}!!"
self.qkv = nn.Conv2d(dim, dim * 3, 1, bias=qkv_bias)
self.dilate_attention = nn.ModuleList(
[DilateAttention(head_dim, qk_scale, attn_drop, kernel_size, dilation[i])
for i in range(self.num_dilation)])
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, H, W, C = x.shape
x = x.permute(0, 3, 1, 2) # B, C, H, W
qkv = self.qkv(x).reshape(B, 3, self.num_dilation, C // self.num_dilation, H, W).permute(2, 1, 0, 3, 4, 5)
# num_dilation,3,B,C//num_dilation,H,W
x = x.reshape(B, self.num_dilation, C // self.num_dilation, H, W).permute(1, 0, 3, 4, 2)
# num_dilation, B, H, W, C//num_dilation
for i in range(self.num_dilation):
x[i] = self.dilate_attention[i](qkv[i][0], qkv[i][1], qkv[i][2]) # B, H, W,C//num_dilation
x = x.permute(1, 2, 3, 0, 4).reshape(B, H, W, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
if __name__ == "__main__":
x = torch.rand([3, 64, 64, 64]).cuda() #输入B C H W
m = MultiDilatelocalAttention(64).cuda()
y = m(x)
print(y.shape)