Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 57 additions & 23 deletions lib/cpp/src/thrift/transport/TServerSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ TServerSocket::TServerSocket(int port)
listening_(false),
interruptSockWriter_(THRIFT_INVALID_SOCKET),
interruptSockReader_(THRIFT_INVALID_SOCKET),
childInterruptSockWriter_(THRIFT_INVALID_SOCKET) {
childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
boundSocketType_(SocketType::NONE) {
}

TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout)
Expand All @@ -136,7 +137,8 @@ TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout)
listening_(false),
interruptSockWriter_(THRIFT_INVALID_SOCKET),
interruptSockReader_(THRIFT_INVALID_SOCKET),
childInterruptSockWriter_(THRIFT_INVALID_SOCKET) {
childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
boundSocketType_(SocketType::NONE) {
}

TServerSocket::TServerSocket(const string& address, int port)
Expand All @@ -156,7 +158,8 @@ TServerSocket::TServerSocket(const string& address, int port)
listening_(false),
interruptSockWriter_(THRIFT_INVALID_SOCKET),
interruptSockReader_(THRIFT_INVALID_SOCKET),
childInterruptSockWriter_(THRIFT_INVALID_SOCKET) {
childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
boundSocketType_(SocketType::NONE) {
}

TServerSocket::TServerSocket(const string& path)
Expand All @@ -176,7 +179,28 @@ TServerSocket::TServerSocket(const string& path)
listening_(false),
interruptSockWriter_(THRIFT_INVALID_SOCKET),
interruptSockReader_(THRIFT_INVALID_SOCKET),
childInterruptSockWriter_(THRIFT_INVALID_SOCKET) {
childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
boundSocketType_(SocketType::NONE) {
}
TServerSocket::TServerSocket(THRIFT_SOCKET sock,SocketType socketType)
: interruptableChildren_(true),
port_(0),
path_(),
serverSocket_(sock),
acceptBacklog_(DEFAULT_BACKLOG),
sendTimeout_(0),
recvTimeout_(0),
accTimeout_(-1),
retryLimit_(0),
retryDelay_(0),
tcpSendBuffer_(0),
tcpRecvBuffer_(0),
keepAlive_(false),
listening_(false),
interruptSockWriter_(THRIFT_INVALID_SOCKET),
interruptSockReader_(THRIFT_INVALID_SOCKET),
childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
boundSocketType_(socketType) {
}

TServerSocket::~TServerSocket() {
Expand Down Expand Up @@ -439,7 +463,8 @@ void TServerSocket::listen() {
if (isUnixDomainSocket()) {
// -- Unix Domain Socket -- //

serverSocket_ = socket(PF_UNIX, SOCK_STREAM, IPPROTO_IP);
if (serverSocket_ == THRIFT_INVALID_SOCKET)
serverSocket_ = socket(PF_UNIX, SOCK_STREAM, IPPROTO_IP);

if (serverSocket_ == THRIFT_INVALID_SOCKET) {
int errno_copy = THRIFT_GET_SOCKET_ERROR;
Expand Down Expand Up @@ -471,6 +496,8 @@ void TServerSocket::listen() {
throw TTransportException(TTransportException::NOT_OPEN,
" Unix Domain socket path not supported");
#endif
} else if( boundSocketType_ != SocketType::NONE){
// -- Socket is already bound
} else {
// -- TCP socket -- //

Expand Down Expand Up @@ -516,25 +543,31 @@ void TServerSocket::listen() {
// use short circuit evaluation here to only sleep if we need to
} while ((retries++ < retryLimit_) && (THRIFT_SLEEP_SEC(retryDelay_) == 0));

// retrieve bind info
if (port_ == 0 && retries <= retryLimit_) {
struct sockaddr_storage sa;
socklen_t len = sizeof(sa);
std::memset(&sa, 0, len);
if (::getsockname(serverSocket_, reinterpret_cast<struct sockaddr*>(&sa), &len) < 0) {
errno_copy = THRIFT_GET_SOCKET_ERROR;
GlobalOutput.perror("TServerSocket::getPort() getsockname() ", errno_copy);
} // TCP socket //

// retrieve bind info
if ((port_ == 0 || path_.empty() ) && retries <= retryLimit_) {
struct sockaddr_storage sa;
socklen_t len = sizeof(sa);
std::memset(&sa, 0, len);
if (::getsockname(serverSocket_, reinterpret_cast<struct sockaddr*>(&sa), &len) < 0) {
errno_copy = THRIFT_GET_SOCKET_ERROR;
GlobalOutput.perror("TServerSocket::getPort() getsockname() ", errno_copy);
} else {
if (sa.ss_family == AF_INET6) {
const auto* sin = reinterpret_cast<const struct sockaddr_in6*>(&sa);
port_ = ntohs(sin->sin6_port);
} else if (sa.ss_family == AF_INET) {
const auto* sin = reinterpret_cast<const struct sockaddr_in*>(&sa);
port_ = ntohs(sin->sin_port);
} else if (sa.ss_family == AF_UNIX) {
const auto* sin = reinterpret_cast<const struct sockaddr_un*>(&sa);
path_ = sin->sun_path;
} else {
if (sa.ss_family == AF_INET6) {
const auto* sin = reinterpret_cast<const struct sockaddr_in6*>(&sa);
port_ = ntohs(sin->sin6_port);
} else {
const auto* sin = reinterpret_cast<const struct sockaddr_in*>(&sa);
port_ = ntohs(sin->sin_port);
}
GlobalOutput.perror("TServerSocket::getPort() getsockname() unhandled socket type",EINVAL);
}
}
} // TCP socket //
}

// throw error if socket still wasn't created successfully
if (serverSocket_ == THRIFT_INVALID_SOCKET) {
Expand Down Expand Up @@ -569,7 +602,7 @@ void TServerSocket::listen() {
listenCallback_(serverSocket_);

// Call listen
if (-1 == ::listen(serverSocket_, acceptBacklog_)) {
if (boundSocketType_ == SocketType::NONE && -1 == ::listen(serverSocket_, acceptBacklog_)) {
errno_copy = THRIFT_GET_SOCKET_ERROR;
GlobalOutput.perror("TServerSocket::listen() listen() ", errno_copy);
close();
Expand Down Expand Up @@ -734,7 +767,8 @@ void TServerSocket::close() {
concurrency::Guard g(rwMutex_);
if (serverSocket_ != THRIFT_INVALID_SOCKET) {
shutdown(serverSocket_, THRIFT_SHUT_RDWR);
::THRIFT_CLOSESOCKET(serverSocket_);
if( boundSocketType_ == SocketType::NONE) //Do not close the server socket if it owned by systemd
::THRIFT_CLOSESOCKET(serverSocket_);
}
if (interruptSockWriter_ != THRIFT_INVALID_SOCKET) {
::THRIFT_CLOSESOCKET(interruptSockWriter_);
Expand Down
16 changes: 16 additions & 0 deletions lib/cpp/src/thrift/transport/TServerSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ namespace transport {

class TSocket;

enum class SocketType {
NONE,
INET,
INET6,
UNIX
};

/**
* Server socket implementation of TServerTransport. Wrapper around a unix
* socket listen and accept calls.
Expand Down Expand Up @@ -82,6 +89,14 @@ class TServerSocket : public TServerTransport {
*/
TServerSocket(const std::string& path);

/**
* Constructor used for to initialize from an already bound unix socket.
* Useful for socket activation on systemd.
*
* @param fd
*/
TServerSocket(THRIFT_SOCKET sock,SocketType socketType);

~TServerSocket() override;


Expand Down Expand Up @@ -172,6 +187,7 @@ class TServerSocket : public TServerTransport {

socket_func_t listenCallback_;
socket_func_t acceptCallback_;
SocketType boundSocketType_;
};
}
}
Expand Down
71 changes: 67 additions & 4 deletions test/cpp/src/TestServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <thrift/server/TSimpleServer.h>
#include <thrift/server/TThreadPoolServer.h>
#include <thrift/server/TThreadedServer.h>
#include <thrift/transport/PlatformSocket.h>
#include <thrift/transport/THttpServer.h>
#include <thrift/transport/THttpTransport.h>
#include <thrift/transport/TNonblockingSSLServerSocket.h>
Expand All @@ -54,14 +55,21 @@
#ifdef HAVE_SIGNAL_H
#include <signal.h>
#endif
#ifdef HAVE_SYS_SOCKET_H
#include <sys/socket.h>
#endif
#ifdef HAVE_SYS_UN_H
#include <sys/un.h>
#endif

#include <iostream>
#include <stdexcept>
#include <memory>
#include <sstream>
#include <stdexcept>

#include <boost/algorithm/string.hpp>
#include <boost/program_options.hpp>
#include <boost/filesystem.hpp>
#include <boost/program_options.hpp>

#if _WIN32
#include <thrift/windows/TWinsockSingleton.h>
Expand Down Expand Up @@ -570,6 +578,47 @@ class TestHandlerAsync : public ThriftTestCobSvIf {
std::shared_ptr<TestHandler> _delegate;
};

struct DomainSocketFd {
THRIFT_SOCKET socket_fd;
std::string path;
DomainSocketFd(const std::string& path) : path(path) {
#ifdef HAVE_SYS_UN_H
unlink(path.c_str());
socket_fd = socket(AF_UNIX, SOCK_STREAM, IPPROTO_IP);
if (socket_fd == -1) {
std::ostringstream os;
os << "Cannot create domain socket: " << strerror(errno);
throw std::runtime_error(os.str());
}
if (path.size() > sizeof(sockaddr_un::sun_path) - 1)
throw std::runtime_error("Path size on domain socket too big");
struct sockaddr_un sa;
memset(&sa, 0, sizeof(sa));
sa.sun_family = AF_UNIX;
strcpy(sa.sun_path, path.c_str());
int rv = bind(socket_fd, (struct sockaddr*)&sa, sizeof(sa));
if (rv == -1) {
std::ostringstream os;
os << "Cannot bind domain socket: " << strerror(errno);
throw std::runtime_error(os.str());
}

rv = ::listen(socket_fd, 16);
if (rv == -1) {
std::ostringstream os;
os << "Cannot listen on domain socket: " << strerror(errno);
throw std::runtime_error(os.str());
}
#else
throw std::runtime_error("Cannot create a domain socket without AF_UNIX");
#endif
}
~DomainSocketFd() {
::THRIFT_CLOSESOCKET(socket_fd);
unlink(path.c_str());
}
};

namespace po = boost::program_options;

int main(int argc, char** argv) {
Expand All @@ -589,6 +638,8 @@ int main(int argc, char** argv) {
string server_type = "simple";
string domain_socket = "";
bool abstract_namespace = false;
bool emulate_socketactivation = false;
std::unique_ptr<DomainSocketFd> domain_socket_fd;
size_t workers = 4;
int string_limit = 0;
int container_limit = 0;
Expand All @@ -599,6 +650,7 @@ int main(int argc, char** argv) {
("port", po::value<int>(&port)->default_value(port), "Port number to listen")
("domain-socket", po::value<string>(&domain_socket) ->default_value(domain_socket), "Unix Domain Socket (e.g. /tmp/ThriftTest.thrift)")
("abstract-namespace", "Create the domain socket in the Abstract Namespace (no connection with filesystem pathnames)")
("emulate-socketactivation","Open the socket from the tester program and pass the library an already open fd")
("server-type", po::value<string>(&server_type)->default_value(server_type), "type of server, \"simple\", \"thread-pool\", \"threaded\", or \"nonblocking\"")
("transport", po::value<string>(&transport_type)->default_value(transport_type), "transport: buffered, framed, http, websocket, zlib")
("protocol", po::value<string>(&protocol_type)->default_value(protocol_type), "protocol: binary, compact, header, json, multi, multic, multih, multij")
Expand Down Expand Up @@ -678,6 +730,9 @@ int main(int argc, char** argv) {
if (vm.count("abstract-namespace")) {
abstract_namespace = true;
}
if (vm.count("emulate-socketactivation")) {
emulate_socketactivation = true;
}

// Dispatcher
std::shared_ptr<TProtocolFactory> protocolFactory;
Expand Down Expand Up @@ -727,8 +782,16 @@ int main(int argc, char** argv) {
abstract_socket += domain_socket;
serverSocket = std::shared_ptr<TServerSocket>(new TServerSocket(abstract_socket));
} else {
unlink(domain_socket.c_str());
serverSocket = std::shared_ptr<TServerSocket>(new TServerSocket(domain_socket));
if (emulate_socketactivation) {
unlink(domain_socket.c_str());
// open and bind the socket
domain_socket_fd.reset(new DomainSocketFd(domain_socket));
serverSocket = std::shared_ptr<TServerSocket>(
new TServerSocket(domain_socket_fd->socket_fd, SocketType::UNIX));
} else {
unlink(domain_socket.c_str());
serverSocket = std::shared_ptr<TServerSocket>(new TServerSocket(domain_socket));
}
}
port = 0;
} else {
Expand Down
4 changes: 2 additions & 2 deletions test/crossrunner/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def _get_domain_port(self):
return port if ok else self._get_domain_port()

def alloc_port(self, socket_type):
if socket_type in ('domain', 'abstract'):
if socket_type in ('domain', 'abstract','domain-socketactivated'):
return self._get_domain_port()
else:
return self._get_tcp_port()
Expand All @@ -323,7 +323,7 @@ def free_port(self, socket_type, port):
self._log.debug('free_port')
self._lock.acquire()
try:
if socket_type == 'domain':
if socket_type in ['domain','domain-socketactivated']:
self._dom_ports.remove(port)
path = domain_socket_path(port)
if os.path.exists(path):
Expand Down
2 changes: 2 additions & 0 deletions test/crossrunner/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,11 @@ def abs_if_exists(arg):
return cmd

def _socket_args(self, socket, port):
support_socket_activation = self.kind == 'server' and sys.platform != "win32"
return {
'ip-ssl': ['--ssl'],
'domain': ['--domain-socket=%s' % domain_socket_path(port)],
'domain-socketactivated': (['--emulate-socketactivation'] if support_socket_activation else []) + ['--domain-socket=%s' % domain_socket_path(port)],
'abstract': ['--abstract-namespace', '--domain-socket=%s' % domain_socket_path(port)],
}.get(socket, None)

Expand Down
6 changes: 3 additions & 3 deletions test/tests.json
Original file line number Diff line number Diff line change
Expand Up @@ -404,13 +404,13 @@
"buffered",
"http",
"framed",
"zlib",
"websocket"
"zlib"
],
"sockets": [
"ip",
"ip-ssl",
"domain"
"domain",
"domain-socketactivated"
],
"protocols": [
"compact",
Expand Down
Loading