diff --git a/include/ylt/coro_io/coro_io.hpp b/include/ylt/coro_io/coro_io.hpp index 40cb4c8d4..d4d114ed6 100644 --- a/include/ylt/coro_io/coro_io.hpp +++ b/include/ylt/coro_io/coro_io.hpp @@ -52,7 +52,7 @@ class callback_awaitor_base { template class callback_awaitor_impl { public: - callback_awaitor_impl(Derived &awaitor, const Op &op) noexcept + callback_awaitor_impl(Derived &awaitor, Op &op) noexcept : awaitor(awaitor), op(op) {} constexpr bool await_ready() const noexcept { return false; } void await_suspend(std::coroutine_handle<> handle) noexcept { @@ -73,7 +73,7 @@ class callback_awaitor_base { private: Derived &awaitor; - const Op &op; + Op &op; }; public: @@ -101,7 +101,7 @@ class callback_awaitor_base { Derived *obj; }; template - callback_awaitor_impl await_resume(const Op &op) noexcept { + callback_awaitor_impl await_resume(Op &&op) noexcept { return callback_awaitor_impl{static_cast(*this), op}; } @@ -316,7 +316,7 @@ inline async_simple::coro::Lazy sleep_for(Duration d) { template struct post_helper { - void operator()(auto handler) const { + void operator()(auto handler) { asio::dispatch(e, [this, handler]() { try { if constexpr (std::is_same_v>) { diff --git a/include/ylt/coro_io/io_context_pool.hpp b/include/ylt/coro_io/io_context_pool.hpp index 81bce7d4a..07eb8e522 100644 --- a/include/ylt/coro_io/io_context_pool.hpp +++ b/include/ylt/coro_io/io_context_pool.hpp @@ -17,8 +17,8 @@ #include #include +#include #include -#include #include #include #include @@ -51,10 +51,10 @@ class ExecutorWrapper : public async_simple::Executor { virtual bool schedule(Func func) override { if constexpr (requires(ExecutorImpl e) { e.post(std::move(func)); }) { - executor_.post(std::move(func)); + executor_.dispatch(std::move(func)); } else { - asio::post(executor_, std::move(func)); + asio::dispatch(executor_, std::move(func)); } return true; @@ -67,7 +67,7 @@ class ExecutorWrapper : public async_simple::Executor { executor.post(std::move(func)); } else { - asio::post(executor, std::move(func)); + asio::dispatch(executor, std::move(func)); } return true; } diff --git a/include/ylt/coro_rpc/coro_rpc_context.hpp b/include/ylt/coro_rpc/coro_rpc_context.hpp index 2557e32bd..a50d6ad29 100644 --- a/include/ylt/coro_rpc/coro_rpc_context.hpp +++ b/include/ylt/coro_rpc/coro_rpc_context.hpp @@ -14,5 +14,4 @@ * limitations under the License. */ #pragma once - #include "impl/protocol/coro_rpc_protocol.hpp" diff --git a/include/ylt/coro_rpc/impl/context.hpp b/include/ylt/coro_rpc/impl/context.hpp index 37a688fb8..84fde1aa2 100644 --- a/include/ylt/coro_rpc/impl/context.hpp +++ b/include/ylt/coro_rpc/impl/context.hpp @@ -29,6 +29,7 @@ #include "coro_connection.hpp" #include "ylt/coro_rpc/impl/errno.h" #include "ylt/util/type_traits.h" +#include "ylt/util/utils.hpp" namespace coro_rpc { /*! @@ -43,14 +44,14 @@ class context_base { typename rpc_protocol::req_header &get_req_head() { return self_->req_head_; } bool check_status() { - auto old_flag = self_->has_response_.exchange(true); - if (old_flag != false) + auto old_flag = self_->status_.exchange(context_status::start_response); + if (old_flag != context_status::init) AS_UNLIKELY { ELOGV(ERROR, "response message more than one time"); return false; } - if (has_closed()) + if (self_->has_closed()) AS_UNLIKELY { ELOGV(DEBUG, "response_msg failed: connection has been closed"); return false; @@ -67,8 +68,7 @@ class context_base { context_base(std::shared_ptr> context_info) : self_(std::move(context_info)) { if (self_->conn_) { - self_->conn_->set_rpc_call_type( - coro_connection::rpc_call_type::callback_started); + self_->conn_->set_rpc_return_by_callback(); } }; context_base() = default; @@ -79,8 +79,10 @@ class context_base { std::string_view error_msg) { if (!check_status()) AS_UNLIKELY { return; }; - self_->conn_->template response_error( - error_code, error_msg, self_->req_head_, self_->is_delay_); + ELOGI << "rpc error in function:" << self_->get_rpc_function_name() + << ". error code:" << error_code.ec << ". message : " << error_msg; + self_->conn_->template response_error(error_code, error_msg, + self_->req_head_); } void response_error(coro_rpc::err_code error_code) { response_error(error_code, error_code.message()); @@ -98,16 +100,15 @@ class context_base { */ template void response_msg(Args &&...args) { + if (!check_status()) + AS_UNLIKELY { return; }; if constexpr (std::is_same_v) { static_assert(sizeof...(args) == 0, "illegal args"); - if (!check_status()) - AS_UNLIKELY { return; }; std::visit( [&](const serialize_proto &) { self_->conn_->template response_msg( serialize_proto::serialize(), - std::move(self_->resp_attachment_), self_->req_head_, - self_->is_delay_); + std::move(self_->resp_attachment_), self_->req_head_); }, *rpc_protocol::get_serialize_protocol(self_->req_head_)); } @@ -115,85 +116,24 @@ class context_base { static_assert( requires { return_msg_type{std::forward(args)...}; }, "constructed return_msg_type failed by illegal args"); - - if (!check_status()) - AS_UNLIKELY { return; }; - return_msg_type ret{std::forward(args)...}; std::visit( [&](const serialize_proto &) { self_->conn_->template response_msg( serialize_proto::serialize(ret), - std::move(self_->resp_attachment_), self_->req_head_, - self_->is_delay_); + std::move(self_->resp_attachment_), self_->req_head_); }, *rpc_protocol::get_serialize_protocol(self_->req_head_)); // response_handler_(std::move(conn_), std::move(ret)); } + /*finish here*/ + self_->status_ = context_status::finish_response; } - - /*! - * Check connection closed or not - * - * @return true if closed, otherwise false - */ - bool has_closed() const { return self_->conn_->has_closed(); } - - /*! - * Close connection - */ - void close() { return self_->conn_->async_close(); } - - /*! - * Get the unique connection ID - * @return connection id - */ - uint64_t get_connection_id() const noexcept { - return self_->conn_->get_connection_id(); - } - - /*! - * Set the response_attachment - * @return a ref of response_attachment - */ - void set_response_attachment(std::string attachment) { - set_response_attachment([attachment = std::move(attachment)] { - return std::string_view{attachment}; - }); - } - - /*! - * Set the response_attachment - * @return a ref of response_attachment - */ - void set_response_attachment(std::function attachment) { - self_->resp_attachment_ = std::move(attachment); - } - - /*! - * Get the request attachment - * @return connection id - */ - std::string_view get_request_attachment() const { - return self_->req_attachment_; - } - - /*! - * Release the attachment - * @return connection id - */ - std::string release_request_attachment() { - return std::move(self_->req_attachment_); - } - - void set_delay() { - self_->is_delay_ = true; - self_->conn_->set_rpc_call_type( - coro_connection::rpc_call_type::callback_with_delay); + const context_info_t *get_context() const noexcept { + return self_.get(); } - std::any &tag() { return self_->conn_->tag(); } - const std::any &tag() const { return self_->conn_->tag(); } + context_info_t *get_context() noexcept { return self_.get(); } }; template diff --git a/include/ylt/coro_rpc/impl/coro_connection.hpp b/include/ylt/coro_rpc/impl/coro_connection.hpp index 5aeb52532..f0be2b297 100644 --- a/include/ylt/coro_rpc/impl/coro_connection.hpp +++ b/include/ylt/coro_rpc/impl/coro_connection.hpp @@ -26,21 +26,29 @@ #include #include #include +#include #include #include +#include "async_simple/Common.h" #include "ylt/coro_io/coro_io.hpp" #include "ylt/coro_rpc/impl/errno.h" +#include "ylt/util/utils.hpp" #ifdef UNIT_TEST_INJECT #include "inject_action.hpp" #endif namespace coro_rpc { - class coro_connection; using rpc_conn = std::shared_ptr; +enum class context_status : int { init, start_response, finish_response }; template struct context_info_t { +#ifndef CORO_RPC_TEST + private: +#endif + typename rpc_protocol::route_key_t key_; + typename rpc_protocol::router &router_; std::shared_ptr conn_; typename rpc_protocol::req_header req_head_; std::string req_body_; @@ -48,11 +56,46 @@ struct context_info_t { std::function resp_attachment_ = [] { return std::string_view{}; }; - std::atomic has_response_ = false; - bool is_delay_ = false; - context_info_t(std::shared_ptr &&conn) - : conn_(std::move(conn)) {} + std::atomic status_ = context_status::init; + + public: + template + friend class context_base; + friend class coro_connection; + context_info_t(typename rpc_protocol::router &r, + std::shared_ptr &&conn) + : router_(r), conn_(std::move(conn)) {} + context_info_t(typename rpc_protocol::router &r, + std::shared_ptr &&conn, + std::string &&req_body_buf, std::string &&req_attachment_buf) + : router_(r), + conn_(std::move(conn)), + req_body_(std::move(req_body_buf)), + req_attachment_(std::move(req_attachment_buf)) {} + uint64_t get_connection_id() noexcept; + uint64_t has_closed() const noexcept; + void close(); + uint64_t get_connection_id() const noexcept; + void set_response_attachment(std::string_view attachment); + void set_response_attachment(std::string attachment); + void set_response_attachment(std::function attachment); + std::string_view get_request_attachment() const; + std::string release_request_attachment(); + std::any &tag() noexcept; + const std::any &tag() const noexcept; + asio::ip::tcp::endpoint get_local_endpoint() const noexcept; + asio::ip::tcp::endpoint get_remote_endpoint() const noexcept; + uint64_t get_request_id() const noexcept; + std::string_view get_rpc_function_name() const { + return router_.get_name(key_); + } }; + +namespace detail { +template +context_info_t *&set_context(); +} + /*! * TODO: add doc */ @@ -70,13 +113,6 @@ struct context_info_t { class coro_connection : public std::enable_shared_from_this { public: - enum rpc_call_type { - non_callback, - callback_with_delay, - callback_finished, - callback_started - }; - /*! * * @param io_context @@ -89,7 +125,6 @@ class coro_connection : public std::enable_shared_from_this { std::chrono::seconds(0)) : executor_(executor), socket_(std::move(socket)), - resp_err_(), timer_(executor->get_asio_executor()) { if (timeout_duration == std::chrono::seconds(0)) { return; @@ -148,16 +183,13 @@ class coro_connection : public std::enable_shared_from_this { template async_simple::coro::Lazy start_impl( typename rpc_protocol::router &router, Socket &socket) noexcept { - auto context_info = - std::make_shared>(shared_from_this()); - std::string resp_error_msg; + auto context_info = std::make_shared>( + router, shared_from_this()); + reset_timer(); while (true) { - auto &req_head = context_info->req_head_; - auto &body = context_info->req_body_; - auto &req_attachment = context_info->req_attachment_; - reset_timer(); - auto ec = co_await rpc_protocol::read_head(socket, req_head); - cancel_timer(); + typename rpc_protocol::req_header req_head_tmp; + // timer will be reset after rpc call response + auto ec = co_await rpc_protocol::read_head(socket, req_head_tmp); // `co_await async_read` uses asio::async_read underlying. // If eof occurred, the bytes_transferred of `co_await async_read` must // less than RPC_HEAD_LEN. Incomplete data will be discarded. @@ -169,7 +201,7 @@ class coro_connection : public std::enable_shared_from_this { } #ifdef UNIT_TEST_INJECT - client_id_ = req_head.seq_num; + client_id_ = req_head_tmp.seq_num; ELOGV(INFO, "conn_id %d, client_id %d", conn_id_, client_id_); #endif @@ -183,6 +215,28 @@ class coro_connection : public std::enable_shared_from_this { break; } #endif + + // try to reuse context + if (is_rpc_return_by_callback_) { + // cant reuse context,make shared new one + is_rpc_return_by_callback_ = false; + if (context_info->status_ != context_status::finish_response) { + // cant reuse buffer + context_info = std::make_shared>( + router, shared_from_this()); + } + else { + // reuse string buffer + context_info = std::make_shared>( + router, shared_from_this(), std::move(context_info->req_body_), + std::move(context_info->req_attachment_)); + } + } + auto &req_head = context_info->req_head_; + auto &body = context_info->req_body_; + auto &req_attachment = context_info->req_attachment_; + auto &key = context_info->key_; + req_head = std::move(req_head_tmp); auto serialize_proto = rpc_protocol::get_serialize_protocol(req_head); if (!serialize_proto.has_value()) @@ -197,6 +251,7 @@ class coro_connection : public std::enable_shared_from_this { ec = co_await rpc_protocol::read_payload(socket, req_head, body, req_attachment); + cancel_timer(); payload = std::string_view{body}; if (ec) @@ -207,85 +262,77 @@ class coro_connection : public std::enable_shared_from_this { break; } - std::pair pair{}; - - auto key = rpc_protocol::get_route_key(req_head); + key = rpc_protocol::get_route_key(req_head); auto handler = router.get_handler(key); + ++rpc_processing_cnt_; if (!handler) { auto coro_handler = router.get_coro_handler(key); - pair = co_await router.route_coro(coro_handler, payload, context_info, - serialize_proto.value(), key); + set_rpc_return_by_callback(); + router.route_coro(coro_handler, payload, serialize_proto.value(), key) + .via(executor_) + .setLazyLocal((void *)context_info.get()) + .start([context_info](auto &&result) mutable { + std::pair &ret = result.value(); + if (ret.first) + AS_UNLIKELY { + ELOGI << "rpc error in function:" + << context_info->get_rpc_function_name() + << ". error code:" << ret.first.ec + << ". message : " << ret.second; + } + auto executor = context_info->conn_->get_executor(); + executor->schedule([context_info = std::move(context_info), + ret = std::move(ret)]() mutable { + context_info->conn_->template direct_response_msg( + ret.first, ret.second, context_info->req_head_, + std::move(context_info->resp_attachment_)); + }); + }); } else { - pair = router.route(handler, payload, context_info, - serialize_proto.value(), key); - } - - auto &[resp_err, resp_buf] = pair; - switch (rpc_call_type_) { - default: - unreachable(); - case rpc_call_type::non_callback: - break; - case rpc_call_type::callback_with_delay: - ++delay_resp_cnt; - rpc_call_type_ = rpc_call_type::non_callback; - continue; - case rpc_call_type::callback_finished: - continue; - case rpc_call_type::callback_started: - coro_io::callback_awaitor awaitor; - rpc_call_type_ = rpc_call_type::callback_finished; - co_await awaitor.await_resume([this](auto handler) { - this->callback_awaitor_handler_ = std::move(handler); - }); - context_info->has_response_ = false; - context_info->resp_attachment_ = []() -> std::string_view { - return {}; - }; - rpc_call_type_ = rpc_call_type::non_callback; - continue; - } - resp_error_msg.clear(); - if (!!resp_err) - AS_UNLIKELY { std::swap(resp_buf, resp_error_msg); } - std::string header_buf = rpc_protocol::prepare_response( - resp_buf, req_head, 0, resp_err, resp_error_msg); - + coro_rpc::detail::set_context() = context_info.get(); + auto &&[resp_err, resp_buf] = router.route( + handler, payload, context_info, serialize_proto.value(), key); + if (is_rpc_return_by_callback_) { + if (!resp_err) { + continue; + } + else { + ELOGI << "rpc error in function:" + << context_info->get_rpc_function_name() + << ". error code:" << resp_err.ec + << ". message : " << resp_buf; + is_rpc_return_by_callback_ = false; + } + } #ifdef UNIT_TEST_INJECT - if (g_action == inject_action::close_socket_after_send_length) { - ELOGV(WARN, - "inject action: close_socket_after_send_length conn_id %d, " - "client_id %d", - conn_id_, client_id_); - co_await coro_io::async_write(socket, asio::buffer(header_buf)); - close(); - break; - } - if (g_action == inject_action::server_send_bad_rpc_result) { - ELOGV(WARN, + if (g_action == inject_action::close_socket_after_send_length) { + ELOGV(WARN, "inject action: close_socket_after_send_length", conn_id_, + client_id_); + std::string header_buf = rpc_protocol::prepare_response( + resp_buf, req_head, 0, resp_err, ""); + co_await coro_io::async_write(socket, asio::buffer(header_buf)); + close(); + break; + } + if (g_action == inject_action::server_send_bad_rpc_result) { + ELOGV( + WARN, "inject action: server_send_bad_rpc_result conn_id %d, client_id " "%d", conn_id_, client_id_); - resp_buf[0] = resp_buf[0] + 1; - } -#endif - if (!resp_err_) - AS_LIKELY { - if (!resp_err) - AS_UNLIKELY { resp_err_ = resp_err; } - write_queue_.emplace_back(std::move(header_buf), std::move(resp_buf), - [] { - return std::string_view{}; - }); - if (write_queue_.size() == 1) { - send_data().start([self = shared_from_this()](auto &&) { - }); - } - if (!!resp_err) - AS_UNLIKELY { break; } + resp_buf[0] = resp_buf[0] + 1; } +#endif + direct_response_msg( + resp_err, resp_buf, req_head, + std::move(context_info->resp_attachment_)); + context_info->resp_attachment_ = [] { + return std::string_view{}; + }; + } } + cancel_timer(); } /*! * send `ret` to RPC client @@ -293,37 +340,54 @@ class coro_connection : public std::enable_shared_from_this { * @tparam R message type * @param ret object of message type */ + template + void direct_response_msg(coro_rpc::err_code &resp_err, std::string &resp_buf, + const typename rpc_protocol::req_header &req_head, + std::function &&attachment) { + std::string resp_error_msg; + if (resp_err) { + resp_error_msg = std::move(resp_buf); + resp_buf = {}; + ELOGV(WARNING, "rpc route/execute error, error msg: %s", + resp_error_msg.data()); + } + std::string header_buf = rpc_protocol::prepare_response( + resp_buf, req_head, attachment().length(), resp_err, resp_error_msg); + + response(std::move(header_buf), std::move(resp_buf), std::move(attachment), + nullptr) + .start([](auto &&) { + }); + } template void response_msg(std::string &&body_buf, std::function &&resp_attachment, - const typename rpc_protocol::req_header &req_head, - bool is_delay) { + const typename rpc_protocol::req_header &req_head) { std::string header_buf = rpc_protocol::prepare_response( body_buf, req_head, resp_attachment().size()); response(std::move(header_buf), std::move(body_buf), - std::move(resp_attachment), shared_from_this(), is_delay) + std::move(resp_attachment), shared_from_this()) .via(executor_) .detach(); } template void response_error(coro_rpc::errc ec, std::string_view error_msg, - const typename rpc_protocol::req_header &req_head, - bool is_delay) { + const typename rpc_protocol::req_header &req_head) { std::function attach_ment = []() -> std::string_view { return {}; }; std::string body_buf; - std::string header_buf = rpc_protocol::prepare_response( - body_buf, req_head, 0, ec, error_msg, true); + std::string header_buf = + rpc_protocol::prepare_response(body_buf, req_head, 0, ec, error_msg); response(std::move(header_buf), std::move(body_buf), std::move(attach_ment), - shared_from_this(), is_delay) + shared_from_this()) .via(executor_) .detach(); } - void set_rpc_call_type(enum rpc_call_type r) { rpc_call_type_ = r; } + void set_rpc_return_by_callback() { is_rpc_return_by_callback_ = true; } /*! * Check the connection has closed or not @@ -359,13 +423,21 @@ class coro_connection : public std::enable_shared_from_this { std::any &tag() { return tag_; } const std::any &tag() const { return tag_; } - auto &get_executor() { return *executor_; } + auto get_executor() { return executor_; } + + asio::ip::tcp::endpoint get_remote_endpoint() { + return socket_.remote_endpoint(); + } + + asio::ip::tcp::endpoint get_local_endpoint() { + return socket_.local_endpoint(); + } private: async_simple::coro::Lazy response( std::string header_buf, std::string body_buf, - std::function resp_attachment, rpc_conn self, - bool is_delay) noexcept { + std::function resp_attachment, + rpc_conn self) noexcept { if (has_closed()) AS_UNLIKELY { ELOGV(DEBUG, "response_msg failed: connection has been closed"); @@ -379,25 +451,14 @@ class coro_connection : public std::enable_shared_from_this { #endif write_queue_.emplace_back(std::move(header_buf), std::move(body_buf), std::move(resp_attachment)); - if (is_delay) { - --delay_resp_cnt; - assert(delay_resp_cnt >= 0); - reset_timer(); - } + --rpc_processing_cnt_; + assert(rpc_processing_cnt_ >= 0); + reset_timer(); if (write_queue_.size() == 1) { + if (self == nullptr) + self = shared_from_this(); co_await send_data(); } - if (!is_delay) { - if (rpc_call_type_ == rpc_call_type::callback_finished) { - // the function start_impl is waiting for resume. - callback_awaitor_handler_.resume(); - } - else { - assert(rpc_call_type_ == rpc_call_type::callback_started); - // the function start_impl is not waiting for resume. - rpc_call_type_ = rpc_call_type::callback_finished; - } - } } async_simple::coro::Lazy send_data() { @@ -456,12 +517,6 @@ class coro_connection : public std::enable_shared_from_this { } write_queue_.pop_front(); } - if (!!resp_err_) - AS_UNLIKELY { - ELOGV(ERROR, "%s, %s", make_error_message(resp_err_), "resp_err_"); - close(); - co_return; - } #ifdef UNIT_TEST_INJECT if (g_action == inject_action::close_socket_after_send_length) { ELOGV(INFO, @@ -477,6 +532,7 @@ class coro_connection : public std::enable_shared_from_this { } void close() { + ELOGV(TRACE, "connection closed"); if (has_closed_) { return; } @@ -490,7 +546,7 @@ class coro_connection : public std::enable_shared_from_this { } void reset_timer() { - if (!enable_check_timeout_ || delay_resp_cnt != 0) { + if (!enable_check_timeout_ || rpc_processing_cnt_ != 0) { return; } @@ -518,17 +574,13 @@ class coro_connection : public std::enable_shared_from_this { asio::error_code ec; timer_.cancel(ec); } - - coro_io::callback_awaitor::awaitor_handler callback_awaitor_handler_{ - nullptr}; async_simple::Executor *executor_; asio::ip::tcp::socket socket_; // FIXME: queue's performance can be imporved. std::deque< std::tuple>> write_queue_; - coro_rpc::errc resp_err_; - rpc_call_type rpc_call_type_{non_callback}; + bool is_rpc_return_by_callback_{false}; // if don't get any message in keep_alive_timeout_duration_, the connection // will be closed when enable_check_timeout_ is true. @@ -539,7 +591,7 @@ class coro_connection : public std::enable_shared_from_this { QuitCallback quit_callback_{nullptr}; uint64_t conn_id_{0}; - uint64_t delay_resp_cnt{0}; + uint64_t rpc_processing_cnt_{0}; std::any tag_; @@ -553,4 +605,85 @@ class coro_connection : public std::enable_shared_from_this { #endif }; +template +uint64_t context_info_t::get_connection_id() noexcept { + return conn_->get_connection_id(); +} + +template +uint64_t context_info_t::has_closed() const noexcept { + return conn_->has_closed(); +} + +template +void context_info_t::close() { + return conn_->async_close(); +} + +template +uint64_t context_info_t::get_connection_id() const noexcept { + return conn_->get_connection_id(); +} + +template +void context_info_t::set_response_attachment( + std::string attachment) { + set_response_attachment([attachment = std::move(attachment)] { + return std::string_view{attachment}; + }); +} + +template +void context_info_t::set_response_attachment( + std::string_view attachment) { + set_response_attachment([attachment] { + return attachment; + }); +} + +template +void context_info_t::set_response_attachment( + std::function attachment) { + resp_attachment_ = std::move(attachment); +} + +template +std::string_view context_info_t::get_request_attachment() const { + return req_attachment_; +} + +template +std::string context_info_t::release_request_attachment() { + return std::move(req_attachment_); +} + +template +std::any &context_info_t::tag() noexcept { + return conn_->tag(); +} + +template +const std::any &context_info_t::tag() const noexcept { + return conn_->tag(); +} + +template +asio::ip::tcp::endpoint context_info_t::get_local_endpoint() + const noexcept { + return conn_->get_local_endpoint(); +} + +template +asio::ip::tcp::endpoint context_info_t::get_remote_endpoint() + const noexcept { + return conn_->get_remote_endpoint(); +} +namespace protocol { +template +uint64_t get_request_id(const typename rpc_protocol::req_header &) noexcept; +} +template +uint64_t context_info_t::get_request_id() const noexcept { + return coro_rpc::protocol::get_request_id(req_head_); +} } // namespace coro_rpc diff --git a/include/ylt/coro_rpc/impl/coro_rpc_client.hpp b/include/ylt/coro_rpc/impl/coro_rpc_client.hpp index f775f457f..b3d924f7a 100644 --- a/include/ylt/coro_rpc/impl/coro_rpc_client.hpp +++ b/include/ylt/coro_rpc/impl/coro_rpc_client.hpp @@ -109,8 +109,8 @@ class coro_rpc_client { using coro_rpc_protocol = coro_rpc::protocol::coro_rpc_protocol; public: - const inline static coro_rpc_protocol::rpc_error connect_error = { - errc::io_error, "client has been closed"}; + const inline static rpc_error connect_error = {errc::io_error, + "client has been closed"}; struct config { uint32_t client_id = 0; std::chrono::milliseconds timeout_duration = @@ -289,8 +289,8 @@ class coro_rpc_client { ELOGV(ERROR, "client has been closed, please re-connect"); auto ret = rpc_result{ unexpect_t{}, - coro_rpc_protocol::rpc_error{ - errc::io_error, "client has been closed, please re-connect"}}; + rpc_error{errc::io_error, + "client has been closed, please re-connect"}}; co_return ret; } @@ -299,9 +299,8 @@ class coro_rpc_client { if (!ssl_init_ret_) { ret = rpc_result{ unexpect_t{}, - coro_rpc_protocol::rpc_error{ - errc::not_connected, - std::string{make_error_message(errc::not_connected)}}}; + rpc_error{errc::not_connected, + std::string{make_error_message(errc::not_connected)}}}; co_return ret; } #endif @@ -328,8 +327,7 @@ class coro_rpc_client { if (is_timeout_) { ret = rpc_result{ - unexpect_t{}, - coro_rpc_protocol::rpc_error{errc::timed_out, "rpc call timed out"}}; + unexpect_t{}, rpc_error{errc::timed_out, "rpc call timed out"}}; } #ifdef UNIT_TEST_INJECT @@ -542,9 +540,8 @@ class coro_rpc_client { rpc_result r{}; if (buffer.empty()) { r = rpc_result{ - unexpect_t{}, - coro_rpc_protocol::rpc_error{errc::message_too_large, - "rpc body serialize size too big"}}; + unexpect_t{}, rpc_error{errc::message_too_large, + "rpc body serialize size too big"}}; co_return r; } #ifdef GENERATE_BENCHMARK_DATA @@ -565,8 +562,7 @@ class coro_rpc_client { ELOGV(INFO, "client_id %d close socket", config_.client_id); close(); r = rpc_result{ - unexpect_t{}, - coro_rpc_protocol::rpc_error{errc::io_error, ret.first.message()}}; + unexpect_t{}, rpc_error{errc::io_error, ret.first.message()}}; co_return r; } else if (g_action == @@ -577,8 +573,7 @@ class coro_rpc_client { ELOGV(INFO, "client_id %d close socket", config_.client_id); close(); r = rpc_result{ - unexpect_t{}, - coro_rpc_protocol::rpc_error{errc::io_error, ret.first.message()}}; + unexpect_t{}, rpc_error{errc::io_error, ret.first.message()}}; co_return r; } else if (g_action == @@ -588,8 +583,7 @@ class coro_rpc_client { ELOGV(INFO, "client_id %d shutdown", config_.client_id); socket_->shutdown(asio::ip::tcp::socket::shutdown_send); r = rpc_result{ - unexpect_t{}, - coro_rpc_protocol::rpc_error{errc::io_error, ret.first.message()}}; + unexpect_t{}, rpc_error{errc::io_error, ret.first.message()}}; co_return r; } else { @@ -614,8 +608,7 @@ class coro_rpc_client { ELOGV(INFO, "client_id %d client_close_socket_after_send_payload", config_.client_id); r = rpc_result{ - unexpect_t{}, - coro_rpc_protocol::rpc_error{errc::io_error, ret.first.message()}}; + unexpect_t{}, rpc_error{errc::io_error, ret.first.message()}}; close(); co_return r; } @@ -668,14 +661,12 @@ class coro_rpc_client { #endif if (is_timeout_) { r = rpc_result{ - unexpect_t{}, - coro_rpc_protocol::rpc_error{.code = errc::timed_out, .msg = {}}}; + unexpect_t{}, rpc_error{.code = errc::timed_out, .msg = {}}}; } else { r = rpc_result{ unexpect_t{}, - coro_rpc_protocol::rpc_error{.code = errc::io_error, - .msg = ret.first.message()}}; + rpc_error{.code = errc::io_error, .msg = ret.first.message()}}; } close(); co_return r; @@ -733,7 +724,7 @@ class coro_rpc_client { bool &error_happen) { rpc_return_type_t ret; struct_pack::err_code ec; - coro_rpc_protocol::rpc_error err; + rpc_error err; if (rpc_errc == 0) AS_LIKELY { ec = struct_pack::deserialize_to(ret, buffer); @@ -747,10 +738,11 @@ class coro_rpc_client { } } else { - err.val() = rpc_errc; if (rpc_errc != UINT8_MAX) { + err.val() = rpc_errc; ec = struct_pack::deserialize_to(err.msg, buffer); if SP_LIKELY (!ec) { + ELOGV(WARNING, "deserilaize rpc result failed"); error_happen = true; return rpc_result{unexpect_t{}, std::move(err)}; } @@ -758,13 +750,14 @@ class coro_rpc_client { else { ec = struct_pack::deserialize_to(err, buffer); if SP_LIKELY (!ec) { + ELOGV(WARNING, "deserilaize rpc result failed"); return rpc_result{unexpect_t{}, std::move(err)}; } } } error_happen = true; // deserialize failed. - err = {errc::invalid_argument, "failed to deserialize rpc return value"}; + err = {errc::invalid_rpc_result, "failed to deserialize rpc return value"}; return rpc_result{unexpect_t{}, std::move(err)}; } diff --git a/include/ylt/coro_rpc/impl/coro_rpc_server.hpp b/include/ylt/coro_rpc/impl/coro_rpc_server.hpp index 7f37a47d9..b74bb7a8c 100644 --- a/include/ylt/coro_rpc/impl/coro_rpc_server.hpp +++ b/include/ylt/coro_rpc/impl/coro_rpc_server.hpp @@ -334,7 +334,7 @@ class coro_rpc_server_base { ec.message().data()); acceptor_.cancel(ec); acceptor_.close(ec); - return coro_rpc::errc::address_in_use; + return coro_rpc::errc::address_in_used; } #ifdef _MSC_VER acceptor_.set_option(tcp::acceptor::reuse_address(true)); @@ -352,7 +352,7 @@ class coro_rpc_server_base { if (ec) { ELOGV(ERROR, "get local endpoint port %d error : %s", port_.load(), ec.message().data()); - return coro_rpc::errc::address_in_use; + return coro_rpc::errc::address_in_used; } port_ = end_point.port(); @@ -401,7 +401,7 @@ class coro_rpc_server_base { std::unique_lock lock(conns_mtx_); conns_.emplace(conn_id, conn); } - start_one(conn).via(&conn->get_executor()).detach(); + start_one(conn).via(conn->get_executor()).detach(); } } diff --git a/include/ylt/coro_rpc/impl/errno.h b/include/ylt/coro_rpc/impl/errno.h index 5514da5c3..0d7c1b6bd 100644 --- a/include/ylt/coro_rpc/impl/errno.h +++ b/include/ylt/coro_rpc/impl/errno.h @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include - #include +#include #pragma once namespace coro_rpc { enum class errc : uint16_t { @@ -23,33 +22,34 @@ enum class errc : uint16_t { io_error, not_connected, timed_out, - invalid_argument, - address_in_use, + invalid_rpc_arguments, + address_in_used, bad_address, open_error, listen_error, operation_canceled, - interrupted, + rpc_throw_exception, function_not_registered, protocol_error, unknown_protocol_version, message_too_large, server_has_ran, + invalid_rpc_result, }; inline constexpr std::string_view make_error_message(errc ec) noexcept { switch (ec) { case errc::ok: return "ok"; case errc::io_error: - return "io_error"; + return "io error"; case errc::not_connected: - return "not_connected"; + return "not connected"; case errc::timed_out: - return "timed_out"; - case errc::invalid_argument: - return "invalid_argument"; - case errc::address_in_use: - return "address_in_use"; + return "time out"; + case errc::invalid_rpc_arguments: + return "invalid rpc arg"; + case errc::address_in_used: + return "address in used"; case errc::bad_address: return "bad_address"; case errc::open_error: @@ -57,19 +57,21 @@ inline constexpr std::string_view make_error_message(errc ec) noexcept { case errc::listen_error: return "listen_error"; case errc::operation_canceled: - return "operation_canceled"; - case errc::interrupted: - return "interrupted"; + return "operation canceled"; + case errc::rpc_throw_exception: + return "rpc throw exception"; case errc::function_not_registered: - return "function_not_registered"; + return "function not registered"; case errc::protocol_error: - return "protocol_error"; + return "protocol error"; case errc::message_too_large: - return "message_too_large"; + return "message too large"; case errc::server_has_ran: - return "server_has_ran"; + return "server has ran"; + case errc::invalid_rpc_result: + return "invalid rpc result"; default: - return "unknown_user-defined_error"; + return "unknown user-defined error"; } } struct err_code { @@ -98,4 +100,12 @@ struct err_code { inline bool operator!(err_code ec) noexcept { return ec == errc::ok; } inline bool operator!(errc ec) noexcept { return ec == errc::ok; } +struct rpc_error { + coro_rpc::err_code code; //!< error code + std::string msg; //!< error message + uint16_t& val() { return *(uint16_t*)&(code.ec); } + const uint16_t& val() const { return *(uint16_t*)&(code.ec); } +}; +STRUCT_PACK_REFL(rpc_error, val(), msg); + }; // namespace coro_rpc \ No newline at end of file diff --git a/include/ylt/coro_rpc/impl/expected.hpp b/include/ylt/coro_rpc/impl/expected.hpp index e6d4c1082..4a0a7e91e 100644 --- a/include/ylt/coro_rpc/impl/expected.hpp +++ b/include/ylt/coro_rpc/impl/expected.hpp @@ -44,7 +44,12 @@ using unexpected = tl::unexpected; using unexpect_t = tl::unexpect_t; #endif -template -using rpc_result = expected; +namespace protocol { +struct coro_rpc_protocol; +} + +template +using rpc_result = expected; } // namespace coro_rpc \ No newline at end of file diff --git a/include/ylt/coro_rpc/impl/protocol/coro_rpc_protocol.hpp b/include/ylt/coro_rpc/impl/protocol/coro_rpc_protocol.hpp index ee20bd5b2..b74a566e3 100644 --- a/include/ylt/coro_rpc/impl/protocol/coro_rpc_protocol.hpp +++ b/include/ylt/coro_rpc/impl/protocol/coro_rpc_protocol.hpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -61,7 +62,7 @@ struct coro_rpc_protocol { uint32_t seq_num; //!< sequence number uint32_t function_id; //!< rpc function ID uint32_t length; //!< length of RPC body - uint32_t attach_length; //!< reserved field + uint32_t attach_length; //!< attachment length }; struct resp_header { @@ -71,7 +72,7 @@ struct coro_rpc_protocol { uint8_t msg_type; //!< message type uint32_t seq_num; //!< sequence number uint32_t length; //!< length of RPC body - uint32_t attach_length; //!< reserved field + uint32_t attach_length; //!< attachment length }; using supported_serialize_protocols = std::variant; @@ -133,8 +134,7 @@ struct coro_rpc_protocol { const req_header& req_header, std::size_t attachment_len, coro_rpc::errc rpc_err_code = {}, - std::string_view err_msg = {}, - bool is_user_defined_error = false) { + std::string_view err_msg = {}) { std::string err_msg_buf; std::string header_buf; header_buf.resize(RESP_HEAD_LEN); @@ -150,7 +150,6 @@ struct coro_rpc_protocol { err_msg_buf = "attachment larger than 4G:" + std::to_string(attachment_len) + "B"; err_msg = err_msg_buf; - is_user_defined_error = false; } else if (rpc_result.size() > UINT32_MAX) AS_UNLIKELY { @@ -160,12 +159,11 @@ struct coro_rpc_protocol { err_msg_buf = "body larger than 4G:" + std::to_string(attachment_len) + "B"; err_msg = err_msg_buf; - is_user_defined_error = false; } if (rpc_err_code != coro_rpc::errc{}) AS_UNLIKELY { rpc_result.clear(); - if (is_user_defined_error) { + if (static_cast(rpc_err_code) > UINT8_MAX) { struct_pack::serialize_to( rpc_result, std::pair{static_cast(rpc_err_code), err_msg}); @@ -186,12 +184,6 @@ struct coro_rpc_protocol { * The `rpc_error` struct holds the error code `code` and error message * `msg`. */ - struct rpc_error { - coro_rpc::err_code code; //!< error code - std::string msg; //!< error message - uint16_t& val() { return *(uint16_t*)&(code.ec); } - const uint16_t& val() const { return *(uint16_t*)&(code.ec); } - }; // internal variable constexpr static inline int8_t magic_number = 21; @@ -203,10 +195,40 @@ struct coro_rpc_protocol { static_assert(RESP_HEAD_LEN == 16); }; -STRUCT_PACK_REFL(coro_rpc_protocol::rpc_error, val(), msg); +template +uint64_t get_request_id( + const typename rpc_protocol::req_header& header) noexcept { + if constexpr (std::is_same_v) { + return header.seq_num; + } + else { + return 0; + } +} } // namespace protocol template using context = coro_rpc::context_base; -using rpc_error = protocol::coro_rpc_protocol::rpc_error; + +template +async_simple::coro::Lazy*> get_context_in_coro() { + auto* ctx = co_await async_simple::coro::LazyLocals{}; + assert(ctx != nullptr); + co_return (context_info_t*) ctx; +} + +namespace detail { +template +context_info_t*& set_context() { + thread_local static context_info_t* ctx; + return ctx; +} +} // namespace detail + +template +context_info_t* get_context() { + return detail::set_context(); +} + } // namespace coro_rpc \ No newline at end of file diff --git a/include/ylt/coro_rpc/impl/router.hpp b/include/ylt/coro_rpc/impl/router.hpp index c2439997c..1414b8e4e 100644 --- a/include/ylt/coro_rpc/impl/router.hpp +++ b/include/ylt/coro_rpc/impl/router.hpp @@ -30,6 +30,8 @@ #include #include "rpc_execute.hpp" +#include "ylt/coro_rpc/impl/expected.hpp" +#include "ylt/coro_rpc/impl/protocol/coro_rpc_protocol.hpp" namespace coro_rpc { @@ -46,21 +48,19 @@ template typename map_t = std::unordered_map> class router { - using router_handler_t = std::function( - std::string_view, rpc_context &context_info, - typename rpc_protocol::supported_serialize_protocols protocols)>; - - using coro_router_handler_t = - std::function>( + public: + using router_handler_t = + std::function( std::string_view, rpc_context &context_info, typename rpc_protocol::supported_serialize_protocols protocols)>; + using coro_router_handler_t = std::function< + async_simple::coro::Lazy>( + std::string_view, + typename rpc_protocol::supported_serialize_protocols protocols)>; + using route_key = typename rpc_protocol::route_key_t; - std::unordered_map handlers_; - std::unordered_map coro_handlers_; - std::unordered_map id2name_; - private: const std::string &get_name(const route_key &key) { static std::string empty_string; if (auto it = id2name_.find(key); it != id2name_.end()) { @@ -70,30 +70,33 @@ class router { return empty_string; } + private: + std::unordered_map handlers_; + std::unordered_map coro_handlers_; + std::unordered_map id2name_; + // See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=100611 // We use this struct instead of lambda for workaround template struct execute_visitor { std::string_view data; - rpc_context &context_info; Self *self; template - async_simple::coro::Lazy> operator()( - const serialize_protocol &) { + async_simple::coro::Lazy> + operator()(const serialize_protocol &) { return internal::execute_coro( - data, context_info, self); + data, self); } }; template struct execute_visitor { std::string_view data; - rpc_context &context_info; template - async_simple::coro::Lazy> operator()( - const serialize_protocol &) { + async_simple::coro::Lazy> + operator()(const serialize_protocol &) { return internal::execute_coro( - data, context_info); + data); } }; @@ -134,9 +137,9 @@ class router { auto it = coro_handlers_.emplace( key, [self]( - std::string_view data, rpc_context &context_info, + std::string_view data, typename rpc_protocol::supported_serialize_protocols protocols) { - execute_visitor visitor{data, context_info, self}; + execute_visitor visitor{data, self}; return std::visit(visitor, protocols); }); if (!it.second) { @@ -188,9 +191,9 @@ class router { async_simple::coro::Lazy>) { auto it = coro_handlers_.emplace( key, - [](std::string_view data, rpc_context &context_info, + [](std::string_view data, typename rpc_protocol::supported_serialize_protocols protocols) { - execute_visitor visitor{data, context_info}; + execute_visitor visitor{data}; return std::visit(visitor, protocols); }); if (!it.second) { @@ -232,11 +235,10 @@ class router { return nullptr; } - async_simple::coro::Lazy> route_coro( - auto handler, std::string_view data, - rpc_context &context_info, - typename rpc_protocol::supported_serialize_protocols protocols, - const typename rpc_protocol::route_key_t &route_key) { + async_simple::coro::Lazy> + route_coro(auto handler, std::string_view data, + typename rpc_protocol::supported_serialize_protocols protocols, + const typename rpc_protocol::route_key_t &route_key) { using namespace std::string_literals; if (handler) AS_LIKELY { @@ -246,41 +248,23 @@ class router { #endif // clang-format off - auto res = co_await (*handler)(data, context_info, protocols); - // clang-format on - if (res.has_value()) - AS_LIKELY { - co_return std::make_pair(coro_rpc::errc{}, - std::move(res.value())); - } - else { // deserialize failed - ELOGV(ERROR, "payload deserialize failed in rpc function: %s", - get_name(route_key).data()); - co_return std::make_pair(coro_rpc::errc::invalid_argument, - "invalid rpc function arguments"s); - } + co_return co_await (*handler)(data, protocols); + } catch (coro_rpc::rpc_error& err) { + co_return std::make_pair(err.code, std::move(err.msg)); } catch (const std::exception &e) { - ELOGV(ERROR, "exception: %s in rpc function: %s", e.what(), - get_name(route_key).data()); - co_return std::make_pair(coro_rpc::errc::interrupted, e.what()); + co_return std::make_pair(coro_rpc::errc::rpc_throw_exception, e.what()); } catch (...) { - ELOGV(ERROR, "unknown exception in rpc function: %s", - get_name(route_key).data()); - co_return std::make_pair(coro_rpc::errc::interrupted, - "unknown exception"s); + co_return std::make_pair(coro_rpc::errc::rpc_throw_exception, + "unknown rpc function exception"s); } } else { - std::ostringstream ss; - ss << route_key; - ELOGV(ERROR, "the rpc function not registered, function ID: %s", - ss.str().data()); co_return std::make_pair(coro_rpc::errc::function_not_registered, "the rpc function not registered"s); } } - std::pair route( + std::pair route( auto handler, std::string_view data, rpc_context &context_info, typename rpc_protocol::supported_serialize_protocols protocols, @@ -292,35 +276,20 @@ class router { #ifndef NDEBUG ELOGV(INFO, "route function name: %s", get_name(route_key).data()); #endif - auto res = (*handler)(data, context_info, protocols); - if (res.has_value()) - AS_LIKELY { - return std::make_pair(coro_rpc::errc{}, std::move(res.value())); - } - else { // deserialize failed - ELOGV(ERROR, "payload deserialize failed in rpc function: %s", - get_name(route_key).data()); - return std::make_pair(coro_rpc::errc::invalid_argument, - "invalid rpc function arguments"s); - } + return (*handler)(data, context_info, protocols); + } catch (coro_rpc::rpc_error& err) { + return std::make_pair(err.code, std::move(err.msg)); } catch (const std::exception &e) { - ELOGV(ERROR, "exception: %s in rpc function: %s", e.what(), - get_name(route_key).data()); - return std::make_pair(coro_rpc::errc::interrupted, e.what()); + return std::make_pair(err_code{coro_rpc::errc::rpc_throw_exception}, e.what()); } catch (...) { - ELOGV(ERROR, "unknown exception in rpc function: %s", - get_name(route_key).data()); - return std::make_pair(coro_rpc::errc::interrupted, + return std::make_pair(err_code{errc::rpc_throw_exception}, "unknown rpc function exception"s); } } else { - std::ostringstream ss; - ss << route_key; - ELOGV(ERROR, "the rpc function not registered, function ID: %s", - ss.str().data()); + using namespace std; return std::make_pair(coro_rpc::errc::function_not_registered, - "the rpc function not registered"s); + "the rpc function not registered"); } } diff --git a/include/ylt/coro_rpc/impl/rpc_execute.hpp b/include/ylt/coro_rpc/impl/rpc_execute.hpp index 0a93a2338..cdaf153d3 100644 --- a/include/ylt/coro_rpc/impl/rpc_execute.hpp +++ b/include/ylt/coro_rpc/impl/rpc_execute.hpp @@ -24,6 +24,8 @@ #include "context.hpp" #include "coro_connection.hpp" +#include "ylt/coro_rpc/impl/errno.h" +#include "ylt/easylog.hpp" #include "ylt/util/type_traits.h" namespace coro_rpc::internal { @@ -43,16 +45,16 @@ auto get_return_type() { return First{}; } } - template using rpc_context = std::shared_ptr>; using rpc_conn = std::shared_ptr; template -inline std::optional execute( +inline std::pair execute( std::string_view data, rpc_context &context_info, Self *self = nullptr) { + using namespace std::string_literals; using T = decltype(func); using param_type = util::function_parameters_t; using return_type = util::function_return_type_t; @@ -78,7 +80,10 @@ inline std::optional execute( } if (!is_ok) - AS_UNLIKELY { return std::nullopt; } + AS_UNLIKELY { + return std::pair{err_code{errc::invalid_rpc_arguments}, + "invalid rpc arg"s}; + } if constexpr (std::is_void_v) { if constexpr (std::is_void_v) { @@ -96,35 +101,37 @@ inline std::optional execute( } } else { - auto &o = *self; if constexpr (has_coro_conn_v) { - // call void o.func(coro_conn, args...) - std::apply(func, - std::tuple_cat( - std::forward_as_tuple( - o, context_base( - context_info)), - std::move(args))); + // call void self->func(coro_conn, args...) + std::apply( + func, std::tuple_cat( + std::forward_as_tuple( + *self, context_base( + context_info)), + std::move(args))); } else { - // call void o.func(args...) - std::apply(func, - std::tuple_cat(std::forward_as_tuple(o), std::move(args))); + // call void self->func(args...) + std::apply(func, std::tuple_cat(std::forward_as_tuple(*self), + std::move(args))); } } + return std::pair{err_code{}, serialize_proto::serialize()}; } else { if constexpr (std::is_void_v) { // call return_type func(args...) - return serialize_proto::serialize(std::apply(func, std::move(args))); + return std::pair{err_code{}, serialize_proto::serialize( + std::apply(func, std::move(args)))}; } else { - auto &o = *self; - // call return_type o.func(args...) + // call return_type self->func(args...) - return serialize_proto::serialize(std::apply( - func, std::tuple_cat(std::forward_as_tuple(o), std::move(args)))); + return std::pair{err_code{}, + serialize_proto::serialize(std::apply( + func, std::tuple_cat(std::forward_as_tuple(*self), + std::move(args))))}; } } } @@ -136,24 +143,25 @@ inline std::optional execute( else { (self->*func)(); } + return std::pair{err_code{}, serialize_proto::serialize()}; } else { if constexpr (std::is_void_v) { - return serialize_proto::serialize(func()); + return std::pair{err_code{}, serialize_proto::serialize(func())}; } else { - return serialize_proto::serialize((self->*func)()); + return std::pair{err_code{}, + serialize_proto::serialize((self->*func)())}; } } } - return serialize_proto::serialize(); } template -inline async_simple::coro::Lazy> execute_coro( - std::string_view data, rpc_context &context_info, - Self *self = nullptr) { +inline async_simple::coro::Lazy> +execute_coro(std::string_view data, Self *self = nullptr) { + using namespace std::string_literals; using T = decltype(func); using param_type = util::function_parameters_t; using return_type = typename get_type_t< @@ -162,67 +170,46 @@ inline async_simple::coro::Lazy> execute_coro( if constexpr (!std::is_void_v) { using First = std::tuple_element_t<0, param_type>; constexpr bool is_conn = requires { typename First::return_type; }; - if constexpr (is_conn) { - static_assert(std::is_void_v, - "The return_type must be void"); - } - - using conn_return_type = decltype(get_return_type()); - constexpr bool has_coro_conn_v = - std::is_same_v, First>; - auto args = util::get_args(); + static_assert( + !is_conn, + "context is not allowed as parameter in coroutine function"); bool is_ok = true; - constexpr size_t size = std::tuple_size_v; + constexpr size_t size = std::tuple_size_v; + param_type args; if constexpr (size > 0) { is_ok = serialize_proto::deserialize_to(args, data); } - + if (!is_ok) + AS_UNLIKELY { + co_return std::make_pair(coro_rpc::errc::invalid_rpc_arguments, + "invalid rpc function arguments"s); + } if constexpr (std::is_void_v) { if constexpr (std::is_void_v) { - if constexpr (has_coro_conn_v) { - // call void func(coro_conn, args...) - co_await std::apply( - func, - std::tuple_cat(std::forward_as_tuple( - context_base( - context_info)), - std::move(args))); - } - else { - // call void func(args...) - co_await std::apply(func, std::move(args)); - } + // call void func(args...) + co_await std::apply(func, std::move(args)); } else { - auto &o = *self; - if constexpr (has_coro_conn_v) { - // call void o.func(coro_conn, args...) - co_await std::apply( - func, std::tuple_cat( - std::forward_as_tuple( - o, context_base( - context_info)), - std::move(args))); - } - else { - // call void o.func(args...) - co_await std::apply( - func, std::tuple_cat(std::forward_as_tuple(o), std::move(args))); - } + // call void self->func(args...) + co_await std::apply(func, std::tuple_cat(std::forward_as_tuple(*self), + std::move(args))); } + co_return std::pair{err_code{}, serialize_proto::serialize()}; } else { if constexpr (std::is_void_v) { // call return_type func(args...) - co_return serialize_proto::serialize( - co_await std::apply(func, std::move(args))); + co_return std::pair{err_code{}, + serialize_proto::serialize( + co_await std::apply(func, std::move(args)))}; } else { - auto &o = *self; - // call return_type o.func(args...) - co_return serialize_proto::serialize(co_await std::apply( - func, std::tuple_cat(std::forward_as_tuple(o), std::move(args)))); + // call return_type self->func(args...) + co_return std::pair{ + err_code{}, serialize_proto::serialize(co_await std::apply( + func, std::tuple_cat(std::forward_as_tuple(*self), + std::move(args))))}; } } } @@ -236,18 +223,19 @@ inline async_simple::coro::Lazy> execute_coro( co_await (self->*func)(); // clang-format on } + co_return std::pair{err_code{}, serialize_proto::serialize()}; } else { if constexpr (std::is_void_v) { - co_return serialize_proto::serialize(co_await func()); + co_return std::pair{err_code{}, + serialize_proto::serialize(co_await func())}; } else { // clang-format off - co_return serialize_proto::serialize(co_await (self->*func)()); + co_return std::pair{err_code{},serialize_proto::serialize(co_await (self->*func)())}; // clang-format on } } } - co_return serialize_proto::serialize(); } } // namespace coro_rpc::internal \ No newline at end of file diff --git a/include/ylt/easylog/record.hpp b/include/ylt/easylog/record.hpp index 9ae1f0294..8c8eee617 100644 --- a/include/ylt/easylog/record.hpp +++ b/include/ylt/easylog/record.hpp @@ -187,7 +187,7 @@ class record_t { else { std::stringstream ss; ss << data; - ss_.append(ss.str()); + ss_.append(std::move(ss).str()); } return *this; diff --git a/src/coro_io/tests/test_coro_channel.cpp b/src/coro_io/tests/test_coro_channel.cpp index bdd343ebf..384ce62ee 100644 --- a/src/coro_io/tests/test_coro_channel.cpp +++ b/src/coro_io/tests/test_coro_channel.cpp @@ -6,6 +6,19 @@ #include using namespace std::chrono_literals; +#ifndef __clang__ +#ifdef __GNUC__ +#include +#if __GNUC_PREREQ(10, 3) // If gcc_version >= 10.3 +#define IS_OK +#endif +#else +#define IS_OK +#endif +#else +#define IS_OK +#endif + async_simple::coro::Lazy test_coro_channel() { auto ch = coro_io::create_channel(1000); @@ -112,6 +125,7 @@ async_simple::coro::Lazy test_select_channel() { } void callback_lazy() { +#ifdef IS_OK using namespace async_simple::coro; auto test0 = []() mutable -> Lazy { co_return 41; @@ -144,6 +158,7 @@ void callback_lazy() { CHECK(result == 83); })); CHECK(index == 0); +#endif } TEST_CASE("test channel send recieve, test select channel and coroutine") { diff --git a/src/coro_rpc/benchmark/api/rpc_functions.hpp b/src/coro_rpc/benchmark/api/rpc_functions.hpp index 5566b2709..440e56fa3 100644 --- a/src/coro_rpc/benchmark/api/rpc_functions.hpp +++ b/src/coro_rpc/benchmark/api/rpc_functions.hpp @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -24,20 +25,30 @@ #include "Monster.h" #include "Rect.h" #include "ValidateRequest.h" +#include "ylt/coro_rpc/impl/protocol/coro_rpc_protocol.hpp" inline coro_io::io_context_pool pool(std::thread::hardware_concurrency()); - -inline std::string echo_4B(const std::string &str) { return str; } -inline std::string echo_50B(const std::string &str) { return str; } -inline std::string echo_100B(const std::string &str) { return str; } -inline std::string echo_500B(const std::string &str) { return str; } -inline std::string echo_1KB(const std::string &str) { return str; } -inline std::string echo_5KB(const std::string &str) { return str; } -inline std::string echo_10KB(const std::string &str) { return str; } +inline async_simple::coro::Lazy coroutine_async_echo( + std::string_view str) { + co_return str; +} +inline void callback_async_echo(coro_rpc::context conn, + std::string_view str) { + conn.response_msg(str); + return; +} +inline std::string_view echo_4B(std::string_view str) { return str; } +inline std::string_view echo_50B(std::string_view str) { return str; } +inline std::string_view echo_100B(std::string_view str) { return str; } +inline std::string_view echo_500B(std::string_view str) { return str; } +inline std::string_view echo_1KB(std::string_view str) { return str; } +inline std::string_view echo_5KB(std::string_view str) { return str; } +inline std::string_view echo_10KB(std::string_view str) { return str; } inline std::vector array_1K_int(std::vector ar) { return ar; } -inline std::vector array_1K_str_4B(std::vector ar) { +inline std::vector array_1K_str_4B( + std::vector ar) { return ar; } @@ -63,8 +74,8 @@ inline void heavy_calculate(coro_rpc::context conn, int a) { }); return; } - -inline std::string download_10KB(int a) { return std::string(10000, 'A'); } +std::string s(10000, 'A'); +inline std::string_view download_10KB(int a) { return std::string_view{s}; } inline void long_tail_heavy_calculate(coro_rpc::context conn, int a) { static std::atomic g_index = 0; diff --git a/src/coro_rpc/benchmark/data_gen.cpp b/src/coro_rpc/benchmark/data_gen.cpp index 25e318bba..da330a656 100644 --- a/src/coro_rpc/benchmark/data_gen.cpp +++ b/src/coro_rpc/benchmark/data_gen.cpp @@ -59,8 +59,8 @@ int main() { coro_rpc::benchmark_file_path = "./test_data/complex_test/"; std::filesystem::create_directories(coro_rpc::benchmark_file_path); syncAwait(client.call(std::vector(1000, 42))); - syncAwait( - client.call(std::vector(1000, std::string{4, 'A'}))); + syncAwait(client.call( + std::vector(1000, std::string_view{"AAAA"}))); syncAwait(client.call( std::vector(1000, rect{.p1 = {1.2, 3.4}, .p2 = {2.5, 4.6}}))); syncAwait( @@ -111,6 +111,11 @@ int main() { syncAwait(client.call(42)); syncAwait(client.call(42)); + coro_rpc::benchmark_file_path = "./test_data/async_test/"; + std::filesystem::create_directories(coro_rpc::benchmark_file_path); + syncAwait(client.call("echo")); + syncAwait(client.call("echo")); + server.stop(); started->wait(); diff --git a/src/coro_rpc/benchmark/server.hpp b/src/coro_rpc/benchmark/server.hpp index 046c7abf0..878867cd8 100644 --- a/src/coro_rpc/benchmark/server.hpp +++ b/src/coro_rpc/benchmark/server.hpp @@ -24,7 +24,8 @@ inline void register_handlers(coro_rpc::coro_rpc_server& server) { echo_4B, echo_100B, echo_500B, echo_1KB, echo_5KB, echo_10KB, async_io, block_io, heavy_calculate, long_tail_async_io, long_tail_block_io, long_tail_heavy_calculate, array_1K_int, array_1K_str_4B, array_1K_rect, - monsterFunc, ValidateRequestFunc, download_10KB>(); + monsterFunc, ValidateRequestFunc, download_10KB, callback_async_echo, + coroutine_async_echo>(); server.register_handler< many_argument>(); diff --git a/src/coro_rpc/examples/base_examples/client.cpp b/src/coro_rpc/examples/base_examples/client.cpp index 9460a2b78..bc6ca3057 100644 --- a/src/coro_rpc/examples/base_examples/client.cpp +++ b/src/coro_rpc/examples/base_examples/client.cpp @@ -32,25 +32,21 @@ Lazy show_rpc_call() { [[maybe_unused]] auto ec = co_await client.connect("127.0.0.1", "8801"); assert(!ec); - auto ret = co_await client.call(); - assert(ret.value() == "hello_world"s); + auto ret = co_await client.call("hello"); + assert(ret.value() == "hello"); - client.set_req_attachment("This is attachment."); - auto ret_void = co_await client.call(); - assert(client.get_resp_attachment() == "This is attachment."); + ret = co_await client.call("42"); + assert(ret.value() == "42"); - client.set_req_attachment("This is attachment2."); - ret_void = co_await client.call(); - assert(client.get_resp_attachment() == "This is attachment2."); + ret = co_await client.call("hi"); + assert(ret.value() == "hi"); - auto ret_int = co_await client.call(12, 30); - assert(ret_int.value() == 42); + ret = co_await client.call("hey"); + assert(ret.value() == "hey"); - ret = co_await client.call("coro_echo"); - assert(ret.value() == "coro_echo"s); - - ret = co_await client.call("hello_with_delay"s); - assert(ret.value() == "hello_with_delay"s); + client.set_req_attachment("This is attachment."); + auto ret_void = co_await client.call(); + assert(client.get_resp_attachment() == "This is attachment."); ret = co_await client.call("nested_echo"s); assert(ret.value() == "nested_echo"s); @@ -58,14 +54,18 @@ Lazy show_rpc_call() { ret = co_await client.call<&HelloService::hello>(); assert(ret.value() == "HelloService::hello"s); - ret = co_await client.call<&HelloService::hello_with_delay>( - "HelloService::hello_with_delay"s); - assert(ret.value() == "HelloService::hello_with_delay"s); + ret_void = co_await client.call(); + assert(ret_void); + + // TODO: fix return error + // ret_void = co_await client.call(); + + // assert(ret.error().code.val() == 404); + // assert(ret.error().msg == "404 Not Found."); - ret = co_await client.call(); + // ret_void = co_await client.call(); - assert(ret.error().code == 404); - assert(ret.error().msg == "404 Not Found."); + // assert(ret.error().code.val() == 404); ret = co_await client.call(); assert(ret.value() == "1"); diff --git a/src/coro_rpc/examples/base_examples/rpc_service.cpp b/src/coro_rpc/examples/base_examples/rpc_service.cpp index 5b45d7c26..27f04db59 100644 --- a/src/coro_rpc/examples/base_examples/rpc_service.cpp +++ b/src/coro_rpc/examples/base_examples/rpc_service.cpp @@ -21,96 +21,112 @@ #include #include +#include "async_simple/coro/Lazy.h" +#include "async_simple/coro/Sleep.h" +#include "ylt/coro_io/client_pool.hpp" +#include "ylt/coro_io/coro_io.hpp" +#include "ylt/coro_rpc/impl/coro_rpc_client.hpp" #include "ylt/coro_rpc/impl/errno.h" +#include "ylt/coro_rpc/impl/expected.hpp" +#include "ylt/coro_rpc/impl/protocol/coro_rpc_protocol.hpp" using namespace coro_rpc; +using namespace async_simple::coro; +using namespace std::chrono_literals; -std::string hello_world() { - ELOGV(INFO, "call helloworld"); - return "hello_world"; +std::string_view echo(std::string_view data) { + ELOGV(INFO, "call echo"); + return data; } -bool return_bool_hello_world() { return true; } - -int A_add_B(int a, int b) { - ELOGV(INFO, "call A+B"); - return a + b; -} - -void echo_with_attachment(coro_rpc::context conn) { - ELOGV(INFO, "call echo_with_attachment"); - std::string str = conn.release_request_attachment(); - conn.set_response_attachment(std::move(str)); - conn.response_msg(); +Lazy coroutine_echo(std::string_view data) { + ELOGV(INFO, "call coroutine_echo"); + co_await coro_io::sleep_for(1s); + co_return data; } -void echo_with_attachment2(coro_rpc::context conn) { - ELOGV(INFO, "call echo_with_attachment2"); - std::string_view str = conn.get_request_attachment(); - // The live time of attachment is same as coro_rpc::context - conn.set_response_attachment([str, conn] { - return str; +void async_echo_by_callback( + coro_rpc::context conn, + std::string_view /*rpc request data here*/ data) { + ELOGV(INFO, "call async_echo_by_callback"); + /* rpc function runs in global io thread pool */ + coro_io::post([conn, data]() mutable { + /* send work to global non-io thread pool */ + auto *ctx = conn.get_context(); + conn.response_msg(data); /*response here*/ + }).start([](auto &&) { }); - conn.response_msg(); } -std::string echo(std::string_view sv) { return std::string{sv}; } +Lazy async_echo_by_coroutine(std::string_view sv) { + ELOGV(INFO, "call async_echo_by_coroutine"); + co_await coro_io::sleep_for(std::chrono::milliseconds(100)); // sleeping + co_return sv; +} -async_simple::coro::Lazy coro_echo(std::string_view sv) { - ELOGV(INFO, "call coro_echo"); - co_await coro_io::sleep_for(std::chrono::milliseconds(100)); - ELOGV(INFO, "after sleep for a while"); - co_return std::string{sv}; +Lazy get_ctx_info() { + ELOGV(INFO, "call get_ctx_info"); + auto *ctx = co_await coro_rpc::get_context_in_coro(); + if (ctx->has_closed()) { + throw std::runtime_error("connection is close!"); + } + ELOGV(INFO, "call function echo_with_attachment, conn ID:%d, request ID:%d", + ctx->get_connection_id(), ctx->get_request_id()); + ELOGI << "remote endpoint: " << ctx->get_remote_endpoint() << "local endpoint" + << ctx->get_local_endpoint(); + std::string sv{ctx->get_request_attachment()}; + auto str = ctx->release_request_attachment(); + if (sv != str) { + ctx->close(); + throw rpc_error{coro_rpc::errc::io_error, "attachment error!"}; + co_return; + } + ctx->set_response_attachment(std::move(str)); + co_await coro_io::sleep_for(514ms, coro_io::get_global_executor()); + ELOGV(INFO, "response in another executor"); + co_return; + co_return; } -void hello_with_delay(context conn, - std::string hello) { - ELOGV(INFO, "call HelloServer hello_with_delay"); - // create a new thread - std::thread([conn = std::move(conn), hello = std::move(hello)]() mutable { - // do some heavy work in this thread that won't block the io-thread, - std::cout << "running heavy work..." << std::endl; - std::this_thread::sleep_for(std::chrono::seconds{1}); - // Remember response before connection destruction! Or the connect will - // be closed. - conn.response_msg(hello); - }).detach(); +void echo_with_attachment() { + ELOGV(INFO, "call echo_with_attachment"); + auto ctx = coro_rpc::get_context(); + ctx->set_response_attachment( + ctx->get_request_attachment()); /*zero-copy by string_view*/ } -async_simple::coro::Lazy nested_echo(std::string_view sv) { +Lazy nested_echo(std::string_view sv) { ELOGV(INFO, "start nested echo"); - coro_rpc::coro_rpc_client client(co_await coro_io::get_current_executor()); - [[maybe_unused]] auto ec = co_await client.connect("127.0.0.1", "8802"); - assert(!ec); + /*get a client by global client pool*/ + auto client = + coro_io::g_clients_pool().at("127.0.0.1:8802"); + assert(client != nullptr); ELOGV(INFO, "connect another server"); - auto ret = co_await client.call(sv); - assert(ret.value() == sv); - ELOGV(INFO, "get echo result from another server"); - co_return std::string{sv}; + auto ret = co_await client->send_request([sv](coro_rpc_client &client) { + return client.call(sv); + }); + co_return ret.value().value(); } -std::string HelloService::hello() { +std::string_view HelloService::hello() { ELOGV(INFO, "call HelloServer::hello"); return "HelloService::hello"; } -void HelloService::hello_with_delay( - coro_rpc::context conn, std::string hello) { - ELOGV(INFO, "call HelloServer::hello_with_delay"); - std::thread([conn = std::move(conn), hello = std::move(hello)]() mutable { - conn.response_msg("HelloService::hello_with_delay"); - }).detach(); - return; +void return_error_by_context(coro_rpc::context conn) { + conn.response_error(coro_rpc::err_code{404}, "404 Not Found."); } -void return_error(coro_rpc::context conn) { - conn.response_error(coro_rpc::err_code{404}, "404 Not Found."); +void return_error_by_exception() { + throw coro_rpc::rpc_error{coro_rpc::errc{404}, "rpc not found."}; } -void rpc_with_state_by_tag(coro_rpc::context conn) { - if (!conn.tag().has_value()) { - conn.tag() = uint64_t{0}; + +Lazy rpc_with_state_by_tag() { + auto *ctx = co_await coro_rpc::get_context_in_coro(); + if (!ctx->tag().has_value()) { + ctx->tag() = std::uint64_t{0}; } - auto &cnter = std::any_cast(conn.tag()); + auto &cnter = std::any_cast(ctx->tag()); ELOGV(INFO, "call count: %d", ++cnter); - conn.response_msg(std::to_string(cnter)); + co_return std::to_string(cnter); } \ No newline at end of file diff --git a/src/coro_rpc/examples/base_examples/rpc_service.h b/src/coro_rpc/examples/base_examples/rpc_service.h index 1198009b7..368568f06 100644 --- a/src/coro_rpc/examples/base_examples/rpc_service.h +++ b/src/coro_rpc/examples/base_examples/rpc_service.h @@ -20,21 +20,22 @@ #include #include -std::string hello_world(); -bool return_bool_hello_world(); -int A_add_B(int a, int b); -void hello_with_delay(coro_rpc::context conn, std::string hello); -std::string echo(std::string_view sv); -void echo_with_attachment(coro_rpc::context conn); -void echo_with_attachment2(coro_rpc::context conn); -void return_error(coro_rpc::context conn); -void rpc_with_state_by_tag(coro_rpc::context conn); -async_simple::coro::Lazy coro_echo(std::string_view sv); -async_simple::coro::Lazy nested_echo(std::string_view sv); +std::string_view echo(std::string_view data); +async_simple::coro::Lazy coroutine_echo( + std::string_view data); +void async_echo_by_callback( + coro_rpc::context conn, + std::string_view /*rpc request data here*/ data); +async_simple::coro::Lazy async_echo_by_coroutine( + std::string_view sv); +void echo_with_attachment(); +async_simple::coro::Lazy nested_echo(std::string_view sv); +void return_error_by_context(coro_rpc::context conn); +void return_error_by_exception(); +async_simple::coro::Lazy get_ctx_info(); class HelloService { public: - std::string hello(); - void hello_with_delay(coro_rpc::context conn, std::string hello); + std::string_view hello(); }; - +async_simple::coro::Lazy rpc_with_state_by_tag(); #endif // CORO_RPC_RPC_API_HPP diff --git a/src/coro_rpc/examples/base_examples/server.cpp b/src/coro_rpc/examples/base_examples/server.cpp index e6cedb2dc..c7afd710a 100644 --- a/src/coro_rpc/examples/base_examples/server.cpp +++ b/src/coro_rpc/examples/base_examples/server.cpp @@ -26,19 +26,15 @@ int main() { coro_rpc_server server2{/*thread=*/1, /*port=*/8802}; - server.register_handler(); - // regist normal function for rpc - server.register_handler(); + server.register_handler< + echo, coroutine_echo, async_echo_by_callback, async_echo_by_coroutine, + echo_with_attachment, nested_echo, return_error_by_context, + return_error_by_exception, rpc_with_state_by_tag, get_ctx_info>(); // regist member function for rpc HelloService hello_service; - server - .register_handler<&HelloService::hello, &HelloService::hello_with_delay>( - &hello_service); + server.register_handler<&HelloService::hello>(&hello_service); server2.register_handler(); // async start server diff --git a/src/coro_rpc/examples/file_transfer/rpc_service.cpp b/src/coro_rpc/examples/file_transfer/rpc_service.cpp index e375fc2b2..e034218a5 100644 --- a/src/coro_rpc/examples/file_transfer/rpc_service.cpp +++ b/src/coro_rpc/examples/file_transfer/rpc_service.cpp @@ -6,18 +6,19 @@ std::string echo(std::string str) { return str; } void upload_file(coro_rpc::context conn, file_part part) { - if (!conn.tag().has_value()) { + auto &ctx = *conn.get_context(); + if (!ctx.tag().has_value()) { auto filename = std::to_string(std::time(0)) + std::filesystem::path(part.filename).extension().string(); - conn.tag() = std::make_shared( + ctx.tag() = std::make_shared( filename, std::ios::binary | std::ios::app); } - auto stream = std::any_cast>(conn.tag()); + auto stream = std::any_cast>(ctx.tag()); std::cout << "file part content size=" << part.content.size() << "\n"; stream->write(part.content.data(), part.content.size()); if (part.eof) { stream->close(); - conn.tag().reset(); + ctx.tag().reset(); std::cout << "file upload finished\n"; } @@ -26,7 +27,8 @@ void upload_file(coro_rpc::context conn, file_part part) { void download_file(coro_rpc::context conn, std::string filename) { - if (!conn.tag().has_value()) { + auto &ctx = *conn.get_context(); + if (!ctx.tag().has_value()) { std::string actual_filename = std::filesystem::path(filename).filename().string(); if (!std::filesystem::is_regular_file(actual_filename) || @@ -34,10 +36,10 @@ void download_file(coro_rpc::context conn, conn.response_msg(response_part{std::errc::invalid_argument}); return; } - conn.tag() = + ctx.tag() = std::make_shared(actual_filename, std::ios::binary); } - auto stream = std::any_cast>(conn.tag()); + auto stream = std::any_cast>(ctx.tag()); char buf[1024]; @@ -47,7 +49,7 @@ void download_file(coro_rpc::context conn, if (stream->eof()) { stream->close(); - conn.tag().reset(); + ctx.tag().reset(); } } diff --git a/src/coro_rpc/tests/rpc_api.cpp b/src/coro_rpc/tests/rpc_api.cpp index d5f09d36b..7ed6c6bc2 100644 --- a/src/coro_rpc/tests/rpc_api.cpp +++ b/src/coro_rpc/tests/rpc_api.cpp @@ -15,9 +15,12 @@ */ #include "rpc_api.hpp" +#include #include #include +#include "ylt/coro_rpc/impl/errno.h" + using namespace coro_rpc; using namespace std::chrono_literals; using namespace std::string_literals; @@ -46,11 +49,69 @@ int long_run_func(int val) { } void echo_with_attachment(coro_rpc::context conn) { - ELOGV(INFO, "conn ID:%d", conn.get_connection_id()); - auto str = conn.release_request_attachment(); - conn.set_response_attachment(std::move(str)); + ELOGV(INFO, "call function echo_with_attachment, conn ID:%d", + conn.get_context()->get_connection_id()); + auto str = conn.get_context()->release_request_attachment(); + conn.get_context()->set_response_attachment(std::move(str)); conn.response_msg(); } +template +void test_ctx_impl(T *ctx, std::string_view name) { + if (ctx->has_closed()) { + throw std::runtime_error("connection is close!"); + } + ELOGV(INFO, "call function echo_with_attachment, conn ID:%d, request ID:%d", + ctx->get_connection_id(), ctx->get_request_id()); + ELOGI << "remote endpoint: " << ctx->get_remote_endpoint() << "local endpoint" + << ctx->get_local_endpoint(); + if (ctx->get_rpc_function_name() != name) { + throw std::runtime_error("get error rpc function name!"); + } + ELOGI << "rpc function name:" << ctx->get_rpc_function_name(); + std::string sv{ctx->get_request_attachment()}; + auto str = ctx->release_request_attachment(); + if (sv != str) { + throw std::runtime_error("coro_rpc::errc::rpc_throw_exception"); + } + ctx->set_response_attachment(std::move(str)); +} +void test_context() { + auto *ctx = coro_rpc::get_context(); + test_ctx_impl(ctx, "test_context"); + return; +} +void test_callback_context(coro_rpc::context conn) { + auto *ctx = conn.get_context(); + test_ctx_impl(ctx, "test_callback_context"); + [](coro_rpc::context conn) -> async_simple::coro::Lazy { + co_await coro_io::sleep_for(514ms); + ELOGV(INFO, "response in another executor"); + conn.response_msg(); + }(std::move(conn)) + .via(coro_io::get_global_executor()) + .detach(); + return; +} +using namespace async_simple::coro; + +Lazy test_lazy_context() { + auto *ctx = co_await coro_rpc::get_context_in_coro(); + test_ctx_impl(ctx, "test_lazy_context"); + co_await coro_io::sleep_for(514ms, coro_io::get_global_executor()); + ELOGV(INFO, "response in another executor"); + co_return; +} + +void test_response_error5() { + throw coro_rpc::rpc_error{coro_rpc::errc::address_in_used, + "error with user-defined msg"}; + return; +} + +Lazy test_response_error6() { + throw coro_rpc::rpc_error{coro_rpc::errc::address_in_used, + "error with user-defined msg"}; +} void coro_fun_with_user_define_connection_type(my_context conn) { conn.ctx_.response_msg(); @@ -91,7 +152,6 @@ void coro_fun_with_delay_return_string_twice( } void fun_with_delay_return_void_cost_long_time(coro_rpc::context conn) { - conn.set_delay(); std::thread([conn = std::move(conn)]() mutable { std::this_thread::sleep_for(700ms); conn.response_msg(); diff --git a/src/coro_rpc/tests/rpc_api.hpp b/src/coro_rpc/tests/rpc_api.hpp index 66fca1981..7d905531f 100644 --- a/src/coro_rpc/tests/rpc_api.hpp +++ b/src/coro_rpc/tests/rpc_api.hpp @@ -37,8 +37,13 @@ struct my_context { }; void echo_with_attachment(coro_rpc::context conn); inline void error_with_context(coro_rpc::context conn) { - conn.response_error(coro_rpc::err_code{104}, "My Error."); + conn.response_error(coro_rpc::err_code{1004}, "My Error."); } +void test_context(); +void test_callback_context(coro_rpc::context conn); +async_simple::coro::Lazy test_lazy_context(); +void test_response_error5(); +async_simple::coro::Lazy test_response_error6(); void coro_fun_with_user_define_connection_type(my_context conn); void coro_fun_with_delay_return_void(coro_rpc::context conn); void coro_fun_with_delay_return_string(coro_rpc::context conn); @@ -48,19 +53,13 @@ void coro_fun_with_delay_return_string_twice( void coro_fun_with_delay_return_void_cost_long_time( coro_rpc::context conn); inline async_simple::coro::Lazy coro_func_return_void(int i) { + auto ctx = co_await coro_rpc::get_context_in_coro(); + ELOGV(INFO, + "call function coro_func_return_void, connection id:%d,request id:%d", + ctx->get_connection_id(), ctx->get_request_id()); co_return; } inline async_simple::coro::Lazy coro_func(int i) { co_return i; } -inline async_simple::coro::Lazy coro_func_delay_return_int( - coro_rpc::context conn, int i) { - conn.response_msg(i); - co_return; -} -inline async_simple::coro::Lazy coro_func_delay_return_void( - coro_rpc::context conn, int i) { - conn.response_msg(); - co_return; -} class HelloService { public: @@ -68,16 +67,6 @@ class HelloService { static std::string static_hello(); async_simple::coro::Lazy coro_func(int i) { co_return i; } async_simple::coro::Lazy coro_func_return_void(int i) { co_return; } - async_simple::coro::Lazy coro_func_delay_return_int( - coro_rpc::context conn, int i) { - conn.response_msg(i); - co_return; - } - async_simple::coro::Lazy coro_func_delay_return_void( - coro_rpc::context conn, int i) { - conn.response_msg(); - co_return; - } private: }; diff --git a/src/coro_rpc/tests/test_coro_rpc_client.cpp b/src/coro_rpc/tests/test_coro_rpc_client.cpp index 77df32dc3..a820f8e73 100644 --- a/src/coro_rpc/tests/test_coro_rpc_client.cpp +++ b/src/coro_rpc/tests/test_coro_rpc_client.cpp @@ -420,7 +420,7 @@ TEST_CASE("testing client with context response user-defined error") { server.register_handler(); auto ret = client.sync_call(); REQUIRE(!ret.has_value()); - CHECK(ret.error().code == coro_rpc::errc{104}); + CHECK(ret.error().code == coro_rpc::errc{1004}); CHECK(ret.error().msg == "My Error."); CHECK(client.has_closed() == false); auto ret2 = client.sync_call(); diff --git a/src/coro_rpc/tests/test_coro_rpc_server.cpp b/src/coro_rpc/tests/test_coro_rpc_server.cpp index 866815c23..6cf512043 100644 --- a/src/coro_rpc/tests/test_coro_rpc_server.cpp +++ b/src/coro_rpc/tests/test_coro_rpc_server.cpp @@ -120,6 +120,8 @@ struct CoroServerTester : ServerTester { test_start_new_server_with_same_port(); test_server_send_bad_rpc_result(); test_server_send_no_body(); + test_context_func(); + test_return_err_by_throw_exception(); this->test_call_with_delay_func(); this->test_call_with_delay_func< coro_fun_with_user_define_connection_type>(); @@ -143,6 +145,10 @@ struct CoroServerTester : ServerTester { server.register_handler<&ns_login::LoginService::login>(&login_service_); server.register_handler<&HelloService::hello>(&hello_service_); server.register_handler(); + server.register_handler(); + server.register_handler(); + server.register_handler(); server.register_handler(); server.register_handler(); server.register_handler(); @@ -151,17 +157,50 @@ struct CoroServerTester : ServerTester { server.register_handler(); server.register_handler(); server.register_handler(); - server.register_handler(); - server.register_handler(); server.register_handler<&HelloService::coro_func, - &HelloService::coro_func_return_void, - &HelloService::coro_func_delay_return_void, - &HelloService::coro_func_delay_return_int>( + &HelloService::coro_func_return_void>( &hello_service_); server.register_handler(); server.register_handler<&CoroServerTester::get_value>(this); } + void test_context_func() { + auto client = create_client(); + ELOGV(INFO, "run %s, client_id %d", __func__, client->get_client_id()); + client->set_req_attachment("1234567890987654321234567890"); + auto result = syncAwait(client->call()); + CHECK(result); + CHECK(client->get_resp_attachment() == "1234567890987654321234567890"); + + client->set_req_attachment("12345678909876543212345678901"); + result = syncAwait(client->call()); + CHECK(result); + CHECK(client->get_resp_attachment() == "12345678909876543212345678901"); + + client->set_req_attachment("01234567890987654321234567890"); + result = syncAwait(client->call()); + CHECK(result); + CHECK(client->get_resp_attachment() == "01234567890987654321234567890"); + } + void test_return_err_by_throw_exception() { + { + auto client = create_client(); + ELOGV(INFO, "run %s, client_id %d", __func__, client->get_client_id()); + auto result = syncAwait(client->call()); + REQUIRE(!result); + CHECK(result.error().code == coro_rpc::errc::address_in_used); + CHECK(result.error().msg == "error with user-defined msg"); + } + { + auto client = create_client(); + ELOGV(INFO, "run %s, client_id %d", __func__, client->get_client_id()); + auto result = syncAwait(client->call()); + REQUIRE(!result); + CHECK(result.error().code == coro_rpc::errc::address_in_used); + CHECK(result.error().msg == "error with user-defined msg"); + } + } + void test_function_not_registered() { g_action = {}; auto client = create_client(); @@ -193,7 +232,7 @@ struct CoroServerTester : ServerTester { auto new_server = coro_rpc_server(2, std::stoi(this->port_)); auto ec = new_server.async_start(); REQUIRE(!ec); - REQUIRE_MESSAGE(ec.error() == coro_rpc::errc::address_in_use, + REQUIRE_MESSAGE(ec.error() == coro_rpc::errc::address_in_used, ec.error().message()); } ELOGV(INFO, "OH NO"); @@ -203,7 +242,7 @@ struct CoroServerTester : ServerTester { ELOGV(INFO, "run %s, client_id %d", __func__, client->get_client_id()); auto ret = this->call(client); CHECK_MESSAGE( - ret.error().code == coro_rpc::errc::invalid_argument, + ret.error().code == coro_rpc::errc::invalid_rpc_result, std::to_string(client->get_client_id()).append(ret.error().msg)); g_action = {}; } @@ -239,20 +278,6 @@ struct CoroServerTester : ServerTester { auto ret5 = this->template call<&HelloService::coro_func_return_void>(client, 42); CHECK(ret5.has_value()); - - auto ret6 = this->template call<&HelloService::coro_func_delay_return_void>( - client, 42); - CHECK(ret6.has_value()); - - auto ret7 = this->template call<&HelloService::coro_func_delay_return_int>( - client, 42); - CHECK(ret7.value() == 42); - - auto ret8 = this->template call(client, 42); - CHECK(ret8.has_value()); - - auto ret9 = this->template call(client, 42); - CHECK(ret9.value() == 42); } coro_rpc_server server; std::thread thd; diff --git a/src/coro_rpc/tests/test_router.cpp b/src/coro_rpc/tests/test_router.cpp index b2c6a2503..9c7aaac0d 100644 --- a/src/coro_rpc/tests/test_router.cpp +++ b/src/coro_rpc/tests/test_router.cpp @@ -28,6 +28,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#define CORO_RPC_TEST #include #include #include @@ -74,7 +75,7 @@ get_result(const auto &pair) { coro_rpc_protocol>; rpc_return_type_t ret; struct_pack::err_code ec; - coro_rpc_protocol::rpc_error err; + rpc_error err; if (!rpc_errc) { ec = struct_pack::deserialize_to(ret, buffer); if (!ec) { @@ -94,7 +95,7 @@ get_result(const auto &pair) { } } // deserialize failed. - err = {coro_rpc::errc::invalid_argument, + err = {coro_rpc::errc::invalid_rpc_arguments, "failed to deserialize rpc return value"}; return return_type{unexpect_t{}, std::move(err)}; } @@ -107,7 +108,7 @@ void check_result(const auto &pair, size_t offset = 0) { typename RPC_trait::return_type r; auto res = struct_pack::deserialize_to(r, data); if (res) { - coro_rpc_protocol::rpc_error r; + rpc_error r; auto res = struct_pack::deserialize_to(r, data); CHECK(!res); } @@ -166,7 +167,8 @@ void bar3(int val) { std::cout << "bar3 val=" << val << "\n"; } using namespace test_util; auto ctx = std::make_shared< - coro_rpc::context_info_t>(nullptr); + coro_rpc::context_info_t>(router, + nullptr); struct person { int id; @@ -223,7 +225,7 @@ TEST_CASE("testing coro_handler") { async_simple::coro::syncAwait(router.route_coro( handler, std::string_view{buf.data() + g_head_offset, buf.size() - g_tail_offset}, - ctx, std::variant{}, id)); + std::variant{}, id)); } TEST_CASE("testing not registered func") { @@ -258,16 +260,16 @@ TEST_CASE("testing invalid arguments") { CHECK(!pair.first); pair = test_route<&test_class::plus_one>(ctx); - CHECK(pair.first == coro_rpc::errc::invalid_argument); + CHECK(pair.first == coro_rpc::errc::invalid_rpc_arguments); pair = test_route<&test_class::plus_one>(ctx, 42, 42); - CHECK(pair.first == coro_rpc::errc::invalid_argument); + CHECK(pair.first == coro_rpc::errc::invalid_rpc_arguments); pair = test_route<&test_class::plus_one>(ctx, "test"); - CHECK(pair.first == coro_rpc::errc::invalid_argument); + CHECK(pair.first == coro_rpc::errc::invalid_rpc_arguments); pair = test_route<&test_class::get_str>(ctx, "test"); - CHECK(pair.first == coro_rpc::errc::invalid_argument); + CHECK(pair.first == coro_rpc::errc::invalid_rpc_arguments); pair = test_route<&test_class::get_str>(ctx, std::string("test")); CHECK(!pair.first); @@ -280,7 +282,7 @@ TEST_CASE("testing invalid arguments") { router.register_handler(); pair = test_route(ctx, "test"); - CHECK(pair.first == coro_rpc::errc::invalid_argument); + CHECK(pair.first == coro_rpc::errc::invalid_rpc_arguments); pair = test_route(ctx, std::string("test")); CHECK(!pair.first); @@ -288,16 +290,16 @@ TEST_CASE("testing invalid arguments") { CHECK(r.value() == "test"); pair = test_route(ctx, 42, 42); - CHECK(pair.first == coro_rpc::errc::invalid_argument); + CHECK(pair.first == coro_rpc::errc::invalid_rpc_arguments); pair = test_route(ctx); - CHECK(pair.first == coro_rpc::errc::invalid_argument); + CHECK(pair.first == coro_rpc::errc::invalid_rpc_arguments); pair = test_route(ctx, 42); CHECK(!pair.first); pair = test_route(ctx, std::string("invalid arguments")); - CHECK(pair.first == coro_rpc::errc::invalid_argument); + CHECK(pair.first == coro_rpc::errc::invalid_rpc_arguments); // register_handler(); // test_route(ctx, 42); // will crash @@ -310,12 +312,12 @@ TEST_CASE("testing invalid buffer") { g_head_offset = 2; pair = test_route(ctx, 42); - CHECK(pair.first == coro_rpc::errc::invalid_argument); + CHECK(pair.first == coro_rpc::errc::invalid_rpc_arguments); g_head_offset = 0; g_tail_offset = 2; pair = test_route(ctx, 42); - CHECK(pair.first == coro_rpc::errc::invalid_argument); + CHECK(pair.first == coro_rpc::errc::invalid_rpc_arguments); g_tail_offset = 0; } @@ -328,12 +330,12 @@ TEST_CASE("testing exceptions") { std::pair pair{}; pair = test_route(ctx); - CHECK(pair.first == coro_rpc::errc::interrupted); + CHECK(pair.first == coro_rpc::errc::rpc_throw_exception); auto r = get_result(pair); std::cout << r.error().msg << "\n"; pair = test_route(ctx); - CHECK(pair.first == coro_rpc::errc::interrupted); + CHECK(pair.first == coro_rpc::errc::rpc_throw_exception); r = get_result(pair); std::cout << r.error().msg << "\n"; } diff --git a/src/coro_rpc/tests/test_variadic.cpp b/src/coro_rpc/tests/test_variadic.cpp index 1f6feba63..79a8b0153 100644 --- a/src/coro_rpc/tests/test_variadic.cpp +++ b/src/coro_rpc/tests/test_variadic.cpp @@ -55,6 +55,5 @@ TEST_CASE("test varadic param") { CHECK(ret); if (ret) { CHECK(ret == "1145142.000000Hello coro_rpc!hellohiwhat"); - std::cout << ret.value() << std::endl; } }