From 806601a925d0d92855ead903622be949c64d0bcd Mon Sep 17 00:00:00 2001 From: Kait <39479354+katrinafyi@users.noreply.github.com> Date: Mon, 2 Dec 2024 09:13:37 +1000 Subject: [PATCH] query: allow multiple session IDs per connection (#2) --- src/query/packets.go | 13 +++++++------ src/query/socket.go | 6 +++--- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/query/packets.go b/src/query/packets.go index 23976b1..901f95e 100644 --- a/src/query/packets.go +++ b/src/query/packets.go @@ -5,17 +5,18 @@ import ( "errors" "fmt" "io" + "net" "main/src/util" "math/rand" "strconv" "strings" ) -func writeHandshakePacket(w io.Writer, sessionID int32) error { +func writeHandshakePacket(w io.Writer, addr net.Addr, sessionID int32) error { challengeToken := strconv.FormatInt(int64(rand.Int31()), 10) sessionsMutex.Lock() - sessions[sessionID] = challengeToken + sessions[addr.String()] = challengeToken sessionsMutex.Unlock() // Type - byte @@ -36,13 +37,13 @@ func writeHandshakePacket(w io.Writer, sessionID int32) error { return nil } -func readRequestPacket(r io.Reader, w io.Writer, sessionID int32) (bool, error) { +func readRequestPacket(r io.Reader, w io.Writer, addr net.Addr, sessionID int32) (bool, error) { sessionsMutex.Lock() defer sessionsMutex.Unlock() - if _, ok := sessions[sessionID]; !ok { - return false, fmt.Errorf("query: invalid or expired session ID: %X", sessionID) + if _, ok := sessions[addr.String()]; !ok { + return false, fmt.Errorf("query: no currently active challenges for %s", addr.String()) } // Challenge Token - int32 @@ -53,7 +54,7 @@ func readRequestPacket(r io.Reader, w io.Writer, sessionID int32) (bool, error) return false, err } - if sessions[sessionID] != strconv.FormatInt(int64(challengeToken), 10) { + if sessions[addr.String()] != strconv.FormatInt(int64(challengeToken), 10) { return false, fmt.Errorf("query: received challenge token did not match stored") } } diff --git a/src/query/socket.go b/src/query/socket.go index e9dda83..033ae89 100644 --- a/src/query/socket.go +++ b/src/query/socket.go @@ -13,7 +13,7 @@ import ( var ( socket net.PacketConn = nil conf *config.Config = nil - sessions map[int32]string = make(map[int32]string) + sessions map[string]string = make(map[string]string) // Map of net.Addr.String() to challenge string sessionsMutex *sync.Mutex = &sync.Mutex{} ) @@ -76,7 +76,7 @@ func handlePacket(data []byte, addr net.Addr) { switch packetType { case 0x09: // Generate challenge token { - if err = writeHandshakePacket(buf, sessionID); err != nil { + if err = writeHandshakePacket(buf, addr, sessionID); err != nil { return } @@ -84,7 +84,7 @@ func handlePacket(data []byte, addr net.Addr) { } case 0x00: // Request { - isFullStat, err := readRequestPacket(r, buf, sessionID) + isFullStat, err := readRequestPacket(r, buf, addr, sessionID) if err != nil { return