diff --git a/include/ylt/coro_rpc/impl/coro_rpc_server.hpp b/include/ylt/coro_rpc/impl/coro_rpc_server.hpp index 6296036b2..7f37a47d9 100644 --- a/include/ylt/coro_rpc/impl/coro_rpc_server.hpp +++ b/include/ylt/coro_rpc/impl/coro_rpc_server.hpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -71,20 +72,36 @@ class coro_rpc_server_base { * default no timeout. */ coro_rpc_server_base(size_t thread_num, unsigned short port, + std::string address = "0.0.0.0", std::chrono::steady_clock::duration conn_timeout_duration = std::chrono::seconds(0)) : pool_(thread_num), acceptor_(pool_.get_executor()->get_asio_executor()), port_(port), conn_timeout_duration_(conn_timeout_duration), - flag_{stat::init} {} + flag_{stat::init} { + init_address(std::move(address)); + } + + coro_rpc_server_base(size_t thread_num, + std::string address /* = "0.0.0.0:9001" */, + std::chrono::steady_clock::duration + conn_timeout_duration = std::chrono::seconds(0)) + : pool_(thread_num), + acceptor_(pool_.get_executor()->get_asio_executor()), + conn_timeout_duration_(conn_timeout_duration), + flag_{stat::init} { + init_address(std::move(address)); + } coro_rpc_server_base(const server_config &config = server_config{}) : pool_(config.thread_num), acceptor_(pool_.get_executor()->get_asio_executor()), port_(config.port), conn_timeout_duration_(config.conn_timeout_duration), - flag_{stat::init} {} + flag_{stat::init} { + init_address(config.address); + } ~coro_rpc_server_base() { ELOGV(INFO, "coro_rpc_server will quit"); @@ -118,7 +135,6 @@ class coro_rpc_server_base { [[nodiscard]] coro_rpc::expected, coro_rpc::err_code> async_start() noexcept { - coro_rpc::err_code ec{}; { std::unique_lock lock(start_mtx_); if (flag_ != stat::init) { @@ -131,8 +147,8 @@ class coro_rpc_server_base { return coro_rpc::unexpected{ coro_rpc::errc::server_has_ran}; } - ec = listen(); - if (!ec) { + errc_ = listen(); + if (!errc_) { if constexpr (requires(typename server_config::executor_pool_t & pool) { pool.run(); }) { @@ -146,12 +162,13 @@ class coro_rpc_server_base { flag_ = stat::stop; } } - if (!ec) { + if (!errc_) { async_simple::Promise promise; auto future = promise.getFuture(); - accept().start([p = std::move(promise)](auto &&res) mutable { + accept().start([this, p = std::move(promise)](auto &&res) mutable { if (res.hasError()) { - p.setValue(coro_rpc::err_code{coro_rpc::errc::io_error}); + errc_ = coro_rpc::err_code{coro_rpc::errc::io_error}; + p.setValue(errc_); } else { p.setValue(res.value()); @@ -160,7 +177,7 @@ class coro_rpc_server_base { return std::move(future); } else { - return coro_rpc::unexpected{ec}; + return coro_rpc::unexpected{errc_}; } } @@ -207,6 +224,8 @@ class coro_rpc_server_base { * @return */ uint16_t port() const { return port_; }; + std::string_view address() const { return address_; } + coro_rpc::err_code get_errc() const { return errc_; } /*! * Register RPC service functions (member function) @@ -288,12 +307,27 @@ class coro_rpc_server_base { coro_rpc::err_code listen() { ELOGV(INFO, "begin to listen"); using asio::ip::tcp; - auto endpoint = tcp::endpoint(tcp::v4(), port_); - acceptor_.open(endpoint.protocol()); + asio::error_code ec; + asio::ip::tcp::resolver::query query(address_, std::to_string(port_)); + asio::ip::tcp::resolver resolver(acceptor_.get_executor()); + asio::ip::tcp::resolver::iterator it = resolver.resolve(query, ec); + + asio::ip::tcp::resolver::iterator it_end; + if (ec || it == it_end) { + ELOGV(ERROR, "resolve address %s error : %s", address_.data(), + ec.message().data()); + return coro_rpc::errc::bad_address; + } + + auto endpoint = it->endpoint(); + acceptor_.open(endpoint.protocol(), ec); + if (ec) { + ELOGV(ERROR, "open failed, error : %s", ec.message().data()); + return coro_rpc::errc::open_error; + } #ifdef __GNUC__ - acceptor_.set_option(tcp::acceptor::reuse_address(true)); + acceptor_.set_option(tcp::acceptor::reuse_address(true), ec); #endif - asio::error_code ec; acceptor_.bind(endpoint, ec); if (ec) { ELOGV(ERROR, "bind port %d error : %s", port_.load(), @@ -305,7 +339,14 @@ class coro_rpc_server_base { #ifdef _MSC_VER acceptor_.set_option(tcp::acceptor::reuse_address(true)); #endif - acceptor_.listen(); + acceptor_.listen(asio::socket_base::max_listen_connections, ec); + if (ec) { + ELOGV(ERROR, "port %d listen error : %s", port_.load(), + ec.message().data()); + acceptor_.cancel(ec); + acceptor_.close(ec); + return coro_rpc::errc::listen_error; + } auto end_point = acceptor_.local_endpoint(ec); if (ec) { @@ -383,6 +424,25 @@ class coro_rpc_server_base { acceptor_close_waiter_.get_future().wait(); } + void init_address(std::string address) { + if (size_t pos = address.find(':'); pos != std::string::npos) { + auto port_sv = std::string_view(address).substr(pos + 1); + + uint16_t port; + auto [ptr, ec] = std::from_chars( + port_sv.data(), port_sv.data() + port_sv.size(), port, 10); + if (ec != std::errc{}) { + address_ = std::move(address); + return; + } + + port_ = port; + address = address.substr(0, pos); + } + + address_ = std::move(address); + } + typename server_config::executor_pool_t pool_; asio::ip::tcp::acceptor acceptor_; std::promise acceptor_close_waiter_; @@ -398,6 +458,8 @@ class coro_rpc_server_base { typename server_config::rpc_protocol::router router_; std::atomic port_; + std::string address_; + coro_rpc::err_code errc_ = {}; std::chrono::steady_clock::duration conn_timeout_duration_; #ifdef YLT_ENABLE_SSL diff --git a/include/ylt/coro_rpc/impl/errno.h b/include/ylt/coro_rpc/impl/errno.h index d785b30db..5514da5c3 100644 --- a/include/ylt/coro_rpc/impl/errno.h +++ b/include/ylt/coro_rpc/impl/errno.h @@ -25,6 +25,9 @@ enum class errc : uint16_t { timed_out, invalid_argument, address_in_use, + bad_address, + open_error, + listen_error, operation_canceled, interrupted, function_not_registered, @@ -47,6 +50,12 @@ inline constexpr std::string_view make_error_message(errc ec) noexcept { return "invalid_argument"; case errc::address_in_use: return "address_in_use"; + case errc::bad_address: + return "bad_address"; + case errc::open_error: + return "open_error"; + case errc::listen_error: + return "listen_error"; case errc::operation_canceled: return "operation_canceled"; case errc::interrupted: diff --git a/include/ylt/standalone/cinatra/coro_http_server.hpp b/include/ylt/standalone/cinatra/coro_http_server.hpp index 9f9ae0b97..68fae10fd 100644 --- a/include/ylt/standalone/cinatra/coro_http_server.hpp +++ b/include/ylt/standalone/cinatra/coro_http_server.hpp @@ -1,13 +1,5 @@ #pragma once -#include -#include -#include -#include - -#include "asio/streambuf.hpp" -#include "async_simple/Promise.h" -#include "async_simple/coro/Lazy.h" #include "cinatra/coro_http_client.hpp" #include "cinatra/coro_http_response.hpp" #include "cinatra/coro_http_router.hpp" @@ -27,16 +19,37 @@ enum class file_resp_format_type { }; class coro_http_server { public: - coro_http_server(asio::io_context &ctx, unsigned short port) - : out_ctx_(&ctx), port_(port), acceptor_(ctx), check_timer_(ctx) {} + coro_http_server(asio::io_context &ctx, unsigned short port, + std::string address = "0.0.0.0") + : out_ctx_(&ctx), port_(port), acceptor_(ctx), check_timer_(ctx) { + init_address(std::move(address)); + } + + coro_http_server(asio::io_context &ctx, + std::string address /* = "0.0.0.0:9001" */) + : out_ctx_(&ctx), acceptor_(ctx), check_timer_(ctx) { + init_address(std::move(address)); + } coro_http_server(size_t thread_num, unsigned short port, - bool cpu_affinity = false) + std::string address = "0.0.0.0", bool cpu_affinity = false) : pool_(std::make_unique(thread_num, cpu_affinity)), port_(port), acceptor_(pool_->get_executor()->get_asio_executor()), - check_timer_(pool_->get_executor()->get_asio_executor()) {} + check_timer_(pool_->get_executor()->get_asio_executor()) { + init_address(std::move(address)); + } + + coro_http_server(size_t thread_num, + std::string address /* = "0.0.0.0:9001" */, + bool cpu_affinity = false) + : pool_(std::make_unique(thread_num, + cpu_affinity)), + acceptor_(pool_->get_executor()->get_asio_executor()), + check_timer_(pool_->get_executor()->get_asio_executor()) { + init_address(std::move(address)); + } ~coro_http_server() { CINATRA_LOG_INFO << "coro_http_server will quit"; @@ -64,21 +77,22 @@ class coro_http_server { // only call once, not thread safe. async_simple::Future async_start() { - auto ec = listen(); + errc_ = listen(); async_simple::Promise promise; auto future = promise.getFuture(); - if (ec == std::errc{}) { + if (errc_ == std::errc{}) { if (out_ctx_ == nullptr) { thd_ = std::thread([this] { pool_->run(); }); } - accept().start([p = std::move(promise)](auto &&res) mutable { + accept().start([p = std::move(promise), this](auto &&res) mutable { if (res.hasError()) { - p.setValue(std::errc::io_error); + errc_ = std::errc::io_error; + p.setValue(errc_); } else { p.setValue(res.value()); @@ -86,7 +100,7 @@ class coro_http_server { }); } else { - promise.setValue(ec); + promise.setValue(errc_); } return future; @@ -150,7 +164,7 @@ class coro_http_server { static_assert(std::is_member_function_pointer_v, "must be member function"); using return_type = typename util::function_traits::return_type; - if constexpr (is_lazy_v) { + if constexpr (coro_io::is_lazy_v) { std::function(coro_http_request & req, coro_http_response & resp)> f = std::bind(handler, &owner, std::placeholders::_1, @@ -488,16 +502,36 @@ class coro_http_server { return connections_.size(); } + std::string_view address() { return address_; } + std::errc get_errc() { return errc_; } + private: std::errc listen() { CINATRA_LOG_INFO << "begin to listen"; using asio::ip::tcp; - auto endpoint = tcp::endpoint(tcp::v4(), port_); - acceptor_.open(endpoint.protocol()); + asio::error_code ec; + + asio::ip::tcp::resolver::query query(address_, std::to_string(port_)); + asio::ip::tcp::resolver resolver(acceptor_.get_executor()); + asio::ip::tcp::resolver::iterator it = resolver.resolve(query, ec); + + asio::ip::tcp::resolver::iterator it_end; + if (ec || it == it_end) { + CINATRA_LOG_ERROR << "bad address: " << address_ + << " error: " << ec.message(); + return std::errc::bad_address; + } + + auto endpoint = it->endpoint(); + acceptor_.open(endpoint.protocol(), ec); + if (ec) { + CINATRA_LOG_ERROR << "acceptor open failed" + << " error: " << ec.message(); + return std::errc::io_error; + } #ifdef __GNUC__ - acceptor_.set_option(tcp::acceptor::reuse_address(true)); + acceptor_.set_option(tcp::acceptor::reuse_address(true), ec); #endif - asio::error_code ec; acceptor_.bind(endpoint, ec); if (ec) { CINATRA_LOG_ERROR << "bind port: " << port_ << " error: " << ec.message(); @@ -508,7 +542,12 @@ class coro_http_server { #ifdef _MSC_VER acceptor_.set_option(tcp::acceptor::reuse_address(true)); #endif - acceptor_.listen(); + acceptor_.listen(asio::socket_base::max_listen_connections, ec); + if (ec) { + CINATRA_LOG_ERROR << "get local endpoint port: " << port_ + << " listen error: " << ec.message(); + return std::errc::io_error; + } auto end_point = acceptor_.local_endpoint(ec); if (ec) { @@ -749,11 +788,32 @@ class coro_http_server { response.set_delay(true); } + void init_address(std::string address) { + if (size_t pos = address.find(':'); pos != std::string::npos) { + auto port_sv = std::string_view(address).substr(pos + 1); + + uint16_t port; + auto [ptr, ec] = std::from_chars( + port_sv.data(), port_sv.data() + port_sv.size(), port, 10); + if (ec != std::errc{}) { + address_ = std::move(address); + return; + } + + port_ = port; + address = address.substr(0, pos); + } + + address_ = std::move(address); + } + private: std::unique_ptr pool_; asio::io_context *out_ctx_ = nullptr; std::unique_ptr> out_executor_ = nullptr; uint16_t port_; + std::string address_; + std::errc errc_ = {}; asio::ip::tcp::acceptor acceptor_; std::thread thd_; std::promise acceptor_close_waiter_; diff --git a/src/coro_rpc/tests/ServerTester.hpp b/src/coro_rpc/tests/ServerTester.hpp index 67f33b6cd..d5fb6d2c5 100644 --- a/src/coro_rpc/tests/ServerTester.hpp +++ b/src/coro_rpc/tests/ServerTester.hpp @@ -57,6 +57,7 @@ struct TesterConfig { bool sync_client; bool use_outer_io_context; unsigned short port; + std::string address = "0.0.0.0"; std::chrono::steady_clock::duration conn_timeout_duration = std::chrono::seconds(0); @@ -67,7 +68,8 @@ struct TesterConfig { << " use_ssl: " << config.use_ssl << ";" << " sync_client: " << config.sync_client << ";" << " use_outer_io_context: " << config.use_outer_io_context << ";" - << " port: " << config.port << ";"; + << " port: " << config.port << ";" + << " address: " << config.address << ";"; os << " conn_timeout_duration: "; auto val = std::chrono::duration_cast( config.conn_timeout_duration) diff --git a/src/coro_rpc/tests/test_coro_rpc_server.cpp b/src/coro_rpc/tests/test_coro_rpc_server.cpp index f40848aaf..866815c23 100644 --- a/src/coro_rpc/tests/test_coro_rpc_server.cpp +++ b/src/coro_rpc/tests/test_coro_rpc_server.cpp @@ -25,6 +25,7 @@ #include "async_simple/coro/Lazy.h" #include "doctest.h" #include "rpc_api.hpp" +#include "ylt/coro_rpc/impl/default_config/coro_rpc_config.hpp" #include "ylt/coro_rpc/impl/errno.h" #include "ylt/struct_pack.hpp" @@ -33,7 +34,7 @@ async_simple::coro::Lazy get_coro_value(int val) { co_return val; } struct CoroServerTester : ServerTester { CoroServerTester(TesterConfig config) : ServerTester(config), - server(2, config.port, config.conn_timeout_duration) { + server(2, config.port, config.address, config.conn_timeout_duration) { #ifdef YLT_ENABLE_SSL if (use_ssl) { server.init_ssl_context( @@ -47,9 +48,72 @@ struct CoroServerTester : ServerTester { async_simple::coro::Lazy get_value(int val) { co_return val; } + void test_set_server_address() { + { + coro_rpc_server server(1, 9001); + [[maybe_unused]] auto r = server.async_start(); + CHECK(!server.get_errc()); + } + + { + coro_rpc_server server(1, 9001, "0.0.0.0"); + [[maybe_unused]] auto r = server.async_start(); + CHECK(!server.get_errc()); + } + + { + coro_rpc_server server(1, 9001, "127.0.0.1"); + [[maybe_unused]] auto r = server.async_start(); + CHECK(!server.get_errc()); + } + + { + coro_rpc_server server(1, 9001, "localhost"); + [[maybe_unused]] auto r = server.async_start(); + CHECK(!server.get_errc()); + } + + { + coro_rpc_server server(1, "0.0.0.0:9001"); + [[maybe_unused]] auto r = server.async_start(); + CHECK(!server.get_errc()); + } + + { + coro_rpc_server server(1, "127.0.0.1:9001"); + [[maybe_unused]] auto r = server.async_start(); + CHECK(!server.get_errc()); + } + + { + coro_rpc_server server(1, "localhost:9001"); + [[maybe_unused]] auto r = server.async_start(); + CHECK(!server.get_errc()); + } + + { + coro_rpc_server server(1, 9001, "x.x.x.x"); + [[maybe_unused]] auto r = server.async_start(); + CHECK(server.get_errc() == coro_rpc::errc::bad_address); + } + + { + coro_rpc_server server(1, "x.x.x.x:9001"); + [[maybe_unused]] auto r = server.async_start(); + CHECK(server.get_errc() == coro_rpc::errc::bad_address); + } + + { + coro_rpc_server server(1, "127.0.0.1:aaa"); + [[maybe_unused]] auto r = server.async_start(); + CHECK(server.get_errc() == coro_rpc::errc::bad_address); + } + } + void test_all() override { g_action = {}; ELOGV(INFO, "run %s", __func__); + test_set_server_address(); test_coro_handler(); ServerTester::test_all(); test_function_not_registered();