-
Notifications
You must be signed in to change notification settings - Fork 6
/
seecoder_utils.py
108 lines (95 loc) · 3.91 KB
/
seecoder_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
import math
import copy
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
return (n & (n-1) == 0) and n != 0
def c2_xavier_fill(module):
# Caffe2 implementation of XavierFill in fact
nn.init.kaiming_uniform_(module.weight, a=1)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
def with_pos_embed(x, pos):
return x if pos is None else x + pos
class PositionEmbeddingSine(nn.Module):
def __init__(self, num_pos_feats=64, temperature=256, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, x, mask=None):
if mask is None:
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
not_mask = ~mask
h, w = not_mask.shape[-2:]
minlen = min(h, w)
h_embed = not_mask.cumsum(1, dtype=torch.float32)
w_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
h_embed = (h_embed - h/2) / (minlen + eps) * self.scale
w_embed = (w_embed - w/2) / (minlen + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_w = w_embed[:, :, :, None] / dim_t
pos_h = h_embed[:, :, :, None] / dim_t
pos_w = torch.stack(
(pos_w[:, :, :, 0::2].sin(), pos_w[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_h = torch.stack(
(pos_h[:, :, :, 0::2].sin(), pos_h[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_h, pos_w), dim=3).permute(0, 3, 1, 2)
return pos
def __repr__(self, _repr_indent=4):
head = "Positional encoding " + self.__class__.__name__
body = [
"num_pos_feats: {}".format(self.num_pos_feats),
"temperature: {}".format(self.temperature),
"normalize: {}".format(self.normalize),
"scale: {}".format(self.scale),
]
# _repr_indent = 4
lines = [head] + [" " * _repr_indent + line for line in body]
return "\n".join(lines)
class Conv2d_Convenience(nn.Conv2d):
def __init__(self, *args, **kwargs):
norm = kwargs.pop("norm", None)
activation = kwargs.pop("activation", None)
super().__init__(*args, **kwargs)
self.norm = norm
self.activation = activation
def forward(self, x):
if not torch.jit.is_scripting():
if x.numel() == 0 and self.training:
assert not isinstance(
self.norm, torch.nn.SyncBatchNorm
), "SyncBatchNorm does not support empty inputs!"
x = F.conv2d(
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
)
if self.norm is not None:
x = self.norm(x)
if self.activation is not None:
x = self.activation(x)
return x