diff --git a/include/ylt/coro_rpc/impl/coro_rpc_server.hpp b/include/ylt/coro_rpc/impl/coro_rpc_server.hpp index 6296036b2..9b550a26f 100644 --- a/include/ylt/coro_rpc/impl/coro_rpc_server.hpp +++ b/include/ylt/coro_rpc/impl/coro_rpc_server.hpp @@ -71,20 +71,25 @@ class coro_rpc_server_base { * default no timeout. */ coro_rpc_server_base(size_t thread_num, unsigned short port, + std::string address = "0.0.0.0", std::chrono::steady_clock::duration conn_timeout_duration = std::chrono::seconds(0)) : pool_(thread_num), acceptor_(pool_.get_executor()->get_asio_executor()), port_(port), conn_timeout_duration_(conn_timeout_duration), - flag_{stat::init} {} + flag_{stat::init} { + init_address(std::move(address)); + } coro_rpc_server_base(const server_config &config = server_config{}) : pool_(config.thread_num), acceptor_(pool_.get_executor()->get_asio_executor()), port_(config.port), conn_timeout_duration_(config.conn_timeout_duration), - flag_{stat::init} {} + flag_{stat::init} { + init_address(config.address); + } ~coro_rpc_server_base() { ELOGV(INFO, "coro_rpc_server will quit"); @@ -118,7 +123,6 @@ class coro_rpc_server_base { [[nodiscard]] coro_rpc::expected, coro_rpc::err_code> async_start() noexcept { - coro_rpc::err_code ec{}; { std::unique_lock lock(start_mtx_); if (flag_ != stat::init) { @@ -131,8 +135,8 @@ class coro_rpc_server_base { return coro_rpc::unexpected{ coro_rpc::errc::server_has_ran}; } - ec = listen(); - if (!ec) { + errc_ = listen(); + if (!errc_) { if constexpr (requires(typename server_config::executor_pool_t & pool) { pool.run(); }) { @@ -146,12 +150,13 @@ class coro_rpc_server_base { flag_ = stat::stop; } } - if (!ec) { + if (!errc_) { async_simple::Promise promise; auto future = promise.getFuture(); - accept().start([p = std::move(promise)](auto &&res) mutable { + accept().start([this, p = std::move(promise)](auto &&res) mutable { if (res.hasError()) { - p.setValue(coro_rpc::err_code{coro_rpc::errc::io_error}); + errc_ = coro_rpc::err_code{coro_rpc::errc::io_error}; + p.setValue(errc_); } else { p.setValue(res.value()); @@ -160,7 +165,7 @@ class coro_rpc_server_base { return std::move(future); } else { - return coro_rpc::unexpected{ec}; + return coro_rpc::unexpected{errc_}; } } @@ -207,6 +212,8 @@ class coro_rpc_server_base { * @return */ uint16_t port() const { return port_; }; + std::string_view address() const { return address_; } + coro_rpc::err_code get_errc() const { return errc_; } /*! * Register RPC service functions (member function) @@ -288,12 +295,23 @@ class coro_rpc_server_base { coro_rpc::err_code listen() { ELOGV(INFO, "begin to listen"); using asio::ip::tcp; - auto endpoint = tcp::endpoint(tcp::v4(), port_); - acceptor_.open(endpoint.protocol()); + asio::error_code ec; + auto addr = asio::ip::address::from_string(address_, ec); + if (ec) { + ELOGV(ERROR, "resolve address %s error : %s", address_.data(), + ec.message().data()); + return coro_rpc::errc::bad_address; + } + + auto endpoint = tcp::endpoint(addr, port_); + acceptor_.open(endpoint.protocol(), ec); + if (ec) { + ELOGV(ERROR, "open failed, error : %s", ec.message().data()); + return coro_rpc::errc::open_error; + } #ifdef __GNUC__ - acceptor_.set_option(tcp::acceptor::reuse_address(true)); + acceptor_.set_option(tcp::acceptor::reuse_address(true), ec); #endif - asio::error_code ec; acceptor_.bind(endpoint, ec); if (ec) { ELOGV(ERROR, "bind port %d error : %s", port_.load(), @@ -305,7 +323,14 @@ class coro_rpc_server_base { #ifdef _MSC_VER acceptor_.set_option(tcp::acceptor::reuse_address(true)); #endif - acceptor_.listen(); + acceptor_.listen(asio::socket_base::max_listen_connections, ec); + if (ec) { + ELOGV(ERROR, "port %d listen error : %s", port_.load(), + ec.message().data()); + acceptor_.cancel(ec); + acceptor_.close(ec); + return coro_rpc::errc::listen_error; + } auto end_point = acceptor_.local_endpoint(ec); if (ec) { @@ -383,6 +408,20 @@ class coro_rpc_server_base { acceptor_close_waiter_.get_future().wait(); } + bool iequal(std::string_view a, std::string_view b) { + return std::equal(a.begin(), a.end(), b.begin(), b.end(), + [](char a, char b) { + return tolower(a) == tolower(b); + }); + } + + void init_address(std::string address) { + if (iequal(address, "localhost")) { + address = "127.0.0.1"; + } + address_ = std::move(address); + } + typename server_config::executor_pool_t pool_; asio::ip::tcp::acceptor acceptor_; std::promise acceptor_close_waiter_; @@ -398,6 +437,8 @@ class coro_rpc_server_base { typename server_config::rpc_protocol::router router_; std::atomic port_; + std::string address_; + coro_rpc::err_code errc_ = {}; std::chrono::steady_clock::duration conn_timeout_duration_; #ifdef YLT_ENABLE_SSL diff --git a/include/ylt/coro_rpc/impl/errno.h b/include/ylt/coro_rpc/impl/errno.h index d785b30db..5514da5c3 100644 --- a/include/ylt/coro_rpc/impl/errno.h +++ b/include/ylt/coro_rpc/impl/errno.h @@ -25,6 +25,9 @@ enum class errc : uint16_t { timed_out, invalid_argument, address_in_use, + bad_address, + open_error, + listen_error, operation_canceled, interrupted, function_not_registered, @@ -47,6 +50,12 @@ inline constexpr std::string_view make_error_message(errc ec) noexcept { return "invalid_argument"; case errc::address_in_use: return "address_in_use"; + case errc::bad_address: + return "bad_address"; + case errc::open_error: + return "open_error"; + case errc::listen_error: + return "listen_error"; case errc::operation_canceled: return "operation_canceled"; case errc::interrupted: diff --git a/include/ylt/standalone/cinatra/coro_http_server.hpp b/include/ylt/standalone/cinatra/coro_http_server.hpp index 9f9ae0b97..68006adae 100644 --- a/include/ylt/standalone/cinatra/coro_http_server.hpp +++ b/include/ylt/standalone/cinatra/coro_http_server.hpp @@ -27,16 +27,21 @@ enum class file_resp_format_type { }; class coro_http_server { public: - coro_http_server(asio::io_context &ctx, unsigned short port) - : out_ctx_(&ctx), port_(port), acceptor_(ctx), check_timer_(ctx) {} + coro_http_server(asio::io_context &ctx, unsigned short port, + std::string address = "0.0.0.0") + : out_ctx_(&ctx), port_(port), acceptor_(ctx), check_timer_(ctx) { + init_address(address); + } coro_http_server(size_t thread_num, unsigned short port, - bool cpu_affinity = false) + std::string address = "0.0.0.0", bool cpu_affinity = false) : pool_(std::make_unique(thread_num, cpu_affinity)), port_(port), acceptor_(pool_->get_executor()->get_asio_executor()), - check_timer_(pool_->get_executor()->get_asio_executor()) {} + check_timer_(pool_->get_executor()->get_asio_executor()) { + init_address(address); + } ~coro_http_server() { CINATRA_LOG_INFO << "coro_http_server will quit"; @@ -64,21 +69,22 @@ class coro_http_server { // only call once, not thread safe. async_simple::Future async_start() { - auto ec = listen(); + errc_ = listen(); async_simple::Promise promise; auto future = promise.getFuture(); - if (ec == std::errc{}) { + if (errc_ == std::errc{}) { if (out_ctx_ == nullptr) { thd_ = std::thread([this] { pool_->run(); }); } - accept().start([p = std::move(promise)](auto &&res) mutable { + accept().start([p = std::move(promise), this](auto &&res) mutable { if (res.hasError()) { - p.setValue(std::errc::io_error); + errc_ = std::errc::io_error; + p.setValue(errc_); } else { p.setValue(res.value()); @@ -86,7 +92,7 @@ class coro_http_server { }); } else { - promise.setValue(ec); + promise.setValue(errc_); } return future; @@ -150,7 +156,7 @@ class coro_http_server { static_assert(std::is_member_function_pointer_v, "must be member function"); using return_type = typename util::function_traits::return_type; - if constexpr (is_lazy_v) { + if constexpr (coro_io::is_lazy_v) { std::function(coro_http_request & req, coro_http_response & resp)> f = std::bind(handler, &owner, std::placeholders::_1, @@ -488,16 +494,31 @@ class coro_http_server { return connections_.size(); } + std::string_view address() { return address_; } + std::errc get_errc() { return errc_; } + private: std::errc listen() { CINATRA_LOG_INFO << "begin to listen"; using asio::ip::tcp; - auto endpoint = tcp::endpoint(tcp::v4(), port_); - acceptor_.open(endpoint.protocol()); + asio::error_code ec; + auto addr = asio::ip::address::from_string(address_, ec); + if (ec) { + CINATRA_LOG_ERROR << "bad address: " << address_ + << " error: " << ec.message(); + return std::errc::bad_address; + } + + auto endpoint = tcp::endpoint(addr, port_); + acceptor_.open(endpoint.protocol(), ec); + if (ec) { + CINATRA_LOG_ERROR << "acceptor open failed" + << " error: " << ec.message(); + return std::errc::io_error; + } #ifdef __GNUC__ - acceptor_.set_option(tcp::acceptor::reuse_address(true)); + acceptor_.set_option(tcp::acceptor::reuse_address(true), ec); #endif - asio::error_code ec; acceptor_.bind(endpoint, ec); if (ec) { CINATRA_LOG_ERROR << "bind port: " << port_ << " error: " << ec.message(); @@ -508,7 +529,12 @@ class coro_http_server { #ifdef _MSC_VER acceptor_.set_option(tcp::acceptor::reuse_address(true)); #endif - acceptor_.listen(); + acceptor_.listen(asio::socket_base::max_listen_connections, ec); + if (ec) { + CINATRA_LOG_ERROR << "get local endpoint port: " << port_ + << " listen error: " << ec.message(); + return std::errc::io_error; + } auto end_point = acceptor_.local_endpoint(ec); if (ec) { @@ -749,11 +775,20 @@ class coro_http_server { response.set_delay(true); } + void init_address(std::string &address) { + if (iequal0(address, "localhost")) { + address = "127.0.0.1"; + } + address_ = std::move(address); + } + private: std::unique_ptr pool_; asio::io_context *out_ctx_ = nullptr; std::unique_ptr> out_executor_ = nullptr; uint16_t port_; + std::string address_; + std::errc errc_ = {}; asio::ip::tcp::acceptor acceptor_; std::thread thd_; std::promise acceptor_close_waiter_; diff --git a/src/coro_rpc/tests/ServerTester.hpp b/src/coro_rpc/tests/ServerTester.hpp index 67f33b6cd..d5fb6d2c5 100644 --- a/src/coro_rpc/tests/ServerTester.hpp +++ b/src/coro_rpc/tests/ServerTester.hpp @@ -57,6 +57,7 @@ struct TesterConfig { bool sync_client; bool use_outer_io_context; unsigned short port; + std::string address = "0.0.0.0"; std::chrono::steady_clock::duration conn_timeout_duration = std::chrono::seconds(0); @@ -67,7 +68,8 @@ struct TesterConfig { << " use_ssl: " << config.use_ssl << ";" << " sync_client: " << config.sync_client << ";" << " use_outer_io_context: " << config.use_outer_io_context << ";" - << " port: " << config.port << ";"; + << " port: " << config.port << ";" + << " address: " << config.address << ";"; os << " conn_timeout_duration: "; auto val = std::chrono::duration_cast( config.conn_timeout_duration) diff --git a/src/coro_rpc/tests/test_coro_rpc_server.cpp b/src/coro_rpc/tests/test_coro_rpc_server.cpp index f40848aaf..ed26d3158 100644 --- a/src/coro_rpc/tests/test_coro_rpc_server.cpp +++ b/src/coro_rpc/tests/test_coro_rpc_server.cpp @@ -25,6 +25,7 @@ #include "async_simple/coro/Lazy.h" #include "doctest.h" #include "rpc_api.hpp" +#include "ylt/coro_rpc/impl/default_config/coro_rpc_config.hpp" #include "ylt/coro_rpc/impl/errno.h" #include "ylt/struct_pack.hpp" @@ -33,7 +34,7 @@ async_simple::coro::Lazy get_coro_value(int val) { co_return val; } struct CoroServerTester : ServerTester { CoroServerTester(TesterConfig config) : ServerTester(config), - server(2, config.port, config.conn_timeout_duration) { + server(2, config.port, config.address, config.conn_timeout_duration) { #ifdef YLT_ENABLE_SSL if (use_ssl) { server.init_ssl_context( @@ -47,9 +48,42 @@ struct CoroServerTester : ServerTester { async_simple::coro::Lazy get_value(int val) { co_return val; } + void test_set_server_address() { + { + coro_rpc_server server(1, 9001); + [[maybe_unused]] auto r = server.async_start(); + CHECK(!server.get_errc()); + } + + { + coro_rpc_server server(1, 9001, "0.0.0.0"); + [[maybe_unused]] auto r = server.async_start(); + CHECK(!server.get_errc()); + } + + { + coro_rpc_server server(1, 9001, "127.0.0.1"); + [[maybe_unused]] auto r = server.async_start(); + CHECK(!server.get_errc()); + } + + { + coro_rpc_server server(1, 9001, "localhost"); + [[maybe_unused]] auto r = server.async_start(); + CHECK(!server.get_errc()); + } + + { + coro_rpc_server server(1, 9001, "x.x.x"); + [[maybe_unused]] auto r = server.async_start(); + CHECK(server.get_errc() == coro_rpc::errc::bad_address); + } + } + void test_all() override { g_action = {}; ELOGV(INFO, "run %s", __func__); + test_set_server_address(); test_coro_handler(); ServerTester::test_all(); test_function_not_registered();