Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 57 additions & 23 deletions lib/cpp/src/thrift/transport/TServerSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ TServerSocket::TServerSocket(int port)
listening_(false),
interruptSockWriter_(THRIFT_INVALID_SOCKET),
interruptSockReader_(THRIFT_INVALID_SOCKET),
childInterruptSockWriter_(THRIFT_INVALID_SOCKET) {
childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
boundSocketType_(SocketType::NONE) {
}

TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout)
Expand All @@ -136,7 +137,8 @@ TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout)
listening_(false),
interruptSockWriter_(THRIFT_INVALID_SOCKET),
interruptSockReader_(THRIFT_INVALID_SOCKET),
childInterruptSockWriter_(THRIFT_INVALID_SOCKET) {
childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
boundSocketType_(SocketType::NONE) {
}

TServerSocket::TServerSocket(const string& address, int port)
Expand All @@ -156,7 +158,8 @@ TServerSocket::TServerSocket(const string& address, int port)
listening_(false),
interruptSockWriter_(THRIFT_INVALID_SOCKET),
interruptSockReader_(THRIFT_INVALID_SOCKET),
childInterruptSockWriter_(THRIFT_INVALID_SOCKET) {
childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
boundSocketType_(SocketType::NONE) {
}

TServerSocket::TServerSocket(const string& path)
Expand All @@ -176,7 +179,28 @@ TServerSocket::TServerSocket(const string& path)
listening_(false),
interruptSockWriter_(THRIFT_INVALID_SOCKET),
interruptSockReader_(THRIFT_INVALID_SOCKET),
childInterruptSockWriter_(THRIFT_INVALID_SOCKET) {
childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
boundSocketType_(SocketType::NONE) {
}
TServerSocket::TServerSocket(THRIFT_SOCKET sock,SocketType socketType)
: interruptableChildren_(true),
port_(0),
path_(),
serverSocket_(sock),
acceptBacklog_(DEFAULT_BACKLOG),
sendTimeout_(0),
recvTimeout_(0),
accTimeout_(-1),
retryLimit_(0),
retryDelay_(0),
tcpSendBuffer_(0),
tcpRecvBuffer_(0),
keepAlive_(false),
listening_(false),
interruptSockWriter_(THRIFT_INVALID_SOCKET),
interruptSockReader_(THRIFT_INVALID_SOCKET),
childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
boundSocketType_(socketType) {
}

TServerSocket::~TServerSocket() {
Expand Down Expand Up @@ -439,7 +463,8 @@ void TServerSocket::listen() {
if (isUnixDomainSocket()) {
// -- Unix Domain Socket -- //

serverSocket_ = socket(PF_UNIX, SOCK_STREAM, IPPROTO_IP);
if(serverSocket_!= THRIFT_INVALID_SOCKET)
serverSocket_ = socket(PF_UNIX, SOCK_STREAM, IPPROTO_IP);

if (serverSocket_ == THRIFT_INVALID_SOCKET) {
int errno_copy = THRIFT_GET_SOCKET_ERROR;
Expand Down Expand Up @@ -471,6 +496,8 @@ void TServerSocket::listen() {
throw TTransportException(TTransportException::NOT_OPEN,
" Unix Domain socket path not supported");
#endif
} else if( boundSocketType_ != SocketType::NONE){
// -- Socket is already bound
} else {
// -- TCP socket -- //

Expand Down Expand Up @@ -516,25 +543,31 @@ void TServerSocket::listen() {
// use short circuit evaluation here to only sleep if we need to
} while ((retries++ < retryLimit_) && (THRIFT_SLEEP_SEC(retryDelay_) == 0));

// retrieve bind info
if (port_ == 0 && retries <= retryLimit_) {
struct sockaddr_storage sa;
socklen_t len = sizeof(sa);
std::memset(&sa, 0, len);
if (::getsockname(serverSocket_, reinterpret_cast<struct sockaddr*>(&sa), &len) < 0) {
errno_copy = THRIFT_GET_SOCKET_ERROR;
GlobalOutput.perror("TServerSocket::getPort() getsockname() ", errno_copy);
} // TCP socket //

// retrieve bind info
if ((port_ == 0 || path_.empty() ) && retries <= retryLimit_) {
struct sockaddr_storage sa;
socklen_t len = sizeof(sa);
std::memset(&sa, 0, len);
if (::getsockname(serverSocket_, reinterpret_cast<struct sockaddr*>(&sa), &len) < 0) {
errno_copy = THRIFT_GET_SOCKET_ERROR;
GlobalOutput.perror("TServerSocket::getPort() getsockname() ", errno_copy);
} else {
if (sa.ss_family == AF_INET6) {
const auto* sin = reinterpret_cast<const struct sockaddr_in6*>(&sa);
port_ = ntohs(sin->sin6_port);
} else if (sa.ss_family == AF_INET) {
const auto* sin = reinterpret_cast<const struct sockaddr_in*>(&sa);
port_ = ntohs(sin->sin_port);
} else if (sa.ss_family == AF_UNIX) {
const auto* sin = reinterpret_cast<const struct sockaddr_un*>(&sa);
path_ = sin->sun_path;
} else {
if (sa.ss_family == AF_INET6) {
const auto* sin = reinterpret_cast<const struct sockaddr_in6*>(&sa);
port_ = ntohs(sin->sin6_port);
} else {
const auto* sin = reinterpret_cast<const struct sockaddr_in*>(&sa);
port_ = ntohs(sin->sin_port);
}
GlobalOutput.perror("TServerSocket::getPort() getsockname() unhandled socket type",EINVAL);
}
}
} // TCP socket //
}

// throw error if socket still wasn't created successfully
if (serverSocket_ == THRIFT_INVALID_SOCKET) {
Expand Down Expand Up @@ -569,7 +602,7 @@ void TServerSocket::listen() {
listenCallback_(serverSocket_);

// Call listen
if (-1 == ::listen(serverSocket_, acceptBacklog_)) {
if (boundSocketType_ == SocketType::NONE && -1 == ::listen(serverSocket_, acceptBacklog_)) {
errno_copy = THRIFT_GET_SOCKET_ERROR;
GlobalOutput.perror("TServerSocket::listen() listen() ", errno_copy);
close();
Expand Down Expand Up @@ -734,7 +767,8 @@ void TServerSocket::close() {
concurrency::Guard g(rwMutex_);
if (serverSocket_ != THRIFT_INVALID_SOCKET) {
shutdown(serverSocket_, THRIFT_SHUT_RDWR);
::THRIFT_CLOSESOCKET(serverSocket_);
if( boundSocketType_ == SocketType::NONE) //Do not close the server socket if it owned by systemd
::THRIFT_CLOSESOCKET(serverSocket_);
}
if (interruptSockWriter_ != THRIFT_INVALID_SOCKET) {
::THRIFT_CLOSESOCKET(interruptSockWriter_);
Expand Down
16 changes: 16 additions & 0 deletions lib/cpp/src/thrift/transport/TServerSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ namespace transport {

class TSocket;

enum class SocketType {
NONE,
INET,
INET6,
UNIX
};

/**
* Server socket implementation of TServerTransport. Wrapper around a unix
* socket listen and accept calls.
Expand Down Expand Up @@ -82,6 +89,14 @@ class TServerSocket : public TServerTransport {
*/
TServerSocket(const std::string& path);

/**
* Constructor used for to initialize from an already bound unix socket.
* Useful for socket activation on systemd.
*
* @param fd
*/
TServerSocket(THRIFT_SOCKET sock,SocketType socketType);

~TServerSocket() override;


Expand Down Expand Up @@ -172,6 +187,7 @@ class TServerSocket : public TServerTransport {

socket_func_t listenCallback_;
socket_func_t acceptCallback_;
SocketType boundSocketType_;
};
}
}
Expand Down
Loading