Skip to content

Commit

Permalink
add more checks to a received connection request packet (#39)
Browse files Browse the repository at this point in the history
* add try_read method for PrivateConnectToken

* add constructor for PrivateConnectTokenAssociatedData

* add try_decrypt function

* add try_decrypt method for ConnectionRequestPacket

* decrypt and check decryption when accepting a ConnectionRequestPacket

* check that app_server_netcode_address is in netcode_addresses in private connect token

* check if client's netcode_address is already connected

* add client_id field to ClientSlot and check if client_id is already connected

* add struct ConnectTokenSlot and check for used connect tokens

* add try_add! method for adding client slot to room is space is available
  • Loading branch information
Sid-Bhatia-0 authored Mar 25, 2024
1 parent d806a3e commit fbf6dd2
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 20 deletions.
46 changes: 46 additions & 0 deletions netcode/serialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,49 @@ function try_read(io::IO, ::Type{ConnectionRequestPacket})
encrypted_private_connect_token_data,
)
end

function try_read(io::IO, ::Type{PrivateConnectToken})
client_id = read(io, TYPE_OF_CLIENT_ID)

timeout_seconds = read(io, TYPE_OF_TIMEOUT_SECONDS)

num_server_addresses = read(io, TYPE_OF_NUM_SERVER_ADDRESSES)
if !(1 <= num_server_addresses <= MAX_NUM_SERVER_ADDRESSES)
return nothing
end

netcode_addresses = NetcodeAddress[]

for i in 1:num_server_addresses
netcode_address = try_read(io, NetcodeAddress)
if !isnothing(netcode_address)
push!(netcode_addresses, netcode_address)
else
return nothing
end
end

client_to_server_key = read(io, SIZE_OF_KEY)

server_to_client_key = read(io, SIZE_OF_KEY)

user_data = read(io, SIZE_OF_USER_DATA)

# TODO(fix): don't read until eof, read only padding size because we can't assume the size of io
while !eof(io)
x = read(io, UInt8)
if x != 0
return nothing
end
end

return PrivateConnectToken(
client_id,
timeout_seconds,
num_server_addresses,
netcode_addresses,
client_to_server_key,
server_to_client_key,
user_data,
)
end
134 changes: 114 additions & 20 deletions netcode/simulate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ include("serialization.jl")

const NULL_NETCODE_ADDRESS = NetcodeAddress(0, 0, 0, 0)

const NULL_CLIENT_SLOT = ClientSlot(false, NULL_NETCODE_ADDRESS)
const NULL_CLIENT_SLOT = ClientSlot(false, NULL_NETCODE_ADDRESS, 0)

const PROTOCOL_ID = parse(TYPE_OF_PROTOCOL_ID, bytes2hex(SHA.sha3_256(cat(NETCODE_VERSION_INFO, Vector{UInt8}("Netcode.jl"), dims = 1)))[1:16], base = 16)

Expand All @@ -34,6 +34,10 @@ const APP_SERVER_ADDRESSES = [Sockets.InetAddr(Sockets.localhost, 10001)]

const APP_SERVER_ADDRESS = APP_SERVER_ADDRESSES[1]

const USED_CONNECT_TOKEN_HISTORY_SIZE = ROOM_SIZE

const NULL_CONNECT_TOKEN_SLOT = ConnectTokenSlot(0, UInt8[], NULL_NETCODE_ADDRESS)

@assert 1 <= length(APP_SERVER_ADDRESSES) <= MAX_NUM_SERVER_ADDRESSES

# TODO: salts must be randomly generated during user registration
Expand Down Expand Up @@ -96,13 +100,71 @@ function create_df_debug_info(debug_info)
)
end

function start_app_server(app_server_address, room_size)
function is_client_already_connected(room, client_netcode_address, client_id)
for client_slot in room
if client_slot.is_used
if client_slot.netcode_address == client_netcode_address
@info "client_netcode_address already connected"
return true
end

if client_slot.client_id == client_id
@info "client_id already connected"
return true
end
end
end

return false
end

function try_add!(used_connect_token_history::Vector{ConnectTokenSlot}, connect_token_slot::ConnectTokenSlot)
i_oldest = 1
last_seen_timestamp_oldest = used_connect_token_history[i_oldest].last_seen_timestamp

for i in axes(used_connect_token_history, 1)
if used_connect_token_history[i].hmac == connect_token_slot.hmac
if used_connect_token_history[i].netcode_address != connect_token_slot.netcode_address
return false
elseif used_connect_token_history[i].last_seen_timestamp < connect_token_slot.last_seen_timestamp
used_connect_token_history[i] = connect_token_slot
return true
end
end

if last_seen_timestamp_oldest > used_connect_token_history[i].last_seen_timestamp
i_oldest = i
last_seen_timestamp_oldest = used_connect_token_history[i].last_seen_timestamp
end
end

used_connect_token_history[i_oldest] = connect_token_slot

return true
end

function try_add!(room::Vector{ClientSlot}, client_slot::ClientSlot)
for i in axes(room, 1)
if !room[i].is_used
room[i] = client_slot
return true
end
end

return false
end

function start_app_server(app_server_address, room_size, used_connect_token_history_size)
room = fill(NULL_CLIENT_SLOT, room_size)

used_connect_token_history = fill(NULL_CONNECT_TOKEN_SLOT, used_connect_token_history_size)

socket = Sockets.UDPSocket()

Sockets.bind(socket, app_server_address.host, app_server_address.port)

app_server_netcode_address = NetcodeAddress(app_server_address)

@info "Server started listening"

