-
Notifications
You must be signed in to change notification settings - Fork 552
/
Copy pathyolov6_to_mmyolo.py
115 lines (106 loc) · 4.3 KB
/
yolov6_to_mmyolo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import argparse
from collections import OrderedDict
import torch
def convert(src, dst):
import sys
sys.path.append('yolov6')
try:
ckpt = torch.load(src, map_location=torch.device('cpu'))
except ModuleNotFoundError:
raise RuntimeError(
'This script must be placed under the meituan/YOLOv6 repo,'
' because loading the official pretrained model need'
' some python files to build model.')
# The saved model is the model before reparameterization
model = ckpt['ema' if ckpt.get('ema') else 'model'].float()
new_state_dict = OrderedDict()
for k, v in model.state_dict().items():
name = k
if 'detect' in k:
if 'proj' in k:
continue
name = k.replace('detect', 'bbox_head.head_module')
if k.find('anchors') >= 0 or k.find('anchor_grid') >= 0:
continue
if 'ERBlock_2' in k:
name = k.replace('ERBlock_2', 'stage1.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'ERBlock_3' in k:
name = k.replace('ERBlock_3', 'stage2.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'ERBlock_4' in k:
name = k.replace('ERBlock_4', 'stage3.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'ERBlock_5' in k:
name = k.replace('ERBlock_5', 'stage4.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
if 'stage4.0.2' in name:
name = name.replace('stage4.0.2', 'stage4.1')
name = name.replace('cv', 'conv')
elif 'reduce_layer0' in k:
name = k.replace('reduce_layer0', 'reduce_layers.2')
elif 'Rep_p4' in k:
name = k.replace('Rep_p4', 'top_down_layers.0.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'reduce_layer1' in k:
name = k.replace('reduce_layer1', 'top_down_layers.0.1')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'Rep_p3' in k:
name = k.replace('Rep_p3', 'top_down_layers.1')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'upsample0' in k:
name = k.replace('upsample0.upsample_transpose',
'upsample_layers.0')
elif 'upsample1' in k:
name = k.replace('upsample1.upsample_transpose',
'upsample_layers.1')
elif 'Rep_n3' in k:
name = k.replace('Rep_n3', 'bottom_up_layers.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'Rep_n4' in k:
name = k.replace('Rep_n4', 'bottom_up_layers.1')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'downsample2' in k:
name = k.replace('downsample2', 'downsample_layers.0')
elif 'downsample1' in k:
name = k.replace('downsample1', 'downsample_layers.1')
new_state_dict[name] = v
data = {'state_dict': new_state_dict}
torch.save(data, dst)
# Note: This script must be placed under the yolov6 repo to run.
def main():
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument(
'--src', default='yolov6s.pt', help='src yolov6 model path')
parser.add_argument('--dst', default='mmyolov6.pt', help='save path')
args = parser.parse_args()
convert(args.src, args.dst)
if __name__ == '__main__':
main()