diff --git a/include/ylt/coro_rpc/impl/context.hpp b/include/ylt/coro_rpc/impl/context.hpp index b4748b10e..daac940fe 100644 --- a/include/ylt/coro_rpc/impl/context.hpp +++ b/include/ylt/coro_rpc/impl/context.hpp @@ -18,8 +18,8 @@ #include #include -#include #include +#include #include #include #include @@ -89,7 +89,8 @@ class context_base { std::visit( [&](const serialize_proto &) { self_->conn_->template response_msg( - serialize_proto::serialize(), self_->req_head_, + serialize_proto::serialize(), + std::move(self_->resp_attachment_), self_->req_head_, self_->is_delay_); }, *rpc_protocol::get_serialize_protocol(self_->req_head_)); @@ -116,13 +117,17 @@ class context_base { std::visit( [&](const serialize_proto &) { self_->conn_->template response_msg( - serialize_proto::serialize(ret), self_->req_head_, + serialize_proto::serialize(ret), + std::move(self_->resp_attachment_), self_->req_head_, self_->is_delay_); }, *rpc_protocol::get_serialize_protocol(self_->req_head_)); // response_handler_(std::move(conn_), std::move(ret)); } + self_->resp_attachment_ = [] { + return std::string_view{}; + }; } /*! @@ -132,6 +137,51 @@ class context_base { */ bool has_closed() const { return self_->conn_->has_closed(); } + /*! + * Close connection + */ + void close() { return self_->conn_->async_close(); } + + /*! + * Get the unique connection ID + * @return connection id + */ + uint64_t get_connection_id() { return self_->conn_->conn_id_; } + + /*! + * Set the response_attachment + * @return a ref of response_attachment + */ + void set_response_attachment(std::string attachment) { + set_response_attachment([attachment = std::move(attachment)] { + return std::string_view{attachment}; + }); + } + + /*! + * Set the response_attachment + * @return a ref of response_attachment + */ + void set_response_attachment(std::function attachment) { + self_->resp_attachment_ = std::move(attachment); + } + + /*! + * Get the request attachment + * @return connection id + */ + std::string_view get_request_attachment() const { + return self_->req_attachment_; + } + + /*! + * Release the attachment + * @return connection id + */ + std::string release_request_attachment() { + return std::move(self_->req_attachment_); + } + void set_delay() { self_->is_delay_ = true; self_->conn_->set_rpc_call_type( diff --git a/include/ylt/coro_rpc/impl/coro_connection.hpp b/include/ylt/coro_rpc/impl/coro_connection.hpp index cf57cb38f..eefed397f 100644 --- a/include/ylt/coro_rpc/impl/coro_connection.hpp +++ b/include/ylt/coro_rpc/impl/coro_connection.hpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -37,12 +38,6 @@ #endif namespace coro_rpc { -template -concept apply_user_buf = requires() { - requires std::is_same_v::buffer_type>; -}; - class coro_connection; using rpc_conn = std::shared_ptr; @@ -51,7 +46,11 @@ struct context_info_t { constexpr static size_t body_default_size_ = 256; std::shared_ptr conn_; typename rpc_protocol::req_header req_head_; - std::vector body_; + std::string body_; + std::string req_attachment_; + std::function resp_attachment_ = [] { + return std::string_view{}; + }; std::atomic has_response_ = false; bool is_delay_ = false; context_info_t(std::shared_ptr &&conn) @@ -158,6 +157,7 @@ class coro_connection : public std::enable_shared_from_this { while (true) { auto &req_head = context_info->req_head_; auto &body = context_info->body_; + auto &req_attachment = context_info->req_attachment_; reset_timer(); auto ec = co_await rpc_protocol::read_head(socket, req_head); cancel_timer(); @@ -195,16 +195,12 @@ class coro_connection : public std::enable_shared_from_this { break; } - std::string_view payload{}; + std::string_view payload; // rpc_protocol::buffer_type maybe from user, default from framework. - constexpr bool apply_user_buf_v = apply_user_buf; - if constexpr (apply_user_buf_v) { - ec = co_await rpc_protocol::read_payload(socket, req_head, payload); - } - else { - ec = co_await rpc_protocol::read_payload(socket, req_head, body); - payload = std::string_view{body.data(), body.size()}; - } + + ec = co_await rpc_protocol::read_payload(socket, req_head, body, + req_attachment); + payload = std::string_view{body}; if (ec) AS_UNLIKELY { @@ -256,7 +252,7 @@ class coro_connection : public std::enable_shared_from_this { if (resp_err != std::errc{}) AS_UNLIKELY { std::swap(resp_buf, resp_error_msg); } std::string header_buf = rpc_protocol::prepare_response( - resp_buf, req_head, resp_err, resp_error_msg); + resp_buf, req_head, 0, resp_err, resp_error_msg); #ifdef UNIT_TEST_INJECT if (g_action == inject_action::close_socket_after_send_length) { @@ -280,7 +276,10 @@ class coro_connection : public std::enable_shared_from_this { AS_LIKELY { if (resp_err != std::errc{}) AS_UNLIKELY { resp_err_ = resp_err; } - write_queue_.emplace_back(std::move(header_buf), std::move(resp_buf)); + write_queue_.emplace_back(std::move(header_buf), std::move(resp_buf), + [] { + return std::string_view{}; + }); if (write_queue_.size() == 1) { send_data().start([self = shared_from_this()](auto &&) { }); @@ -299,11 +298,13 @@ class coro_connection : public std::enable_shared_from_this { template void response_msg(std::string &&body_buf, + std::function &&resp_attachment, const typename rpc_protocol::req_header &req_head, bool is_delay) { - std::string header_buf = rpc_protocol::prepare_response(body_buf, req_head); - response(std::move(header_buf), std::move(body_buf), shared_from_this(), - is_delay) + std::string header_buf = rpc_protocol::prepare_response( + body_buf, req_head, resp_attachment().size()); + response(std::move(header_buf), std::move(body_buf), + std::move(resp_attachment), shared_from_this(), is_delay) .via(executor_) .detach(); } @@ -351,9 +352,10 @@ class coro_connection : public std::enable_shared_from_this { auto &get_executor() { return *executor_; } private: - async_simple::coro::Lazy response(std::string header_buf, - std::string body_buf, rpc_conn self, - bool is_delay) noexcept { + async_simple::coro::Lazy response( + std::string header_buf, std::string body_buf, + std::function resp_attachment, rpc_conn self, + bool is_delay) noexcept { if (has_closed()) AS_UNLIKELY { ELOGV(DEBUG, "response_msg failed: connection has been closed"); @@ -365,7 +367,8 @@ class coro_connection : public std::enable_shared_from_this { body_buf.clear(); } #endif - write_queue_.emplace_back(std::move(header_buf), std::move(body_buf)); + write_queue_.emplace_back(std::move(header_buf), std::move(body_buf), + std::move(resp_attachment)); if (is_delay) { --delay_resp_cnt; assert(delay_resp_cnt >= 0); @@ -402,19 +405,38 @@ class coro_connection : public std::enable_shared_from_this { co_return; } #endif - std::array buffers{asio::buffer(msg.first), - asio::buffer(msg.second)}; + auto attachment = std::get<2>(msg)(); + if (attachment.empty()) { + std::array buffers{ + asio::buffer(std::get<0>(msg)), asio::buffer(std::get<1>(msg))}; +#ifdef YLT_ENABLE_SSL + if (use_ssl_) { + assert(ssl_stream_); + ret = co_await coro_io::async_write(*ssl_stream_, buffers); + } + else { +#endif + ret = co_await coro_io::async_write(socket_, buffers); #ifdef YLT_ENABLE_SSL - if (use_ssl_) { - assert(ssl_stream_); - ret = co_await coro_io::async_write(*ssl_stream_, buffers); + } +#endif } else { + std::array buffers{ + asio::buffer(std::get<0>(msg)), asio::buffer(std::get<1>(msg)), + asio::buffer(attachment)}; +#ifdef YLT_ENABLE_SSL + if (use_ssl_) { + assert(ssl_stream_); + ret = co_await coro_io::async_write(*ssl_stream_, buffers); + } + else { #endif - ret = co_await coro_io::async_write(socket_, buffers); + ret = co_await coro_io::async_write(socket_, buffers); #ifdef YLT_ENABLE_SSL - } + } #endif + } if (ret.first) AS_UNLIKELY { ELOGV(ERROR, "%s, %s", ret.first.message().data(), @@ -449,11 +471,13 @@ class coro_connection : public std::enable_shared_from_this { if (has_closed_) { return; } - close_socket(); + has_closed_ = true; + asio::error_code ignored_ec; + socket_.shutdown(asio::ip::tcp::socket::shutdown_both, ignored_ec); + socket_.close(ignored_ec); if (quit_callback_) { quit_callback_(conn_id_); } - has_closed_ = true; } void reset_timer() { @@ -472,22 +496,11 @@ class coro_connection : public std::enable_shared_from_this { ELOGV(INFO, "close timeout client conn_id %d", conn_id_); #endif - close_socket(); + close(); } }); } - void close_socket() { - if (has_closed_) { - return; - } - - asio::error_code ignored_ec; - socket_.shutdown(asio::ip::tcp::socket::shutdown_both, ignored_ec); - socket_.close(ignored_ec); - has_closed_ = true; - } - void cancel_timer() { if (!enable_check_timeout_) { return; @@ -502,7 +515,9 @@ class coro_connection : public std::enable_shared_from_this { async_simple::Executor *executor_; asio::ip::tcp::socket socket_; // FIXME: queue's performance can be imporved. - std::deque> write_queue_; + std::deque< + std::tuple>> + write_queue_; std::errc resp_err_; rpc_call_type rpc_call_type_{non_callback}; diff --git a/include/ylt/coro_rpc/impl/coro_rpc_client.hpp b/include/ylt/coro_rpc/impl/coro_rpc_client.hpp index 1a7529f45..2a35995c5 100644 --- a/include/ylt/coro_rpc/impl/coro_rpc_client.hpp +++ b/include/ylt/coro_rpc/impl/coro_rpc_client.hpp @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -36,7 +37,9 @@ #include #include +#include "asio/buffer.hpp" #include "asio/dispatch.hpp" +#include "asio/registered_buffer.hpp" #include "common_service.hpp" #include "context.hpp" #include "expected.hpp" @@ -44,6 +47,7 @@ #include "ylt/coro_io/coro_io.hpp" #include "ylt/coro_io/io_context_pool.hpp" #include "ylt/struct_pack.hpp" +#include "ylt/struct_pack/util.h" #include "ylt/util/function_name.h" #include "ylt/util/type_traits.h" #include "ylt/util/utils.hpp" @@ -127,7 +131,6 @@ class coro_rpc_client { : executor(executor), socket_(std::make_shared(executor)) { config_.client_id = client_id; - read_buf_.resize(default_read_buf_size_); } /*! @@ -141,7 +144,6 @@ class coro_rpc_client { socket_(std::make_shared( executor.get_asio_executor())) { config_.client_id = client_id; - read_buf_.resize(default_read_buf_size_); } std::string_view get_host() const { return config_.host; } @@ -347,12 +349,24 @@ class coro_rpc_client { if (has_closed_) { return; } - + has_closed_ = true; ELOGV(INFO, "client_id %d close", config_.client_id); - close_socket(socket_); + } - has_closed_ = true; + bool set_req_attachment(std::string_view attachment) { + if (attachment.size() > UINT32_MAX) { + ELOGV(ERROR, "too large rpc attachment"); + return false; + } + req_attachment_ = attachment; + return true; + } + + std::string_view get_resp_attachment() const { return resp_attachment_buf_; } + + std::string release_resp_attachment() { + return std::move(resp_attachment_buf_); } template @@ -534,7 +548,15 @@ class coro_rpc_client { using R = decltype(get_return_type()); auto buffer = prepare_buffer(std::move(args)...); + rpc_result r{}; + if (buffer.empty()) { + r = rpc_result{ + unexpect_t{}, + coro_rpc_protocol::rpc_error{std::errc::message_size, + "rpc body serialize size too big"}}; + co_return r; + } #ifdef GENERATE_BENCHMARK_DATA std::ofstream file( benchmark_file_path + std::string{get_func_name()} + ".in", @@ -581,12 +603,20 @@ class coro_rpc_client { co_return r; } else { - ret = co_await coro_io::async_write( - socket, asio::buffer(buffer.data(), buffer.size())); +#endif + if (req_attachment_.empty()) { + ret = co_await coro_io::async_write( + socket, asio::buffer(buffer.data(), buffer.size())); + } + else { + std::array iov{ + asio::const_buffer{buffer.data(), buffer.size()}, + asio::const_buffer{req_attachment_.data(), req_attachment_.size()}}; + ret = co_await coro_io::async_write(socket, iov); + req_attachment_ = {}; + } +#ifdef UNIT_TEST_INJECT } -#else - ret = co_await coro_io::async_write( - socket, asio::buffer(buffer.data(), buffer.size())); #endif if (!ret.first) { #ifdef UNIT_TEST_INJECT @@ -606,11 +636,21 @@ class coro_rpc_client { asio::buffer((char *)&header, coro_rpc_protocol::RESP_HEAD_LEN)); if (!ret.first) { uint32_t body_len = header.length; - if (body_len > read_buf_.size()) { - read_buf_.resize(body_len); + struct_pack::detail::resize(read_buf_, body_len); + if (header.attach_length == 0) { + ret = co_await coro_io::async_read( + socket, asio::buffer(read_buf_.data(), body_len)); + resp_attachment_buf_.clear(); + } + else { + struct_pack::detail::resize(resp_attachment_buf_, + header.attach_length); + std::array iov{ + asio::mutable_buffer{read_buf_.data(), body_len}, + asio::mutable_buffer{resp_attachment_buf_.data(), + resp_attachment_buf_.size()}}; + ret = co_await coro_io::async_read(socket, iov); } - ret = co_await coro_io::async_read( - socket, asio::buffer(read_buf_.data(), body_len)); if (!ret.first) { #ifdef GENERATE_BENCHMARK_DATA std::ofstream file( @@ -618,11 +658,11 @@ class coro_rpc_client { std::ofstream::binary | std::ofstream::out); file << std::string_view{(char *)&header, coro_rpc_protocol::RESP_HEAD_LEN}; - file << std::string_view{(char *)read_buf_.data(), body_len}; + file << read_buf_; + file << resp_attachment_buf_; file.close(); #endif - r = handle_response_buffer(read_buf_.data(), ret.second, - std::errc{header.err_code}); + r = handle_response_buffer(read_buf_, std::errc{header.err_code}); if (!r) { close(); } @@ -673,6 +713,7 @@ class coro_rpc_client { header = {}; header.magic = coro_rpc_protocol::magic_number; header.function_id = func_id(); + header.attach_length = req_attachment_.size(); #ifdef UNIT_TEST_INJECT header.seq_num = config_.client_id; if (g_action == inject_action::client_send_bad_magic_num) { @@ -683,7 +724,12 @@ class coro_rpc_client { } else { #endif - header.length = buffer.size() - coro_rpc_protocol::REQ_HEAD_LEN; + auto sz = buffer.size() - coro_rpc_protocol::REQ_HEAD_LEN; + if (sz > UINT32_MAX) { + ELOGV(ERROR, "too large rpc body"); + return {}; + } + header.length = sz; #ifdef UNIT_TEST_INJECT } #endif @@ -691,13 +737,13 @@ class coro_rpc_client { } template - rpc_result handle_response_buffer( - const std::byte *buffer, std::size_t len, std::errc rpc_errc) { + rpc_result handle_response_buffer(std::string &buffer, + std::errc rpc_errc) { rpc_return_type_t ret; struct_pack::errc ec; coro_rpc_protocol::rpc_error err; if (rpc_errc == std::errc{}) { - ec = struct_pack::deserialize_to(ret, (const char *)buffer, len); + ec = struct_pack::deserialize_to(ret, buffer); if (ec == struct_pack::errc::ok) { if constexpr (std::is_same_v) { return {}; @@ -709,7 +755,7 @@ class coro_rpc_client { } else { err.code = rpc_errc; - ec = struct_pack::deserialize_to(err.msg, (const char *)buffer, len); + ec = struct_pack::deserialize_to(err.msg, buffer); if (ec == struct_pack::errc::ok) { return rpc_result{unexpect_t{}, std::move(err)}; } @@ -773,7 +819,8 @@ class coro_rpc_client { private: coro_io::ExecutorWrapper<> executor; std::shared_ptr socket_; - std::vector read_buf_; + std::string read_buf_, resp_attachment_buf_; + std::string_view req_attachment_; config config_; constexpr static std::size_t default_read_buf_size_ = 256; #ifdef YLT_ENABLE_SSL diff --git a/include/ylt/coro_rpc/impl/protocol/coro_rpc_protocol.hpp b/include/ylt/coro_rpc/impl/protocol/coro_rpc_protocol.hpp index e5ad8f99a..ee884c91b 100644 --- a/include/ylt/coro_rpc/impl/protocol/coro_rpc_protocol.hpp +++ b/include/ylt/coro_rpc/impl/protocol/coro_rpc_protocol.hpp @@ -16,12 +16,16 @@ #pragma once #include +#include +#include #include +#include #include #include #include #include +#include "asio/buffer.hpp" #include "struct_pack_protocol.hpp" #include "ylt/coro_io/coro_io.hpp" #include "ylt/coro_rpc/impl/context.hpp" @@ -55,21 +59,19 @@ struct coro_rpc_protocol { uint32_t seq_num; //!< sequence number uint32_t function_id; //!< rpc function ID uint32_t length; //!< length of RPC body - uint32_t reserved; //!< reserved field + uint32_t attach_length; //!< reserved field }; struct resp_header { - uint8_t magic; //!< magic number - uint8_t version; //!< rpc protocol version - uint8_t err_code; //!< rpc error type - uint8_t msg_type; //!< message type - uint32_t seq_num; //!< sequence number - uint32_t length; //!< length of RPC body - uint32_t reserved; //!< reserved field + uint8_t magic; //!< magic number + uint8_t version; //!< rpc protocol version + uint8_t err_code; //!< rpc error type + uint8_t msg_type; //!< message type + uint32_t seq_num; //!< sequence number + uint32_t length; //!< length of RPC body + uint32_t attach_length; //!< reserved field }; - using buffer_type = std::vector; - using supported_serialize_protocols = std::variant; using route_key_t = uint32_t; using router = coro_rpc::protocol::router; @@ -102,25 +104,57 @@ struct coro_rpc_protocol { template static async_simple::coro::Lazy read_payload( - Socket& socket, req_header& req_head, buffer_type& buffer) { - buffer.resize(req_head.length); - auto [ec, _] = co_await coro_io::async_read(socket, asio::buffer(buffer)); - co_return ec; + Socket& socket, req_header& req_head, std::string& buffer, + std::string& attchment) { + uint64_t total = req_head.length + req_head.attach_length; + struct_pack::detail::resize(buffer, req_head.length); + if (req_head.attach_length > 0) { + struct_pack::detail::resize(attchment, req_head.attach_length); + std::array iov = { + asio::mutable_buffer{buffer.data(), buffer.size()}, + asio::mutable_buffer{attchment.data(), attchment.size()}}; + auto [ec, _] = co_await coro_io::async_read(socket, iov); + co_return ec; + } + else { + auto [ec, _] = co_await coro_io::async_read(socket, asio::buffer(buffer)); + co_return ec; + } } static std::string prepare_response(std::string& rpc_result, const req_header& req_header, + std::size_t attachment_len, std::errc rpc_err_code = {}, std::string_view err_msg = {}) { + std::string err_msg_buf; + if (attachment_len > UINT32_MAX) + AS_UNLIKELY { + ELOGV(ERROR, "attachment larger than 4G:%d", attachment_len); + rpc_err_code = std::errc::message_size; + err_msg_buf = + "attachment larger than 4G:" + std::to_string(attachment_len) + "B"; + err_msg = err_msg_buf; + } + else if (rpc_result.size() > UINT32_MAX) + AS_UNLIKELY { + auto sz = rpc_result.size(); + ELOGV(ERROR, "body larger than 4G:%d", sz); + rpc_err_code = std::errc::message_size; + err_msg_buf = + "body larger than 4G:" + std::to_string(attachment_len) + "B"; + err_msg = err_msg_buf; + } std::string header_buf; header_buf.resize(RESP_HEAD_LEN); auto& resp_head = *(resp_header*)header_buf.data(); resp_head.magic = magic_number; resp_head.seq_num = req_header.seq_num; resp_head.err_code = static_cast(rpc_err_code); + resp_head.attach_length = attachment_len; if (rpc_err_code != std::errc{}) AS_UNLIKELY { - assert(rpc_result.empty()); + rpc_result.clear(); struct_pack::serialize_to(rpc_result, err_msg); } resp_head.length = rpc_result.size(); diff --git a/src/coro_rpc/examples/base_examples/client.cpp b/src/coro_rpc/examples/base_examples/client.cpp index 4435ac3bb..c64ceb309 100644 --- a/src/coro_rpc/examples/base_examples/client.cpp +++ b/src/coro_rpc/examples/base_examples/client.cpp @@ -29,12 +29,27 @@ Lazy show_rpc_call() { [[maybe_unused]] auto ec = co_await client.connect("127.0.0.1", "8801"); assert(ec == std::errc{}); + auto ret = co_await client.call(); if (!ret) { std::cout << "hello_world err: " << ret.error().msg << std::endl; } assert(ret.value() == "hello_world"s); + client.set_req_attachment("This is attachment."); + auto ret_void = co_await client.call(); + if (!ret) { + std::cout << "hello_world err: " << ret.error().msg << std::endl; + } + assert(client.get_resp_attachment() == "This is attachment."); + + client.set_req_attachment("This is attachment2."); + ret_void = co_await client.call(); + if (!ret) { + std::cout << "hello_world err: " << ret.error().msg << std::endl; + } + assert(client.get_resp_attachment() == "This is attachment2."); + auto ret_int = co_await client.call(12, 30); if (!ret_int) { std::cout << "A_add_B err: " << ret_int.error().msg << std::endl; diff --git a/src/coro_rpc/examples/base_examples/rpc_service.cpp b/src/coro_rpc/examples/base_examples/rpc_service.cpp index 54d70b4ce..00127d8e6 100644 --- a/src/coro_rpc/examples/base_examples/rpc_service.cpp +++ b/src/coro_rpc/examples/base_examples/rpc_service.cpp @@ -32,6 +32,23 @@ int A_add_B(int a, int b) { return a + b; } +void echo_with_attachment(coro_rpc::context conn) { + ELOGV(INFO, "call echo_with_attachment"); + std::string str = conn.release_request_attachment(); + conn.set_response_attachment(std::move(str)); + conn.response_msg(); +} + +void echo_with_attachment2(coro_rpc::context conn) { + ELOGV(INFO, "call echo_with_attachment2"); + std::string_view str = conn.get_request_attachment(); + // The live time of attachment is same as coro_rpc::context + conn.set_response_attachment([str, conn] { + return str; + }); + conn.response_msg(); +} + std::string echo(std::string_view sv) { return std::string{sv}; } async_simple::coro::Lazy coro_echo(std::string_view sv) { diff --git a/src/coro_rpc/examples/base_examples/rpc_service.h b/src/coro_rpc/examples/base_examples/rpc_service.h index 1db3b756d..897d6f287 100644 --- a/src/coro_rpc/examples/base_examples/rpc_service.h +++ b/src/coro_rpc/examples/base_examples/rpc_service.h @@ -24,6 +24,8 @@ std::string hello_world(); int A_add_B(int a, int b); void hello_with_delay(coro_rpc::context conn, std::string hello); std::string echo(std::string_view sv); +void echo_with_attachment(coro_rpc::context conn); +void echo_with_attachment2(coro_rpc::context conn); async_simple::coro::Lazy coro_echo(std::string_view sv); async_simple::coro::Lazy nested_echo(std::string_view sv); class HelloService { diff --git a/src/coro_rpc/examples/base_examples/server.cpp b/src/coro_rpc/examples/base_examples/server.cpp index 23ad38551..267935ee2 100644 --- a/src/coro_rpc/examples/base_examples/server.cpp +++ b/src/coro_rpc/examples/base_examples/server.cpp @@ -28,7 +28,8 @@ int main() { // regist normal function for rpc server.register_handler(); + nested_echo, coro_echo, echo_with_attachment, + echo_with_attachment2>(); // regist member function for rpc HelloService hello_service; diff --git a/src/coro_rpc/tests/rpc_api.cpp b/src/coro_rpc/tests/rpc_api.cpp index c26b394c8..eb244a434 100644 --- a/src/coro_rpc/tests/rpc_api.cpp +++ b/src/coro_rpc/tests/rpc_api.cpp @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "rpc_api.hpp" + #include #include -#include "rpc_api.hpp" - using namespace coro_rpc; using namespace std::chrono_literals; using namespace std::string_literals; @@ -45,6 +45,12 @@ int long_run_func(int val) { return val; } +void echo_with_attachment(coro_rpc::context conn) { + auto str = conn.release_request_attachment(); + conn.set_response_attachment(std::move(str)); + conn.response_msg(); +} + void coro_fun_with_user_define_connection_type(my_context conn) { conn.ctx_.response_msg(); } diff --git a/src/coro_rpc/tests/rpc_api.hpp b/src/coro_rpc/tests/rpc_api.hpp index acf6a2176..655a6e99e 100644 --- a/src/coro_rpc/tests/rpc_api.hpp +++ b/src/coro_rpc/tests/rpc_api.hpp @@ -15,9 +15,9 @@ */ #ifndef CORO_RPC_RPC_API_HPP #define CORO_RPC_RPC_API_HPP -#include #include #include +#include void hi(); std::string hello(); @@ -33,7 +33,7 @@ struct my_context { my_context(coro_rpc::context&& ctx) : ctx_(std::move(ctx)) {} using return_type = void; }; - +void echo_with_attachment(coro_rpc::context conn); void coro_fun_with_user_define_connection_type(my_context conn); void coro_fun_with_delay_return_void(coro_rpc::context conn); void coro_fun_with_delay_return_string(coro_rpc::context conn); diff --git a/src/coro_rpc/tests/test_coro_rpc_client.cpp b/src/coro_rpc/tests/test_coro_rpc_client.cpp index 4aee088cb..d1b76d742 100644 --- a/src/coro_rpc/tests/test_coro_rpc_client.cpp +++ b/src/coro_rpc/tests/test_coro_rpc_client.cpp @@ -380,6 +380,31 @@ TEST_CASE("testing client with eof") { ret = client.sync_call(); REQUIRE_MESSAGE(ret.error().code == std::errc::io_error, ret.error().msg); } +TEST_CASE("testing client with attachment") { + g_action = {}; + coro_rpc_server server(2, 8801); + + auto res = server.async_start(); + REQUIRE_MESSAGE(res, "server start failed"); + coro_rpc_client client(*coro_io::get_global_executor(), g_client_id++); + auto ec = client.sync_connect("127.0.0.1", "8801"); + REQUIRE_MESSAGE(ec == std::errc{}, make_error_code(ec).message()); + + server.register_handler(); + + auto ret = client.sync_call(); + CHECK(ret.has_value()); + CHECK(client.get_resp_attachment() == ""); + + client.set_req_attachment("hellohi"); + ret = client.sync_call(); + CHECK(ret.has_value()); + CHECK(client.get_resp_attachment() == "hellohi"); + + ret = client.sync_call(); + CHECK(ret.has_value()); + CHECK(client.get_resp_attachment() == ""); +} TEST_CASE("testing client with shutdown") { g_action = {};