From 4170d320a9c56c161018ff524ded8dac15386250 Mon Sep 17 00:00:00 2001 From: Jacquwes <38167139+Jacquwes@users.noreply.github.com> Date: Sat, 25 Mar 2023 00:01:41 +0100 Subject: [PATCH] Implement Send and Receive chat message #4 --- Server/Connection.cpp | 10 +- Server/Connection.h | 8 +- Server/ReceiveChatMessage.h | 13 ++- Server/SendChatMessage.h | 3 +- Server/Server.cpp | 4 +- Server/Server.h | 2 +- Server/ServerOnMessage.cpp | 24 ++++- ServerUnitTest/ConnectionUnitTest.cpp | 131 ++++++++++++++++++++++---- 8 files changed, 163 insertions(+), 32 deletions(-) diff --git a/Server/Connection.cpp b/Server/Connection.cpp index 7869244..0b55180 100644 --- a/Server/Connection.cpp +++ b/Server/Connection.cpp @@ -13,9 +13,10 @@ -Connection::Connection(SOCKET socket, Server& server) +Connection::Connection(SOCKET socket, Server& server, bool serverConnection) : m_server{ server } , m_socket{ socket } + , m_serverConnection{ serverConnection } { std::cout << "New client trying to connect: " << std::dec << m_id << std::endl; } @@ -77,6 +78,13 @@ AsyncOperation> Connection::ReceiveRawMessage(uint64_t cons int n = recv(m_socket, std::bit_cast(buffer.data()), static_cast(bufferSize), 0); if (n == SOCKET_ERROR) { + int error = WSAGetLastError(); + + if (error == WSAECONNRESET || error == WSAECONNABORTED && m_serverConnection) + { + co_return buffer; + } + throw ServerException{ "Failed to receive message: " + WSAGetLastError() }; } diff --git a/Server/Connection.h b/Server/Connection.h index c480121..71a71e8 100644 --- a/Server/Connection.h +++ b/Server/Connection.h @@ -17,7 +17,7 @@ class Server; class Connection : public std::enable_shared_from_this { public: - Connection(SOCKET socket, Server& server); + Connection(SOCKET socket, Server& server, bool serverConnection = true); ~Connection(); AsyncTask Listen(); @@ -36,6 +36,8 @@ class Connection : public std::enable_shared_from_this [[nodiscard]] constexpr bool const& IsDisconnecting() const noexcept { return m_disconnecting; } constexpr void SetDisconnecting(bool const& disconnecting) noexcept { m_disconnecting = disconnecting; } + AsyncTask SendAck() const; + private: AsyncOperation EstablishConnection(); @@ -43,13 +45,13 @@ class Connection : public std::enable_shared_from_this AsyncOperation Identify(); AsyncOperation ValidateConnection() const; - AsyncTask SendAck() const; - Snowflake m_id; Snowflake m_lastMessageId; Server& m_server; SOCKET m_socket; std::shared_ptr m_user{ std::make_shared() }; std::jthread m_thread; + bool m_disconnecting{ false }; + bool m_serverConnection{ true }; }; diff --git a/Server/ReceiveChatMessage.h b/Server/ReceiveChatMessage.h index 59b7700..2829e0f 100644 --- a/Server/ReceiveChatMessage.h +++ b/Server/ReceiveChatMessage.h @@ -18,18 +18,29 @@ namespace SocketMessages bool ParseBody(std::vector const& buffer) override { - if (buffer.size() != GetBodySize()) + if (buffer.size() < sizeof(m_authorUsernameLength) + UsernameMinLength + sizeof(m_chatMessageLength) + ChatMessageMinLength) + return false; + + if (buffer.size() > sizeof(m_authorUsernameLength) + UsernameMaxLength + sizeof(m_chatMessageLength) + ChatMessageMaxLength) return false; std::memcpy(std::bit_cast(&m_authorUsernameLength), buffer.data(), sizeof(m_authorUsernameLength)); + + if (buffer.size() < sizeof(m_authorUsernameLength) + m_authorUsernameLength + sizeof(m_chatMessageLength) + ChatMessageMinLength) + return false; + m_authorUsername = std::string(buffer.begin() + sizeof(m_authorUsernameLength), buffer.begin() + sizeof(m_authorUsernameLength) + m_authorUsernameLength); std::memcpy(std::bit_cast(&m_chatMessageLength), buffer.data() + sizeof(m_authorUsernameLength) + m_authorUsernameLength, sizeof(m_chatMessageLength)); + + if (buffer.size() != sizeof(m_authorUsernameLength) + m_authorUsernameLength + sizeof(m_chatMessageLength) + m_chatMessageLength) + return false; + m_chatMessage = std::string(buffer.begin() + sizeof(m_authorUsernameLength) + m_authorUsernameLength + sizeof(m_chatMessageLength), buffer.end()); diff --git a/Server/SendChatMessage.h b/Server/SendChatMessage.h index 890ceb3..21f399f 100644 --- a/Server/SendChatMessage.h +++ b/Server/SendChatMessage.h @@ -20,7 +20,8 @@ namespace SocketMessages bool ParseBody(std::vector const& buffer) override { - if (buffer.size() != GetBodySize()) + if (buffer.size() < sizeof(m_chatMessageLength) + ChatMessageMinLength + || buffer.size() > sizeof (m_chatMessageLength) + ChatMessageMaxLength) return false; std::memcpy(std::bit_cast(&m_chatMessageLength), diff --git a/Server/Server.cpp b/Server/Server.cpp index 7792df6..06f450e 100644 --- a/Server/Server.cpp +++ b/Server/Server.cpp @@ -83,7 +83,7 @@ AsyncTask Server::DisconnectClient(Snowflake clientId) return potentialClient->GetId() == clientId; }); - if (client != m_clients.end()) + if (client != m_clients.end() && !(*client)->IsDisconnecting()) { (*client)->SetDisconnecting(true); closesocket((*client)->GetSocket()); @@ -99,8 +99,6 @@ AsyncTask Server::DisconnectClient(Snowflake clientId) } m_cv.notify_all(); - lock.unlock(); - lock.release(); co_return; } diff --git a/Server/Server.h b/Server/Server.h index 2397e2a..386f5fe 100644 --- a/Server/Server.h +++ b/Server/Server.h @@ -25,7 +25,7 @@ class Server AsyncTask DisconnectClient(Snowflake clientId); AsyncTask MessageClient(std::shared_ptr const& client, std::shared_ptr const& message) const; AsyncTask OnConnect(std::shared_ptr const& client) const; - AsyncTask OnMessage(std::shared_ptr const& client, std::shared_ptr const&); + AsyncTask OnMessage(std::shared_ptr client, std::shared_ptr message); #ifndef MS_CPP_UNITTESTFRAMEWORK private: diff --git a/Server/ServerOnMessage.cpp b/Server/ServerOnMessage.cpp index bc7464a..385ec75 100644 --- a/Server/ServerOnMessage.cpp +++ b/Server/ServerOnMessage.cpp @@ -7,7 +7,7 @@ #include #include -AsyncTask Server::OnMessage(std::shared_ptr const& client, std::shared_ptr const& message) +AsyncTask Server::OnMessage(std::shared_ptr client, std::shared_ptr message) { std::cout << " Message received from client: " << std::dec << client->GetId() << std::endl; @@ -16,20 +16,36 @@ AsyncTask Server::OnMessage(std::shared_ptr const& client, std::shar using enum SocketMessages::MessageType; case AcknowledgeMessage: - [[fallthroug]] case ErrorMessage: - [[fallthroug]] case ReceiveChatMessage: std::cout << " Message type not expected from client: " << std::dec << static_cast(message->header.messageType) << std::endl; DisconnectClient(client->GetId()); break; case HelloMessage: - [[fallthroug]] case IdentifyMessage: std::cout << " Message type not expected from client now: " << std::dec << static_cast(message->header.messageType) << std::endl; break; + case SendChatMessage: + { + std::cout << " Received chat message from client: " << std::dec << client->GetId() << std::endl; + auto receiveChat = std::make_shared(); + receiveChat->SetAuthorUsername(client->GetUser()->GetUsername()); + receiveChat->SetChatMessage(std::dynamic_pointer_cast(message)->GetChatMessage()); + + co_await client->SendAck(); + + for (auto const& iteratedClient : m_clients) + { + if (iteratedClient->GetId() == client->GetId()) + continue; + co_await client->SendMessage(receiveChat); + } + + break; + } + default: std::cout << " Unknown message type: " << std::dec << static_cast(message->header.messageType) << std::endl; DisconnectClient(client->GetId()); diff --git a/ServerUnitTest/ConnectionUnitTest.cpp b/ServerUnitTest/ConnectionUnitTest.cpp index 641e411..90af926 100644 --- a/ServerUnitTest/ConnectionUnitTest.cpp +++ b/ServerUnitTest/ConnectionUnitTest.cpp @@ -51,7 +51,31 @@ namespace ServerUnitTest freeaddrinfo(result); - return std::make_shared(sendingSocket, server); + return std::make_shared(sendingSocket, server, false); + } + + AsyncTask ConnectClient(std::shared_ptr client, std::string username) const + { + std::vector token = co_await client->ReceiveRawMessage(sizeof(uint64_t)); + *std::bit_cast(token.data()) ^= 0xF007CAFEC0C0CA7E; + co_await client->SendRawMessage(token); + + // Hello + std::shared_ptr message = co_await client->ReceiveMessage(); + + message = std::make_shared(); + // Hello + co_await client->SendMessage(message); + // ACK + message = co_await client->ReceiveMessage(); + + message = std::make_shared(); + std::dynamic_pointer_cast(message)->SetUsername(username); + + // Identify + co_await client->SendMessage(message); + // ACK + message = co_await client->ReceiveMessage(); } Server server; @@ -62,11 +86,12 @@ namespace ServerUnitTest { StartServer(); - std::shared_ptr clientConnection = CreateConnection(); - std::condition_variable cv; bool finished = false; + [&]() -> AsyncTask { + std::shared_ptr clientConnection = CreateConnection(); + co_await SwitchThread(clientConnection->GetThread()); @@ -77,7 +102,6 @@ namespace ServerUnitTest *std::bit_cast(token.data()) ^= 0xF007CAFEC0C0CA7E; co_await clientConnection->SendRawMessage(token); - Logger::WriteMessage("Sent token"); @@ -118,7 +142,7 @@ namespace ServerUnitTest L"Wrong message type instead of ACK"); actualId = std::dynamic_pointer_cast(message) - ->GetAcknowledgedMessageId(); + ->GetAcknowledgedMessageId(); Assert::AreEqual(expectedId, actualId, L"Wrong message id received"); @@ -132,25 +156,96 @@ namespace ServerUnitTest co_return; }(); - std::unique_lock lock{ mutex }; - cv.wait(lock, [&finished] { return finished; }); - - lock.unlock(); - lock.release(); - - if (clientConnection->GetThread().joinable()) { - clientConnection->GetThread().request_stop(); - clientConnection->GetThread().join(); + std::unique_lock lock{ mutex }; + cv.wait(lock, [&finished] { return finished; }); } - server.Stop(); + } + + TEST_METHOD(SendChatMessage) + { + StartServer(); + + std::condition_variable cv; + std::pair finished = { false, false }; + std::pair connected = { false, false }; + + std::string chatContent = "Hello world"; + std::string senderUsername = "Sender"; + + std::mutex readyMutex; + + + + [&]() -> AsyncTask { + auto receiverConnection = CreateConnection(); + + co_await SwitchThread(receiverConnection->GetThread()); + + co_await ConnectClient(receiverConnection, "Receiver"); - if (serverThread.joinable()) - serverThread.join(); + connected.second = true; + cv.notify_all(); - return; + auto&& message = co_await receiverConnection->ReceiveMessage(); + Assert::AreEqual(static_cast(message->header.messageType), + static_cast(SocketMessages::MessageType::SendChatMessage), + L"Wrong message type received instead of SendChatMessage"); + Assert::AreEqual(std::dynamic_pointer_cast(message)->GetChatMessage(), + chatContent, + L"Wrong message received"); + Assert::AreEqual(std::dynamic_pointer_cast(message)->GetAuthorUsername(), + senderUsername, + L"Wrong message received"); + + + closesocket(receiverConnection->GetSocket()); + shutdown(receiverConnection->GetSocket(), SD_BOTH); + + finished.second = true; + cv.notify_all(); + }(); + + + + [&]() -> AsyncTask { + auto senderConnection = CreateConnection(); + + co_await SwitchThread(senderConnection->GetThread()); + + co_await ConnectClient(senderConnection, senderUsername); + + { + std::unique_lock lock{ readyMutex }; + cv.wait(lock, [&connected] { return connected.second; }); + } + + auto message = std::make_shared(); + message->SetChatMessage(chatContent); + co_await senderConnection->SendMessage(message); + + auto&& Ack = co_await senderConnection->ReceiveMessage(); + Assert::AreEqual(static_cast(Ack->header.messageType), + static_cast(SocketMessages::MessageType::AcknowledgeMessage), + L"Wrong message type received instead of ACK"); + + closesocket(senderConnection->GetSocket()); + shutdown(senderConnection->GetSocket(), SD_BOTH); + + finished.first = true; + cv.notify_all(); + }(); + + + + { + std::unique_lock lock{ mutex }; + cv.wait(lock, [&finished] { return finished.first && finished.second; }); + } + + server.Stop(); } }; } \ No newline at end of file