Skip to content

Commit 8c90cea

Browse files
authored
add resnet export of onnx (#341)
* add checkpoint_sync_export for resnet config
1 parent 1d1ac8a commit 8c90cea

File tree

9 files changed

+166
-14
lines changed

9 files changed

+166
-14
lines changed

configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py

+3
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,6 @@
66
depth=50,
77
out_indices=[4], # 0: conv-1, x: stage-x
88
norm_cfg=dict(type='BN')))
9+
10+
checkpoint_sync_export = True
11+
export = dict(export_type='raw', export_neck=True)

easycv/apis/export.py

+40-3
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,37 @@ def _get_blade_model():
157157
torch.jit.save(blade_model, ofile)
158158

159159

160+
def _export_onnx_cls(model, model_config, cfg, filename, meta):
161+
162+
if model_config['backbone'].get(
163+
'type', None) == 'ResNet' and model_config['backbone'].get(
164+
'depth', None) == 50:
165+
# save json config for test_pipline and class
166+
with io.open(
167+
filename +
168+
'.config.json' if filename.endswith('onnx') else filename +
169+
'.onnx.config.json', 'w') as ofile:
170+
json.dump(meta, ofile)
171+
172+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
173+
model.eval()
174+
model.to(device)
175+
img_size = int(cfg.image_size2)
176+
x_input = torch.randn((1, 3, img_size, img_size)).to(device)
177+
torch.onnx.export(
178+
model,
179+
(x_input, 'onnx'),
180+
filename if filename.endswith('onnx') else filename + '.onnx',
181+
export_params=True,
182+
opset_version=12,
183+
do_constant_folding=True,
184+
input_names=['input'],
185+
output_names=['output'],
186+
)
187+
else:
188+
raise ValueError('Only support export onnx model for ResNet now!')
189+
190+
160191
def _export_cls(model, cfg, filename):
161192
""" export cls (cls & metric learning)model and preprocess config
162193
@@ -170,6 +201,7 @@ def _export_cls(model, cfg, filename):
170201
else:
171202
export_cfg = dict(export_neck=False)
172203

204+
export_type = export_cfg.get('export_type', 'raw')
173205
export_neck = export_cfg.get('export_neck', True)
174206
label_map_path = cfg.get('label_map_path', None)
175207
class_list = None
@@ -232,9 +264,14 @@ def _export_cls(model, cfg, filename):
232264
if export_neck and (k.startswith('neck') or k.startswith('head')):
233265
state_dict[k] = v
234266

235-
checkpoint = dict(state_dict=state_dict, meta=meta, author='EasyCV')
236-
with io.open(filename, 'wb') as ofile:
237-
torch.save(checkpoint, ofile)
267+
if export_type == 'raw':
268+
checkpoint = dict(state_dict=state_dict, meta=meta, author='EasyCV')
269+
with io.open(filename, 'wb') as ofile:
270+
torch.save(checkpoint, ofile)
271+
elif export_type == 'onnx':
272+
_export_onnx_cls(model, model_config, cfg, filename, config)
273+
else:
274+
raise ValueError('Only support export onnx/raw model!')
238275

239276

240277
def _export_yolox(model, cfg, filename):

easycv/models/classification/classification.py

+17
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,20 @@ def forward_backbone(self, img: torch.Tensor) -> List[torch.Tensor]:
151151
x = self.backbone(img)
152152
return x
153153

154+
def forward_onnx(self, img: torch.Tensor) -> Dict[str, torch.Tensor]:
155+
"""
156+
forward_onnx means generate prob from image only support one neck + one head
157+
"""
158+
x = self.forward_backbone(img) # tuple
159+
160+
# if self.neck_num > 0:
161+
if hasattr(self, 'neck_0'):
162+
x = self.neck_0([i for i in x])
163+
164+
out = self.head_0(x)[0].cpu()
165+
out = self.activate_fn(out)
166+
return out
167+
154168
@torch.jit.unused
155169
def forward_train(self, img, gt_labels) -> Dict[str, torch.Tensor]:
156170
"""
@@ -290,6 +304,9 @@ def forward(
290304
return self.forward_test_label(img, gt_labels)
291305
else:
292306
return self.forward_test(img)
307+
elif mode == 'onnx':
308+
return self.forward_onnx(img)
309+
293310
elif mode == 'extract':
294311
rd = self.forward_feature(img)
295312
rv = {}

easycv/predictors/classifier.py

+58
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,26 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
import glob
23
import math
4+
import os
35

46
import numpy as np
57
import torch
68
from PIL import Image
79

810
from easycv.file import io
911
from easycv.framework.errors import ValueError
12+
from easycv.utils.checkpoint import load_checkpoint
1013
from easycv.utils.misc import deprecated
1114
from .base import InputProcessor, OutputProcessor, Predictor, PredictorV2
1215
from .builder import PREDICTORS
1316

1417

18+
# onnx specific
19+
def onnx_to_numpy(tensor):
20+
return tensor.detach().cpu().numpy(
21+
) if tensor.requires_grad else tensor.cpu().numpy()
22+
23+
1524
class ClsInputProcessor(InputProcessor):
1625
"""Process inputs for classification models.
1726
@@ -146,6 +155,20 @@ def __init__(self,
146155
self.pil_input = pil_input
147156
self.label_map_path = label_map_path
148157

158+
if model_path.endswith('onnx'):
159+
self.model_type = 'onnx'
160+
pwd_model = os.path.dirname(model_path)
161+
raw_model = glob.glob(
162+
os.path.join(pwd_model, '*.onnx.config.json'))
163+
if len(raw_model) != 0:
164+
config_file = raw_model[0]
165+
else:
166+
assert len(
167+
raw_model
168+
) == 0, 'Please have a file with the .onnx.config.json extension in your directory'
169+
else:
170+
self.model_type = 'raw'
171+
149172
if self.pil_input:
150173
mode = 'RGB'
151174
super(ClassificationPredictor, self).__init__(
@@ -186,6 +209,41 @@ def get_output_processor(self):
186209

187210
return ClsOutputProcessor(topk=self.topk, label_map=self.label_map)
188211

212+
def prepare_model(self):
213+
"""Build model from config file by default.
214+
If the model is not loaded from a configuration file, e.g. torch jit model, you need to reimplement it.
215+
"""
216+
if self.model_type == 'raw':
217+
model = self._build_model()
218+
model.to(self.device)
219+
model.eval()
220+
load_checkpoint(model, self.model_path, map_location='cpu')
221+
return model
222+
else:
223+
import onnxruntime
224+
if onnxruntime.get_device() == 'GPU':
225+
onnx_model = onnxruntime.InferenceSession(
226+
self.model_path, providers=['CUDAExecutionProvider'])
227+
else:
228+
onnx_model = onnxruntime.InferenceSession(self.model_path)
229+
230+
return onnx_model
231+
232+
def model_forward(self, inputs):
233+
"""Model forward.
234+
If you need refactor model forward, you need to reimplement it.
235+
"""
236+
with torch.no_grad():
237+
if self.model_type == 'raw':
238+
outputs = self.model(**inputs, mode='test')
239+
else:
240+
outputs = self.model.run(None, {
241+
self.model.get_inputs()[0].name:
242+
onnx_to_numpy(inputs['img'])
243+
})[0]
244+
outputs = dict(prob=torch.from_numpy(outputs))
245+
return outputs
246+
189247

190248
try:
191249
from easy_vision.python.inference.predictor import PredictorInterface

easycv/version.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
# GENERATED VERSION FILE
33
# TIME: Thu Nov 5 14:17:50 2020
44

5-
__version__ = '0.11.6'
6-
short_version = '0.11.6'
5+
__version__ = '0.11.7'
6+
short_version = '0.11.7'

tests/test_apis/test_export.py

+22
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def test_export_cls_syncbn(self):
116116
cfg = mmcv_config_fromfile(config_file)
117117
cfg_options = {
118118
'model.backbone.norm_cfg.type': 'SyncBN',
119+
'export.export_type': 'raw'
119120
}
120121
if cfg_options is not None:
121122
cfg.merge_from_dict(cfg_options)
@@ -210,6 +211,27 @@ def test_export_stgcn_jit(self):
210211

211212
self.assertTrue(os.path.exists(filename + '.jit'))
212213

214+
def test_export_resnet_onnx(self):
215+
216+
ckpt_path = PRETRAINED_MODEL_RESNET50
217+
218+
easycv_dir = os.path.dirname(easycv.__file__)
219+
220+
if os.path.exists(os.path.join(easycv_dir, 'configs')):
221+
config_dir = os.path.join(easycv_dir, 'configs')
222+
else:
223+
config_dir = os.path.join(os.path.dirname(easycv_dir), 'configs')
224+
config_file = os.path.join(
225+
config_dir,
226+
'classification/imagenet/resnet/imagenet_resnet50_jpg.py')
227+
228+
with tempfile.TemporaryDirectory() as tmpdir:
229+
cfg = mmcv_config_fromfile(config_file)
230+
cfg.export.export_type = 'onnx'
231+
filename = os.path.join(tmpdir, 'imagenet_resnet50')
232+
export(cfg, ckpt_path, filename)
233+
self.assertTrue(os.path.exists(filename + '.onnx'))
234+
213235

214236
if __name__ == '__main__':
215237
unittest.main()

tests/test_predictors/test_classifier.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from easycv.predictors.classifier import ClassificationPredictor
1212
from easycv.utils.test_util import clean_up, get_tmp_dir
1313
from tests.ut_config import (PRETRAINED_MODEL_RESNET50_WITHOUTHEAD,
14-
IMAGENET_LABEL_TXT, TEST_IMAGES_DIR)
14+
IMAGENET_LABEL_TXT, TEST_IMAGES_DIR,
15+
PRETRAINED_MODEL_RESNET50_ONNX_WITHOUTHEAD)
1516

1617

1718
class ClassificationPredictorTest(unittest.TestCase):
@@ -33,6 +34,17 @@ def test_single(self):
3334
self.assertListEqual(results['class_name'], ['"Persian cat",'])
3435
self.assertEqual(len(results['class_probs']), 1000)
3536

37+
def test_onnx_single(self):
38+
checkpoint = PRETRAINED_MODEL_RESNET50_ONNX_WITHOUTHEAD
39+
predict_op = ClassificationPredictor(model_path=checkpoint)
40+
41+
img_path = os.path.join(TEST_IMAGES_DIR, 'catb.jpg')
42+
43+
results = predict_op([img_path])[0]
44+
self.assertListEqual(results['class'], [578])
45+
self.assertListEqual(results['class_name'], ['gown'])
46+
self.assertEqual(len(results['class_probs']), 1000)
47+
3648
def test_batch(self):
3749
checkpoint = PRETRAINED_MODEL_RESNET50_WITHOUTHEAD
3850
config_file = 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py'

tests/test_predictors/test_pose_predictor.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ def _base_test(self, predictor):
5454

5555
assert_array_almost_equal(
5656
result0['bbox'],
57-
np.array([[352.3085, 59.00325, 691.4247, 511.15814, 1.],
58-
[10.511196, 177.74883, 101.824326, 299.49966, 1.],
59-
[224.82036, 114.439865, 312.51306, 231.36348, 1.],
60-
[200.71407, 114.716736, 337.17535, 296.6651, 1.]],
57+
np.array([[438.9, 59., 604.8, 511.2, 0.9],
58+
[10.5, 179.6, 101.8, 297.7, 0.9],
59+
[229.6, 114.4, 307.8, 231.4, 0.6],
60+
[229.4, 114.7, 308.5, 296.7, 0.6]],
6161
dtype=np.float32),
6262
decimal=1)
6363
vis_result = predictor.show_result(img1, result0)
@@ -92,10 +92,10 @@ def _base_test(self, predictor):
9292

9393
assert_array_almost_equal(
9494
result1['bbox'][:4],
95-
np.array([[436.23096, 214.72766, 584.26013, 412.09985, 1.],
96-
[43.990044, 91.04126, 164.28406, 251.43329, 1.],
97-
[127.44148, 100.38604, 254.219, 269.42273, 1.],
98-
[190.08075, 117.31801, 311.22394, 278.8423, 1.]],
95+
np.array([[470.6, 214.7, 549.9, 412.1, 0.9],
96+
[71.6, 91., 136.7, 251.4, 0.9],
97+
[159.7, 100.4, 221.9, 269.4, 0.9],
98+
[219.4, 117.3, 281.9, 278.8, 0.9]],
9999
dtype=np.float32),
100100
decimal=1)
101101
vis_result = predictor.show_result(img2, result1)

tests/ut_config.py

+3
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@
179179
PRETRAINED_MODEL_RESNET50_WITHOUTHEAD = os.path.join(
180180
BASE_LOCAL_PATH,
181181
'pretrained_models/classification/resnet/resnet50_withhead.pth')
182+
PRETRAINED_MODEL_RESNET50_ONNX_WITHOUTHEAD = os.path.join(
183+
BASE_LOCAL_PATH,
184+
'pretrained_models/classification/resnet/imagenet_resnet50.onnx')
182185
PRETRAINED_MODEL_FACEID = os.path.join(BASE_LOCAL_PATH,
183186
'pretrained_models/faceid')
184187
PRETRAINED_MODEL_YOLOXS_EXPORT = os.path.join(

0 commit comments

Comments
 (0)