diff --git a/include/coro_rpc/coro_rpc/coro_rpc_client.hpp b/include/coro_rpc/coro_rpc/coro_rpc_client.hpp index ca7bcb2f..79d0cbb2 100644 --- a/include/coro_rpc/coro_rpc/coro_rpc_client.hpp +++ b/include/coro_rpc/coro_rpc/coro_rpc_client.hpp @@ -27,6 +27,7 @@ #include #include +#include "asio/io_context.hpp" #include "asio_util/asio_coro_util.hpp" #include "async_simple/coro/SyncAwait.h" #include "common_service.hpp" @@ -104,14 +105,17 @@ class coro_rpc_client { * Create client */ coro_rpc_client(uint32_t client_id = 0) - : io_context_ptr_(inner_io_context_.get()), + : inner_io_context_(std::make_unique()), + io_context_ptr_(inner_io_context_.get()), executor_(*inner_io_context_), socket_(*inner_io_context_), client_id_(client_id) { std::promise promise; thd_ = std::thread([this, &promise] { - asio::io_context::work work(*inner_io_context_); - promise.set_value(); + work_ = std::make_unique(*inner_io_context_); + asio::post(*io_context_ptr_, [&] { + promise.set_value(); + }); inner_io_context_->run(); }); promise.get_future().wait(); @@ -227,7 +231,10 @@ class coro_rpc_client { } #endif - ~coro_rpc_client() { sync_close(); } + ~coro_rpc_client() { + close(); + stop_inner_io_context(); + } /*! * Call RPC function with default timeout (5 second) @@ -338,7 +345,7 @@ class coro_rpc_client { } is_timeout_ = is_timeout; - sync_close(false); + close_socket(); promise.setValue(async_simple::Unit()); co_return true; } @@ -409,7 +416,7 @@ class coro_rpc_client { ret = co_await asio_util::async_write( socket, asio::buffer(buffer.data(), REQ_HEAD_LEN)); ELOGV(INFO, "client_id %d close socket", client_id_); - co_await close(); + close(); r = rpc_result{unexpect_t{}, rpc_error{std::errc::io_error, ret.first.message()}}; co_return r; @@ -419,7 +426,7 @@ class coro_rpc_client { ret = co_await asio_util::async_write( socket, asio::buffer(buffer.data(), REQ_HEAD_LEN - 1)); ELOGV(INFO, "client_id %d close socket", client_id_); - co_await close(); + close(); r = rpc_result{unexpect_t{}, rpc_error{std::errc::io_error, ret.first.message()}}; co_return r; @@ -449,7 +456,7 @@ class coro_rpc_client { client_id_); r = rpc_result{unexpect_t{}, rpc_error{std::errc::io_error, ret.first.message()}}; - co_await close(); + close(); co_return r; } #endif @@ -461,7 +468,7 @@ class coro_rpc_client { auto errc = struct_pack::deserialize_to(header, head, RESP_HEAD_LEN); if (errc != struct_pack::errc::ok) [[unlikely]] { ELOGV(ERROR, "deserialize rpc header failed"); - co_await close(); + close(); r = rpc_result{ unexpect_t{}, rpc_error{std::errc::io_error, struct_pack::error_message(errc)}}; @@ -485,7 +492,7 @@ class coro_rpc_client { r = handle_response_buffer(read_buf_.data(), ret.second, std::errc{header.err_code}); if (!r) { - co_await close(); + close(); } co_return r; } @@ -504,7 +511,7 @@ class coro_rpc_client { r = rpc_result{unexpect_t{}, rpc_error{.code = std::errc::io_error, .msg = ret.first.message()}}; } - co_await close(); + close(); co_return r; } /* @@ -606,60 +613,26 @@ class coro_rpc_client { offset, std::forward(args)...); } - async_simple::coro::Lazy close(bool close_ssl = true) { -#ifdef ENABLE_SSL - if (close_ssl) { - close_ssl_stream(); - } -#endif - if (has_closed_) { - co_return; - } - - ELOGV(INFO, "client_id %d close", client_id_); - - co_await asio_util::async_close(socket_); - has_closed_ = true; - } - -#ifdef ENABLE_SSL - void close_ssl_stream() { - if (ssl_stream_) { - asio::error_code ec; - ssl_stream_->shutdown(ec); - ssl_stream_ = nullptr; - } + void close_socket() { + asio::error_code ignored_ec; + socket_.shutdown(asio::ip::tcp::socket::shutdown_both, ignored_ec); + socket_.close(ignored_ec); } -#endif - void sync_close(bool close_ssl = true) { -#ifdef ENABLE_SSL - if (close_ssl) { - close_ssl_stream(); - stop_inner_io_context(); - } -#endif - if (close_ssl && has_closed_) { - stop_inner_io_context(); + void close() { + if (has_closed_) { return; } ELOGV(INFO, "client_id %d close", client_id_); - - asio::error_code ignored_ec; - socket_.shutdown(asio::ip::tcp::socket::shutdown_both, ignored_ec); - socket_.close(ignored_ec); + close_socket(); has_closed_ = true; - - if (close_ssl) { - stop_inner_io_context(); - } } void stop_inner_io_context() { if (thd_.joinable()) { - inner_io_context_->stop(); + work_ = nullptr; if (thd_.get_id() == std::this_thread::get_id()) { thd_.detach(); } @@ -689,8 +662,8 @@ class coro_rpc_client { } #endif private: - std::shared_ptr inner_io_context_ = - std::make_shared(); + std::unique_ptr inner_io_context_; + std::unique_ptr work_; asio::io_context *io_context_ptr_ = nullptr; std::thread thd_; asio_util::AsioExecutor executor_; @@ -708,5 +681,5 @@ class coro_rpc_client { uint32_t client_id_ = 0; std::atomic has_closed_ = false; -}; // namespace coro_rpc +}; } // namespace coro_rpc