-
Notifications
You must be signed in to change notification settings - Fork 30
/
model.py
306 lines (269 loc) · 12.6 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, ReLU, Sigmoid, Dropout2d, Dropout, AvgPool2d, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module, Parameter
import torch.nn.functional as F
import torch
from collections import namedtuple
import math
import pdb
################################## Original Arcface Model #############################################################
class Flatten(Module):
def forward(self, input):
return input.view(input.size(0), -1)
def l2_norm(input,axis=1):
norm = torch.norm(input,2,axis,True)
output = torch.div(input, norm)
return output
class SEModule(Module):
def __init__(self, channels, reduction):
super(SEModule, self).__init__()
self.avg_pool = AdaptiveAvgPool2d(1)
self.fc1 = Conv2d(
channels, channels // reduction, kernel_size=1, padding=0 ,bias=False)
self.relu = ReLU(inplace=True)
self.fc2 = Conv2d(
channels // reduction, channels, kernel_size=1, padding=0 ,bias=False)
self.sigmoid = Sigmoid()
def forward(self, x):
module_input = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x
class bottleneck_IR(Module):
def __init__(self, in_channel, depth, stride):
super(bottleneck_IR, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride ,bias=False), BatchNorm2d(depth))
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1 ,bias=False), PReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1 ,bias=False), BatchNorm2d(depth))
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
class bottleneck_IR_SE(Module):
def __init__(self, in_channel, depth, stride):
super(bottleneck_IR_SE, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride ,bias=False),
BatchNorm2d(depth))
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3,3), (1,1),1 ,bias=False),
PReLU(depth),
Conv2d(depth, depth, (3,3), stride, 1 ,bias=False),
BatchNorm2d(depth),
SEModule(depth,16)
)
def forward(self,x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
'''A named tuple describing a ResNet block.'''
def get_block(in_channel, depth, num_units, stride = 2):
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units-1)]
def get_blocks(num_layers):
if num_layers == 50:
blocks = [
get_block(in_channel=64, depth=64, num_units = 3),
get_block(in_channel=64, depth=128, num_units=4),
get_block(in_channel=128, depth=256, num_units=14),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 100:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=13),
get_block(in_channel=128, depth=256, num_units=30),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 152:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=8),
get_block(in_channel=128, depth=256, num_units=36),
get_block(in_channel=256, depth=512, num_units=3)
]
return blocks
class Backbone(Module):
def __init__(self, num_layers, drop_ratio, mode='ir'):
super(Backbone, self).__init__()
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
blocks = get_blocks(num_layers)
if mode == 'ir':
unit_module = bottleneck_IR
elif mode == 'ir_se':
unit_module = bottleneck_IR_SE
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1 ,bias=False),
BatchNorm2d(64),
PReLU(64))
self.output_layer = Sequential(BatchNorm2d(512),
Dropout(drop_ratio),
Flatten(),
Linear(512 * 7 * 7, 512),
BatchNorm1d(512))
modules = []
for block in blocks:
for bottleneck in block:
modules.append(
unit_module(bottleneck.in_channel,
bottleneck.depth,
bottleneck.stride))
self.body = Sequential(*modules)
def forward(self,x):
x = self.input_layer(x)
x = self.body(x)
x = self.output_layer(x)
return l2_norm(x)
################################## MobileFaceNet #############################################################
class Conv_block(Module):
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
super(Conv_block, self).__init__()
self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False)
self.bn = BatchNorm2d(out_c)
self.prelu = PReLU(out_c)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.prelu(x)
return x
class Linear_block(Module):
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
super(Linear_block, self).__init__()
self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False)
self.bn = BatchNorm2d(out_c)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class Depth_Wise(Module):
def __init__(self, in_c, out_c, residual = False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
super(Depth_Wise, self).__init__()
self.conv = Conv_block(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
self.conv_dw = Conv_block(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride)
self.project = Linear_block(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
self.residual = residual
def forward(self, x):
if self.residual:
short_cut = x
x = self.conv(x)
x = self.conv_dw(x)
x = self.project(x)
if self.residual:
output = short_cut + x
else:
output = x
return output
class Residual(Module):
def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
super(Residual, self).__init__()
modules = []
for _ in range(num_block):
modules.append(Depth_Wise(c, c, residual=True, kernel=kernel, padding=padding, stride=stride, groups=groups))
self.model = Sequential(*modules)
def forward(self, x):
return self.model(x)
class MobileFaceNet(Module):
def __init__(self, embedding_size):
super(MobileFaceNet, self).__init__()
self.conv1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1))
self.conv2_dw = Conv_block(64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
self.conv_23 = Depth_Wise(64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128)
self.conv_3 = Residual(64, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
self.conv_34 = Depth_Wise(64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256)
self.conv_4 = Residual(128, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
self.conv_45 = Depth_Wise(128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512)
self.conv_5 = Residual(128, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
self.conv_6_sep = Conv_block(128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(7,7), stride=(1, 1), padding=(0, 0))
self.conv_6_flatten = Flatten()
self.linear = Linear(512, embedding_size, bias=False)
self.bn = BatchNorm1d(embedding_size)
def forward(self, x):
out = self.conv1(x)
out = self.conv2_dw(out)
out = self.conv_23(out)
out = self.conv_3(out)
out = self.conv_34(out)
out = self.conv_4(out)
out = self.conv_45(out)
out = self.conv_5(out)
out = self.conv_6_sep(out)
out = self.conv_6_dw(out)
out = self.conv_6_flatten(out)
out = self.linear(out)
out = self.bn(out)
return l2_norm(out)
################################## Arcface head #############################################################
class Arcface(Module):
# implementation of additive margin softmax loss in https://arxiv.org/abs/1801.05599
def __init__(self, embedding_size=512, classnum=51332, s=64., m=0.5):
super(Arcface, self).__init__()
self.classnum = classnum
self.kernel = Parameter(torch.Tensor(embedding_size,classnum))
# initial kernel
self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
self.m = m # the margin value, default is 0.5
self.s = s # scalar value default is 64, see normface https://arxiv.org/abs/1704.06369
self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
self.mm = self.sin_m * m # issue 1
self.threshold = math.cos(math.pi - m)
def forward(self, embbedings, label):
# weights norm
nB = len(embbedings)
kernel_norm = l2_norm(self.kernel,axis=0)
# cos(theta+m)
cos_theta = torch.mm(embbedings,kernel_norm)
# output = torch.mm(embbedings,kernel_norm)
cos_theta = cos_theta.clamp(-1,1) # for numerical stability
cos_theta_2 = torch.pow(cos_theta, 2)
sin_theta_2 = 1 - cos_theta_2
sin_theta = torch.sqrt(sin_theta_2)
cos_theta_m = (cos_theta * self.cos_m - sin_theta * self.sin_m)
# this condition controls the theta+m should in range [0, pi]
# 0<=theta+m<=pi
# -m<=theta<=pi-m
cond_v = cos_theta - self.threshold
cond_mask = cond_v <= 0
keep_val = (cos_theta - self.mm) # when theta not in [0,pi], use cosface instead
cos_theta_m[cond_mask] = keep_val[cond_mask]
output = cos_theta * 1.0 # a little bit hacky way to prevent in_place operation on cos_theta
idx_ = torch.arange(0, nB, dtype=torch.long)
output[idx_, label] = cos_theta_m[idx_, label]
output *= self.s # scale up in order to make softmax work, first introduced in normface
return output
################################## Cosface head #############################################################
class Am_softmax(Module):
# implementation of additive margin softmax loss in https://arxiv.org/abs/1801.05599
def __init__(self,embedding_size=512,classnum=51332):
super(Am_softmax, self).__init__()
self.classnum = classnum
self.kernel = Parameter(torch.Tensor(embedding_size,classnum))
# initial kernel
self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
self.m = 0.35 # additive margin recommended by the paper
self.s = 30. # see normface https://arxiv.org/abs/1704.06369
def forward(self,embbedings,label):
kernel_norm = l2_norm(self.kernel,axis=0)
cos_theta = torch.mm(embbedings,kernel_norm)
cos_theta = cos_theta.clamp(-1,1) # for numerical stability
phi = cos_theta - self.m
label = label.view(-1,1) #size=(B,1)
index = cos_theta.data * 0.0 #size=(B,Classnum)
index.scatter_(1,label.data.view(-1,1),1)
index = index.byte()
output = cos_theta * 1.0
output[index] = phi[index] #only change the correct predicted output
output *= self.s # scale up in order to make softmax work, first introduced in normface
return output