while true
Expand All @@ -121,26 +183,58 @@ function start_app_server(app_server_address, room_size)
io = IOBuffer(data)

connection_request_packet = try_read(io, ConnectionRequestPacket)
if isnothing(connection_request_packet)
@info "Invalid connection request packet received"
continue
end

if !isnothing(connection_request_packet)
@info "Received PACKET_TYPE_CONNECTION_REQUEST_PACKET"
pprint(connection_request_packet)

for i in 1:room_size
if !room[i].is_used
client_slot = ClientSlot(true, NetcodeAddress(client_address))
room[i] = client_slot
@info "Client accepted" client_address
break
end
end
pprint(connection_request_packet)

if all(client_slot -> client_slot.is_used, room)
@info "Room full" app_server_address room
break
end
private_connect_token = try_decrypt(connection_request_packet, SERVER_SIDE_SHARED_KEY)
if isnothing(private_connect_token)
@info "Invalid connection request packet received"
continue
end

pprint(private_connect_token)

if !(app_server_netcode_address in private_connect_token.netcode_addresses)
@info "Invalid connection request packet received"
continue
end

client_netcode_address = NetcodeAddress(client_address)

if is_client_already_connected(room, client_netcode_address, private_connect_token.client_id)
@info "Client already connected"
continue
end

connect_token_slot = ConnectTokenSlot(time_ns(), connection_request_packet.encrypted_private_connect_token_data[end - SIZE_OF_HMAC + 1 : end], client_netcode_address)

if !try_add!(used_connect_token_history, connect_token_slot)
@info "connect token already used by another netcode_address"
continue
end

pprint(used_connect_token_history)

client_slot = ClientSlot(true, NetcodeAddress(client_address), private_connect_token.client_id)

is_client_added = try_add!(room, client_slot)

if is_client_added
@info "Client accepted" client_address
else
@info "Received malformed PACKET_TYPE_CONNECTION_REQUEST_PACKET"
@info "no empty client slots available"
continue
end

pprint(room)

if all(client_slot -> client_slot.is_used, room)
@info "Room full" app_server_address room
break
end
else
@info "Received unknown packet type"
Expand Down Expand Up @@ -253,7 +347,7 @@ if length(ARGS) == 1
if ARGS[1] == "--app_server"
@info "Running as app_server" APP_SERVER_ADDRESS AUTH_SERVER_ADDRESS

start_app_server(APP_SERVER_ADDRESS, ROOM_SIZE)
start_app_server(APP_SERVER_ADDRESS, ROOM_SIZE, USED_CONNECT_TOKEN_HISTORY_SIZE)

elseif ARGS[1] == "--auth_server"
@info "Running as auth_server" APP_SERVER_ADDRESS AUTH_SERVER_ADDRESS
Expand Down
52 changes: 52 additions & 0 deletions netcode/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ end
struct ClientSlot
is_used::Bool
netcode_address::NetcodeAddress
client_id::TYPE_OF_CLIENT_ID
end

struct ConnectTokenInfo
Expand Down Expand Up @@ -58,6 +59,12 @@ struct PrivateConnectTokenAssociatedData
expire_timestamp::TYPE_OF_TIMESTAMP
end

struct ConnectTokenSlot
last_seen_timestamp::TYPE_OF_TIMESTAMP
hmac::Vector{UInt8} # TODO(perf): can store hash of hmac instead of hmac
netcode_address::NetcodeAddress
end

abstract type AbstractPacket end

struct ConnectTokenPacket <: AbstractPacket
Expand Down Expand Up @@ -152,6 +159,14 @@ function PrivateConnectTokenAssociatedData(connect_token_info::ConnectTokenInfo)
)
end

function PrivateConnectTokenAssociatedData(connection_request_packet::ConnectionRequestPacket)
return PrivateConnectTokenAssociatedData(
connection_request_packet.netcode_version_info,
connection_request_packet.protocol_id,
connection_request_packet.expire_timestamp,
)
end

function encrypt(message, associated_data, nonce, key)
ciphertext = zeros(UInt8, length(message) + SIZE_OF_HMAC)
ciphertext_length_ref = Ref{UInt}()
Expand All @@ -164,6 +179,43 @@ function encrypt(message, associated_data, nonce, key)
return ciphertext
end

function try_decrypt(ciphertext, associated_data, nonce, key)
decrypted = zeros(UInt8, length(ciphertext) - SIZE_OF_HMAC)
decrypted_length_ref = Ref{UInt}()

decrypt_status = Sodium.LibSodium.crypto_aead_xchacha20poly1305_ietf_decrypt(decrypted, decrypted_length_ref, C_NULL, ciphertext, length(ciphertext), associated_data, length(associated_data), nonce, key)

if decrypt_status != 0
return nothing
end

@assert decrypted_length_ref[] == length(decrypted)

return decrypted
end

function try_decrypt(connection_request_packet::ConnectionRequestPacket, key)
decrypted = try_decrypt(
connection_request_packet.encrypted_private_connect_token_data,
get_serialized_data(PrivateConnectTokenAssociatedData(connection_request_packet)),
connection_request_packet.nonce,
key,
)

if isnothing(decrypted)
return nothing
end

io = IOBuffer(decrypted)

private_connect_token = try_read(io, PrivateConnectToken)
if isnothing(private_connect_token)
return nothing
end

return private_connect_token
end

function ConnectTokenPacket(connect_token_info::ConnectTokenInfo)
message = get_serialized_data(PrivateConnectToken(connect_token_info))

Expand Down

0 comments on commit fbf6dd2

Please sign in to comment.