Skip to content

Commit 3189798

Browse files
authored
implement onnx export for inception3/4, resnext, mobilenetv2 (#346)
* add inceptionv4 backbone/training settings * add converted backbone, top-1 acc 80.08
1 parent 8c90cea commit 3189798

File tree

14 files changed

+761
-48
lines changed

14 files changed

+761
-48
lines changed

configs/classification/imagenet/inception/inceptionv3_b32x8_100e.py

+26-9
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,31 @@
44
# model settings
55
model = dict(
66
type='Classification',
7-
backbone=dict(type='Inception3'),
8-
head=dict(
9-
type='ClsHead',
10-
with_avg_pool=True,
11-
in_channels=2048,
12-
loss_config=dict(
13-
type='CrossEntropyLossWithLabelSmooth',
14-
label_smooth=0,
7+
backbone=dict(type='Inception3', num_classes=1000),
8+
head=[
9+
dict(
10+
type='ClsHead',
11+
with_fc=False,
12+
in_channels=2048,
13+
loss_config=dict(
14+
type='CrossEntropyLossWithLabelSmooth',
15+
label_smooth=0,
16+
),
17+
num_classes=num_classes,
18+
input_feature_index=[1],
1519
),
16-
num_classes=num_classes))
20+
dict(
21+
type='ClsHead',
22+
with_fc=False,
23+
in_channels=768,
24+
loss_config=dict(
25+
type='CrossEntropyLossWithLabelSmooth',
26+
label_smooth=0,
27+
),
28+
num_classes=num_classes,
29+
input_feature_index=[0],
30+
)
31+
])
1732

1833
class_list = [
1934
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13',
@@ -196,3 +211,5 @@
196211
interval=10,
197212
hooks=[dict(type='TextLoggerHook'),
198213
dict(type='TensorboardLoggerHook')])
214+
215+
export = dict(export_type='raw', export_neck=True)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
_base_ = 'configs/classification/imagenet/inception/inceptionv3_b32x8_100e.py'
2+
3+
num_classes = 1000
4+
# model settings
5+
model = dict(
6+
type='Classification',
7+
backbone=dict(type='Inception4', num_classes=num_classes),
8+
head=[
9+
dict(
10+
type='ClsHead',
11+
with_fc=False,
12+
in_channels=1536,
13+
loss_config=dict(
14+
type='CrossEntropyLossWithLabelSmooth',
15+
label_smooth=0,
16+
),
17+
num_classes=num_classes,
18+
input_feature_index=[1],
19+
),
20+
dict(
21+
type='ClsHead',
22+
with_fc=False,
23+
in_channels=768,
24+
loss_config=dict(
25+
type='CrossEntropyLossWithLabelSmooth',
26+
label_smooth=0,
27+
),
28+
num_classes=num_classes,
29+
input_feature_index=[0],
30+
)
31+
])
32+
33+
img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# A config with the optimization settings from https://arxiv.org/pdf/1602.07261
2+
# May run with 20 GPUs
3+
_base_ = 'configs/classification/imagenet/inception/inceptionv3_b32x8_100e.py'
4+
5+
num_classes = 1000
6+
# model settings
7+
model = dict(
8+
type='Classification',
9+
backbone=dict(type='Inception4', num_classes=num_classes),
10+
head=[
11+
dict(
12+
type='ClsHead',
13+
with_fc=False,
14+
in_channels=1536,
15+
loss_config=dict(
16+
type='CrossEntropyLossWithLabelSmooth',
17+
label_smooth=0,
18+
),
19+
num_classes=num_classes,
20+
input_feature_index=[1],
21+
),
22+
dict(
23+
type='ClsHead',
24+
with_fc=False,
25+
in_channels=768,
26+
loss_config=dict(
27+
type='CrossEntropyLossWithLabelSmooth',
28+
label_smooth=0,
29+
),
30+
num_classes=num_classes,
31+
input_feature_index=[0],
32+
)
33+
])
34+
35+
img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
36+
37+
# optimizer
38+
optimizer = dict(
39+
type='RMSprop', lr=0.045, momentum=0.9, weight_decay=0.9, eps=1.0)
40+
41+
# learning policy
42+
lr_config = dict(policy='exp', gamma=0.96954) # gamma**2 ~ 0.94
43+
checkpoint_config = dict(interval=10)
44+
45+
# runtime settings
46+
total_epochs = 200

configs/classification/imagenet/mobilenet/mobilenetv2.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
type='CrossEntropyLossWithLabelSmooth',
1414
label_smooth=0,
1515
),
16-
num_classes=num_classes))
16+
num_classes=num_classes),
17+
pretrained=True)
1718

1819
# optimizer
1920
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
@@ -25,4 +26,4 @@
2526
# runtime settings
2627
total_epochs = 100
2728
checkpoint_sync_export = True
28-
export = dict(export_neck=True)
29+
export = dict(export_type='raw', export_neck=True)

