Skip to content

Commit a06d628

Browse files
committed
add resnet50 for training
1 parent f7e3515 commit a06d628

23 files changed

+210
-125
lines changed

README.md

+25-21
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,30 @@
11
# RetinaFace in PyTorch
22

3-
A [PyTorch](https://pytorch.org/) implementation of [RetinaFace: Single-stage Dense Face Localisation in the Wild](https://arxiv.org/abs/1905.00641). Model size only 1.7M, when Retinaface use mobilenet0.25 as backbone net. The official code in Mxnet can be found [here](https://github.com/deepinsight/insightface/tree/master/RetinaFace).
3+
A [PyTorch](https://pytorch.org/) implementation of [RetinaFace: Single-stage Dense Face Localisation in the Wild](https://arxiv.org/abs/1905.00641). Model size only 1.7M, when Retinaface use mobilenet0.25 as backbone net. We also provide resnet50 as backbone net to get better result. The official code in Mxnet can be found [here](https://github.com/deepinsight/insightface/tree/master/RetinaFace).
44

55
## WiderFace Val Performance in single scale When using Resnet50 as backbone net.
66
| Style | easy | medium | hard |
77
|:-|:-:|:-:|:-:|
8-
| Pytorch (same parameter with Mxnet) | 94.47 % | 93.54% | 89.21% |
9-
| Pytorch (original image scale) | 95.55 % | 94.09% | 84.05% |
8+
| Pytorch (same parameter with Mxnet) | 94.82 % | 93.84% | 89.60% |
9+
| Pytorch (original image scale) | 95.48% | 94.04% | 84.43% |
1010
| Mxnet | 94.86% | 93.87% | 88.33% |
1111
| Mxnet(original image scale) | 94.97% | 93.89% | 82.27% |
1212

13-
ps: The resnet50-based demo will be updated recently.
14-
1513
## WiderFace Val Performance in single scale When using Mobilenet0.25 as backbone net.
1614
| Style | easy | medium | hard |
1715
|:-|:-:|:-:|:-:|
18-
| Pytorch (same parameter with Mxnet) | 86.85 % | 85.84% | 79.69% |
19-
| Pytorch (original image scale) | 90.58 % | 87.94% | 73.96% |
16+
| Pytorch (same parameter with Mxnet) | 88.67% | 87.09% | 80.99% |
17+
| Pytorch (original image scale) | 90.70% | 88.16% | 73.82% |
2018
| Mxnet | 88.72% | 86.97% | 79.19% |
2119
| Mxnet(original image scale) | 89.58% | 87.11% | 69.12% |
22-
<p align="center"><img src="curve/r_3.png" width="640"\></p>
20+
<p align="center"><img src="curve/Widerface.jpg" width="640"\></p>
2321

24-
## FDDB Performance When using Mobilenet0.25 as backbone net.
25-
| Dataset | performance |
22+
## FDDB Performance.
23+
| FDDB(pytorch) | performance |
2624
|:-|:-:|
27-
| FDDB(pytorch) | 97.93% |
28-
<p align="center"><img src="curve/FDDB_DiscROC.png" width="640"\></p>
25+
| Mobilenet0.25 | 98.64% |
26+
| Resnet50 | 99.22% |
27+
<p align="center"><img src="curve/FDDB.png" width="640"\></p>
2928

3029
### Contents
3130
- [Installation](#installation)
@@ -62,24 +61,31 @@ ps: wider_val.txt only include val file names but not label information.
6261
##### Data1
6362
We also provide the organized dataset we used as in the above directory structure.
6463

65-
Link: from [baidu cloud](https://pan.baidu.com/s/1jIp9t30oYivrAvrgUgIoLQ) Password: ruck
64+
Link: from [google cloud](https://drive.google.com/open?id=11UGV3nbVv1x9IC--_tK3Uxf7hA6rlbsS) or [baidu cloud](https://pan.baidu.com/s/1jIp9t30oYivrAvrgUgIoLQ) Password: ruck
6665

6766
## Training
68-
We trained Mobilenet0.25 on imagenet dataset and get 46.75% in top 1. We use it as pretrain model which has been put in repository named ``model_best.pth.tar``.
69-
1. Before training, you can check the mobilenet*0.25 network configuration (e.g. batch_size, min_sizes and steps etc..) in ``data/config.py and train.py``.
67+
We provide restnet50 or mobilenet0.25 as backbone network.
68+
We trained Mobilenet0.25 on imagenet dataset and get 46.58% in top 1. If you do not wish to train the model, we also provide trained model. Pretrain model and trained model are put in [google cloud](https://drive.google.com/open?id=1oZRSG0ZegbVkVwUd8wUIQx8W7yfZ_ki1) and [baidu cloud](https://pan.baidu.com/s/12h97Fy1RYuqMMIV-RpzdPg) Password: fstq . The model could be put as follows:
69+
'''Shell
70+
./weights/
71+
mobilenet0.25_Final.pth
72+
mobilenetV1X0.25_pretrain.tar
73+
Resnet50_Final.pth
74+
'''
75+
1. Before training, you can check network configuration (e.g. batch_size, min_sizes and steps etc..) in ``data/config.py and train.py``.
7076

7177
2. Train the model using WIDER FACE:
7278
```Shell
73-
python train.py
79+
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --network resnet50 or
80+
CUDA_VISIBLE_DEVICES=0 python train.py --network mobile0.25
7481
```
7582

76-
If you do not wish to train the model, we also provide trained model in `./weights/Final_Retinaface.pth`.
7783

7884
## Evaluation
7985
### Evaluation widerface val
8086
1. Generate txt file
8187
```Shell
82-
python test_widerface.py --trained_model weight_file
88+
python test_widerface.py --trained_model weight_file --network mobile0.25 or resnet50
8389
```
8490
2. Evaluate txt results. Demo come from [Here](https://github.com/wondervictor/WiderFace-Evaluation)
8591
```Shell
@@ -97,14 +103,12 @@ python evaluation.py
97103

98104
2. Evaluate the trained model using:
99105
```Shell
100-
python test.py --dataset FDDB
106+
python test_fddb.py --trained_model weight_file --network mobile0.25 or resnet50
101107
```
102108

103109
3. Download [eval_tool](https://bitbucket.org/marcopede/face-eval) to evaluate the performance.
104110

105-
## RetinaFace-MobileNet0.25
106111
<p align="center"><img src="curve/1.jpg" width="640"\></p>
107-
<p align="center"><img src="curve/2.jpg" width="640"\></p>
108112

109113
## References
110114
- [FaceBoxes](https://github.com/zisianw/FaceBoxes.PyTorch)

curve/1.jpg

-21.9 KB
Loading

curve/2.jpg

-55 KB
Binary file not shown.

curve/FDDB.png

83.7 KB
Loading

curve/FDDB_DiscROC.png

-6.37 KB
Binary file not shown.

curve/Widerface.jpg

221 KB
Loading

curve/o_1.png

-82.8 KB
Binary file not shown.

curve/o_2.png

-81.8 KB
Binary file not shown.

curve/o_3.png

-89.3 KB
Binary file not shown.

curve/r_1.png

-81.4 KB
Binary file not shown.

curve/r_2.png

-81 KB
Binary file not shown.

curve/r_3.png

-86.7 KB
Binary file not shown.

data/config.py

+34-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,42 @@
11
# config.py
22

3-
cfg = {
4-
'name': 'Retinaface',
3+
cfg_mnet = {
4+
'name': 'mobilenet0.25',
55
'min_sizes': [[16, 32], [64, 128], [256, 512]],
66
'steps': [8, 16, 32],
77
'variance': [0.1, 0.2],
88
'clip': False,
99
'loc_weight': 2.0,
10-
'gpu_train': True
10+
'gpu_train': True,
11+
'batch_size': 32,
12+
'ngpu': 1,
13+
'epoch': 250,
14+
'decay1': 190,
15+
'decay2': 220,
16+
'image_size': 640,
17+
'pretrain': True,
18+
'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3},
19+
'in_channel': 32,
20+
'out_channel': 64
1121
}
22+
23+
cfg_re50 = {
24+
'name': 'Resnet50',
25+
'min_sizes': [[16, 32], [64, 128], [256, 512]],
26+
'steps': [8, 16, 32],
27+
'variance': [0.1, 0.2],
28+
'clip': False,
29+
'loc_weight': 2.0,
30+
'gpu_train': True,
31+
'batch_size': 24,
32+
'ngpu': 4,
33+
'epoch': 100,
34+
'decay1': 70,
35+
'decay2': 90,
36+
'image_size': 840,
37+
'pretrain': True,
38+
'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3},
39+
'in_channel': 256,
40+
'out_channel': 256
41+
}
42+

data/data_augment.py

+4
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,14 @@ def _crop(image, boxes, labels, landm, img_dim):
99
pad_image_flag = True
1010

1111
for _ in range(250):
12+
"""
1213
if random.uniform(0, 1) <= 0.2:
1314
scale = 1.0
1415
else:
1516
scale = random.uniform(0.3, 1.0)
17+
"""
18+
PRE_SCALES = [0.3, 0.45, 0.6, 0.8, 1.0]
19+
scale = random.choice(PRE_SCALES)
1620
short_side = min(width, height)
1721
w = int(scale * short_side)
1822
h = w

detect.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
import torch.backends.cudnn as cudnn
66
import numpy as np
7-
from data import cfg
7+
from data import cfg_mnet, cfg_re50
88
from layers.functions.prior_box import PriorBox
99
from utils.nms.py_cpu_nms import py_cpu_nms
1010
import cv2
@@ -14,12 +14,13 @@
1414

1515
parser = argparse.ArgumentParser(description='Retinaface')
1616

17-
parser.add_argument('-m', '--trained_model', default='./weights/Final_Retinaface.pth',
17+
parser.add_argument('-m', '--trained_model', default='./weights/Resnet50_Final.pth',
1818
type=str, help='Trained state_dict file path to open')
19+
parser.add_argument('--network', default='resnet50', help='Backbone network mobile0.25 or resnet50')
1920
parser.add_argument('--cpu', action="store_true", default=False, help='Use cpu inference')
20-
parser.add_argument('--confidence_threshold', default=0.05, type=float, help='confidence_threshold')
21+
parser.add_argument('--confidence_threshold', default=0.02, type=float, help='confidence_threshold')
2122
parser.add_argument('--top_k', default=5000, type=int, help='top_k')
22-
parser.add_argument('--nms_threshold', default=0.3, type=float, help='nms_threshold')
23+
parser.add_argument('--nms_threshold', default=0.4, type=float, help='nms_threshold')
2324
parser.add_argument('--keep_top_k', default=750, type=int, help='keep_top_k')
2425
parser.add_argument('-s', '--save_image', action="store_true", default=True, help='show detection results')
2526
parser.add_argument('--vis_thres', default=0.6, type=float, help='visualization_threshold')
@@ -64,8 +65,13 @@ def load_model(model, pretrained_path, load_to_cpu):
6465

6566
if __name__ == '__main__':
6667
torch.set_grad_enabled(False)
68+
cfg = None
69+
if args.network == "mobile0.25":
70+
cfg = cfg_mnet
71+
elif args.network == "resnet50":
72+
cfg = cfg_re50
6773
# net and model
68-
net = RetinaFace(phase="test")
74+
net = RetinaFace(cfg=cfg, phase = 'test')
6975
net = load_model(net, args.trained_model, args.cpu)
7076
net.eval()
7177
print('Finished loading model!')
@@ -150,11 +156,11 @@ def load_model(model, pretrained_path, load_to_cpu):
150156
cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255))
151157

152158
# landms
153-
cv2.circle(img_raw, (b[5], b[6]), 4, (0, 0, 255), 4)
154-
cv2.circle(img_raw, (b[7], b[8]), 4, (0, 255, 255), 4)
155-
cv2.circle(img_raw, (b[9], b[10]), 4, (255, 0, 255), 4)
156-
cv2.circle(img_raw, (b[11], b[12]), 4, (0, 255, 0), 4)
157-
cv2.circle(img_raw, (b[13], b[14]), 4, (255, 0, 0), 4)
159+
cv2.circle(img_raw, (b[5], b[6]), 1, (0, 0, 255), 4)
160+
cv2.circle(img_raw, (b[7], b[8]), 1, (0, 255, 255), 4)
161+
cv2.circle(img_raw, (b[9], b[10]), 1, (255, 0, 255), 4)
162+
cv2.circle(img_raw, (b[11], b[12]), 1, (0, 255, 0), 4)
163+
cv2.circle(img_raw, (b[13], b[14]), 1, (255, 0, 0), 4)
158164
# save image
159165

160166
name = "test.jpg"

layers/modules/multibox_loss.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import torch.nn.functional as F
44
from torch.autograd import Variable
55
from utils.box_utils import match, log_sum_exp
6-
from data import cfg
7-
GPU = cfg['gpu_train']
6+
from data import cfg_mnet
7+
GPU = cfg_mnet['gpu_train']
88

99
class MultiBoxLoss(nn.Module):
1010
"""SSD Weighted Loss Function

model_best.pth.tar

-3.65 MB
Binary file not shown.

models/mobilev1.py models/net.py

+22-16
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
import torch.nn.functional as F
77
from torch.autograd import Variable
88

9-
def conv_bn(inp, oup, stride = 1):
9+
def conv_bn(inp, oup, stride = 1, leaky = 0):
1010
return nn.Sequential(
1111
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
1212
nn.BatchNorm2d(oup),
13-
nn.ReLU(inplace=True)
13+
nn.LeakyReLU(negative_slope=leaky, inplace=True)
1414
)
1515

1616
def conv_bn_no_relu(inp, oup, stride):
@@ -19,34 +19,37 @@ def conv_bn_no_relu(inp, oup, stride):
1919
nn.BatchNorm2d(oup),
2020
)
2121

22-
def conv_bn1X1(inp, oup, stride):
22+
def conv_bn1X1(inp, oup, stride, leaky=0):
2323
return nn.Sequential(
2424
nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False),
2525
nn.BatchNorm2d(oup),
26-
nn.ReLU(inplace=True)
26+
nn.LeakyReLU(negative_slope=leaky, inplace=True)
2727
)
2828

29-
def conv_dw(inp, oup, stride):
29+
def conv_dw(inp, oup, stride, leaky=0.1):
3030
return nn.Sequential(
3131
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
3232
nn.BatchNorm2d(inp),
33-
nn.ReLU(inplace=True),
33+
nn.LeakyReLU(negative_slope= leaky,inplace=True),
3434

3535
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
3636
nn.BatchNorm2d(oup),
37-
nn.ReLU(inplace=True),
37+
nn.LeakyReLU(negative_slope= leaky,inplace=True),
3838
)
3939

4040
class SSH(nn.Module):
4141
def __init__(self, in_channel, out_channel):
4242
super(SSH, self).__init__()
4343
assert out_channel % 4 == 0
44+
leaky = 0
45+
if (out_channel <= 64):
46+
leaky = 0.1
4447
self.conv3X3 = conv_bn_no_relu(in_channel, out_channel//2, stride=1)
4548

46-
self.conv5X5_1 = conv_bn(in_channel, out_channel//4, stride=1)
49+
self.conv5X5_1 = conv_bn(in_channel, out_channel//4, stride=1, leaky = leaky)
4750
self.conv5X5_2 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)
4851

49-
self.conv7X7_2 = conv_bn(out_channel//4, out_channel//4, stride=1)
52+
self.conv7X7_2 = conv_bn(out_channel//4, out_channel//4, stride=1, leaky = leaky)
5053
self.conv7x7_3 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)
5154

5255
def forward(self, input):
@@ -65,15 +68,18 @@ def forward(self, input):
6568
class FPN(nn.Module):
6669
def __init__(self,in_channels_list,out_channels):
6770
super(FPN,self).__init__()
68-
self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1)
69-
self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1)
70-
self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1)
71+
leaky = 0
72+
if (out_channels <= 64):
73+
leaky = 0.1
74+
self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1, leaky = leaky)
75+
self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1, leaky = leaky)
76+
self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1, leaky = leaky)
7177

72-
self.merge1 = conv_bn(out_channels, out_channels)
73-
self.merge2 = conv_bn(out_channels, out_channels)
78+
self.merge1 = conv_bn(out_channels, out_channels, leaky = leaky)
79+
self.merge2 = conv_bn(out_channels, out_channels, leaky = leaky)
7480

7581
def forward(self, input):
76-
names = list(input.keys())
82+
# names = list(input.keys())
7783
input = list(input.values())
7884

7985
output1 = self.output1(input[0])
@@ -97,7 +103,7 @@ class MobileNetV1(nn.Module):
97103
def __init__(self):
98104
super(MobileNetV1, self).__init__()
99105
self.stage1 = nn.Sequential(
100-
conv_bn(3, 8, 2), # 3
106+
conv_bn(3, 8, 2, leaky = 0.1), # 3
101107
conv_dw(8, 16, 1), # 7
102108
conv_dw(16, 32, 2), # 11
103109
conv_dw(32, 32, 1), # 19

0 commit comments

Comments
 (0)