diff --git a/CMakeLists.txt b/CMakeLists.txt index adf86b303c..3122c82897 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -87,7 +87,7 @@ if (NOT EXISTS ${CMAKE_BINARY_DIR}/src) COMMAND bash -c "mkdir ${CMAKE_BINARY_DIR}/src") endif() -set(CMAKE_CXX_FLAGS "-g -O2 -Wall -Wextra -fopenmp -fPIC -std=c++14 -DUSE_OPENCV -DUSE_LMDB") +set(CMAKE_CXX_FLAGS "-g -O2 -Wall -Wextra -fopenmp -fPIC -std=c++14 -DUSE_OPENCV -DUSE_LMDB -fmax-errors=4") if(WARNING) string(APPEND CMAKE_CXX_FLAGS " -Werror") diff --git a/src/apidata.h b/src/apidata.h index 90389a53e8..d93992eb2c 100644 --- a/src/apidata.h +++ b/src/apidata.h @@ -298,6 +298,20 @@ namespace dd .getPtr(); } + template + inline static APIData fromDTO(const oatpp::Object &dto) + { + std::shared_ptr object_mapper + = oatpp::parser::json::mapping::ObjectMapper::createShared(); + + oatpp::String json = object_mapper->writeToString(dto); + APIData ad; + rapidjson::Document d; + d.Parse(json->c_str()); + ad.fromRapidJson(d); + return ad; + } + public: /** * \brief render Mustache template based on this APIData object diff --git a/src/backends/ncnn/ncnnlib.cc b/src/backends/ncnn/ncnnlib.cc index 81f3bf4ce4..cf5166109d 100644 --- a/src/backends/ncnn/ncnnlib.cc +++ b/src/backends/ncnn/ncnnlib.cc @@ -30,6 +30,8 @@ #include "net.h" #include +#include "dto/mllib.hpp" + namespace dd { template void NCNNLib::init_mllib(const APIData &ad) + TMLModel>::init_mllib(const oatpp::Object &init_dto) { - _init_dto = ad.createSharedDTO(); + _init_dto = init_dto; - bool use_fp32 = (ad.has("datatype") - && ad.get("datatype").get() - == "fp32"); // default is fp16 + bool use_fp32 = (_init_dto->datatype == "fp32"); _net->opt.use_fp16_packed = !use_fp32; _net->opt.use_fp16_storage = !use_fp32; _net->opt.use_fp16_arithmetic = !use_fp32; diff --git a/src/backends/ncnn/ncnnlib.h b/src/backends/ncnn/ncnnlib.h index 63f4fea481..2f74f215da 100644 --- a/src/backends/ncnn/ncnnlib.h +++ b/src/backends/ncnn/ncnnlib.h @@ -25,7 +25,7 @@ #include "apidata.h" #include "utils/utils.hpp" -#include "dto/ncnn.hpp" +#include "dto/mllib.hpp" // NCNN #include "net.h" @@ -44,7 +44,7 @@ namespace dd ~NCNNLib(); /*- from mllib -*/ - void init_mllib(const APIData &ad); + void init_mllib(const oatpp::Object &init_dto); void clear_mllib(const APIData &ad); @@ -59,7 +59,7 @@ namespace dd bool _timeserie = false; private: - std::shared_ptr _init_dto; + oatpp::Object _init_dto; static ncnn::UnlockedPoolAllocator _blob_pool_allocator; static ncnn::PoolAllocator _workspace_pool_allocator; diff --git a/src/backends/ncnn/ncnnmodel.h b/src/backends/ncnn/ncnnmodel.h index f4bc75590a..d3c8d9d68c 100644 --- a/src/backends/ncnn/ncnnmodel.h +++ b/src/backends/ncnn/ncnnmodel.h @@ -25,6 +25,8 @@ #include "dd_spdlog.h" #include "mlmodel.h" #include "apidata.h" +#include "dto/model.hpp" +#include "dto/service_create.hpp" namespace dd { @@ -34,12 +36,13 @@ namespace dd NCNNModel() : MLModel() { } - NCNNModel(const APIData &ad, APIData &adg, + NCNNModel(const oatpp::Object &model_dto, + const oatpp::Object &service_dto, const std::shared_ptr &logger) - : MLModel(ad, adg, logger) + : MLModel(model_dto, service_dto, logger) { - if (ad.has("repository")) - this->_repo = ad.get("repository").get(); + if (model_dto->repository) + this->_repo = model_dto->repository->std_str(); read_from_repository(spdlog::get("api")); read_corresp_file(); } diff --git a/src/dto/img_connector.hpp b/src/dto/input_connector.hpp similarity index 87% rename from src/dto/img_connector.hpp rename to src/dto/input_connector.hpp index cada6d85ea..5130f33ad5 100644 --- a/src/dto/img_connector.hpp +++ b/src/dto/input_connector.hpp @@ -19,8 +19,8 @@ * along with deepdetect. If not, see . */ -#ifndef HTTP_DTO_IMG_CONNECTOR_HPP -#define HTTP_DTO_IMG_CONNECTOR_HPP +#ifndef DTO_INPUT_CONNECTOR_HPP +#define DTO_INPUT_CONNECTOR_HPP #include "dd_config.h" #include "oatpp/core/Types.hpp" @@ -32,10 +32,13 @@ namespace dd { #include OATPP_CODEGEN_BEGIN(DTO) - class ImgInputConnectorParameters : public oatpp::DTO + class InputConnector : public oatpp::DTO { - DTO_INIT(ImgInputConnectorParameters, DTO /* extends */) + DTO_INIT(InputConnector, DTO /* extends */) + // Connector type + DTO_FIELD(String, connector); + // IMG Input Connector DTO_FIELD(Int32, width); DTO_FIELD(Int32, height); DTO_FIELD(Int32, crop_width); diff --git a/src/dto/mllib.hpp b/src/dto/mllib.hpp new file mode 100644 index 0000000000..5a9ea738b1 --- /dev/null +++ b/src/dto/mllib.hpp @@ -0,0 +1,85 @@ +/** + * DeepDetect + * Copyright (c) 2021 Jolibrain SASU + * Author: Mehdi Abaakouk + * + * This file is part of deepdetect. + * + * deepdetect is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * deepdetect is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with deepdetect. If not, see . + */ + +#ifndef DTO_MLLIB_H +#define DTO_MLLIB_H + +#include "dd_config.h" +#include "utils/utils.hpp" +#include "oatpp/core/Types.hpp" +#include "oatpp/core/macro/codegen.hpp" + +namespace dd +{ + namespace DTO + { +#include OATPP_CODEGEN_BEGIN(DTO) ///< Begin DTO codegen section + +class MLLib : public oatpp::DTO +{ + DTO_INIT(MLLib, DTO /* extends */) + + // NCNN Options + DTO_FIELD_INFO(nclasses) + { + info->description = "number of output classes (`supervised` service " + "type), classification only"; + }; + DTO_FIELD(Int32, nclasses) = 0; + + DTO_FIELD_INFO(threads) + { + info->description = "number of threads"; + }; + DTO_FIELD(Int32, threads) = dd::dd_utils::my_hardware_concurrency(); + + DTO_FIELD_INFO(lightmode) + { + info->description = "enable light mode"; + }; + DTO_FIELD(Boolean, lightmode) = true; + + DTO_FIELD_INFO(inputBlob) + { + info->description = "network input blob name"; + }; + DTO_FIELD(String, inputBlob) = "data"; + + DTO_FIELD_INFO(outputBlob) + { + info->description = "network output blob name (default depends on " + "network type(ie prob or " + "rnn_pred or probs or detection_out)"; + }; + DTO_FIELD(String, outputBlob); + + DTO_FIELD_INFO(datatype) + { + info->description = "fp16 or fp32"; + }; + + DTO_FIELD(String, datatype) = "fp16"; +}; +#include OATPP_CODEGEN_END(DTO) ///< End DTO codegen section + + } +} +#endif diff --git a/src/dto/model.hpp b/src/dto/model.hpp new file mode 100644 index 0000000000..1dc11b2f1f --- /dev/null +++ b/src/dto/model.hpp @@ -0,0 +1,49 @@ +/** + * DeepDetect + * Copyright (c) 2021 Jolibrain SASU + * Author: Mehdi Abaakouk + * + * This file is part of deepdetect. + * + * deepdetect is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * deepdetect is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with deepdetect. If not, see . + */ + +#ifndef DTO_MODEL +#define DTO_MODEL + +#include "oatpp/core/Types.hpp" +#include "oatpp/core/macro/codegen.hpp" + +namespace dd +{ + namespace DTO + { +#include OATPP_CODEGEN_BEGIN(DTO) ///< Begin DTO codegen section + + +class Model: public oatpp::DTO +{ + DTO_INIT(Model, DTO /* extends */) + DTO_FIELD(String, repository); + DTO_FIELD(String, init); + DTO_FIELD(Boolean, create_repository) = false; + DTO_FIELD(Boolean, index_preload) = false; +}; + +#include OATPP_CODEGEN_END(DTO) ///< End DTO codegen section + + } +} + +#endif diff --git a/src/dto/output_connector.hpp b/src/dto/output_connector.hpp new file mode 100644 index 0000000000..17bd545411 --- /dev/null +++ b/src/dto/output_connector.hpp @@ -0,0 +1,44 @@ +/** + * DeepDetect + * Copyright (c) 2021 Jolibrain SASU + * Author: Mehdi Abaakouk + * + * This file is part of deepdetect. + * + * deepdetect is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * deepdetect is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with deepdetect. If not, see . + */ + +#ifndef DTO_OUTPUT_CONNECTOR_H +#define DTO_OUTPUT_CONNECTOR_H + +#include "oatpp/core/Types.hpp" +#include "oatpp/core/macro/codegen.hpp" + +namespace dd +{ + namespace DTO + { +#include OATPP_CODEGEN_BEGIN(DTO) ///< Begin DTO codegen section + +class OutputConnector: public oatpp::DTO +{ + DTO_INIT(OutputConnector, DTO /* extends */) + DTO_FIELD(Boolean, store_config); +}; + +#include OATPP_CODEGEN_END(DTO) ///< End DTO codegen section +} +} + +#endif diff --git a/src/dto/parameters.hpp b/src/dto/parameters.hpp new file mode 100644 index 0000000000..9172709cbf --- /dev/null +++ b/src/dto/parameters.hpp @@ -0,0 +1,49 @@ +/** + * DeepDetect + * Copyright (c) 2021 Jolibrain SASU + * Author: Mehdi Abaakouk + * + * This file is part of deepdetect. + * + * deepdetect is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * deepdetect is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with deepdetect. If not, see . + */ + +#ifndef DTO_PARAMETERS_H +#define DTO_PARAMETERS_H + +#include "oatpp/core/Types.hpp" +#include "oatpp/core/macro/codegen.hpp" +#include "dto/mllib.hpp" +#include "dto/input_connector.hpp" +#include "dto/output_connector.hpp" + +namespace dd +{ + namespace DTO + { +#include OATPP_CODEGEN_BEGIN(DTO) ///< Begin DTO codegen section + +class Parameters: public oatpp::DTO +{ + DTO_INIT(Parameters, DTO /* extends */) + DTO_FIELD(Object, input) = InputConnector::createShared(); + DTO_FIELD(Object, mllib) = MLLib::createShared(); + DTO_FIELD(Object, output) = OutputConnector::createShared(); +}; + +#include OATPP_CODEGEN_END(DTO) ///< End DTO codegen section +} +} + +#endif diff --git a/src/dto/service_create.hpp b/src/dto/service_create.hpp new file mode 100644 index 0000000000..294a9acc56 --- /dev/null +++ b/src/dto/service_create.hpp @@ -0,0 +1,51 @@ +/** + * DeepDetect + * Copyright (c) 2021 Jolibrain SASU + * Author: Mehdi Abaakouk + * + * This file is part of deepdetect. + * + * deepdetect is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * deepdetect is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with deepdetect. If not, see . + */ + +#ifndef DTO_SERVICE_CREATE_H +#define DTO_SERVICE_CREATE_H + +#include "oatpp/core/Types.hpp" +#include "oatpp/core/macro/codegen.hpp" +#include "dto/model.hpp" +#include "dto/parameters.hpp" + +namespace dd +{ + namespace DTO + { +#include OATPP_CODEGEN_BEGIN(DTO) ///< Begin DTO codegen section + +class ServiceCreate : public oatpp::DTO +{ + DTO_INIT(ServiceCreate, DTO /* extends */) + + DTO_FIELD(String, mllib); + DTO_FIELD(String, description) = ""; + DTO_FIELD(String, type) = "supervised"; + DTO_FIELD(Object, parameters) = Parameters::createShared(); + DTO_FIELD(Object, model) = Model::createShared(); +}; + +#include OATPP_CODEGEN_END(DTO) ///< End DTO codegen section +} +} + +#endif diff --git a/src/http/controller.hpp b/src/http/controller.hpp index 2393abc546..a8854ec354 100644 --- a/src/http/controller.hpp +++ b/src/http/controller.hpp @@ -35,6 +35,7 @@ #include "apidata.h" #include "oatppjsonapi.h" +#include "dto/service_create.hpp" #include "http/dto/info.hpp" #include OATPP_CODEGEN_BEGIN(ApiController) @@ -109,10 +110,10 @@ class DedeController : public oatpp::web::server::api::ApiController } ENDPOINT("POST", "services/{service-name}", create_service, PATH(oatpp::String, service_name, "service-name"), - BODY_STRING(oatpp::String, service_data)) + BODY_DTO(oatpp::Object, service_dto)) { auto janswer = _oja->service_create(service_name.get()->std_str(), - service_data.get()->std_str()); + service_dto); return _oja->jdoc_to_response(janswer); } @@ -123,10 +124,10 @@ class DedeController : public oatpp::web::server::api::ApiController } ENDPOINT("PUT", "services/{service-name}", update_service, PATH(oatpp::String, service_name, "service-name"), - BODY_STRING(oatpp::String, service_data)) + BODY_DTO(oatpp::Object, service_dto)) { auto janswer = _oja->service_create(service_name.get()->std_str(), - service_data.get()->std_str()); + service_dto); return _oja->jdoc_to_response(janswer); } ENDPOINT_INFO(delete_service) diff --git a/src/imginputfileconn.h b/src/imginputfileconn.h index bd6c692037..481ad3fdfb 100644 --- a/src/imginputfileconn.h +++ b/src/imginputfileconn.h @@ -43,7 +43,7 @@ #include "utils/apitools.h" #include -#include "dto/img_connector.hpp" +#include "dto/input_connector.hpp" namespace dd { @@ -457,7 +457,7 @@ namespace dd void fillup_parameters(const APIData &ad) { - auto params = ad.createSharedDTO(); + auto params = ad.createSharedDTO(); // optional parameters. if (params->width) diff --git a/src/jsonapi.cc b/src/jsonapi.cc index 6f5d9a11fd..f927dd95bd 100644 --- a/src/jsonapi.cc +++ b/src/jsonapi.cc @@ -84,6 +84,8 @@ namespace dd return dd_internal_error_500(); } + std::shared_ptr json_object_mapper = oatpp::parser::json::mapping::ObjectMapper::createShared(); + std::vector calls_output; std::string line; int lines = 0; @@ -106,7 +108,8 @@ namespace dd { std::string sname = elts.at(1); std::string body = elts.at(2); - calls_output.push_back(service_create(sname, body)); + auto service_dto = json_object_mapper->readFromString>(body.c_str()); + calls_output.push_back(service_create(sname, service_dto)); if (calls_output.back() != dd_created_201()) { _logger->error("Service creation failed for {}", sname); @@ -435,8 +438,7 @@ namespace dd return jinfo; } - JDoc JsonAPI::service_create(const std::string &sname, - const std::string &jstr) + JDoc JsonAPI::service_create(const std::string &sname, const oatpp::Object &service_dto) { if (sname.empty()) { @@ -450,62 +452,26 @@ namespace dd return dd_service_already_exists_1014(); } - rapidjson::Document d; - d.Parse(jstr.c_str()); - if (d.HasParseError()) - { - _logger->error("JSON parsing error on string: {}", jstr); - return dd_bad_request_400(); - } - std::string mllib, input; - std::string type, description; - bool store_config = false; - APIData ad, ad_model; - try - { - // mandatory parameters. - mllib = d["mllib"].GetString(); - input = d["parameters"]["input"]["connector"].GetString(); + if (!service_dto->mllib) { + return dd_bad_request_400("mllib required"); + } + if (!service_dto->parameters || !service_dto->parameters->input || !service_dto->parameters->input->connector) { + return dd_bad_request_400("parameters/input/connector required"); + } - // optional parameters. - if (d.HasMember("type")) - type = d["type"].GetString(); - else - type = "supervised"; // default - if (d.HasMember("description")) - description = d["description"].GetString(); - - // model parameters (mandatory). - ad.fromRapidJson(d); - ad_model = ad.getobj("model"); - APIData ad_param = ad.getobj("parameters"); - if (ad_param.has("output")) - { - APIData ad_output = ad_param.getobj("output"); - if (ad_output.has("store_config")) - store_config = ad_output.get("store_config").get(); - } - } - catch (RapidjsonException &e) - { - _logger->error("JSON error {}", e.what()); - return dd_bad_request_400(e.what()); - } - catch (...) - { - return dd_bad_request_400(); - } + std::shared_ptr json_object_mapper = oatpp::parser::json::mapping::ObjectMapper::createShared(); + std::string input = service_dto->parameters->input->connector->std_str(); // create service. try { - if (mllib.empty()) + if (!service_dto->mllib) { return dd_unknown_library_1000(); } #ifdef USE_CAFFE - else if (mllib == "caffe") + else if (service_dto->mllib == "caffe") { CaffeModel cmodel(ad_model, ad, _logger); read_metrics_json(cmodel._repo, ad); @@ -675,35 +641,37 @@ namespace dd #endif // USE_CAFFE2 #ifdef USE_NCNN - else if (mllib == "ncnn") + else if (service_dto->mllib == "ncnn") { - NCNNModel ncnnmodel(ad_model, ad, _logger); - read_metrics_json(ncnnmodel._repo, ad); - if (type == "supervised") + NCNNModel ncnnmodel(service_dto->model, service_dto, _logger); + // TODO(sileht): Create a metrics DTO first + //read_metrics_json(ncnnmodel._repo, ad); + if (service_dto->type == "supervised") { if (input == "image") - add_service( + add_service( sname, std::move(MLService( - sname, ncnnmodel, description)), - ad); + sname, ncnnmodel, service_dto->description->std_str())), + service_dto); else if (input == "csv_ts" || input == "csvts") add_service( sname, std::move(MLService( - sname, ncnnmodel, description)), - ad); + sname, ncnnmodel, service_dto->description->std_str())), + service_dto); else return dd_input_connector_not_found_1004(); - if (JsonAPI::store_json_blob(ncnnmodel._repo, jstr)) + + if (JsonAPI::store_json_blob(ncnnmodel._repo, service_dto)) _logger->error( "couldn't write {} file in model repository {}", JsonAPI::_json_blob_fname, ncnnmodel._repo); // store model configuration json blob - if (store_config - && JsonAPI::store_json_config_blob(ncnnmodel._repo, jstr)) + if (service_dto->parameters->output->store_config + && JsonAPI::store_json_config_blob(ncnnmodel._repo, service_dto)) { _logger->error( "couldn't write {} file in model repository {}", @@ -1616,8 +1584,17 @@ namespace dd return 0; } + int JsonAPI::store_json_blob(const std::string &model_repo, + const oatpp::Object &dto, + const std::string &jfilename) + { + std::shared_ptr object_mapper = oatpp::parser::json::mapping::ObjectMapper::createShared(); + oatpp::String json = object_mapper->writeToString(dto); + return JsonAPI::store_json_blob(model_repo, json->std_str(), jfilename); + } + int JsonAPI::store_json_config_blob(const std::string &model_repo, - const std::string &jstr) + const std::string &jstr) { std::ofstream outf; outf.open(model_repo + "/" + JsonAPI::_json_config_blob_fname, @@ -1628,6 +1605,14 @@ namespace dd return 0; } + int JsonAPI::store_json_config_blob(const std::string &model_repo, + const oatpp::Object &dto) + { + std::shared_ptr object_mapper = oatpp::parser::json::mapping::ObjectMapper::createShared(); + oatpp::String json = object_mapper->writeToString(dto); + return JsonAPI::store_json_config_blob(model_repo, json->std_str()); + } + // read_json file blob to apidata int JsonAPI::read_json_blob(const std::string &model_repo, const std::string &jfilename, APIData &ad) @@ -1659,6 +1644,7 @@ namespace dd return 0; } + // FIXME(sileht) create a metrics DTO // read_json file blob to apidata void JsonAPI::read_metrics_json(const std::string &model_repo, APIData &ad) { diff --git a/src/jsonapi.h b/src/jsonapi.h index 9e71466b71..4afbc0548a 100644 --- a/src/jsonapi.h +++ b/src/jsonapi.h @@ -92,7 +92,7 @@ namespace dd // resources // return a JSON document for every API call JDoc info(const std::string &jstr) const; - JDoc service_create(const std::string &sname, const std::string &jstr); + JDoc service_create(const std::string &sname, const oatpp::Object &service_dto); JDoc service_status(const std::string &sname); JDoc service_delete(const std::string &sname, const std::string &jstr); @@ -111,6 +111,13 @@ namespace dd static int store_json_config_blob(const std::string &model_repo, const std::string &jstr); + static int store_json_blob(const std::string &model_repo, + const oatpp::Object &dto, + const std::string &jfilename = ""); + + static int store_json_config_blob(const std::string &model_repo, + const oatpp::Object &dto); + static int read_json_blob(const std::string &model_repo, const std::string &jfilename, APIData &ad); diff --git a/src/mllibstrategy.h b/src/mllibstrategy.h index 6a61e2a0b6..0c993c2a9b 100644 --- a/src/mllibstrategy.h +++ b/src/mllibstrategy.h @@ -23,6 +23,7 @@ #define MLLIBSTRATEGY_H #include "apidata.h" +#include "dto/mllib.hpp" #include "service_stats.h" #include "utils/fileops.hpp" #include "dd_spdlog.h" @@ -30,6 +31,7 @@ #include #include + namespace dd { /** @@ -112,9 +114,9 @@ namespace dd /** * \brief initializes ML lib - * @param ad data object for "parameters/mllib" + * @param data transfert object for "parameters/mllib" */ - void init_mllib(const APIData &ad); + void init_mllib(const DTO::MLLib &dto); /** * \brief clear the lib service from local model files etc... diff --git a/src/mlmodel.h b/src/mlmodel.h index 5daec250cb..0e359ea2b5 100644 --- a/src/mlmodel.h +++ b/src/mlmodel.h @@ -30,11 +30,14 @@ #include #include #include "apidata.h" +#include "dto/model.hpp" #include "utils/fileops.hpp" #ifndef WIN32 #include "utils/httpclient.hpp" #endif #include "mllibstrategy.h" +#include "dto/service_create.hpp" +#include "dto/parameters.hpp" namespace dd { @@ -45,26 +48,27 @@ namespace dd { } - MLModel(const APIData &ad, APIData &adg, + MLModel(const oatpp::Object &model_dto, + const oatpp::Object &service_dto, const std::shared_ptr &logger) { - init_repo_dir(ad, logger.get()); - if (ad.has("init")) - read_config_json(adg, logger); + init_repo_dir(model_dto, logger.get()); + if (model_dto->init) + read_config_json(service_dto); } - MLModel(const APIData &ad) + MLModel(const oatpp::Object &model_dto) { - init_repo_dir(ad, nullptr); + init_repo_dir(model_dto, nullptr); } MLModel(const std::string &repo) : _repo(repo) { } - MLModel(const APIData &ad, const std::string &repo) : _repo(repo) + MLModel(const oatpp::Object &model_dto, const std::string &repo) : _repo(repo) { - init_repo_dir(ad, nullptr); + init_repo_dir(model_dto, nullptr); } ~MLModel() @@ -210,12 +214,11 @@ namespace dd #endif private: - void init_repo_dir(const APIData &ad, spdlog::logger *logger) + void init_repo_dir(const oatpp::Object &model_dto, spdlog::logger *logger) { // auto-creation of model directory - _repo = ad.get("repository").get(); - bool create = ad.has("create_repository") - && ad.get("create_repository").get(); + _repo = model_dto->repository->std_str(); + bool isDir; bool exists = fileops::file_exists(_repo, isDir); if (exists && !isDir) @@ -225,7 +228,7 @@ namespace dd logger->error(errmsg); throw MLLibBadParamException(errmsg); } - if (!exists && create) + if (!exists && model_dto->create_repository) fileops::create_dir(_repo, 0775); if (!fileops::is_directory_writable(_repo)) @@ -237,13 +240,13 @@ namespace dd } #ifdef USE_SIMSEARCH - if (ad.has("index_preload") && ad.get("index_preload").get()) + if (model_dto->index_preload) _index_preload = true; #endif // auto-install from model archive - if (ad.has("init")) + if (model_dto->init) { - std::string compressedf = ad.get("init").get(); + std::string compressedf = model_dto->init->std_str(); // check whether already in the directory std::string base_model_fname @@ -301,8 +304,7 @@ namespace dd } } - void read_config_json(APIData &adg, - const std::shared_ptr &logger) + void read_config_json(const oatpp::Object service_create_dto) { const std::string cf = _repo + "/config.json"; if (!fileops::file_exists(cf)) @@ -310,26 +312,13 @@ namespace dd std::ifstream is(cf); std::stringstream jbuf; jbuf << is.rdbuf(); - rapidjson::Document d; - d.Parse(jbuf.str().c_str()); - if (d.HasParseError()) - { - logger->error("config.json parsing error on string: {}", jbuf.str()); - throw MLLibBadParamException("Failed parsing config file " + cf); - } - APIData adcj; - try - { - adcj.fromRapidJson(d); - } - catch (RapidjsonException &e) - { - logger->error("JSON error {}", e.what()); - throw MLLibBadParamException( - "Failed converting JSON file to internal data format"); - } - APIData adcj_parameters = adcj.getobj("parameters"); - adg.add("parameters", adcj_parameters); + + // FIXME(sileht): Replacing the user provided data here doesn't look good + // to me + std::shared_ptr objectMapper + = oatpp::parser::json::mapping::ObjectMapper::createShared(); + service_create_dto->parameters = objectMapper->readFromString>( + jbuf.str().c_str()); } }; } diff --git a/src/mlservice.h b/src/mlservice.h index 25c6b2988e..4267f981cd 100644 --- a/src/mlservice.h +++ b/src/mlservice.h @@ -146,20 +146,23 @@ namespace dd * - init of ML library * @param ad root data object */ - void init(const APIData &ad) + void init(const oatpp::Object &service_dto) { - this->_inputc._model_repo - = ad.getobj("model").get("repository").get(); + this->_inputc._model_repo = service_dto->model->repository->std_str(); if (this->_inputc._model_repo.empty()) throw MLLibBadParamException("empty repository"); this->_inputc._logger = this->_logger; this->_outputc._logger = this->_logger; - _init_parameters = ad.getobj("parameters"); - this->_inputc.init(_init_parameters.getobj("input")); - this->_outputc.init(_init_parameters.getobj("output")); - this->init_mllib(_init_parameters.getobj("mllib")); - this->fillup_measures_history(ad); + _init_parameters = service_dto->parameters; + + // NOTE(sileht): Beware using fromDTO is a bit ricky because anything + // passed in the JSON but not present in the DTO description will not + // appear in the generated APIData neither + this->_inputc.init(APIData::fromDTO(_init_parameters->input)); + this->_outputc.init(APIData::fromDTO(_init_parameters->output)); + this->init_mllib(_init_parameters->mllib); + this->fillup_measures_history(APIData::fromDTO(service_dto)); } /** @@ -288,7 +291,7 @@ namespace dd // platform use the new name ad.add("model_stats", stats); ad.add("jobs", vad); - ad.add("parameters", _init_parameters); + ad.add("parameters", APIData::fromDTO(_init_parameters)); ad.add("repository", this->_inputc._model_repo); ad.add("mltype", this->_mltype); if (typeid(this->_outputc) == typeid(UnsupervisedOutput)) @@ -526,7 +529,7 @@ namespace dd std::string _sname; /**< service name. */ std::string _description; /**< optional description of the service. */ - APIData _init_parameters; /**< service creation parameters. */ + oatpp::Object _init_parameters; /**< service creation parameters. */ mutable std::mutex _tjobs_mutex; /**< mutex around training jobs. */ std::atomic _tjobs_counter = { 0 }; /**< training jobs counter. */ diff --git a/src/services.h b/src/services.h index c9cfb5b261..9883229ff8 100644 --- a/src/services.h +++ b/src/services.h @@ -273,16 +273,16 @@ namespace dd class v_init { public: - const APIData &_in; + const oatpp::Object &_service_dto; template void operator()(T &mllib) { - mllib.init(_in); + mllib.init(_service_dto); } }; - template static void init(T &mllib, const APIData &in) + template static void init(T &mllib, const oatpp::Object &service_dto) { - visitor_mllib::v_init v{ in }; + visitor_mllib::v_init v{ service_dto }; mapbox::util::apply_visitor(v, mllib); } @@ -384,7 +384,7 @@ namespace dd * @param ad optional root data object holding service's parameters */ void add_service(const std::string &sname, mls_variant_type &&mls, - const APIData &ad = APIData()) + const oatpp::Object &service_dto) { std::unordered_map::const_iterator hit; if ((hit = _mlservices.find(sname)) != _mlservices.end()) @@ -395,7 +395,7 @@ namespace dd auto llog = spdlog::get(sname); try { - visitor_mllib::init(mls, ad); + visitor_mllib::init(mls, service_dto); std::lock_guard lock(_mlservices_mtx); _mlservices.insert( std::pair(sname, std::move(mls))); diff --git a/tests/ut-ncnnapi.cc b/tests/ut-ncnnapi.cc index 82ea2f32d3..53213dc3e3 100644 --- a/tests/ut-ncnnapi.cc +++ b/tests/ut-ncnnapi.cc @@ -24,6 +24,7 @@ #include #include #include +#include "dto/service_create.hpp" using namespace dd; @@ -50,11 +51,12 @@ static std::string iterations_lstm = "200"; static std::string iterations_lstm = "20"; #endif +std::shared_ptr objectMapper + = oatpp::parser::json::mapping::ObjectMapper::createShared(); + TEST(ncnnapi, service_predict_bbox) { // create service - JsonAPI japi; - std::string sname = "imgserv"; std::string jstr = "{\"mllib\":\"ncnn\",\"description\":\"squeezenet-ssd\",\"type\":" "\"supervised\",\"model\":{\"repository\":\"" @@ -62,7 +64,15 @@ TEST(ncnnapi, service_predict_bbox) + "\"},\"parameters\":{\"input\":{\"connector\":\"image\",\"height\":" "300,\"width\":300}," "\"mllib\":{\"nclasses\":21}}}"; - std::string joutstr = japi.jrender(japi.service_create(sname, jstr)); + + JsonAPI japi; + std::string sname = "imgserv"; + auto service_create + = objectMapper->readFromString>( + jstr.c_str()); + + std::string joutstr + = japi.jrender(japi.service_create(sname, service_create)); ASSERT_EQ(created_str, joutstr); // predict @@ -151,6 +161,7 @@ TEST(ncnnapi, service_predict_classification) // create service JsonAPI japi; std::string sname = "imgserv"; + std::string jstr = "{\"mllib\":\"ncnn\",\"description\":\"squeezenet\",\"type\":" "\"supervised\",\"model\":{\"repository\":\"" @@ -158,7 +169,13 @@ TEST(ncnnapi, service_predict_classification) + "\"},\"parameters\":{\"input\":{\"connector\":\"image\",\"height\":" "224,\"width\":224,\"mean\":[128,128,128]}," "\"mllib\":{\"nclasses\":1000}}}"; - std::string joutstr = japi.jrender(japi.service_create(sname, jstr)); + + auto service_create + = objectMapper->readFromString>( + jstr.c_str()); + + std::string joutstr + = japi.jrender(japi.service_create(sname, service_create)); ASSERT_EQ(created_str, joutstr); // predict @@ -196,7 +213,14 @@ TEST(ncnnapi, service_lstm) "\"output\"]},\"mllib\":{\"template\":\"recurrent\",\"layers\":[" "\"L10\",\"L10\"],\"dropout\":[0.0,0.0,0.0],\"regression\":true," "\"sl1sigma\":100.0,\"loss\":\"L1\"}}}"; - std::string joutstr = japi.jrender(japi.service_create(sname, jstr)); + + auto service_create + = objectMapper->readFromString>( + jstr.c_str()); + + std::string joutstr + = japi.jrender(japi.service_create(sname, service_create)); + ASSERT_EQ(created_str, joutstr); // train @@ -272,7 +296,12 @@ TEST(ncnnapi, service_lstm) "\"csvts\",\"label\":[" "\"output\"]}" "}}"; - joutstr = japi.jrender(japi.service_create(sname, jstr)); + + auto service_create + = objectMapper->readFromString>( + jstr.c_str()); + joutstr = japi.jrender(japi.service_create(sname, service_create)); + ASSERT_EQ(created_str, joutstr); jpredictstr @@ -312,7 +341,13 @@ TEST(ncnnapi, ocr) + ocr_repo + "\"},\"parameters\":{\"input\":{\"connector\":\"image\",\"ctc\":" "true, \"height\":136,\"width\":220},\"mllib\":{\"nclasses\":69}}}"; - std::string joutstr = japi.jrender(japi.service_create(sname, jstr)); + + auto service_create + = objectMapper->readFromString>( + jstr.c_str()); + + std::string joutstr + = japi.jrender(japi.service_create(sname, service_create)); ASSERT_EQ(created_str, joutstr); // predict