Skip to content

Commit b294e30

Browse files
committed
feat: introduce predict output parameter
1 parent 1e3fe0a commit b294e30

File tree

2 files changed

+41
-64
lines changed

2 files changed

+41
-64
lines changed

src/backends/ncnn/ncnnlib.cc

+21-50
Original file line numberDiff line numberDiff line change
@@ -186,63 +186,33 @@ namespace dd
186186
}
187187

188188
APIData ad_output = ad.getobj("parameters").getobj("output");
189-
190-
// Get bbox
191-
bool bbox = false;
192-
if (ad_output.has("bbox"))
193-
bbox = ad_output.get("bbox").get<bool>();
194-
195-
// Ctc model
196-
bool ctc = false;
197-
int blank_label = -1;
198-
if (ad_output.has("ctc"))
199-
{
200-
ctc = ad_output.get("ctc").get<bool>();
201-
if (ctc)
202-
{
203-
if (ad_output.has("blank_label"))
204-
blank_label = ad_output.get("blank_label").get<int>();
205-
}
206-
}
189+
auto output_params
190+
= ad_output.createSharedDTO<PredictOutputParametersDto>();
207191

208192
// Extract detection or classification
209-
int ret = 0;
210193
std::string out_blob;
211194
if (_init_dto->outputBlob != nullptr)
212195
out_blob = _init_dto->outputBlob->std_str();
213196

214197
if (out_blob.empty())
215198
{
216-
if (bbox == true)
199+
if (output_params->bbox == true)
217200
out_blob = "detection_out";
218-
else if (ctc == true)
201+
else if (output_params->ctc == true)
219202
out_blob = "probs";
220203
else if (_timeserie)
221204
out_blob = "rnn_pred";
222205
else
223206
out_blob = "prob";
224207
}
225208

226-
std::vector<APIData> vrad;
227-
228-
// Get confidence_threshold
229-
float confidence_threshold = 0.0;
230-
if (ad_output.has("confidence_threshold"))
231-
{
232-
apitools::get_float(ad_output, "confidence_threshold",
233-
confidence_threshold);
234-
}
235-
236209
// Get best
237-
int best = -1;
238-
if (ad_output.has("best"))
239-
{
240-
best = ad_output.get("best").get<int>();
241-
}
242-
if (best == -1 || best > _init_dto->nclasses)
243-
best = _init_dto->nclasses;
210+
if (output_params->best == -1 || output_params->best > _init_dto->nclasses)
211+
output_params->best = _init_dto->nclasses;
212+
213+
std::vector<APIData> vrad;
244214

245-
// for loop around batch size
215+
// for loop around batch size
246216
#pragma omp parallel for num_threads(*_init_dto->threads)
247217
for (size_t b = 0; b < inputc._ids.size(); b++)
248218
{
@@ -256,13 +226,13 @@ namespace dd
256226
ex.set_num_threads(_init_dto->threads);
257227
ex.input(_init_dto->inputBlob->c_str(), inputc._in.at(b));
258228

259-
ret = ex.extract(out_blob.c_str(), inputc._out.at(b));
229+
int ret = ex.extract(out_blob.c_str(), inputc._out.at(b));
260230
if (ret == -1)
261231
{
262232
throw MLLibInternalException("NCNN internal error");
263233
}
264234

265-
if (bbox == true)
235+
if (output_params->bbox == true)
266236
{
267237
std::string uri = inputc._ids.at(b);
268238
auto bit = inputc._imgs_size.find(uri);
@@ -282,7 +252,7 @@ namespace dd
282252
for (int i = 0; i < inputc._out.at(b).h; i++)
283253
{
284254
const float *values = inputc._out.at(b).row(i);
285-
if (values[1] < confidence_threshold)
255+
if (values[1] < output_params->confidence_threshold)
286256
break; // output is sorted by confidence
287257

288258
cats.push_back(this->_mlmodel.get_hcorresp(values[0]));
@@ -300,7 +270,7 @@ namespace dd
300270
bboxes.push_back(ad_bbox);
301271
}
302272
}
303-
else if (ctc == true)
273+
else if (output_params->ctc == true)
304274
{
305275
int alphabet = inputc._out.at(b).w;
306276
int time_step = inputc._out.at(b).h;
@@ -313,11 +283,11 @@ namespace dd
313283
}
314284

315285
std::vector<int> pred_label_seq;
316-
int prev = blank_label;
286+
int prev = output_params->blank_label;
317287
for (int t = 0; t < time_step; ++t)
318288
{
319289
int cur = pred_label_seq_with_blank[t];
320-
if (cur != prev && cur != blank_label)
290+
if (cur != prev && cur != output_params->blank_label)
321291
pred_label_seq.push_back(cur);
322292
prev = cur;
323293
}
@@ -365,12 +335,13 @@ namespace dd
365335
vec[i] = std::make_pair(cls_scores[i], i);
366336
}
367337

