Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial attempt at verifying identity on join #3059

Merged
merged 2 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 69 additions & 9 deletions lib/netplay/netplay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ static char externalIPAddress[40];
/**
* Used for connections with clients.
*/
#define NET_PING_TMP_PING_CHALLENGE_SIZE 128
static std::array<std::vector<uint8_t>, MAX_TMP_SOCKETS> tmp_challenges{};
static Socket *tmp_socket[MAX_TMP_SOCKETS] = { nullptr }; ///< Sockets used to talk to clients which have not yet been assigned a player number (host only).

static SocketSet *tmp_socket_set = nullptr;
Expand Down Expand Up @@ -3668,6 +3670,11 @@ static void NETallowJoining()
debug(LOG_ERROR, "Cannot create socket set: %s", strSockError(getSockErr()));
return;
}
// FIXME: I guess initialization of allowjoining is here now... - FlexCoral
for (auto& challenge : tmp_challenges)
{
challenge.clear();
}
}

// Find the first empty socket slot
Expand Down Expand Up @@ -3762,6 +3769,13 @@ static void NETallowJoining()

// Connection is successful.
connectFailed = false;

// Give client a challenge to solve before connecting
tmp_challenges[i].resize(NET_PING_TMP_PING_CHALLENGE_SIZE);
genSecRandomBytes(tmp_challenges[i].data(), tmp_challenges[i].size());
NETbeginEncode(NETnetTmpQueue(i), NET_PING);
NETbytes(&(tmp_challenges[i]));
NETend();
}
else
{
Expand Down Expand Up @@ -3799,6 +3813,7 @@ static void NETallowJoining()
SocketSet_DelSocket(tmp_socket_set, tmp_socket[i]);
socketClose(tmp_socket[i]);
tmp_socket[i] = nullptr;
tmp_challenges[i].clear();
}
}

Expand Down Expand Up @@ -3844,14 +3859,42 @@ static void NETallowJoining()
char ModList[modlist_string_size] = { '\0' };
char GamePassword[password_string_size] = { '\0' };
uint8_t playerType = 0;
EcKey::Key pkey;
EcKey identity;
EcKey::Sig challengeResponse;

NETbeginDecode(NETnetTmpQueue(i), NET_JOIN);
NETstring(name, sizeof(name));
NETstring(ModList, sizeof(ModList));
NETstring(GamePassword, sizeof(GamePassword));
NETuint8_t(&playerType);
NETbytes(&pkey);
NETbytes(&challengeResponse);
NETend();

identity.fromBytes(pkey, EcKey::Public);
// verify signature that player is joining with, reject him if he can not do that
if (!identity.verify(challengeResponse, tmp_challenges[i].data(), tmp_challenges[i].size()))
{
debug(LOG_ERROR, "freeing temp socket %p, couldn't create player!", static_cast<void *>(tmp_socket[i]));

rejected = ERROR_WRONGDATA;
NETbeginEncode(NETnetTmpQueue(i), NET_REJECTED);
NETuint8_t(&rejected);
NETend();
NETflush();
NETpop(NETnetTmpQueue(i));

SocketSet_DelSocket(tmp_socket_set, tmp_socket[i]);
socketClose(tmp_socket[i]);
tmp_socket[i] = nullptr;
tmp_challenges[i].clear();
sync_counter.cantjoin++;
return;
}

tmp_challenges[i].clear();

