Skip to content

Commit

Permalink
Implement EstablishConnection #3
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacquwes committed Mar 24, 2023
1 parent 494e103 commit 98b8813
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 27 deletions.
56 changes: 38 additions & 18 deletions Server/Connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ AsyncTask Connection::SendRawMessage(std::vector<uint8_t> const& buffer) const



AsyncOperation<std::shared_ptr<SocketMessages::Message>> Connection::ReceiveMessage() const
AsyncOperation<std::shared_ptr<SocketMessages::Message>> Connection::ReceiveMessage()
{
auto message = std::make_shared<SocketMessages::Message>();

Expand All @@ -116,56 +116,68 @@ AsyncOperation<std::shared_ptr<SocketMessages::Message>> Connection::ReceiveMess

switch (header.messageType)
{
using namespace SocketMessages;
using enum SocketMessages::MessageType;

case MessageType::HelloMessage:
case AcknowledgeMessage:
message = std::make_shared<SocketMessages::AcknowledgeMessage>();

if (!std::dynamic_pointer_cast<SocketMessages::AcknowledgeMessage>(message)->ParseBody(body))
co_return message;
break;

case HelloMessage:
message = std::make_shared<SocketMessages::HelloMessage>();

if (!std::dynamic_pointer_cast<SocketMessages::HelloMessage>(message)->ParseBody(body))
co_return message;
break;

case MessageType::IdentifyMessage:
case IdentifyMessage:
message = std::make_shared<SocketMessages::IdentifyMessage>();

if (!std::dynamic_pointer_cast<SocketMessages::IdentifyMessage>(message)->ParseBody(body))
co_return message;
break;

case MessageType::KeepAliveMessage:
case KeepAliveMessage:
message = std::make_shared<SocketMessages::KeepAliveMessage>();

if (!std::dynamic_pointer_cast<SocketMessages::KeepAliveMessage>(message)->ParseBody(body))
co_return message;
break;

case MessageType::SendChatMessage:
case SendChatMessage:
message = std::make_shared<SocketMessages::SendChatMessage>();

if (!std::dynamic_pointer_cast<SocketMessages::SendChatMessage>(message)->ParseBody(body))
co_return message;
break;

case ReceiveChatMessage:
message = std::make_shared<SocketMessages::ReceiveChatMessage>();

if (!std::dynamic_pointer_cast<SocketMessages::ReceiveChatMessage>(message)->ParseBody(body))
co_return message;
break;

default:
co_return message;
break;
}


message->header = header;

m_lastMessageId = message->header.messageId;

co_return message;
}



AsyncTask Connection::SendMessage(std::shared_ptr<SocketMessages::Message> const& message) const
{
std::vector<uint8_t> buffer;

if (message->header.messageType == SocketMessages::MessageType::HelloMessage)
buffer = std::dynamic_pointer_cast<SocketMessages::HelloMessage>(message)->Serialize();
else
throw ServerException{ "Trying to send unknown message type." };
std::vector<uint8_t> buffer = message->Serialize();

co_await SendRawMessage(buffer);
}
Expand All @@ -188,6 +200,8 @@ AsyncOperation<bool> Connection::EstablishConnection()
co_return false;
}

co_await SendAck();

std::cout << " Client passed version check: " << std::dec << m_id << std::endl;

if (!(co_await Identify()))
Expand All @@ -196,20 +210,20 @@ AsyncOperation<bool> Connection::EstablishConnection()
co_return false;
}

co_await SendAck();

std::cout << " Client successfully identified in as \"" << m_user->m_username << "\": " << std::dec << m_id << std::endl;

co_return true;
}



AsyncOperation<bool> Connection::CheckVersion() const
AsyncOperation<bool> Connection::CheckVersion()
{
co_await SendMessage(std::static_pointer_cast<SocketMessages::Message>(
std::make_shared<SocketMessages::HelloMessage>()
));
co_await SendMessage(std::make_shared<SocketMessages::HelloMessage>());

std::shared_ptr<SocketMessages::Message> hello = co_await ReceiveMessage();
auto&& hello = co_await ReceiveMessage();
if (hello->header.messageType != SocketMessages::MessageType::HelloMessage)
co_return false;

Expand Down Expand Up @@ -262,4 +276,10 @@ AsyncOperation<bool> Connection::ValidateConnection() const
if (std::ranges::equal(keyBuffer, response))
co_return true;
co_return false;
}
}

AsyncTask Connection::SendAck() const
{
auto message = std::make_shared<SocketMessages::AcknowledgeMessage>(m_lastMessageId);
co_await SendMessage(message);
}
14 changes: 8 additions & 6 deletions Server/Connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ class Connection : public std::enable_shared_from_this<Connection>

AsyncTask Listen();

AsyncOperation<std::shared_ptr<SocketMessages::Message>> ReceiveMessage() const;
AsyncOperation<std::shared_ptr<SocketMessages::Message>> ReceiveMessage();
AsyncTask SendMessage(std::shared_ptr<SocketMessages::Message> const& message) const;