368-
std::partial_sort(vec.begin(), vec.begin() + best, vec.end(),
338+
std::partial_sort(vec.begin(), vec.begin() + output_params->best,
339+
vec.end(),
369340
std::greater<std::pair<float, int>>());
370341

371-
for (int i = 0; i < best; i++)
342+
for (int i = 0; i < output_params->best; i++)
372343
{
373-
if (vec[i].first < confidence_threshold)
344+
if (vec[i].first < output_params->confidence_threshold)
374345
continue;
375346
cats.push_back(this->_mlmodel.get_hcorresp(vec[i].second));
376347
probs.push_back(vec[i].first);
@@ -380,7 +351,7 @@ namespace dd
380351
rad.add("uri", inputc._ids.at(b));
381352
rad.add("loss", 0.0);
382353
rad.add("cats", cats);
383-
if (bbox == true)
354+
if (output_params->bbox == true)
384355
rad.add("bboxes", bboxes);
385356
if (_timeserie)
386357
{
@@ -402,7 +373,7 @@ namespace dd
402373
tout.add_results(vrad);
403374
int nclasses = this->_init_dto->nclasses;
404375
out.add("nclasses", nclasses);
405-
if (bbox == true)
376+
if (output_params->bbox == true)
406377
out.add("bbox", true);
407378
out.add("roi", false);
408379
out.add("multibox_rois", false);

src/supervisedoutputconnector.h

+20-14
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
#define SUPERVISEDOUTPUTCONNECTOR_H
2424
#define TS_METRICS_EPSILON 1E-2
2525

26+
#include "http/dto/predict.hpp"
27+
2628
template <typename T>
2729
bool SortScorePairDescend(const std::pair<double, T> &pair1,
2830
const std::pair<double, T> &pair2)
@@ -161,10 +163,11 @@ namespace dd
161163
void init(const APIData &ad)
162164
{
163165
APIData ad_out = ad.getobj("parameters").getobj("output");
164-
if (ad_out.has("best"))
165-
_best = ad_out.get("best").get<int>();
166+
auto output_params
167+
= ad_out.createSharedDTO<PredictOutputParametersDto>();
168+
_best = output_params->best;
166169
if (_best == -1)
167-
_best = ad_out.get("nclasses").get<int>();
170+
_best = output_params->nclasses;
168171
}
169172

170173
/**
@@ -242,13 +245,13 @@ namespace dd
242245
* @param ad_out output data object
243246
* @param bcats supervised output connector
244247
*/
245-
void best_cats(const APIData &ad_out, SupervisedOutput &bcats,
248+
void best_cats(SupervisedOutput &bcats, const int &output_param_best,
246249
const int &nclasses, const bool &has_bbox,
247250
const bool &has_roi, const bool &has_mask) const
248251
{
249252
int best = _best;
250-
if (ad_out.has("best"))
251-
best = ad_out.get("best").get<int>();
253+
if (output_param_best != -1)
254+
best = output_param_best;
252255
if (best == -1)
253256
best = nclasses;
254257
if (!has_bbox && !has_roi && !has_mask)
@@ -399,6 +402,8 @@ namespace dd
399402
*/
400403
void finalize(const APIData &ad_in, APIData &ad_out, MLModel *mlm)
401404
{
405+
auto output_params = ad_in.createSharedDTO<PredictOutputParametersDto>();
406+
402407
#ifndef USE_SIMSEARCH
403408
(void)mlm;
404409
#endif
@@ -443,12 +448,13 @@ namespace dd
443448
}
444449

445450
if (!timeseries)
446-
best_cats(ad_in, bcats, nclasses, has_bbox, has_roi, has_mask);
451+
best_cats(bcats, output_params->best, nclasses, has_bbox, has_roi,
452+
has_mask);
447453

448454
std::unordered_set<std::string> indexed_uris;
449455
#ifdef USE_SIMSEARCH
450456
// index
451-
if (ad_in.has("index") && ad_in.get("index").get<bool>())
457+
if (output_params->index == true)
452458
{
453459
// check whether index has been created
454460
if (!mlm->_se)
@@ -553,7 +559,7 @@ namespace dd
553559
}
554560

555561
// build index
556-
if (ad_in.has("build_index") && ad_in.get("build_index").get<bool>())
562+
if (output_params->build_index == true)
557563
{
558564
if (mlm->_se)
559565
mlm->build_index();
@@ -562,7 +568,7 @@ namespace dd
562568
}
563569

564570
// search
565-
if (ad_in.has("search") && ad_in.get("search").get<bool>())
571+
if (output_params->search == true)
566572
{
567573
// check whether index has been created
568574
if (!mlm->_se)
@@ -582,11 +588,11 @@ namespace dd
582588
int search_nn = _best;
583589
if (has_roi)
584590
search_nn = _search_nn;
585-
if (ad_in.has("search_nn"))
586-
search_nn = ad_in.get("search_nn").get<int>();
591+
if (output_params->search_nn != nullptr)
592+
search_nn = output_params->search_nn;
587593
#ifdef USE_FAISS
588-
if (ad_in.has("nprobe"))
589-
mlm->_se->_tse->_nprobe = ad_in.get("nprobe").get<int>();
594+
if (output_params->nprobe != nullptr)
595+
mlm->_se->_tse->_nprobe = output_params->nprobe;
590596
#endif
591597
if (!has_roi)
592598
{

0 commit comments

Comments
 (0)