if ((playerType == NET_JOIN_SPECTATOR) || (int)NetPlay.playercount <= gamestruct.desc.dwMaxPlayers)
{
tmp = NET_CreatePlayer(name, false, (playerType == NET_JOIN_SPECTATOR));
Expand Down Expand Up @@ -3959,6 +4002,7 @@ static void NETallowJoining()
snprintf(buf, sizeof(buf), "%s[%" PRIu8 "] %s has joined, IP is: %s", pPlayerType, index, name, NetPlay.players[index].IPtextAddress);
debug(LOG_INFO, "%s", buf);
NETlogEntry(buf, SYNC_FLAG, index);
wz_command_interface_output("WZEVENT: player join: %u %s %s %s\n", i, base64Encode(pkey).c_str(), identity.publicHashString().c_str(), NetPlay.players[i].IPtextAddress);

debug(LOG_NET, "%s, %s, with index of %u has joined using socket %p", pPlayerType, name, (unsigned int)index, static_cast<void *>(connected_bsocket[index]));

Expand All @@ -3967,6 +4011,8 @@ static void NETallowJoining()

MultiPlayerJoin(index);

ingame.VerifiedIdentity[index] = true;

// Narrowcast to new player that everyone has joined.
for (j = 0; j < MAX_CONNECTED_PLAYERS; ++j)
{
Expand Down Expand Up @@ -4402,7 +4448,7 @@ bool NETfindGame(uint32_t gameId, GAMESTRUCT& output)
// ////////////////////////////////////////////////////////////////////////
// ////////////////////////////////////////////////////////////////////////
// Functions used to setup and join games.
bool NETjoinGame(const char *host, uint32_t port, const char *playername, bool asSpectator /*= false*/)
bool NETjoinGame(const char *host, uint32_t port, const char *playername, const EcKey& playerIdentity, bool asSpectator /*= false*/)
{
SocketAddress *hosts = nullptr;
unsigned int i;
Expand Down Expand Up @@ -4496,14 +4542,7 @@ bool NETjoinGame(const char *host, uint32_t port, const char *playername, bool a
socketBeginCompression(bsocket);

uint8_t playerType = (!asSpectator) ? NET_JOIN_PLAYER : NET_JOIN_SPECTATOR;

// Send a join message to the host
NETbeginEncode(NETnetQueue(NET_HOST_ONLY), NET_JOIN);
NETstring(playername, 64);
NETstring(getModList().c_str(), modlist_string_size);
NETstring(NetPlay.gamePassword, sizeof(NetPlay.gamePassword));
NETuint8_t(&playerType);
NETend();

if (bsocket == nullptr)
{
return false; // Connection dropped while sending NET_JOIN.
Expand Down Expand Up @@ -4592,6 +4631,27 @@ bool NETjoinGame(const char *host, uint32_t port, const char *playername, bool a
NETclose();
return false;
}
else if (type == NET_PING)
{
std::vector<uint8_t> challenge(NET_PING_TMP_PING_CHALLENGE_SIZE, 0);
NETbeginDecode(NETnetQueue(NET_HOST_ONLY), NET_PING);
NETbytes(&challenge, NET_PING_TMP_PING_CHALLENGE_SIZE * 4);
NETend();
NETpop(queue);

EcKey::Sig challengeResponse = playerIdentity.sign(challenge.data(), challenge.size());
EcKey::Key identity = playerIdentity.toBytes(EcKey::Public);

NETbeginEncode(NETnetQueue(NET_HOST_ONLY), NET_JOIN);
NETstring(playername, 64);
NETstring(getModList().c_str(), modlist_string_size);
NETstring(NetPlay.gamePassword, sizeof(NetPlay.gamePassword));
NETuint8_t(&playerType);
NETbytes(&identity);
NETbytes(&challengeResponse);
NETend();
NETflush();
}
else
{
debug(LOG_ERROR, "Unexpected %s.", messageTypeToString(type));
Expand Down
2 changes: 1 addition & 1 deletion lib/netplay/netplay.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ bool NEThaltJoining(); // stop new players joining this game
bool NETenumerateGames(const std::function<bool (const GAMESTRUCT& game)>& handleEnumerateGameFunc);
bool NETfindGames(std::vector<GAMESTRUCT>& results, size_t startingIndex, size_t resultsLimit, bool onlyMatchingLocalVersion = false);
bool NETfindGame(uint32_t gameId, GAMESTRUCT& output);
bool NETjoinGame(const char *host, uint32_t port, const char *playername, bool asSpectator = false); // join game given with playername
bool NETjoinGame(const char *host, uint32_t port, const char *playername, const EcKey& playerIdentity, bool asSpectator = false); // join game given with playername
bool NEThostGame(const char *SessionName, const char *PlayerName, bool spectatorHost, // host a game
uint32_t gameType, uint32_t two, uint32_t three, uint32_t four, UDWORD plyrs);
bool NETchangePlayerName(UDWORD player, char *newName);// change a players name.
Expand Down
2 changes: 1 addition & 1 deletion lib/netplay/nettypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ void NETbytes(std::vector<uint8_t> *vec, unsigned maxLen)

if (len > maxLen)
{
debug(LOG_ERROR, "NETstring: %s packet, length %u truncated at %u", NETgetPacketDir() == PACKET_ENCODE ? "Encoding" : "Decoding", len, maxLen);
debug(LOG_ERROR, "NETbytes: %s packet, length %u truncated at %u", NETgetPacketDir() == PACKET_ENCODE ? "Encoding" : "Decoding", len, maxLen);
}

len = std::min<unsigned>(len, maxLen); // Truncate length if necessary.
Expand Down
4 changes: 2 additions & 2 deletions src/multiint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1049,13 +1049,14 @@ static JoinGameResult joinGameInternalConnect(const char *host, uint32_t port, s
{
// oldUI may get captured for use in the password dialog, among other things.
PLAYERSTATS playerStats;
loadMultiStats(sPlayer, &playerStats);

if (ingame.localJoiningInProgress)
{
return JoinGameResult::FAILED;
}

if (!NETjoinGame(host, port, (char *)sPlayer, asSpectator)) // join
if (!NETjoinGame(host, port, (char *)sPlayer, playerStats.identity, asSpectator)) // join
{
switch (getLobbyError())
{
Expand Down Expand Up @@ -1084,7 +1085,6 @@ static JoinGameResult joinGameInternalConnect(const char *host, uint32_t port, s
}
ingame.localJoiningInProgress = true;

loadMultiStats(sPlayer, &playerStats);
setMultiStats(selectedPlayer, playerStats, false);
setMultiStats(selectedPlayer, playerStats, true);

Expand Down
Loading