-
Notifications
You must be signed in to change notification settings - Fork 91
/
multi_attention_detector.py
473 lines (409 loc) · 22.7 KB
/
multi_attention_detector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
"""
# author: Kangran ZHAO
# email: [email protected]
# date: 2024-0401
# description: Class for the Multi-attention Detector
Functions in the Class are summarized as:
1. __init__: Initialization
2. build_backbone: Backbone-building
3. build_loss: Loss-function-building
4. features: Feature-extraction
5. classifier: Classification
6. get_losses: Loss-computation
7. get_train_metrics: Training-metrics-computation
8. get_test_metrics: Testing-metrics-computation
9. forward: Forward-propagation
Reference:
@INPROCEEDINGS{9577592,
author={Zhao, Hanqing and Wei, Tianyi and Zhou, Wenbo and Zhang, Weiming and Chen, Dongdong and Yu, Nenghai},
booktitle={2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
title={Multi-attentional Deepfake Detection},
year={2021},
volume={},
number={},
pages={2185-2194},
keywords={Measurement;Semantics;Feature extraction;Forgery;Pattern recognition;Feeds;Task analysis},
doi={10.1109/CVPR46437.2021.00222}
}
Codes are modified based on GitHub repo https://github.com/yoctta/multiple-attention
"""
import random
import kornia
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from detectors import DETECTOR
from loss import LOSSFUNC
from metrics.base_metrics_class import calculate_metrics_for_train
from networks import BACKBONE
from sklearn import metrics
from .base_detector import AbstractDetector
@DETECTOR.register_module(module_name='multi_attention')
class MultiAttentionDetector(AbstractDetector):
def __init__(self, config):
super().__init__()
self.config = config
self.block_layer = {"b1": 1, "b2": 5, "b3": 9, "b4": 15, "b5": 21, "b6": 29, "b7": 31}
self.mid_dim = config["mid_dim"]
self.backbone = self.build_backbone(config)
self.loss_func = self.build_loss(config)
self.batch_cnt = 0
with torch.no_grad():
layer_outputs = self.features({"image": torch.zeros(1, 3, config["resolution"], config["resolution"])})
self.feature_layer = config["feature_layer"]
self.attention_layer = config["attention_layer"]
self.num_classes = config["backbone_config"]["num_classes"]
self.num_shallow_features = layer_outputs[self.feature_layer].shape[1]
self.num_attention_features = layer_outputs[self.attention_layer].shape[1]
self.num_final_features = layer_outputs["final"].shape[1]
self.num_attentions = config["num_attentions"]
self.AGDA = AGDA(kernel_size=config["AGDA"]["kernel_size"],
dilation=config["AGDA"]["dilation"],
sigma=config["AGDA"]["sigma"],
threshold=config["AGDA"]["threshold"],
zoom=config["AGDA"]["zoom"],
scale_factor=config["AGDA"]["scale_factor"],
noise_rate=config["AGDA"]["noise_rate"])
self.attention_generation = AttentionMap(self.num_attention_features, self.num_attentions)
self.attention_pooling = AttentionPooling()
self.texture_enhance = TextureEnhanceV1(self.num_shallow_features, self.num_attentions) # Todo
self.num_enhanced_features = self.texture_enhance.output_features
self.num_features_d = self.texture_enhance.output_features_d
self.projection_local = nn.Sequential(nn.Linear(self.num_attentions * self.num_enhanced_features, self.mid_dim),
nn.Hardswish(),
nn.Linear(self.mid_dim, self.mid_dim),
nn.Hardswish())
self.projection_final = nn.Sequential(nn.Linear(self.num_final_features, self.mid_dim),
nn.Hardswish())
self.ensemble_classifier_fc = nn.Sequential(nn.Linear(self.mid_dim * 2, self.mid_dim),
nn.Hardswish(),
nn.Linear(self.mid_dim, self.num_classes))
self.dropout = nn.Dropout(config["dropout_rate"], inplace=True)
self.dropout_final = nn.Dropout(config["dropout_rate_final"], inplace=True)
def build_backbone(self, config):
backbone_class = BACKBONE[config['backbone_name']]
model_config = config['backbone_config']
model_config['pretrained'] = self.config.get('pretrained', None)
backbone = backbone_class(model_config)
return backbone
def build_loss(self, config):
cls_loss_class = LOSSFUNC[config["loss_func"]["cls_loss"]]
ril_loss_class = LOSSFUNC[config["loss_func"]["ril_loss"]]
cls_loss_func = cls_loss_class()
ril_loss_func = ril_loss_class(M=config["num_attentions"],
N=config["loss_func"]["ril_params"]["N"],
alpha=config["loss_func"]["ril_params"]["alpha"],
alpha_decay=config["loss_func"]["ril_params"]["alpha_decay"],
decay_batch=config["batch_per_epoch"],
inter_margin=config["loss_func"]["ril_params"]["inter_margin"],
intra_margin=config["loss_func"]["ril_params"]["intra_margin"])
return {"cls": cls_loss_func, "ril": ril_loss_func, "weights": config["loss_func"]["weights"]}
def features(self, data_dict: dict) -> torch.tensor:
x = data_dict["image"]
layer_output = {}
for name, module in self.backbone.efficientnet.named_children():
if name == "_avg_pooling":
layer_output["final"] = x
break
elif name != "_blocks":
x = module(x)
else:
for i in range(len(module)):
x = module[i](x)
if i == self.block_layer["b1"]:
layer_output["b1"] = x
elif i == self.block_layer["b2"]:
layer_output["b2"] = x
elif i == self.block_layer["b3"]:
layer_output["b3"] = x
elif i == self.block_layer["b4"]:
layer_output["b4"] = x
elif i == self.block_layer["b5"]:
layer_output["b5"] = x
elif i == self.block_layer["b6"]:
layer_output["b6"] = x
elif i == self.block_layer["b7"]:
layer_output["b7"] = x
x = F.adaptive_avg_pool2d(x, (1, 1))
x = x.view(x.size(0), -1)
layer_output["logit"] = self.backbone.last_layer(x)
return layer_output
def classifier(self, features: torch.tensor) -> torch.tensor:
pass # do not overwrite this, since classifier structure has been written in self.forward()
def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
if self.batch_cnt <= self.config["backbone_nEpochs"] * self.config["batch_per_epoch"]:
label = data_dict["label"]
pred = pred_dict["cls"]
ce_loss = self.loss_func["cls"](pred, label)
return {"overall": ce_loss, "ce_loss": ce_loss}
else:
label = data_dict["label"]
pred = pred_dict["cls"]
feature_maps_d = pred_dict["feature_maps_d"]
attention_maps = pred_dict["attentions"]
ce_loss = self.loss_func["cls"](pred, label)
ril_loss = self.loss_func["ril"](feature_maps_d, attention_maps, label)
weights = self.loss_func["weights"]
over_all_loss = weights[0] * ce_loss + weights[1] * ril_loss
return {"overall": over_all_loss, "ce_loss": ce_loss, "ril_loss": ril_loss}
def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
label = data_dict['label']
pred = pred_dict['cls']
auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach())
metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap}
return metric_batch_dict
def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
label = data_dict['label']
pred = pred_dict['cls']
auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach())
metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap}
return metric_batch_dict
def forward(self, data_dict: dict, inference=False) -> dict:
self.batch_cnt += 1
if self.batch_cnt <= self.config["backbone_nEpochs"] * self.config["batch_per_epoch"]:
layer_output = self.features(data_dict)
pred = layer_output["logit"]
prob = torch.softmax(pred, dim=1)[:, 1]
pred_dict = {"cls": pred,
"prob": prob,
"feat": layer_output["final"]}
else:
if not inference: # use AGDA when training
with torch.no_grad():
layer_output = self.features(data_dict)
raw_attentions = layer_output[self.attention_layer]
attention_maps = self.attention_generation(raw_attentions)
data_dict["image"], _ = self.AGDA.agda(data_dict["image"], attention_maps)
# Get Attention Maps
layer_output = self.features(data_dict)
raw_attentions = layer_output[self.attention_layer]
attention_maps = self.attention_generation(raw_attentions)
# Get Textural Feature Matrix P
shallow_features = layer_output[self.feature_layer]
enhanced_features, feature_maps_d = self.texture_enhance(shallow_features, attention_maps)
textural_feature_matrix_p = self.attention_pooling(enhanced_features, attention_maps)
B, M, N = textural_feature_matrix_p.size()
feature_matrix = self.dropout(textural_feature_matrix_p).view(B, -1)
feature_matrix = self.projection_local(feature_matrix)
# Get Global Feature G
final = layer_output["final"]
attention_maps2 = attention_maps.sum(dim=1, keepdim=True) # [B, 1, H_A, W_A]
final = self.attention_pooling(final, attention_maps2, norm=1).squeeze(1) # [B, C_F]
final = self.projection_final(final)
final = F.hardswish(final)
# Get the Prediction by Ensemble Classifier
feature_matrix = torch.cat((feature_matrix, final), dim=1) # [B, 2 * mid_dim]
pred = self.ensemble_classifier_fc(feature_matrix) # [B, 2]
# Get probability
prob = torch.softmax(pred, dim=1)[:, 1]
pred_dict = {"cls": pred,
"prob": prob,
"feat": layer_output["final"],
"attentions": attention_maps,
"feature_maps_d": feature_maps_d}
return pred_dict
class AttentionMap(nn.Module):
def __init__(self, in_channels, num_attention):
super(AttentionMap, self).__init__()
self.register_buffer('mask', torch.zeros([1, 1, 24, 24]))
self.mask[0, 0, 2:-2, 2:-2] = 1
self.num_attentions = num_attention
self.conv_extract = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv2 = nn.Conv2d(in_channels, num_attention, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(num_attention)
def forward(self, x):
"""
Convert deep feature to attention map
Args:
x: extracted features
Returns:
attention_maps: conventionally 4 attention maps
"""
if self.num_attentions == 0:
return torch.ones([x.shape[0], 1, 1, 1], device=x.device)
x = self.conv_extract(x)
x = self.bn1(x)
x = F.relu(x, inplace=True)
x = self.conv2(x)
x = self.bn2(x)
x = F.elu(x) + 1
mask = F.interpolate(self.mask, (x.shape[2], x.shape[3]), mode='nearest')
return x * mask
class AttentionPooling(nn.Module):
def __init__(self):
super().__init__()
def forward(self, features, attentions, norm=2):
"""
Bilinear Attention Pooing, when used for
Args:
features: [Tensor in [B, C_F, H_F, W_F]] extracted feature maps, either shallow ones or deep ones ???
attentions: [Tensor in [B, M, H, W]] attention maps, conventionally 4 attention maps (M = 4)
norm: [int, default=2] 1 for deep features, 2 for shallow features
Returns:
feature_matrix: [Tensor in [B, M, C_F] or [B, M, 1]] P (shallow feature) or G (deep feature) ???
"""
feature_size = features.size()[-2:]
attention_size = attentions.size()[-2:]
if feature_size != attention_size:
attentions = F.interpolate(attentions, size=feature_size, mode='bilinear', align_corners=True)
if len(features.shape) == 4:
# In TextureEnhanceV1, in accordance with paper
feature_matrix = torch.einsum('imjk,injk->imn', attentions, features) # [B, M, C_F]
else:
# In TextureEnhanceV2
feature_matrix = torch.einsum('imjk,imnjk->imn', attentions, features)
if norm == 1: # Used for deep feature BAP
w = torch.sum(attentions + 1e-8, dim=(2, 3)).unsqueeze(-1)
feature_matrix /= w
elif norm == 2: # Used for shallow feature BAP
feature_matrix = F.normalize(feature_matrix, p=2, dim=-1)
return feature_matrix
class TextureEnhanceV1(nn.Module):
def __init__(self, num_features, num_attentions):
super().__init__()
# self.output_features=num_features
self.output_features = num_features * 4
self.output_features_d = num_features
self.conv0 = nn.Conv2d(num_features, num_features, 1)
self.conv1 = nn.Conv2d(num_features, num_features, 3, padding=1)
self.bn1 = nn.BatchNorm2d(num_features)
self.conv2 = nn.Conv2d(num_features * 2, num_features, 3, padding=1)
self.bn2 = nn.BatchNorm2d(2 * num_features)
self.conv3 = nn.Conv2d(num_features * 3, num_features, 3, padding=1)
self.bn3 = nn.BatchNorm2d(3 * num_features)
self.conv_last = nn.Conv2d(num_features * 4, num_features * 4, 1)
self.bn4 = nn.BatchNorm2d(4 * num_features)
self.bn_last = nn.BatchNorm2d(num_features * 4)
def forward(self, feature_maps, attention_maps=(1, 1)):
"""
Texture Enhancement Block V1, in accordance with description in paper
1. Local average pooling.
2. Residual local features.
3. Dense Net
Args:
feature_maps: [Tensor in [B, C', H', W']] extracted shallow features
attention_maps: [Tensor in [B, M, H_A, W_A]] calculated attention maps, or
[Tuple with two float elements] local average grid scale,
used for conduct local average pooling, local patch size is decided by attention map size.
Returns:
feature_maps: [Tensor in [B, C_1, H_1, W_1]] enhanced feature maps
feature_maps_d: [Tensor in [B, C', H_A, W_A]] textural information
"""
B, N, H, W = feature_maps.shape
if type(attention_maps) == tuple:
attention_size = (int(H * attention_maps[0]), int(W * attention_maps[1]))
else:
attention_size = (attention_maps.shape[2], attention_maps.shape[3])
feature_maps_d = F.adaptive_avg_pool2d(feature_maps, attention_size)
feature_maps = feature_maps - F.interpolate(feature_maps_d, (feature_maps.shape[2], feature_maps.shape[3]),
mode='nearest')
feature_maps0 = self.conv0(feature_maps)
feature_maps1 = self.conv1(F.relu(self.bn1(feature_maps0), inplace=True))
feature_maps1_ = torch.cat([feature_maps0, feature_maps1], dim=1)
feature_maps2 = self.conv2(F.relu(self.bn2(feature_maps1_), inplace=True))
feature_maps2_ = torch.cat([feature_maps1_, feature_maps2], dim=1)
feature_maps3 = self.conv3(F.relu(self.bn3(feature_maps2_), inplace=True))
feature_maps3_ = torch.cat([feature_maps2_, feature_maps3], dim=1)
feature_maps = self.bn_last(self.conv_last(F.relu(self.bn4(feature_maps3_), inplace=True)))
return feature_maps, feature_maps_d
class TextureEnhanceV2(nn.Module):
def __init__(self, num_features, num_attentions):
super().__init__()
self.output_features = num_features
self.output_features_d = num_features
self.conv_extract = nn.Conv2d(num_features, num_features, 3, padding=1)
self.conv0 = nn.Conv2d(num_features * num_attentions, num_features * num_attentions, 5, padding=2,
groups=num_attentions)
self.conv1 = nn.Conv2d(num_features * num_attentions, num_features * num_attentions, 3, padding=1,
groups=num_attentions)
self.bn1 = nn.BatchNorm2d(num_features * num_attentions)
self.conv2 = nn.Conv2d(num_features * 2 * num_attentions, num_features * num_attentions, 3, padding=1,
groups=num_attentions)
self.bn2 = nn.BatchNorm2d(2 * num_features * num_attentions)
self.conv3 = nn.Conv2d(num_features * 3 * num_attentions, num_features * num_attentions, 3, padding=1,
groups=num_attentions)
self.bn3 = nn.BatchNorm2d(3 * num_features * num_attentions)
self.conv_last = nn.Conv2d(num_features * 4 * num_attentions, num_features * num_attentions, 1,
groups=num_attentions)
self.bn4 = nn.BatchNorm2d(4 * num_features * num_attentions)
self.bn_last = nn.BatchNorm2d(num_features * num_attentions)
self.M = num_attentions
def cat(self, a, b):
B, C, H, W = a.shape
c = torch.cat([a.reshape(B, self.M, -1, H, W), b.reshape(B, self.M, -1, H, W)], dim=2).reshape(B, -1, H, W)
return c
def forward(self, feature_maps, attention_maps=(1, 1)):
"""
Args:
feature_maps: [Tensor in [B, N, H, W]] extracted feature maps from shallow layer
attention_maps: [Tensor in [B, M, H_A, W_A] or float of (H_ratio, W_ratio)] either extracted attention maps
or average pooling down-sampling ratio
Returns:
feature_maps, feature_maps_d: [Tensor in [B, M, N, H, W], Tensor in [B, N, H, W]] feature maps after dense
network and non-textural feature map D
"""
B, N, H, W = feature_maps.shape
if type(attention_maps) == tuple:
attention_size = (int(H * attention_maps[0]), int(W * attention_maps[1]))
else:
attention_size = (attention_maps.shape[2], attention_maps.shape[3])
feature_maps = self.conv_extract(feature_maps)
feature_maps_d = F.adaptive_avg_pool2d(feature_maps, attention_size)
if feature_maps.size(2) > feature_maps_d.size(2):
feature_maps = feature_maps - F.interpolate(feature_maps_d, (feature_maps.shape[2], feature_maps.shape[3]),
mode='nearest')
attention_maps = (
torch.tanh(F.interpolate(attention_maps.detach(), (H, W), mode='bilinear', align_corners=True))).unsqueeze(
2) if type(attention_maps) != tuple else 1
feature_maps = feature_maps.unsqueeze(1)
feature_maps = (feature_maps * attention_maps).reshape(B, -1, H, W)
feature_maps0 = self.conv0(feature_maps)
feature_maps1 = self.conv1(F.relu(self.bn1(feature_maps0), inplace=True))
feature_maps1_ = self.cat(feature_maps0, feature_maps1)
feature_maps2 = self.conv2(F.relu(self.bn2(feature_maps1_), inplace=True))
feature_maps2_ = self.cat(feature_maps1_, feature_maps2)
feature_maps3 = self.conv3(F.relu(self.bn3(feature_maps2_), inplace=True))
feature_maps3_ = self.cat(feature_maps2_, feature_maps3)
feature_maps = F.relu(self.bn_last(self.conv_last(F.relu(self.bn4(feature_maps3_), inplace=True))),
inplace=True)
feature_maps = feature_maps.reshape(B, -1, N, H, W)
return feature_maps, feature_maps_d
class AGDA(nn.Module):
def __init__(self, kernel_size, dilation, sigma, threshold, zoom, scale_factor, noise_rate):
super().__init__()
self.kernel_size = kernel_size
self.dilation = dilation
self.sigma = sigma
self.noise_rate = noise_rate
self.scale_factor = scale_factor
self.threshold = threshold
self.zoom = zoom
self.filter = kornia.filters.GaussianBlur2d((self.kernel_size, self.kernel_size), (self.sigma, self.sigma))
def mod_func(self, x):
threshold = random.uniform(*self.threshold) if type(self.threshold) == list else self.threshold
zoom = random.uniform(*self.zoom) if type(self.zoom) == list else self.zoom
bottom = torch.sigmoid((torch.tensor(0.) - threshold) * zoom)
return (torch.sigmoid((x - threshold) * zoom) - bottom) / (1 - bottom)
def soft_drop2(self, x, attention_map):
with torch.no_grad():
attention_map = self.mod_func(attention_map)
B, C, H, W = x.size()
xs = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=True)
xs = self.filter(xs)
xs += torch.randn_like(xs) * self.noise_rate
xs = F.interpolate(xs, (H, W), mode='bilinear', align_corners=True)
x = x * (1 - attention_map) + xs * attention_map
return x
def agda(self, X, attention_map):
with torch.no_grad():
attention_weight = torch.sum(attention_map, dim=(2, 3))
attention_map = F.interpolate(attention_map, (X.size(2), X.size(3)), mode="bilinear", align_corners=True)
attention_weight = torch.sqrt(attention_weight + 1)
index = torch.distributions.categorical.Categorical(attention_weight).sample()
index1 = index.view(-1, 1, 1, 1).repeat(1, 1, X.size(2), X.size(3))
attention_map = torch.gather(attention_map, 1, index1)
atten_max = torch.max(attention_map.view(attention_map.shape[0], 1, -1), 2)[0] + 1e-8
attention_map = attention_map / atten_max.view(attention_map.shape[0], 1, 1, 1)
return self.soft_drop2(X, attention_map), index