configs/classification/imagenet/resnext/resnext50-32x4d_b32x8_100e_jpg.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
type='CrossEntropyLossWithLabelSmooth',
2020
label_smooth=0,
2121
),
22-
num_classes=num_classes))
22+
num_classes=num_classes),
23+
pretrained=True)
2324

2425
# optimizer
2526
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
@@ -30,3 +31,4 @@
3031

3132
# runtime settings
3233
total_epochs = 100
34+
export = dict(export_type='raw', export_neck=True)

easycv/apis/export.py

+42-27
Original file line numberDiff line numberDiff line change
@@ -158,34 +158,49 @@ def _get_blade_model():
158158

159159

160160
def _export_onnx_cls(model, model_config, cfg, filename, meta):
161+
support_backbones = {
162+
'ResNet': {
163+
'depth': [50]
164+
},
165+
'MobileNetV2': {},
166+
'Inception3': {},
167+
'Inception4': {},
168+
'ResNeXt': {
169+
'depth': [50]
170+
}
171+
}
172+
if model_config['backbone'].get('type', None) not in support_backbones:
173+
tmp = ' '.join(support_backbones.keys())
174+
info_str = f'Only support export onnx model for {tmp} now!'
175+
raise ValueError(info_str)
176+
configs = support_backbones[model_config['backbone'].get('type')]
177+
for k, v in configs.items():
178+
if v[0].__class__(model_config['backbone'].get(k, None)) not in v:
179+
raise ValueError(
180+
f"Unsupport config for {model_config['backbone'].get('type')}")
181+
182+
# save json config for test_pipline and class
183+
with io.open(
184+
filename +
185+
'.config.json' if filename.endswith('onnx') else filename +
186+
'.onnx.config.json', 'w') as ofile:
187+
json.dump(meta, ofile)
161188

162-
if model_config['backbone'].get(
163-
'type', None) == 'ResNet' and model_config['backbone'].get(
164-
'depth', None) == 50:
165-
# save json config for test_pipline and class
166-
with io.open(
167-
filename +
168-
'.config.json' if filename.endswith('onnx') else filename +
169-
'.onnx.config.json', 'w') as ofile:
170-
json.dump(meta, ofile)
171-
172-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
173-
model.eval()
174-
model.to(device)
175-
img_size = int(cfg.image_size2)
176-
x_input = torch.randn((1, 3, img_size, img_size)).to(device)
177-
torch.onnx.export(
178-
model,
179-
(x_input, 'onnx'),
180-
filename if filename.endswith('onnx') else filename + '.onnx',
181-
export_params=True,
182-
opset_version=12,
183-
do_constant_folding=True,
184-
input_names=['input'],
185-
output_names=['output'],
186-
)
187-
else:
188-
raise ValueError('Only support export onnx model for ResNet now!')
189+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
190+
model.eval()
191+
model.to(device)
192+
img_size = int(cfg.image_size2)
193+
x_input = torch.randn((1, 3, img_size, img_size)).to(device)
194+
torch.onnx.export(
195+
model,
196+
(x_input, 'onnx'),
197+
filename if filename.endswith('onnx') else filename + '.onnx',
198+
export_params=True,
199+
opset_version=12,
200+
do_constant_folding=True,
201+
input_names=['input'],
202+
output_names=['output'],
203+
)
189204

190205

191206
def _export_cls(model, cfg, filename):

easycv/models/backbones/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .genet import PlainNet
1111
from .hrnet import HRNet
1212
from .inceptionv3 import Inception3
13+
from .inceptionv4 import Inception4
1314
from .lighthrnet import LiteHRNet
1415
from .mae_vit_transformer import *
1516
from .mit import MixVisionTransformer

easycv/models/backbones/inceptionv3.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
r""" This model is taken from the official PyTorch model zoo.
33
- torchvision.models.inception.py on 31th Aug, 2019
44
"""
5-
6-
from collections import namedtuple
7-
85
import torch
96
import torch.nn as nn
107
import torch.nn.functional as F
@@ -16,8 +13,6 @@
1613

1714
__all__ = ['Inception3']
1815

19-
_InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])
20-
2116

2217
@BACKBONES.register_module
2318
class Inception3(nn.Module):
@@ -113,6 +108,7 @@ def forward(self, x):
113108
# N x 768 x 17 x 17
114109
x = self.Mixed_6e(x)
115110
# N x 768 x 17 x 17
111+
aux = None
116112
if self.training and self.aux_logits:
117113
aux = self.AuxLogits(x)
118114
# N x 768 x 17 x 17
@@ -132,10 +128,7 @@ def forward(self, x):
132128
if hasattr(self, 'fc'):
133129
x = self.fc(x)
134130

135-
# N x 1000 (num_classes)
136-
if self.training and self.aux_logits and hasattr(self, 'fc'):
137-
return [_InceptionOutputs(x, aux)]
138-
return [x]
131+
return [aux, x]
139132

140133

141134
class InceptionA(nn.Module):

0 commit comments

Comments
 (0)