@@ -158,34 +158,49 @@ def _get_blade_model():
158
158
159
159
160
160
def _export_onnx_cls (model , model_config , cfg , filename , meta ):
161
+ support_backbones = {
162
+ 'ResNet' : {
163
+ 'depth' : [50 ]
164
+ },
165
+ 'MobileNetV2' : {},
166
+ 'Inception3' : {},
167
+ 'Inception4' : {},
168
+ 'ResNeXt' : {
169
+ 'depth' : [50 ]
170
+ }
171
+ }
172
+ if model_config ['backbone' ].get ('type' , None ) not in support_backbones :
173
+ tmp = ' ' .join (support_backbones .keys ())
174
+ info_str = f'Only support export onnx model for { tmp } now!'
175
+ raise ValueError (info_str )
176
+ configs = support_backbones [model_config ['backbone' ].get ('type' )]
177
+ for k , v in configs .items ():
178
+ if v [0 ].__class__ (model_config ['backbone' ].get (k , None )) not in v :
179
+ raise ValueError (
180
+ f"Unsupport config for { model_config ['backbone' ].get ('type' )} " )
181
+
182
+ # save json config for test_pipline and class
183
+ with io .open (
184
+ filename +
185
+ '.config.json' if filename .endswith ('onnx' ) else filename +
186
+ '.onnx.config.json' , 'w' ) as ofile :
187
+ json .dump (meta , ofile )
161
188
162
- if model_config ['backbone' ].get (
163
- 'type' , None ) == 'ResNet' and model_config ['backbone' ].get (
164
- 'depth' , None ) == 50 :
165
- # save json config for test_pipline and class
166
- with io .open (
167
- filename +
168
- '.config.json' if filename .endswith ('onnx' ) else filename +
169
- '.onnx.config.json' , 'w' ) as ofile :
170
- json .dump (meta , ofile )
171
-
172
- device = 'cuda' if torch .cuda .is_available () else 'cpu'
173
- model .eval ()
174
- model .to (device )
175
- img_size = int (cfg .image_size2 )
176
- x_input = torch .randn ((1 , 3 , img_size , img_size )).to (device )
177
- torch .onnx .export (
178
- model ,
179
- (x_input , 'onnx' ),
180
- filename if filename .endswith ('onnx' ) else filename + '.onnx' ,
181
- export_params = True ,
182
- opset_version = 12 ,
183
- do_constant_folding = True ,
184
- input_names = ['input' ],
185
- output_names = ['output' ],
186
- )
187
- else :
188
- raise ValueError ('Only support export onnx model for ResNet now!' )
189
+ device = 'cuda' if torch .cuda .is_available () else 'cpu'
190
+ model .eval ()
191
+ model .to (device )
192
+ img_size = int (cfg .image_size2 )
193
+ x_input = torch .randn ((1 , 3 , img_size , img_size )).to (device )
194
+ torch .onnx .export (
195
+ model ,
196
+ (x_input , 'onnx' ),
197
+ filename if filename .endswith ('onnx' ) else filename + '.onnx' ,
198
+ export_params = True ,
199
+ opset_version = 12 ,
200
+ do_constant_folding = True ,
201
+ input_names = ['input' ],
202
+ output_names = ['output' ],
203
+ )
189
204
190
205
191
206
def _export_cls (model , cfg , filename ):
0 commit comments