diff --git a/01_project.sh b/01_project.sh new file mode 100644 index 0000000..ba989aa --- /dev/null +++ b/01_project.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +cmake --preset=linux-dbg -S . + diff --git a/CMakePresets.json b/CMakePresets.json index 014a716..ed76077 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -2,11 +2,20 @@ "version": 3, "configurePresets": [ { - "name": "default", - "binaryDir": "$env{BUILD_ROOT}/dns_server", + "name": "windows-dbg", + "binaryDir": "C:/build/dns_server", "cacheVariables": { - "CMAKE_TOOLCHAIN_FILE": "$env{VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake" + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_TOOLCHAIN_FILE": "C:/vcpkg/scripts/buildsystems/vcpkg.cmake" + } + }, + { + "name": "linux-dbg", + "binaryDir": "/home/sergv/Documents/build/dns_server", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_TOOLCHAIN_FILE": "/home/sergv/Documents/vcpkg/scripts/buildsystems/vcpkg.cmake" } } ] -} \ No newline at end of file +} diff --git a/libdns/CMakeLists.txt b/libdns/CMakeLists.txt index ea5cddc..940b309 100644 --- a/libdns/CMakeLists.txt +++ b/libdns/CMakeLists.txt @@ -9,12 +9,19 @@ add_library(dns STATIC dns_answer.cpp dns_answer.h dns_auth_server.cpp dns_auth_server.h dns_package.cpp dns_package.h + dns_selector.cpp dns_selector.h + dns_socket.cpp dns_socket.h dns.cpp dns.h ) + target_include_directories(dns PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) target_link_libraries(dns PRIVATE JsonCpp::JsonCpp) if(WIN32) target_link_libraries(dns PUBLIC wsock32 ws2_32) + target_sources(dns PRIVATE dns_selector_win32.cpp) +else() + target_link_libraries(dns PUBLIC pthread) + target_sources(dns PRIVATE dns_selector_posix.cpp) endif() diff --git a/libdns/dns.cpp b/libdns/dns.cpp index 45b1309..1e34e24 100644 --- a/libdns/dns.cpp +++ b/libdns/dns.cpp @@ -1,10 +1,18 @@ -#include +#if defined(_WIN32) +#include #include +#else +#include +#include +#include +#include +#include +#endif #include #include #include #include -#include +#include #include #include #include @@ -15,9 +23,12 @@ #include "dns_buffer.h" #include "dns_request.h" #include "dns_package.h" +#include "dns_selector.h" -class DNSServerImpl +class DNSServerImpl: private ISocketHandler { + friend class ISocketHandler; + struct Request { DNSRecordType type; @@ -53,15 +64,15 @@ class DNSServerImpl void closeTcpSocket(SOCKET s) { - pollfds_del(&readfds, s); - pollfds_del(&writefds, s); + selector.removeReadSocket(s); + selector.removeWriteSocket(s); closesocket(s); } void closeUdpSocket(SOCKET s) { - pollfds_del(&readfds, s); - pollfds_del(&writefds, s); + selector.removeReadSocket(s); + selector.removeWriteSocket(s); closesocket(s); } @@ -108,8 +119,8 @@ class DNSServerImpl } // now be ready to write response - pollfds_del(&readfds, s); - pollfds_add(&writefds, s); + selector.removeReadSocket(s); + selector.addWriteSocket(s); } void writeTcpSocket(SOCKET s) @@ -157,7 +168,7 @@ class DNSServerImpl void readUdpSocket(SOCKET s) { char message[UDP_SIZE] = {}; - int slen = sizeof(udp_socket_data.client); + socklen_t slen = sizeof(udp_socket_data.client); int msg_len = recvfrom(s, message, UDP_SIZE, 0, (sockaddr*)&udp_socket_data.client, &slen); if (msg_len <= 0) { @@ -169,8 +180,8 @@ class DNSServerImpl udp_socket_data.request.assign(ptr, ptr + msg_len); // now be ready to write response - pollfds_del(&readfds, s); - pollfds_add(&writefds, s); + selector.removeReadSocket(s); + selector.addWriteSocket(s); } void writeUdpSocket(SOCKET s) @@ -200,10 +211,11 @@ class DNSServerImpl sendto(s, reinterpret_cast(&buf.result[0]), bytes_to_write, 0, (sockaddr*)&udp_socket_data.client, slen); } - udp_socket_data.request.clear(); + // now be ready to read requests + selector.removeWriteSocket(s); + selector.addReadSocket(s); - pollfds_del(&writefds, s); - pollfds_add(&readfds, s); + udp_socket_data.request.clear(); } void process(const uint8_t* query, DNSBuffer& buf) @@ -257,126 +269,105 @@ class DNSServerImpl } } - static void pollfds_add(fd_set* set, SOCKET fd) + // ISocketHandler + virtual void socketReadyRead(SOCKET s) { - FD_SET(fd, set); + if (s == socket_tcp) + { + struct sockaddr_storage client_addr; + socklen_t client_addr_len = sizeof(client_addr); + SOCKET client = ::accept(s, (struct sockaddr*)&client_addr, &client_addr_len); + selector.addReadSocket(client); + setupsocket(client); + } + else if (s == socket_udp) + { + readUdpSocket(s); + } + else + { + readTcpSocket(s); + } } - static void pollfds_del(fd_set* set, SOCKET fd) + // ISocketHandler + virtual void socketReadyWrite(SOCKET s) { - FD_CLR(fd, set); + if (s == socket_udp) + { + writeUdpSocket(s); + } + else if (s != socket_tcp) + { + writeTcpSocket(s); + } } void process() { - u_long mode = 1; // 1 to enable non-blocking socket - - fd_set readfds_work; - fd_set writefds_work; - - FD_ZERO(&readfds); - FD_ZERO(&writefds); - sockaddr_in server = { 0 }; server.sin_family = AF_INET; server.sin_addr.s_addr = INADDR_ANY; server.sin_port = htons(port); // UDP socket - SOCKET socket_udp = 0; if ((socket_udp = socket(AF_INET, SOCK_DGRAM, 0)) == INVALID_SOCKET) { throw std::runtime_error("Create UDP socket failed"); } - - ioctlsocket(socket_udp, FIONBIO, &mode); + setupsocket(socket_udp); if (bind(socket_udp, (sockaddr*)&server, sizeof(server)) == SOCKET_ERROR) { + std::cerr << "Bind error:" << errno << std::endl; throw std::runtime_error("Bind UDP socket failed"); } - pollfds_add(&readfds, socket_udp); + selector.addReadSocket(socket_udp); // TCP socket - SOCKET socket_tcp = 0; if ((socket_tcp = socket(AF_INET, SOCK_STREAM, 0)) == INVALID_SOCKET) { throw std::runtime_error("Create TCP socket failed"); } - ioctlsocket(socket_tcp, FIONBIO, &mode); + setupsocket(socket_tcp); if (bind(socket_tcp, (sockaddr*)&server, sizeof(server)) == SOCKET_ERROR) { - throw std::runtime_error("Bind UDP socket failed"); + throw std::runtime_error("Bind TCP socket failed"); } listen(socket_tcp, 5); - pollfds_add(&readfds, socket_tcp); + selector.addReadSocket(socket_tcp); canExit = false; while (!canExit) { - readfds_work = readfds; - writefds_work = writefds; - - int res = select(0, &readfds_work, &writefds_work, NULL, NULL); - if (res == -1) - { - int code = WSAGetLastError(); - continue; - } - - for (auto i = 0u; i < readfds_work.fd_count; i++) - { - SOCKET fd = readfds_work.fd_array[i]; - if (fd == socket_tcp) - { - struct sockaddr_storage client_addr; - socklen_t client_addr_len = sizeof(client_addr); - SOCKET client = accept(fd, (struct sockaddr*)&client_addr, &client_addr_len); - pollfds_add(&readfds, client); - u_long mode = 1; // 1 to enable non-blocking socket - ioctlsocket(client, FIONBIO, &mode); - } - else if (fd == socket_udp) - { - readUdpSocket(fd); - } - else - { - readTcpSocket(fd); - } - } - for (auto i = 0u; i < writefds_work.fd_count; i++) - { - SOCKET fd = writefds_work.fd_array[i]; - if (fd == socket_udp) // client socket - { - writeUdpSocket(fd); - } - else if (fd != socket_tcp) - { - writeTcpSocket(fd); - } - } + selector.select(); } closeUdpSocket(socket_udp); - closeTcpSocket(socket_tcp); for (auto it = tcp_socket_data.begin(); it != tcp_socket_data.end();) { closeTcpSocket(it->first); it = tcp_socket_data.erase(it); } + closeTcpSocket(socket_tcp); } public: DNSServerImpl(const std::string& host, int port) - : wsa{0} + : selector(this) , host(host) , port(port) + , socket_udp(INVALID_SOCKET) + , socket_tcp(INVALID_SOCKET) +#ifdef _WIN32 + , wsa{0} +#endif { +#ifdef _WIN32 if (WSAStartup(MAKEWORD(2, 2), &wsa) != 0) { throw std::runtime_error("WSAStartup() failed"); } +#endif } void addRecord(DNSRecordType type, const std::string& host, const std::vector& answer) @@ -395,7 +386,7 @@ class DNSServerImpl } private: - WSADATA wsa; + DNSSelector selector; std::string host; int port; SOCKET socket_udp, socket_tcp; @@ -406,6 +397,9 @@ class DNSServerImpl fd_set writefds; std::thread thread; bool canExit; +#ifdef _WIN32 + WSADATA wsa; +#endif }; DNSServer::DNSServer(const std::string& host, int port) @@ -496,12 +490,11 @@ DNSPackage DNSClient::requestUdp(uint16_t id, DNSRecordType type, const std::str sockaddr_in server = { 0 }; server.sin_family = AF_INET; server.sin_port = htons(port); - inet_pton(AF_INET, this->host.c_str(), &server.sin_addr.S_un.S_addr); + inet_pton(AF_INET, this->host.c_str(), &server.sin_addr); int bytes_sent = sendto(s, reinterpret_cast(&buf.result[0]), static_cast(buf.result.size()), 0, reinterpret_cast(&server), static_cast(sizeof(server))); if (bytes_sent < buf.result.size()) { - int code = WSAGetLastError(); throw std::runtime_error("Error sending UDP data"); } @@ -509,7 +502,6 @@ DNSPackage DNSClient::requestUdp(uint16_t id, DNSRecordType type, const std::str int bytes_received = recvfrom(s, reinterpret_cast(&in_buf[0]), static_cast(in_buf.size()), 0, nullptr, nullptr); if (bytes_received < 0) { - int code = WSAGetLastError(); throw std::runtime_error("Error receiving UDP data"); } @@ -541,7 +533,7 @@ DNSPackage DNSClient::requestTcp(uint16_t id, DNSRecordType type, const std::str sockaddr_in server = { 0 }; server.sin_family = AF_INET; server.sin_port = htons(port); - inet_pton(AF_INET, this->host.c_str(), &server.sin_addr.S_un.S_addr); + inet_pton(AF_INET, this->host.c_str(), &server.sin_addr); int result = connect(s, reinterpret_cast(&server), sizeof(server)); if (result == SOCKET_ERROR) { @@ -551,7 +543,6 @@ DNSPackage DNSClient::requestTcp(uint16_t id, DNSRecordType type, const std::str int bytes_sent = send(s, reinterpret_cast(&buf.result[0]), static_cast(buf.result.size()), 0); if (bytes_sent < buf.result.size()) { - int code = WSAGetLastError(); throw std::runtime_error("Error sending TCP data"); } @@ -559,7 +550,6 @@ DNSPackage DNSClient::requestTcp(uint16_t id, DNSRecordType type, const std::str int bytes_received = recv(s, reinterpret_cast(&in_buf[0]), static_cast(in_buf.size()), 0); if (bytes_received < 0) { - int code = WSAGetLastError(); throw std::runtime_error("Error receiving TCP data"); } @@ -569,7 +559,6 @@ DNSPackage DNSClient::requestTcp(uint16_t id, DNSRecordType type, const std::str bytes_received = recv(s, reinterpret_cast(&in_buf[sizeof(uint16_t)]), static_cast(size), 0); if (bytes_received < 0) { - int code = WSAGetLastError(); throw std::runtime_error("Error receiving TCP data"); } @@ -590,7 +579,7 @@ bool DNSClient::command(const std::string& cmd) sockaddr_in server = {0}; server.sin_family = AF_INET; server.sin_port = htons(port); - inet_pton(AF_INET, host.c_str(), &server.sin_addr.S_un.S_addr); + inet_pton(AF_INET, host.c_str(), &server.sin_addr); int length = sendto(s, cmd.c_str(), static_cast(cmd.size()), 0, reinterpret_cast(&server), static_cast(sizeof(server))); diff --git a/libdns/dns_answer.cpp b/libdns/dns_answer.cpp index 5eeca86..b32a70a 100644 --- a/libdns/dns_answer.cpp +++ b/libdns/dns_answer.cpp @@ -2,9 +2,6 @@ #include -#include // TODO: remove -#include // TODO: remove - #include "dns_utils.h" #include "dns_buffer.h" diff --git a/libdns/dns_buffer.cpp b/libdns/dns_buffer.cpp index 35d8ef5..8929fdd 100644 --- a/libdns/dns_buffer.cpp +++ b/libdns/dns_buffer.cpp @@ -1,6 +1,10 @@ #include "dns_buffer.h" +#if defined(_WIN32) #include +#else +#include +#endif #include "dns_utils.h" diff --git a/libdns/dns_package.cpp b/libdns/dns_package.cpp index 7ccbf66..d7bdbf1 100644 --- a/libdns/dns_package.cpp +++ b/libdns/dns_package.cpp @@ -2,9 +2,6 @@ #include -#include // TODO: remove -#include // TODO: remove - #include "dns_consts.h" DNSPackage::DNSPackage(const uint8_t* data) diff --git a/libdns/dns_selector.cpp b/libdns/dns_selector.cpp new file mode 100644 index 0000000..a5081e9 --- /dev/null +++ b/libdns/dns_selector.cpp @@ -0,0 +1,21 @@ +#include "dns_selector.h" + +void DNSSelector::addReadSocket(SOCKET s) +{ + rsockets.insert(s); +} + +void DNSSelector::removeReadSocket(SOCKET s) +{ + rsockets.erase(s); +} + +void DNSSelector::addWriteSocket(SOCKET s) +{ + wsockets.insert(s); +} + +void DNSSelector::removeWriteSocket(SOCKET s) +{ + wsockets.erase(s); +} diff --git a/libdns/dns_selector.h b/libdns/dns_selector.h new file mode 100644 index 0000000..34862c5 --- /dev/null +++ b/libdns/dns_selector.h @@ -0,0 +1,31 @@ +#pragma once + +#include + +#include "dns_socket.h" + +class ISocketHandler +{ +public: + virtual void socketReadyRead(SOCKET s) = 0; + virtual void socketReadyWrite(SOCKET s) = 0; +}; + +class DNSSelector +{ +public: + DNSSelector(ISocketHandler* handler) + : handler(handler) + {} + + void addReadSocket(SOCKET s); + void removeReadSocket(SOCKET s); + void addWriteSocket(SOCKET s); + void removeWriteSocket(SOCKET s); + int select(); + +private: + ISocketHandler* handler; + std::set rsockets; + std::set wsockets; +}; diff --git a/libdns/dns_selector_posix.cpp b/libdns/dns_selector_posix.cpp new file mode 100644 index 0000000..fc74d76 --- /dev/null +++ b/libdns/dns_selector_posix.cpp @@ -0,0 +1,50 @@ +#include + +#include "dns_selector.h" + +int DNSSelector::select() +{ + fd_set rset; + FD_ZERO(&rset); + for (const auto s : rsockets) + { + FD_SET(s, &rset); + } + + fd_set wset; + FD_ZERO(&wset); + for (const auto s : wsockets) + { + FD_SET(s, &wset); + } + + int size = 0; + if (!rsockets.empty()) + { + size = std::max(size, *rsockets.rbegin() + 1); + } + if (!wsockets.empty()) + { + size = std::max(size, *wsockets.rbegin() + 1); + } + + int result = ::select(size, &rset, &wset, nullptr, nullptr); + if (result == SOCKET_ERROR) + { + return result; + } + + for (auto i = 0; i < size; i++) + { + if (FD_ISSET(i, &rset)) + { + handler->socketReadyRead(static_cast(i)); + } + if (FD_ISSET(i, &wset)) + { + handler->socketReadyWrite(static_cast(i)); + } + } + + return result; +} diff --git a/libdns/dns_selector_win32.cpp b/libdns/dns_selector_win32.cpp new file mode 100644 index 0000000..2f3e7fc --- /dev/null +++ b/libdns/dns_selector_win32.cpp @@ -0,0 +1,38 @@ +#pragma once + +#include "dns_selector.h" + +int DNSSelector::select() +{ + fd_set rset; + FD_ZERO(&rset); + for (const auto s : rsockets) + { + FD_SET(s, &rset); + } + + fd_set wset; + FD_ZERO(&wset); + for (const auto s : wsockets) + { + FD_SET(s, &wset); + } + + int result = ::select(0, &rset, &wset, nullptr, nullptr); + if (result == SOCKET_ERROR) + { + return result; + } + + for (auto i = 0u; i < rset.fd_count; i++) + { + handler->socketReadyRead(rset.fd_array[i]); + } + + for (auto i = 0u; i < wset.fd_count; i++) + { + handler->socketReadyWrite(wset.fd_array[i]); + } + + return result; +} diff --git a/libdns/dns_socket.cpp b/libdns/dns_socket.cpp new file mode 100644 index 0000000..a579574 --- /dev/null +++ b/libdns/dns_socket.cpp @@ -0,0 +1,40 @@ +#include "dns_socket.h" + +#ifndef _WIN32 +#include +#include +#include +#endif + +#ifndef _WIN32 +void closesocket(SOCKET s) +{ + close(s); +} +#endif + +bool setupsocket(SOCKET fd) +{ + if (fd < 0) return false; + +#ifdef _WIN32 + unsigned long mode = 1; + return (ioctlsocket(fd, FIONBIO, &mode) == 0); +#else + int option = 1; + if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &option, sizeof(option)) < 0) + { + return false; + } + int flags = fcntl(fd, F_GETFL, 0); + if (flags == -1) return false; + flags = flags | O_NONBLOCK; + if (fcntl(fd, F_SETFL, flags) < 0) + { + return false; + } + return true; +#endif +} + + diff --git a/libdns/dns_socket.h b/libdns/dns_socket.h new file mode 100644 index 0000000..20e7b8d --- /dev/null +++ b/libdns/dns_socket.h @@ -0,0 +1,13 @@ +#pragma once + +#ifdef _WIN32 +#include +typedef int socklen_t; +#else +typedef int SOCKET; +#define INVALID_SOCKET -1 +#define SOCKET_ERROR -1 +void closesocket(SOCKET s); +#endif + +bool setupsocket(SOCKET fd); diff --git a/libdns/dns_utils.cpp b/libdns/dns_utils.cpp index fc8ddfb..319ca82 100644 --- a/libdns/dns_utils.cpp +++ b/libdns/dns_utils.cpp @@ -1,9 +1,15 @@ #include "dns_utils.h" +#include #include #include + +#if defined(_WIN32) #include #include +#else +#include +#endif uint8_t get_uint8(const uint8_t*& data) { @@ -95,7 +101,7 @@ DNSRecordType StrToRecType(const std::string& str) {"TXT", DNSRecordType::TXT}, }; std::string strUpper{ str }; - std::transform(str.begin(), str.end(), strUpper.begin(), std::toupper); + std::transform(str.begin(), str.end(), strUpper.begin(), ::toupper); const auto iter = map.find(strUpper); return iter != map.end() ? iter->second : DNSRecordType::OTHER; } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 65996f9..1b5c502 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,6 +1,7 @@ enable_testing() find_package(GTest REQUIRED) +find_package(Threads REQUIRED) add_executable( tst_dns