@@ -186,63 +186,33 @@ namespace dd
186
186
}
187
187
188
188
APIData ad_output = ad.getobj (" parameters" ).getobj (" output" );
189
-
190
- // Get bbox
191
- bool bbox = false ;
192
- if (ad_output.has (" bbox" ))
193
- bbox = ad_output.get (" bbox" ).get <bool >();
194
-
195
- // Ctc model
196
- bool ctc = false ;
197
- int blank_label = -1 ;
198
- if (ad_output.has (" ctc" ))
199
- {
200
- ctc = ad_output.get (" ctc" ).get <bool >();
201
- if (ctc)
202
- {
203
- if (ad_output.has (" blank_label" ))
204
- blank_label = ad_output.get (" blank_label" ).get <int >();
205
- }
206
- }
189
+ auto output_params
190
+ = ad_output.createSharedDTO <PredictOutputParametersDto>();
207
191
208
192
// Extract detection or classification
209
- int ret = 0 ;
210
193
std::string out_blob;
211
194
if (_init_dto->outputBlob != nullptr )
212
195
out_blob = _init_dto->outputBlob ->std_str ();
213
196
214
197
if (out_blob.empty ())
215
198
{
216
- if (bbox == true )
199
+ if (output_params-> bbox == true )
217
200
out_blob = " detection_out" ;
218
- else if (ctc == true )
201
+ else if (output_params-> ctc == true )
219
202
out_blob = " probs" ;
220
203
else if (_timeserie)
221
204
out_blob = " rnn_pred" ;
222
205
else
223
206
out_blob = " prob" ;
224
207
}
225
208
226
- std::vector<APIData> vrad;
227
-
228
- // Get confidence_threshold
229
- float confidence_threshold = 0.0 ;
230
- if (ad_output.has (" confidence_threshold" ))
231
- {
232
- apitools::get_float (ad_output, " confidence_threshold" ,
233
- confidence_threshold);
234
- }
235
-
236
209
// Get best
237
- int best = -1 ;
238
- if (ad_output.has (" best" ))
239
- {
240
- best = ad_output.get (" best" ).get <int >();
241
- }
242
- if (best == -1 || best > _init_dto->nclasses )
243
- best = _init_dto->nclasses ;
210
+ if (output_params->best == -1 || output_params->best > _init_dto->nclasses )
211
+ output_params->best = _init_dto->nclasses ;
212
+
213
+ std::vector<APIData> vrad;
244
214
245
- // for loop around batch size
215
+ // for loop around batch size
246
216
#pragma omp parallel for num_threads(*_init_dto->threads)
247
217
for (size_t b = 0 ; b < inputc._ids .size (); b++)
248
218
{
@@ -256,13 +226,13 @@ namespace dd
256
226
ex.set_num_threads (_init_dto->threads );
257
227
ex.input (_init_dto->inputBlob ->c_str (), inputc._in .at (b));
258
228
259
- ret = ex.extract (out_blob.c_str (), inputc._out .at (b));
229
+ int ret = ex.extract (out_blob.c_str (), inputc._out .at (b));
260
230
if (ret == -1 )
261
231
{
262
232
throw MLLibInternalException (" NCNN internal error" );
263
233
}
264
234
265
- if (bbox == true )
235
+ if (output_params-> bbox == true )
266
236
{
267
237
std::string uri = inputc._ids .at (b);
268
238
auto bit = inputc._imgs_size .find (uri);
@@ -282,7 +252,7 @@ namespace dd
282
252
for (int i = 0 ; i < inputc._out .at (b).h ; i++)
283
253
{
284
254
const float *values = inputc._out .at (b).row (i);
285
- if (values[1 ] < confidence_threshold)
255
+ if (values[1 ] < output_params-> confidence_threshold )
286
256
break ; // output is sorted by confidence
287
257
288
258
cats.push_back (this ->_mlmodel .get_hcorresp (values[0 ]));
@@ -300,7 +270,7 @@ namespace dd
300
270
bboxes.push_back (ad_bbox);
301
271
}
302
272
}
303
- else if (ctc == true )
273
+ else if (output_params-> ctc == true )
304
274
{
305
275
int alphabet = inputc._out .at (b).w ;
306
276
int time_step = inputc._out .at (b).h ;
@@ -313,11 +283,11 @@ namespace dd
313
283
}
314
284
315
285
std::vector<int > pred_label_seq;
316
- int prev = blank_label;
286
+ int prev = output_params-> blank_label ;
317
287
for (int t = 0 ; t < time_step; ++t)
318
288
{
319
289
int cur = pred_label_seq_with_blank[t];
320
- if (cur != prev && cur != blank_label)
290
+ if (cur != prev && cur != output_params-> blank_label )
321
291
pred_label_seq.push_back (cur);
322
292
prev = cur;
323
293
}
@@ -365,12 +335,13 @@ namespace dd
365
335
vec[i] = std::make_pair (cls_scores[i], i);
366
336
}
367
337
368
- std::partial_sort (vec.begin (), vec.begin () + best, vec.end (),
338
+ std::partial_sort (vec.begin (), vec.begin () + output_params->best ,
339
+ vec.end (),
369
340
std::greater<std::pair<float , int >>());
370
341
371
- for (int i = 0 ; i < best; i++)
342
+ for (int i = 0 ; i < output_params-> best ; i++)
372
343
{
373
- if (vec[i].first < confidence_threshold)
344
+ if (vec[i].first < output_params-> confidence_threshold )
374
345
continue ;
375
346
cats.push_back (this ->_mlmodel .get_hcorresp (vec[i].second ));
376
347
probs.push_back (vec[i].first );
@@ -380,7 +351,7 @@ namespace dd
380
351
rad.add (" uri" , inputc._ids .at (b));
381
352
rad.add (" loss" , 0.0 );
382
353
rad.add (" cats" , cats);
383
- if (bbox == true )
354
+ if (output_params-> bbox == true )
384
355
rad.add (" bboxes" , bboxes);
385
356
if (_timeserie)
386
357
{
@@ -402,7 +373,7 @@ namespace dd
402
373
tout.add_results (vrad);
403
374
int nclasses = this ->_init_dto ->nclasses ;
404
375
out.add (" nclasses" , nclasses);
405
- if (bbox == true )
376
+ if (output_params-> bbox == true )
406
377
out.add (" bbox" , true );
407
378
out.add (" roi" , false );
408
379
out.add (" multibox_rois" , false );
0 commit comments