Skip to content

Commit

Permalink
Implement Send and Receive chat message #4
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacquwes committed Mar 24, 2023
1 parent c57cf71 commit 4170d32
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 32 deletions.
10 changes: 9 additions & 1 deletion Server/Connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@



Connection::Connection(SOCKET socket, Server& server)
Connection::Connection(SOCKET socket, Server& server, bool serverConnection)
: m_server{ server }
, m_socket{ socket }
, m_serverConnection{ serverConnection }
{
std::cout << "New client trying to connect: " << std::dec << m_id << std::endl;
}
Expand Down Expand Up @@ -77,6 +78,13 @@ AsyncOperation<std::vector<uint8_t>> Connection::ReceiveRawMessage(uint64_t cons
int n = recv(m_socket, std::bit_cast<char*>(buffer.data()), static_cast<int>(bufferSize), 0);
if (n == SOCKET_ERROR)
{
int error = WSAGetLastError();

if (error == WSAECONNRESET || error == WSAECONNABORTED && m_serverConnection)
{
co_return buffer;
}

throw ServerException{ "Failed to receive message: " + WSAGetLastError() };
}

Expand Down
8 changes: 5 additions & 3 deletions Server/Connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Server;
class Connection : public std::enable_shared_from_this<Connection>
{
public:
Connection(SOCKET socket, Server& server);
Connection(SOCKET socket, Server& server, bool serverConnection = true);
~Connection();

AsyncTask Listen();
Expand All @@ -36,20 +36,22 @@ class Connection : public std::enable_shared_from_this<Connection>
[[nodiscard]] constexpr bool const& IsDisconnecting() const noexcept { return m_disconnecting; }
constexpr void SetDisconnecting(bool const& disconnecting) noexcept { m_disconnecting = disconnecting; }

AsyncTask SendAck() const;

private:

AsyncOperation<bool> EstablishConnection();
AsyncOperation<bool> CheckVersion();
AsyncOperation<bool> Identify();
AsyncOperation<bool> ValidateConnection() const;

AsyncTask SendAck() const;

Snowflake m_id;
Snowflake m_lastMessageId;
Server& m_server;
SOCKET m_socket;
std::shared_ptr<User> m_user{ std::make_shared<User>() };
std::jthread m_thread;

bool m_disconnecting{ false };
bool m_serverConnection{ true };
};
13 changes: 12 additions & 1 deletion Server/ReceiveChatMessage.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,29 @@ namespace SocketMessages

bool ParseBody(std::vector<uint8_t> const& buffer) override
{
if (buffer.size() != GetBodySize())
if (buffer.size() < sizeof(m_authorUsernameLength) + UsernameMinLength + sizeof(m_chatMessageLength) + ChatMessageMinLength)
return false;

if (buffer.size() > sizeof(m_authorUsernameLength) + UsernameMaxLength + sizeof(m_chatMessageLength) + ChatMessageMaxLength)
return false;

std::memcpy(std::bit_cast<uint8_t*>(&m_authorUsernameLength),
buffer.data(),
sizeof(m_authorUsernameLength));

if (buffer.size() < sizeof(m_authorUsernameLength) + m_authorUsernameLength + sizeof(m_chatMessageLength) + ChatMessageMinLength)
return false;

m_authorUsername = std::string(buffer.begin() + sizeof(m_authorUsernameLength),
buffer.begin() + sizeof(m_authorUsernameLength) + m_authorUsernameLength);

std::memcpy(std::bit_cast<uint16_t*>(&m_chatMessageLength),
buffer.data() + sizeof(m_authorUsernameLength) + m_authorUsernameLength,
sizeof(m_chatMessageLength));

if (buffer.size() != sizeof(m_authorUsernameLength) + m_authorUsernameLength + sizeof(m_chatMessageLength) + m_chatMessageLength)
return false;

m_chatMessage = std::string(buffer.begin() + sizeof(m_authorUsernameLength) + m_authorUsernameLength + sizeof(m_chatMessageLength),
buffer.end());

Expand Down
3 changes: 2 additions & 1 deletion Server/SendChatMessage.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ namespace SocketMessages

bool ParseBody(std::vector<uint8_t> const& buffer) override
{
if (buffer.size() != GetBodySize())
if (buffer.size() < sizeof(m_chatMessageLength) + ChatMessageMinLength
|| buffer.size() > sizeof (m_chatMessageLength) + ChatMessageMaxLength)
return false;

std::memcpy(std::bit_cast<uint8_t*>(&m_chatMessageLength),
Expand Down
4 changes: 1 addition & 3 deletions Server/Server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ AsyncTask Server::DisconnectClient(Snowflake clientId)
return potentialClient->GetId() == clientId;
});

if (client != m_clients.end())
if (client != m_clients.end() && !(*client)->IsDisconnecting())
{
(*client)->SetDisconnecting(true);
closesocket((*client)->GetSocket());
Expand All @@ -99,8 +99,6 @@ AsyncTask Server::DisconnectClient(Snowflake clientId)
}

m_cv.notify_all();
lock.unlock();
lock.release();

co_return;
}
Expand Down
2 changes: 1 addition & 1 deletion Server/Server.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Server
AsyncTask DisconnectClient(Snowflake clientId);
AsyncTask MessageClient(std::shared_ptr<Connection> const& client, std::shared_ptr<SocketMessages::Message> const& message) const;
AsyncTask OnConnect(std::shared_ptr<Connection> const& client) const;
AsyncTask OnMessage(std::shared_ptr<Connection> const& client, std::shared_ptr<SocketMessages::Message> const&);
AsyncTask OnMessage(std::shared_ptr<Connection> client, std::shared_ptr<SocketMessages::Message> message);

#ifndef MS_CPP_UNITTESTFRAMEWORK
private:
Expand Down
24 changes: 20 additions & 4 deletions Server/ServerOnMessage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <iostream>
#include <memory>

AsyncTask Server::OnMessage(std::shared_ptr<Connection> const& client, std::shared_ptr<SocketMessages::Message> const& message)
AsyncTask Server::OnMessage(std::shared_ptr<Connection> client, std::shared_ptr<SocketMessages::Message> message)
{
std::cout << " Message received from client: " << std::dec << client->GetId() << std::endl;

Expand All @@ -16,20 +16,36 @@ AsyncTask Server::OnMessage(std::shared_ptr<Connection> const& client, std::shar
using enum SocketMessages::MessageType;

case AcknowledgeMessage:
[[fallthroug]]
case ErrorMessage:
[[fallthroug]]
case ReceiveChatMessage:
std::cout << " Message type not expected from client: " << std::dec << static_cast<int>(message->header.messageType) << std::endl;
DisconnectClient(client->GetId());
break;

case HelloMessage:
[[fallthroug]]
case IdentifyMessage:
std::cout << " Message type not expected from client now: " << std::dec << static_cast<int>(message->header.messageType) << std::endl;
break;

case SendChatMessage:
{
std::cout << " Received chat message from client: " << std::dec << client->GetId() << std::endl;
auto receiveChat = std::make_shared<SocketMessages::ReceiveChatMessage>();
receiveChat->SetAuthorUsername(client->GetUser()->GetUsername());
receiveChat->SetChatMessage(std::dynamic_pointer_cast<SocketMessages::SendChatMessage>(message)->GetChatMessage());

co_await client->SendAck();

for (auto const& iteratedClient : m_clients)
{
if (iteratedClient->GetId() == client->GetId())
continue;
co_await client->SendMessage(receiveChat);
}

break;
}

default:
std::cout << " Unknown message type: " << std::dec << static_cast<int>(message->header.messageType) << std::endl;
DisconnectClient(client->GetId());
Expand Down
131 changes: 113 additions & 18 deletions ServerUnitTest/ConnectionUnitTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,31 @@ namespace ServerUnitTest

freeaddrinfo(result);

return std::make_shared<Connection>(sendingSocket, server);
return std::make_shared<Connection>(sendingSocket, server, false);
}

AsyncTask ConnectClient(std::shared_ptr<Connection> client, std::string username) const
{
std::vector<uint8_t> token = co_await client->ReceiveRawMessage(sizeof(uint64_t));
*std::bit_cast<uint64_t*>(token.data()) ^= 0xF007CAFEC0C0CA7E;
co_await client->SendRawMessage(token);

// Hello
std::shared_ptr<SocketMessages::Message> message = co_await client->ReceiveMessage();

message = std::make_shared<SocketMessages::HelloMessage>();
// Hello
co_await client->SendMessage(message);
// ACK
message = co_await client->ReceiveMessage();

message = std::make_shared<SocketMessages::IdentifyMessage>();
std::dynamic_pointer_cast<SocketMessages::IdentifyMessage>(message)->SetUsername(username);

// Identify
co_await client->SendMessage(message);
// ACK
message = co_await client->ReceiveMessage();
}

Server server;
Expand All @@ -62,11 +86,12 @@ namespace ServerUnitTest
{
StartServer();

std::shared_ptr<Connection> clientConnection = CreateConnection();

std::condition_variable cv;
bool finished = false;

[&]() -> AsyncTask {
std::shared_ptr<Connection> clientConnection = CreateConnection();

co_await SwitchThread(clientConnection->GetThread());


Expand All @@ -77,7 +102,6 @@ namespace ServerUnitTest
*std::bit_cast<uint64_t*>(token.data()) ^= 0xF007CAFEC0C0CA7E;

co_await clientConnection->SendRawMessage(token);
Logger::WriteMessage("Sent token");



Expand Down Expand Up @@ -118,7 +142,7 @@ namespace ServerUnitTest
L"Wrong message type instead of ACK");

actualId = std::dynamic_pointer_cast<SocketMessages::AcknowledgeMessage>(message)
->GetAcknowledgedMessageId();
->GetAcknowledgedMessageId();
Assert::AreEqual(expectedId, actualId, L"Wrong message id received");


Expand All @@ -132,25 +156,96 @@ namespace ServerUnitTest
co_return;
}();

std::unique_lock lock{ mutex };
cv.wait(lock, [&finished] { return finished; });

lock.unlock();
lock.release();

if (clientConnection->GetThread().joinable())
{
clientConnection->GetThread().request_stop();
clientConnection->GetThread().join();
std::unique_lock lock{ mutex };
cv.wait(lock, [&finished] { return finished; });
}


server.Stop();
}

TEST_METHOD(SendChatMessage)
{
StartServer();

std::condition_variable cv;
std::pair<bool, bool> finished = { false, false };
std::pair<bool, bool> connected = { false, false };

std::string chatContent = "Hello world";
std::string senderUsername = "Sender";

std::mutex readyMutex;



[&]() -> AsyncTask {
auto receiverConnection = CreateConnection();

co_await SwitchThread(receiverConnection->GetThread());

co_await ConnectClient(receiverConnection, "Receiver");

if (serverThread.joinable())
serverThread.join();
connected.second = true;
cv.notify_all();

return;
auto&& message = co_await receiverConnection->ReceiveMessage();
Assert::AreEqual(static_cast<uint8_t>(message->header.messageType),
static_cast<uint8_t>(SocketMessages::MessageType::SendChatMessage),
L"Wrong message type received instead of SendChatMessage");
Assert::AreEqual(std::dynamic_pointer_cast<SocketMessages::ReceiveChatMessage>(message)->GetChatMessage(),
chatContent,
L"Wrong message received");
Assert::AreEqual(std::dynamic_pointer_cast<SocketMessages::ReceiveChatMessage>(message)->GetAuthorUsername(),
senderUsername,
L"Wrong message received");


closesocket(receiverConnection->GetSocket());
shutdown(receiverConnection->GetSocket(), SD_BOTH);

finished.second = true;
cv.notify_all();
}();



[&]() -> AsyncTask {
auto senderConnection = CreateConnection();

co_await SwitchThread(senderConnection->GetThread());

co_await ConnectClient(senderConnection, senderUsername);

{
std::unique_lock lock{ readyMutex };
cv.wait(lock, [&connected] { return connected.second; });
}

auto message = std::make_shared<SocketMessages::SendChatMessage>();
message->SetChatMessage(chatContent);
co_await senderConnection->SendMessage(message);

auto&& Ack = co_await senderConnection->ReceiveMessage();
Assert::AreEqual(static_cast<uint8_t>(Ack->header.messageType),
static_cast<uint8_t>(SocketMessages::MessageType::AcknowledgeMessage),
L"Wrong message type received instead of ACK");

closesocket(senderConnection->GetSocket());
shutdown(senderConnection->GetSocket(), SD_BOTH);

finished.first = true;
cv.notify_all();
}();



{
std::unique_lock lock{ mutex };
cv.wait(lock, [&finished] { return finished.first && finished.second; });
}

server.Stop();
}
};
}

0 comments on commit 4170d32

Please sign in to comment.