diff --git a/src/mlservice.h b/src/mlservice.h index ed9e5cd9a..a62f00866 100644 --- a/src/mlservice.h +++ b/src/mlservice.h @@ -187,6 +187,9 @@ namespace dd } ++hit; } + + // wait for predict to finish + boost::unique_lock lock2(_train_or_predict_mutex); } /** @@ -323,7 +326,7 @@ namespace dd // XXX: due to lock below, queued jobs may not // start in requested order boost::unique_lock lock( - _train_mutex); + _train_or_predict_mutex); APIData out; int run_code = this->train(ad, out); std::pair p(local_tcounter, @@ -336,7 +339,8 @@ namespace dd } else { - boost::unique_lock lock(_train_mutex); + boost::unique_lock lock( + _train_or_predict_mutex); this->_has_predict = false; int status = this->train(ad, out); APIData ad_params_out = ad.getobj("parameters").getobj("output"); @@ -498,7 +502,7 @@ namespace dd oatpp::Object predict_job(const APIData &ad, const bool &chain = false) { - if (!_train_mutex.try_lock_shared()) + if (!_train_or_predict_mutex.try_lock_shared()) throw MLServiceLockException( "Predict call while training with an offline learning algorithm"); @@ -513,13 +517,13 @@ namespace dd } catch (std::exception &e) { - _train_mutex.unlock_shared(); + _train_or_predict_mutex.unlock_shared(); this->_stats.predict_end(false); throw; } this->_stats.predict_end(true); - _train_mutex.unlock_shared(); + _train_or_predict_mutex.unlock_shared(); return out; } @@ -533,7 +537,7 @@ namespace dd _training_jobs; // XXX: the futures' dtor blocks if the object is being // terminated std::unordered_map _training_out; - boost::shared_mutex _train_mutex; + boost::shared_mutex _train_or_predict_mutex; }; }