diff --git a/include/ylt/coro_io/client_pool.hpp b/include/ylt/coro_io/client_pool.hpp index 636dfe315..fe62f34b7 100644 --- a/include/ylt/coro_io/client_pool.hpp +++ b/include/ylt/coro_io/client_pool.hpp @@ -208,9 +208,9 @@ class client_pool : public std::enable_shared_from_this< ++self->promise_cnt_; self->promise_queue_.enqueue(handler); timer->expires_after( - std::max(std::chrono::milliseconds{0}, - self->pool_config_.max_connection_time - - std::chrono::milliseconds{20})); + (std::max)(std::chrono::milliseconds{0}, + self->pool_config_.max_connection_time - + std::chrono::milliseconds{20})); timer->async_await().start([handler = std::move(handler), client_ptr = client_ptr](auto&& res) { auto has_response = handler->flag_.exchange(true); diff --git a/include/ylt/coro_rpc/impl/coro_rpc_client.hpp b/include/ylt/coro_rpc/impl/coro_rpc_client.hpp index 4271b962a..55d03c019 100644 --- a/include/ylt/coro_rpc/impl/coro_rpc_client.hpp +++ b/include/ylt/coro_rpc/impl/coro_rpc_client.hpp @@ -129,9 +129,8 @@ class coro_rpc_client { */ coro_rpc_client(asio::io_context::executor_type executor, uint32_t client_id = 0) - : executor_(executor), - timer_(executor), - socket_(std::make_shared(executor)) { + : control_(std::make_shared(executor, false)), + timer_(executor) { config_.client_id = client_id; } @@ -142,10 +141,9 @@ class coro_rpc_client { coro_rpc_client( coro_io::ExecutorWrapper<> &executor = *coro_io::get_global_executor(), uint32_t client_id = 0) - : executor_(executor.get_asio_executor()), - timer_(executor.get_asio_executor()), - socket_(std::make_shared( - executor.get_asio_executor())) { + : control_( + std::make_shared(executor.get_asio_executor(), false)), + timer_(executor.get_asio_executor()) { config_.client_id = client_id; } @@ -317,7 +315,7 @@ class coro_rpc_client { } else { #endif - ret = co_await call_impl(*socket_, std::move(args)...); + ret = co_await call_impl(control_->socket_, std::move(args)...); #ifdef YLT_ENABLE_SSL } #endif @@ -325,7 +323,7 @@ class coro_rpc_client { std::error_code err_code; timer_.cancel(err_code); - if (*is_timeout_) { + if (control_->is_timeout_) { ret = rpc_result{ unexpect_t{}, rpc_error{errc::timed_out, "rpc call timed out"}}; } @@ -340,7 +338,7 @@ class coro_rpc_client { /*! * Get inner executor */ - auto &get_executor() { return executor_; } + auto &get_executor() { return control_->executor_; } uint32_t get_client_id() const { return config_.client_id; } @@ -350,7 +348,7 @@ class coro_rpc_client { } has_closed_ = true; ELOGV(INFO, "client_id %d close", config_.client_id); - close_socket(socket_); + close_socket(control_); } bool set_req_attachment(std::string_view attachment) { @@ -379,10 +377,10 @@ class coro_rpc_client { }; void reset() { - close_socket(socket_); - socket_ = - std::make_shared(executor_.get_asio_executor()); - *is_timeout_ = false; + close_socket(control_); + control_->socket_ = + asio::ip::tcp::socket(control_->executor_.get_asio_executor()); + control_->is_timeout_ = false; has_closed_ = false; } static bool is_ok(coro_rpc::err_code ec) noexcept { return !ec; } @@ -411,23 +409,23 @@ class coro_rpc_client { }); std::error_code ec = co_await coro_io::async_connect( - &executor_, *socket_, config_.host, config_.port); + &control_->executor_, control_->socket_, config_.host, config_.port); std::error_code err_code; timer_.cancel(err_code); if (ec) { - if (*is_timeout_) { + if (control_->is_timeout_) { co_return errc::timed_out; } co_return errc::not_connected; } - if (*is_timeout_) { + if (control_->is_timeout_) { ELOGV(WARN, "client_id %d connect timeout", config_.client_id); co_return errc::timed_out; } - socket_->set_option(asio::ip::tcp::no_delay(true), ec); + control_->socket_.set_option(asio::ip::tcp::no_delay(true), ec); #ifdef YLT_ENABLE_SSL if (!config_.ssl_cert_path.empty()) { @@ -465,7 +463,7 @@ class coro_rpc_client { asio::ssl::host_name_verification(config_.ssl_domain)); ssl_stream_ = std::make_unique>( - *socket_, ssl_ctx_); + control_->socket_, ssl_ctx_); ssl_init_ret_ = true; } catch (std::exception &e) { ELOGV(ERROR, "init ssl failed: %s", e.what()); @@ -475,15 +473,17 @@ class coro_rpc_client { #endif async_simple::coro::Lazy timeout(auto duration, std::string err_msg) { timer_.expires_after(duration); - auto socker_watcher = socket_; - auto timeout_watcher = is_timeout_; + std::weak_ptr socket_watcher = control_; bool is_timeout = co_await timer_.async_await(); if (!is_timeout) { co_return false; } - *timeout_watcher = is_timeout; - close_socket(socker_watcher); - co_return true; + if (auto self = socket_watcher.lock()) { + self->is_timeout_ = is_timeout; + close_socket(self); + co_return true; + } + co_return false; } template @@ -582,7 +582,7 @@ class coro_rpc_client { ret = co_await coro_io::async_write( socket, asio::buffer(buffer.data(), coro_rpc_protocol::REQ_HEAD_LEN)); ELOGV(INFO, "client_id %d shutdown", config_.client_id); - socket_->shutdown(asio::ip::tcp::socket::shutdown_send); + control_->socket_.shutdown(asio::ip::tcp::socket::shutdown_send); r = rpc_result{ unexpect_t{}, rpc_error{errc::io_error, ret.first.message()}}; co_return r; @@ -657,10 +657,10 @@ class coro_rpc_client { } #ifdef UNIT_TEST_INJECT if (g_action == inject_action::force_inject_client_write_data_timeout) { - *is_timeout_ = true; + control_->is_timeout_ = true; } #endif - if (*is_timeout_) { + if (control_->is_timeout_) { r = rpc_result{ unexpect_t{}, rpc_error{.code = errc::timed_out, .msg = {}}}; } @@ -790,11 +790,21 @@ class coro_rpc_client { offset, std::forward(args)...); } - static void close_socket(std::shared_ptr socket) { - asio::dispatch(socket->get_executor(), [socket = std::move(socket)]() { + struct control_t { + asio::ip::tcp::socket socket_; + bool is_timeout_; + coro_io::ExecutorWrapper<> executor_; + control_t(asio::io_context::executor_type executor, bool is_timeout) + : socket_(executor), is_timeout_(is_timeout), executor_(executor) {} + }; + + static void close_socket( + std::shared_ptr control) { + control->executor_.schedule([control = std::move(control)]() { asio::error_code ignored_ec; - socket->shutdown(asio::ip::tcp::socket::shutdown_both, ignored_ec); - socket->close(ignored_ec); + control->socket_.shutdown(asio::ip::tcp::socket::shutdown_both, + ignored_ec); + control->socket_.close(ignored_ec); }); } @@ -812,10 +822,10 @@ class coro_rpc_client { call(std::forward(args)...)); } #endif + private: - coro_io::ExecutorWrapper<> executor_; coro_io::period_timer timer_; - std::shared_ptr socket_; + std::shared_ptr control_; std::string read_buf_, resp_attachment_buf_; std::string_view req_attachment_; config config_; @@ -825,7 +835,6 @@ class coro_rpc_client { std::unique_ptr> ssl_stream_; bool ssl_init_ret_ = true; #endif - std::shared_ptr is_timeout_ = std::make_shared(false); std::atomic has_closed_ = false; }; } // namespace coro_rpc diff --git a/src/coro_rpc/tests/test_coro_rpc_client.cpp b/src/coro_rpc/tests/test_coro_rpc_client.cpp index a820f8e73..f2bc1eaec 100644 --- a/src/coro_rpc/tests/test_coro_rpc_client.cpp +++ b/src/coro_rpc/tests/test_coro_rpc_client.cpp @@ -116,7 +116,7 @@ TEST_CASE("testing client") { g_action = {}; auto f = [&io_context, &port]() -> Lazy { auto client = co_await create_client(io_context, port); - auto ret = co_await client->template call_for(20ms); + auto ret = co_await client->template call_for(10ms); CHECK_MESSAGE(ret.error().code == coro_rpc::errc::timed_out, ret.error().msg); co_return;