AsyncOperation<std::vector<uint8_t>> ReceiveRawMessage(uint64_t const& bufferSize) const;
AsyncTask SendRawMessage(std::vector<uint8_t> const& buffer) const;

[[nodiscard]] constexpr Snowflake const& GetId() const noexcept { return m_id; }
[[nodiscard]] constexpr std::jthread& GetThread() noexcept { return m_thread; }
[[nodiscard]] constexpr std::shared_ptr<User> const& GetUser() const noexcept { return m_user; }
Expand All @@ -33,18 +36,17 @@ class Connection : public std::enable_shared_from_this<Connection>
[[nodiscard]] constexpr bool const& IsDisconnecting() const noexcept { return m_disconnecting; }
constexpr void SetDisconnecting(bool const& disconnecting) noexcept { m_disconnecting = disconnecting; }

#ifndef MS_CPP_UNITTESTFRAMEWORK
private:
#endif
AsyncOperation<std::vector<uint8_t>> ReceiveRawMessage(uint64_t const& bufferSize) const;
AsyncTask SendRawMessage(std::vector<uint8_t> const& buffer) const;

AsyncOperation<bool> EstablishConnection();
AsyncOperation<bool> CheckVersion() const;
AsyncOperation<bool> CheckVersion();
AsyncOperation<bool> Identify();
AsyncOperation<bool> ValidateConnection() const;

AsyncTask SendAck() const;

Snowflake m_id;
Snowflake m_lastMessageId;
Server& m_server;
SOCKET m_socket;
std::shared_ptr<User> m_user{ std::make_shared<User>() };
Expand Down
3 changes: 2 additions & 1 deletion Server/IdentifyMessage.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ namespace SocketMessages

bool ParseBody(std::vector<uint8_t> const& buffer) override
{
if (buffer.size() != GetBodySize())
if (buffer.size() < sizeof(m_usernameLength) + UsernameMinLength
|| buffer.size() > sizeof(m_usernameLength) + UsernameMaxLength)
return false;

std::memcpy(std::bit_cast<void*>(&m_usernameLength), std::bit_cast<void*>(buffer.data()), sizeof(m_usernameLength));
Expand Down
45 changes: 43 additions & 2 deletions ServerUnitTest/ConnectionUnitTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,26 +69,67 @@ namespace ServerUnitTest
[&]() -> AsyncTask {
co_await SwitchThread(clientConnection->GetThread());


Logger::WriteMessage("Waiting for token");
std::vector<uint8_t> token = co_await clientConnection->ReceiveRawMessage(sizeof(uint64_t));
Assert::AreEqual(token.size(), sizeof(uint64_t), L"Wrong token size");
Logger::WriteMessage("Received token");

*std::bit_cast<uint64_t*>(token.data()) ^= 0xF007CAFEC0C0CA7E;

co_await clientConnection->SendRawMessage(token);
Logger::WriteMessage("Sent token");




Logger::WriteMessage("Waiting for Hello");
std::shared_ptr<SocketMessages::Message> message = co_await clientConnection->ReceiveMessage();
Assert::AreEqual(static_cast<uint8_t>(message->header.messageType),
static_cast<uint8_t>(SocketMessages::MessageType::HelloMessage),
L"Wrong message type instead of Hello");
L"Wrong message type received instead of Hello");
Logger::WriteMessage("Received Hello message");




message = std::make_shared<SocketMessages::HelloMessage>();
co_await clientConnection->SendMessage(message);
uint64_t expectedId = message->header.messageId;

co_await clientConnection->SendMessage(message);
Logger::WriteMessage("Sent Hello message");
Logger::WriteMessage("Waiting for ACK");
message = co_await clientConnection->ReceiveMessage();
Assert::AreEqual(static_cast<uint8_t>(message->header.messageType),
static_cast<uint8_t>(SocketMessages::MessageType::AcknowledgeMessage),
L"Wrong message type instead of ACK");
Logger::WriteMessage("Received ACK");

uint64_t actualId =
std::dynamic_pointer_cast<SocketMessages::AcknowledgeMessage>(message)
->GetAcknowledgedMessageId();
Assert::AreEqual(expectedId, actualId, L"Wrong message id received");




message = std::make_shared<SocketMessages::IdentifyMessage>();
std::dynamic_pointer_cast<SocketMessages::IdentifyMessage>(message)->SetUsername("Username");
expectedId = message->header.messageId;

co_await clientConnection->SendMessage(message);
Logger::WriteMessage("Sent Identify message");
message = co_await clientConnection->ReceiveMessage();
Assert::AreEqual(static_cast<uint8_t>(message->header.messageType),
static_cast<uint8_t>(SocketMessages::MessageType::AcknowledgeMessage),
L"Wrong message type instead of ACK");
Logger::WriteMessage("Received ACK");

actualId = std::dynamic_pointer_cast<SocketMessages::AcknowledgeMessage>(message)
->GetAcknowledgedMessageId();
Assert::AreEqual(expectedId, actualId, L"Wrong message id received");



closesocket(clientConnection->GetSocket());
shutdown(clientConnection->GetSocket(), SD_BOTH);
Expand Down

0 comments on commit 98b8813

Please sign in to comment.