diff --git a/src/Driver.cpp b/src/Driver.cpp index d2e4ab8..8a86bdc 100644 --- a/src/Driver.cpp +++ b/src/Driver.cpp @@ -160,8 +160,8 @@ int Driver::getFileDescriptor() const bool Driver::isValid() const { return m_stream; } static void validateURIScheme(std::string const& scheme) { - char const* knownSchemes[6] = { "serial", "tcp", "udp", "udpserver", "file", "test" }; - for (int i = 0; i < 6; ++i) { + char const* knownSchemes[7] = { "serial", "tcp", "tcpserver", "udp", "udpserver", "file", "test" }; + for (int i = 0; i < 7; ++i) { if (scheme == knownSchemes[i]) { return; } @@ -215,6 +215,12 @@ void Driver::openURI(std::string const& uri_string) { } openTCP(uri.getHost(), uri.getPort()); } + else if (scheme == "tcpserver") { + if (uri.getPort() == 0) { + throw std::invalid_argument("missing port specification in tcp server URI"); + } + openTCPServer(uri.getPort()); + } else if (scheme == "udp") { // UDP udp://hostname:remoteport openURI_UDP(uri); } @@ -337,6 +343,9 @@ static int createIPServerSocket(const char* port, addrinfo const& hints) if (sfd == -1) continue; + int option = 1; + setsockopt(sfd, SOL_SOCKET, SO_REUSEADDR, &option, sizeof(option)); + if (::bind(sfd, rp->ai_addr, rp->ai_addrlen) == 0) { return sfd; } @@ -432,6 +441,20 @@ void Driver::openTCP(std::string const& hostname, int port){ } } +void Driver::openTCPServer(int port) { + struct addrinfo hints; + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_INET; /* Allow IPv4 or IPv6 */ + hints.ai_socktype = SOCK_STREAM; /* Datagram socket */ + hints.ai_flags = AI_PASSIVE; /* For wildcard IP address */ + + int sfd = createIPServerSocket(lexical_cast(port).c_str(), hints); + setMainStream(new TCPServerStream(sfd)); + + listen(sfd,5); + fcntl(sfd,F_SETFL,O_NONBLOCK); +} + void Driver::openUDP(std::string const& hostname, int port) { if (hostname.empty()) diff --git a/src/Driver.hpp b/src/Driver.hpp index 5dc6448..80f97cc 100644 --- a/src/Driver.hpp +++ b/src/Driver.hpp @@ -259,6 +259,11 @@ class Driver */ void openTCP(std::string const& hostname, int port); + /** + * Opens a TCP connection for one client, + */ + void openTCPServer(int port); + /** * Opens a UDP connection * diff --git a/src/IOStream.cpp b/src/IOStream.cpp index 1f9f83c..e2358c4 100644 --- a/src/IOStream.cpp +++ b/src/IOStream.cpp @@ -113,6 +113,103 @@ bool FDStream::setNonBlockingFlag(int fd) } int FDStream::getFileDescriptor() const { return m_fd; } +TCPServerStream::TCPServerStream(int socket_fd) + : FDStream(socket_fd, false) + , m_client_fd(0) +{ + m_clilen = sizeof(m_cli_addr); +} + +TCPServerStream::~TCPServerStream() { + if(m_client_fd) { + std::cout << "~TCPServerStream:: close client connection" << std::endl; + ::close(m_client_fd); + } + std::cout << "~TCPServerStream:: close server socket" << std::endl; + ::close(m_fd); +} + +int TCPServerStream::getFileDescriptor() const { return m_client_fd; } + +size_t TCPServerStream::read(uint8_t* buffer, size_t buffer_size) { + if (m_client_fd == 0) + return 0; + int c = ::read(m_client_fd, buffer, buffer_size); + + if (c > 0) + return c; + else if (c == 0) + { + m_eof = m_has_eof; + return 0; + } + else + { + if (errno == EAGAIN) + return 0; + throw UnixError("readPacket(): error reading the file descriptor"); + } +} + +size_t TCPServerStream::write(uint8_t const* buffer, size_t buffer_size) { + if (m_client_fd == 0) + return 0; + + int c = ::write(m_client_fd, buffer, buffer_size); + + if (c == -1 && errno != EAGAIN && errno != ENOBUFS) + throw UnixError("writePacket(): error during write"); + if (c == -1) + return 0; + return c; +} + +bool TCPServerStream::waitRead(base::Time const& timeout) { + return checkClientConnection(timeout); +} + +bool TCPServerStream::waitWrite(base::Time const& timeout) { + return checkClientConnection(timeout); +} + +bool TCPServerStream::checkClientConnection(base::Time const& timeout) { + fd_set set; + FD_ZERO(&set); + FD_SET(m_fd, &set); + + timeval timeout_spec = { static_cast(timeout.toSeconds()), suseconds_t(timeout.toMicroseconds() % 1000000)}; + int ret = select(m_fd + 1, &set, NULL, NULL, &timeout_spec); + if (ret < 0 && errno != EINTR) + throw UnixError("checkClientConnection(): error in select()"); + else if (ret == 0) + throw TimeoutError(TimeoutError::NONE, "waitWrite(): timeout"); + + if (!FD_ISSET(m_fd, &set)) // no new client + throw UnixError("File descriptor is not set"); + + int new_client = accept(m_fd, (struct sockaddr *) &m_cli_addr, &m_clilen); + if(new_client < 0) + throw UnixError("checkClientConnection(): error in accept()"); + + if(m_client_fd) { + std::cout << "checkClientConnection(): close the connection to the previous client, since there is a new client" << std::endl; + close(m_client_fd); + } + + std::cout << "New client is connected" << std::endl; + + setNonBlockingFlag(new_client); + + m_client_fd = new_client; + return true; +} + +bool TCPServerStream::isClientConnected() { + if (m_client_fd == 0) + return false; + return true; +} + UDPServerStream::UDPServerStream(int fd, bool auto_close) : FDStream(fd,auto_close) , m_s_len(sizeof(m_si_other)) diff --git a/src/IOStream.hpp b/src/IOStream.hpp index 9dc094d..2c66da8 100644 --- a/src/IOStream.hpp +++ b/src/IOStream.hpp @@ -70,6 +70,34 @@ namespace iodrivers_base void setAutoClose(bool flag); }; + class TCPServerStream : public FDStream + { + int m_client_fd; + + /** + * Internal members to handle the connection + */ + struct sockaddr_in m_cli_addr; + + /** + * Internal members to handle the connection + */ + socklen_t m_clilen; + + public: + TCPServerStream(int socket_fd); + ~TCPServerStream(); + int getFileDescriptor() const; + size_t read(uint8_t* buffer, size_t buffer_size); + size_t write(uint8_t const* buffer, size_t buffer_size); + virtual bool waitRead(base::Time const& timeout); + virtual bool waitWrite(base::Time const& timeout); + + bool checkClientConnection(base::Time const& timeout); + + bool isClientConnected(); + }; + class UDPServerStream : public FDStream { public: diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index fa150cf..45fd6b6 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -15,9 +15,19 @@ rock_executable(test_tcp_write test_tcp_write.cpp DEPS iodrivers_base NOINSTALL) +rock_executable(test_tcpserver_read test_tcpserver_read.cpp + DEPS iodrivers_base + NOINSTALL) + +rock_executable(test_tcpserver_write test_tcpserver_write.cpp + DEPS iodrivers_base + NOINSTALL) + rock_executable(test_udp_read test_udp_read.cpp DEPS iodrivers_base NOINSTALL) rock_executable(test_udp_write test_udp_write.cpp DEPS iodrivers_base NOINSTALL) + + diff --git a/test/test_tcpserver_read.cpp b/test/test_tcpserver_read.cpp new file mode 100644 index 0000000..df19366 --- /dev/null +++ b/test/test_tcpserver_read.cpp @@ -0,0 +1,36 @@ +#include +#include +#include + +using namespace iodrivers_base; +using std::string; + +struct DisplayDriver : public iodrivers_base::Driver +{ + DisplayDriver() + : iodrivers_base::Driver(10000) {} + int extractPacket(uint8_t const* buffer, size_t size) const + { + std::cout << iodrivers_base::Driver::printable_com(buffer, size) << std::endl; + return -size; + } +}; + +int main(int argc, char const* const* argv) +{ + if (argc < 2) + throw UnixError("to few arguments, add tcp server port"); + + string addr = string("tcpserver://localhost:") + argv[1]; + std::cout << "TCP server: " << addr << std::endl; + + DisplayDriver driver; + driver.openURI(addr); + + uint8_t buffer[10000]; + driver.setReadTimeout(base::Time::fromSeconds(60)); + + driver.readPacket(buffer, 10000); + return 0; +} + diff --git a/test/test_tcpserver_write.cpp b/test/test_tcpserver_write.cpp new file mode 100644 index 0000000..acb5a36 --- /dev/null +++ b/test/test_tcpserver_write.cpp @@ -0,0 +1,36 @@ +#include +#include +#include + +using namespace iodrivers_base; +using std::string; + +struct DisplayDriver : public iodrivers_base::Driver +{ + DisplayDriver() + : iodrivers_base::Driver(10000) {} + int extractPacket(uint8_t const* buffer, size_t size) const + { + std::cout << iodrivers_base::Driver::printable_com(buffer, size) << std::endl; + return -size; + } +}; + +int main(int argc, char const* const* argv) +{ + if (argc < 3) + throw UnixError("to few arguments, add tcp server port and message to send"); + + string addr = string("tcpserver://localhost:") + argv[1]; + std::cout << "TCP server: " << addr << std::endl; + + DisplayDriver driver; + driver.openURI(addr); + + uint8_t buffer[10000]; + driver.setWriteTimeout(base::Time::fromSeconds(60)); + + driver.writePacket(reinterpret_cast(argv[2]), strlen(argv[2])); + return 0; +} +