diff --git a/lib/cpp/src/thrift/transport/TServerSocket.cpp b/lib/cpp/src/thrift/transport/TServerSocket.cpp index ffe9ed3a504..b0c9aeb8ace 100644 --- a/lib/cpp/src/thrift/transport/TServerSocket.cpp +++ b/lib/cpp/src/thrift/transport/TServerSocket.cpp @@ -117,7 +117,8 @@ TServerSocket::TServerSocket(int port) listening_(false), interruptSockWriter_(THRIFT_INVALID_SOCKET), interruptSockReader_(THRIFT_INVALID_SOCKET), - childInterruptSockWriter_(THRIFT_INVALID_SOCKET) { + childInterruptSockWriter_(THRIFT_INVALID_SOCKET), + boundSocketType_(SocketType::NONE) { } TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout) @@ -136,7 +137,8 @@ TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout) listening_(false), interruptSockWriter_(THRIFT_INVALID_SOCKET), interruptSockReader_(THRIFT_INVALID_SOCKET), - childInterruptSockWriter_(THRIFT_INVALID_SOCKET) { + childInterruptSockWriter_(THRIFT_INVALID_SOCKET), + boundSocketType_(SocketType::NONE) { } TServerSocket::TServerSocket(const string& address, int port) @@ -156,7 +158,8 @@ TServerSocket::TServerSocket(const string& address, int port) listening_(false), interruptSockWriter_(THRIFT_INVALID_SOCKET), interruptSockReader_(THRIFT_INVALID_SOCKET), - childInterruptSockWriter_(THRIFT_INVALID_SOCKET) { + childInterruptSockWriter_(THRIFT_INVALID_SOCKET), + boundSocketType_(SocketType::NONE) { } TServerSocket::TServerSocket(const string& path) @@ -176,7 +179,28 @@ TServerSocket::TServerSocket(const string& path) listening_(false), interruptSockWriter_(THRIFT_INVALID_SOCKET), interruptSockReader_(THRIFT_INVALID_SOCKET), - childInterruptSockWriter_(THRIFT_INVALID_SOCKET) { + childInterruptSockWriter_(THRIFT_INVALID_SOCKET), + boundSocketType_(SocketType::NONE) { +} +TServerSocket::TServerSocket(THRIFT_SOCKET sock,SocketType socketType) + : interruptableChildren_(true), + port_(0), + path_(), + serverSocket_(sock), + acceptBacklog_(DEFAULT_BACKLOG), + sendTimeout_(0), + recvTimeout_(0), + accTimeout_(-1), + retryLimit_(0), + retryDelay_(0), + tcpSendBuffer_(0), + tcpRecvBuffer_(0), + keepAlive_(false), + listening_(false), + interruptSockWriter_(THRIFT_INVALID_SOCKET), + interruptSockReader_(THRIFT_INVALID_SOCKET), + childInterruptSockWriter_(THRIFT_INVALID_SOCKET), + boundSocketType_(socketType) { } TServerSocket::~TServerSocket() { @@ -439,7 +463,8 @@ void TServerSocket::listen() { if (isUnixDomainSocket()) { // -- Unix Domain Socket -- // - serverSocket_ = socket(PF_UNIX, SOCK_STREAM, IPPROTO_IP); + if (serverSocket_ == THRIFT_INVALID_SOCKET) + serverSocket_ = socket(PF_UNIX, SOCK_STREAM, IPPROTO_IP); if (serverSocket_ == THRIFT_INVALID_SOCKET) { int errno_copy = THRIFT_GET_SOCKET_ERROR; @@ -471,6 +496,8 @@ void TServerSocket::listen() { throw TTransportException(TTransportException::NOT_OPEN, " Unix Domain socket path not supported"); #endif + } else if( boundSocketType_ != SocketType::NONE){ + // -- Socket is already bound } else { // -- TCP socket -- // @@ -516,25 +543,31 @@ void TServerSocket::listen() { // use short circuit evaluation here to only sleep if we need to } while ((retries++ < retryLimit_) && (THRIFT_SLEEP_SEC(retryDelay_) == 0)); - // retrieve bind info - if (port_ == 0 && retries <= retryLimit_) { - struct sockaddr_storage sa; - socklen_t len = sizeof(sa); - std::memset(&sa, 0, len); - if (::getsockname(serverSocket_, reinterpret_cast(&sa), &len) < 0) { - errno_copy = THRIFT_GET_SOCKET_ERROR; - GlobalOutput.perror("TServerSocket::getPort() getsockname() ", errno_copy); + } // TCP socket // + + // retrieve bind info + if ((port_ == 0 || path_.empty() ) && retries <= retryLimit_) { + struct sockaddr_storage sa; + socklen_t len = sizeof(sa); + std::memset(&sa, 0, len); + if (::getsockname(serverSocket_, reinterpret_cast(&sa), &len) < 0) { + errno_copy = THRIFT_GET_SOCKET_ERROR; + GlobalOutput.perror("TServerSocket::getPort() getsockname() ", errno_copy); + } else { + if (sa.ss_family == AF_INET6) { + const auto* sin = reinterpret_cast(&sa); + port_ = ntohs(sin->sin6_port); + } else if (sa.ss_family == AF_INET) { + const auto* sin = reinterpret_cast(&sa); + port_ = ntohs(sin->sin_port); + } else if (sa.ss_family == AF_UNIX) { + const auto* sin = reinterpret_cast(&sa); + path_ = sin->sun_path; } else { - if (sa.ss_family == AF_INET6) { - const auto* sin = reinterpret_cast(&sa); - port_ = ntohs(sin->sin6_port); - } else { - const auto* sin = reinterpret_cast(&sa); - port_ = ntohs(sin->sin_port); - } + GlobalOutput.perror("TServerSocket::getPort() getsockname() unhandled socket type",EINVAL); } } - } // TCP socket // + } // throw error if socket still wasn't created successfully if (serverSocket_ == THRIFT_INVALID_SOCKET) { @@ -569,7 +602,7 @@ void TServerSocket::listen() { listenCallback_(serverSocket_); // Call listen - if (-1 == ::listen(serverSocket_, acceptBacklog_)) { + if (boundSocketType_ == SocketType::NONE && -1 == ::listen(serverSocket_, acceptBacklog_)) { errno_copy = THRIFT_GET_SOCKET_ERROR; GlobalOutput.perror("TServerSocket::listen() listen() ", errno_copy); close(); @@ -734,7 +767,8 @@ void TServerSocket::close() { concurrency::Guard g(rwMutex_); if (serverSocket_ != THRIFT_INVALID_SOCKET) { shutdown(serverSocket_, THRIFT_SHUT_RDWR); - ::THRIFT_CLOSESOCKET(serverSocket_); + if( boundSocketType_ == SocketType::NONE) //Do not close the server socket if it owned by systemd + ::THRIFT_CLOSESOCKET(serverSocket_); } if (interruptSockWriter_ != THRIFT_INVALID_SOCKET) { ::THRIFT_CLOSESOCKET(interruptSockWriter_); diff --git a/lib/cpp/src/thrift/transport/TServerSocket.h b/lib/cpp/src/thrift/transport/TServerSocket.h index e826707ac55..4d43ccf5ade 100644 --- a/lib/cpp/src/thrift/transport/TServerSocket.h +++ b/lib/cpp/src/thrift/transport/TServerSocket.h @@ -40,6 +40,13 @@ namespace transport { class TSocket; +enum class SocketType { + NONE, + INET, + INET6, + UNIX +}; + /** * Server socket implementation of TServerTransport. Wrapper around a unix * socket listen and accept calls. @@ -82,6 +89,14 @@ class TServerSocket : public TServerTransport { */ TServerSocket(const std::string& path); + /** + * Constructor used for to initialize from an already bound unix socket. + * Useful for socket activation on systemd. + * + * @param fd + */ + TServerSocket(THRIFT_SOCKET sock,SocketType socketType); + ~TServerSocket() override; @@ -172,6 +187,7 @@ class TServerSocket : public TServerTransport { socket_func_t listenCallback_; socket_func_t acceptCallback_; + SocketType boundSocketType_; }; } } diff --git a/test/cpp/src/TestServer.cpp b/test/cpp/src/TestServer.cpp index 858fffa3852..dc95af6ba15 100644 --- a/test/cpp/src/TestServer.cpp +++ b/test/cpp/src/TestServer.cpp @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -54,14 +55,21 @@ #ifdef HAVE_SIGNAL_H #include #endif +#ifdef HAVE_SYS_SOCKET_H +#include +#endif +#ifdef HAVE_SYS_UN_H +#include +#endif #include -#include +#include #include +#include #include -#include #include +#include #if _WIN32 #include @@ -570,6 +578,47 @@ class TestHandlerAsync : public ThriftTestCobSvIf { std::shared_ptr _delegate; }; +struct DomainSocketFd { + THRIFT_SOCKET socket_fd; + std::string path; + DomainSocketFd(const std::string& path) : path(path) { +#ifdef HAVE_SYS_UN_H + unlink(path.c_str()); + socket_fd = socket(AF_UNIX, SOCK_STREAM, IPPROTO_IP); + if (socket_fd == -1) { + std::ostringstream os; + os << "Cannot create domain socket: " << strerror(errno); + throw std::runtime_error(os.str()); + } + if (path.size() > sizeof(sockaddr_un::sun_path) - 1) + throw std::runtime_error("Path size on domain socket too big"); + struct sockaddr_un sa; + memset(&sa, 0, sizeof(sa)); + sa.sun_family = AF_UNIX; + strcpy(sa.sun_path, path.c_str()); + int rv = bind(socket_fd, (struct sockaddr*)&sa, sizeof(sa)); + if (rv == -1) { + std::ostringstream os; + os << "Cannot bind domain socket: " << strerror(errno); + throw std::runtime_error(os.str()); + } + + rv = ::listen(socket_fd, 16); + if (rv == -1) { + std::ostringstream os; + os << "Cannot listen on domain socket: " << strerror(errno); + throw std::runtime_error(os.str()); + } +#else + throw std::runtime_error("Cannot create a domain socket without AF_UNIX"); +#endif + } + ~DomainSocketFd() { + ::THRIFT_CLOSESOCKET(socket_fd); + unlink(path.c_str()); + } +}; + namespace po = boost::program_options; int main(int argc, char** argv) { @@ -589,6 +638,8 @@ int main(int argc, char** argv) { string server_type = "simple"; string domain_socket = ""; bool abstract_namespace = false; + bool emulate_socketactivation = false; + std::unique_ptr domain_socket_fd; size_t workers = 4; int string_limit = 0; int container_limit = 0; @@ -599,6 +650,7 @@ int main(int argc, char** argv) { ("port", po::value(&port)->default_value(port), "Port number to listen") ("domain-socket", po::value(&domain_socket) ->default_value(domain_socket), "Unix Domain Socket (e.g. /tmp/ThriftTest.thrift)") ("abstract-namespace", "Create the domain socket in the Abstract Namespace (no connection with filesystem pathnames)") + ("emulate-socketactivation","Open the socket from the tester program and pass the library an already open fd") ("server-type", po::value(&server_type)->default_value(server_type), "type of server, \"simple\", \"thread-pool\", \"threaded\", or \"nonblocking\"") ("transport", po::value(&transport_type)->default_value(transport_type), "transport: buffered, framed, http, websocket, zlib") ("protocol", po::value(&protocol_type)->default_value(protocol_type), "protocol: binary, compact, header, json, multi, multic, multih, multij") @@ -678,6 +730,9 @@ int main(int argc, char** argv) { if (vm.count("abstract-namespace")) { abstract_namespace = true; } + if (vm.count("emulate-socketactivation")) { + emulate_socketactivation = true; + } // Dispatcher std::shared_ptr protocolFactory; @@ -727,8 +782,16 @@ int main(int argc, char** argv) { abstract_socket += domain_socket; serverSocket = std::shared_ptr(new TServerSocket(abstract_socket)); } else { - unlink(domain_socket.c_str()); - serverSocket = std::shared_ptr(new TServerSocket(domain_socket)); + if (emulate_socketactivation) { + unlink(domain_socket.c_str()); + // open and bind the socket + domain_socket_fd.reset(new DomainSocketFd(domain_socket)); + serverSocket = std::shared_ptr( + new TServerSocket(domain_socket_fd->socket_fd, SocketType::UNIX)); + } else { + unlink(domain_socket.c_str()); + serverSocket = std::shared_ptr(new TServerSocket(domain_socket)); + } } port = 0; } else { diff --git a/test/crossrunner/run.py b/test/crossrunner/run.py index 3ccc6e32bd3..e5324170e31 100644 --- a/test/crossrunner/run.py +++ b/test/crossrunner/run.py @@ -306,7 +306,7 @@ def _get_domain_port(self): return port if ok else self._get_domain_port() def alloc_port(self, socket_type): - if socket_type in ('domain', 'abstract'): + if socket_type in ('domain', 'abstract','domain-socketactivated'): return self._get_domain_port() else: return self._get_tcp_port() @@ -323,7 +323,7 @@ def free_port(self, socket_type, port): self._log.debug('free_port') self._lock.acquire() try: - if socket_type == 'domain': + if socket_type in ['domain','domain-socketactivated']: self._dom_ports.remove(port) path = domain_socket_path(port) if os.path.exists(path): diff --git a/test/crossrunner/test.py b/test/crossrunner/test.py index 2a1a4da7c48..3da38f4e257 100644 --- a/test/crossrunner/test.py +++ b/test/crossrunner/test.py @@ -59,9 +59,11 @@ def abs_if_exists(arg): return cmd def _socket_args(self, socket, port): + support_socket_activation = self.kind == 'server' and sys.platform != "win32" return { 'ip-ssl': ['--ssl'], 'domain': ['--domain-socket=%s' % domain_socket_path(port)], + 'domain-socketactivated': (['--emulate-socketactivation'] if support_socket_activation else []) + ['--domain-socket=%s' % domain_socket_path(port)], 'abstract': ['--abstract-namespace', '--domain-socket=%s' % domain_socket_path(port)], }.get(socket, None) diff --git a/test/tests.json b/test/tests.json index 16b47acd8ea..9731a8872df 100644 --- a/test/tests.json +++ b/test/tests.json @@ -404,13 +404,13 @@ "buffered", "http", "framed", - "zlib", - "websocket" + "zlib" ], "sockets": [ "ip", "ip-ssl", - "domain" + "domain", + "domain-socketactivated" ], "protocols": [ "compact",