From 98b8813b49c2e69d792c660e264b1c7961359a95 Mon Sep 17 00:00:00 2001 From: Jacquwes <38167139+Jacquwes@users.noreply.github.com> Date: Fri, 24 Mar 2023 14:33:58 +0100 Subject: [PATCH] Implement EstablishConnection #3 --- Server/Connection.cpp | 56 ++++++++++++++++++--------- Server/Connection.h | 14 ++++--- Server/IdentifyMessage.h | 3 +- ServerUnitTest/ConnectionUnitTest.cpp | 45 ++++++++++++++++++++- 4 files changed, 91 insertions(+), 27 deletions(-) diff --git a/Server/Connection.cpp b/Server/Connection.cpp index 622fa61..4f92182 100644 --- a/Server/Connection.cpp +++ b/Server/Connection.cpp @@ -99,7 +99,7 @@ AsyncTask Connection::SendRawMessage(std::vector const& buffer) const -AsyncOperation> Connection::ReceiveMessage() const +AsyncOperation> Connection::ReceiveMessage() { auto message = std::make_shared(); @@ -116,43 +116,60 @@ AsyncOperation> Connection::ReceiveMess switch (header.messageType) { - using namespace SocketMessages; + using enum SocketMessages::MessageType; - case MessageType::HelloMessage: + case AcknowledgeMessage: + message = std::make_shared(); + + if (!std::dynamic_pointer_cast(message)->ParseBody(body)) + co_return message; + break; + + case HelloMessage: message = std::make_shared(); if (!std::dynamic_pointer_cast(message)->ParseBody(body)) co_return message; break; - case MessageType::IdentifyMessage: + case IdentifyMessage: message = std::make_shared(); if (!std::dynamic_pointer_cast(message)->ParseBody(body)) co_return message; break; - case MessageType::KeepAliveMessage: + case KeepAliveMessage: message = std::make_shared(); if (!std::dynamic_pointer_cast(message)->ParseBody(body)) co_return message; break; - case MessageType::SendChatMessage: + case SendChatMessage: message = std::make_shared(); if (!std::dynamic_pointer_cast(message)->ParseBody(body)) co_return message; break; + case ReceiveChatMessage: + message = std::make_shared(); + + if (!std::dynamic_pointer_cast(message)->ParseBody(body)) + co_return message; + break; + default: co_return message; break; } + message->header = header; + m_lastMessageId = message->header.messageId; + co_return message; } @@ -160,12 +177,7 @@ AsyncOperation> Connection::ReceiveMess AsyncTask Connection::SendMessage(std::shared_ptr const& message) const { - std::vector buffer; - - if (message->header.messageType == SocketMessages::MessageType::HelloMessage) - buffer = std::dynamic_pointer_cast(message)->Serialize(); - else - throw ServerException{ "Trying to send unknown message type." }; + std::vector buffer = message->Serialize(); co_await SendRawMessage(buffer); } @@ -188,6 +200,8 @@ AsyncOperation Connection::EstablishConnection() co_return false; } + co_await SendAck(); + std::cout << " Client passed version check: " << std::dec << m_id << std::endl; if (!(co_await Identify())) @@ -196,6 +210,8 @@ AsyncOperation Connection::EstablishConnection() co_return false; } + co_await SendAck(); + std::cout << " Client successfully identified in as \"" << m_user->m_username << "\": " << std::dec << m_id << std::endl; co_return true; @@ -203,13 +219,11 @@ AsyncOperation Connection::EstablishConnection() -AsyncOperation Connection::CheckVersion() const +AsyncOperation Connection::CheckVersion() { - co_await SendMessage(std::static_pointer_cast( - std::make_shared() - )); + co_await SendMessage(std::make_shared()); - std::shared_ptr hello = co_await ReceiveMessage(); + auto&& hello = co_await ReceiveMessage(); if (hello->header.messageType != SocketMessages::MessageType::HelloMessage) co_return false; @@ -262,4 +276,10 @@ AsyncOperation Connection::ValidateConnection() const if (std::ranges::equal(keyBuffer, response)) co_return true; co_return false; -} \ No newline at end of file +} + +AsyncTask Connection::SendAck() const +{ + auto message = std::make_shared(m_lastMessageId); + co_await SendMessage(message); +} diff --git a/Server/Connection.h b/Server/Connection.h index e791633..c480121 100644 --- a/Server/Connection.h +++ b/Server/Connection.h @@ -22,9 +22,12 @@ class Connection : public std::enable_shared_from_this AsyncTask Listen(); - AsyncOperation> ReceiveMessage() const; + AsyncOperation> ReceiveMessage(); AsyncTask SendMessage(std::shared_ptr const& message) const; + AsyncOperation> ReceiveRawMessage(uint64_t const& bufferSize) const; + AsyncTask SendRawMessage(std::vector const& buffer) const; + [[nodiscard]] constexpr Snowflake const& GetId() const noexcept { return m_id; } [[nodiscard]] constexpr std::jthread& GetThread() noexcept { return m_thread; } [[nodiscard]] constexpr std::shared_ptr const& GetUser() const noexcept { return m_user; } @@ -33,18 +36,17 @@ class Connection : public std::enable_shared_from_this [[nodiscard]] constexpr bool const& IsDisconnecting() const noexcept { return m_disconnecting; } constexpr void SetDisconnecting(bool const& disconnecting) noexcept { m_disconnecting = disconnecting; } -#ifndef MS_CPP_UNITTESTFRAMEWORK private: -#endif - AsyncOperation> ReceiveRawMessage(uint64_t const& bufferSize) const; - AsyncTask SendRawMessage(std::vector const& buffer) const; AsyncOperation EstablishConnection(); - AsyncOperation CheckVersion() const; + AsyncOperation CheckVersion(); AsyncOperation Identify(); AsyncOperation ValidateConnection() const; + AsyncTask SendAck() const; + Snowflake m_id; + Snowflake m_lastMessageId; Server& m_server; SOCKET m_socket; std::shared_ptr m_user{ std::make_shared() }; diff --git a/Server/IdentifyMessage.h b/Server/IdentifyMessage.h index b9caf03..4167ae7 100644 --- a/Server/IdentifyMessage.h +++ b/Server/IdentifyMessage.h @@ -19,7 +19,8 @@ namespace SocketMessages bool ParseBody(std::vector const& buffer) override { - if (buffer.size() != GetBodySize()) + if (buffer.size() < sizeof(m_usernameLength) + UsernameMinLength + || buffer.size() > sizeof(m_usernameLength) + UsernameMaxLength) return false; std::memcpy(std::bit_cast(&m_usernameLength), std::bit_cast(buffer.data()), sizeof(m_usernameLength)); diff --git a/ServerUnitTest/ConnectionUnitTest.cpp b/ServerUnitTest/ConnectionUnitTest.cpp index 1c548f1..4cc048a 100644 --- a/ServerUnitTest/ConnectionUnitTest.cpp +++ b/ServerUnitTest/ConnectionUnitTest.cpp @@ -69,26 +69,67 @@ namespace ServerUnitTest [&]() -> AsyncTask { co_await SwitchThread(clientConnection->GetThread()); + + Logger::WriteMessage("Waiting for token"); std::vector token = co_await clientConnection->ReceiveRawMessage(sizeof(uint64_t)); Assert::AreEqual(token.size(), sizeof(uint64_t), L"Wrong token size"); + Logger::WriteMessage("Received token"); *std::bit_cast(token.data()) ^= 0xF007CAFEC0C0CA7E; co_await clientConnection->SendRawMessage(token); + Logger::WriteMessage("Sent token"); + + + + Logger::WriteMessage("Waiting for Hello"); std::shared_ptr message = co_await clientConnection->ReceiveMessage(); Assert::AreEqual(static_cast(message->header.messageType), static_cast(SocketMessages::MessageType::HelloMessage), - L"Wrong message type instead of Hello"); + L"Wrong message type received instead of Hello"); + Logger::WriteMessage("Received Hello message"); + + + message = std::make_shared(); - co_await clientConnection->SendMessage(message); + uint64_t expectedId = message->header.messageId; + co_await clientConnection->SendMessage(message); + Logger::WriteMessage("Sent Hello message"); + Logger::WriteMessage("Waiting for ACK"); message = co_await clientConnection->ReceiveMessage(); + Assert::AreEqual(static_cast(message->header.messageType), + static_cast(SocketMessages::MessageType::AcknowledgeMessage), + L"Wrong message type instead of ACK"); + Logger::WriteMessage("Received ACK"); + + uint64_t actualId = + std::dynamic_pointer_cast(message) + ->GetAcknowledgedMessageId(); + Assert::AreEqual(expectedId, actualId, L"Wrong message id received"); + + + + + message = std::make_shared(); + std::dynamic_pointer_cast(message)->SetUsername("Username"); + expectedId = message->header.messageId; + co_await clientConnection->SendMessage(message); + Logger::WriteMessage("Sent Identify message"); + message = co_await clientConnection->ReceiveMessage(); Assert::AreEqual(static_cast(message->header.messageType), static_cast(SocketMessages::MessageType::AcknowledgeMessage), L"Wrong message type instead of ACK"); + Logger::WriteMessage("Received ACK"); + + actualId = std::dynamic_pointer_cast(message) + ->GetAcknowledgedMessageId(); + Assert::AreEqual(expectedId, actualId, L"Wrong message id received"); + + closesocket(clientConnection->GetSocket()); shutdown(clientConnection->GetSocket(), SD_BOTH);