Skip to content

Commit

Permalink
query: allow multiple session IDs per connection (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
katrinafyi authored Dec 1, 2024
1 parent d86be25 commit 806601a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
13 changes: 7 additions & 6 deletions src/query/packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@ import (
"errors"
"fmt"
"io"
"net"
"main/src/util"
"math/rand"
"strconv"
"strings"
)

func writeHandshakePacket(w io.Writer, sessionID int32) error {
func writeHandshakePacket(w io.Writer, addr net.Addr, sessionID int32) error {
challengeToken := strconv.FormatInt(int64(rand.Int31()), 10)

sessionsMutex.Lock()
sessions[sessionID] = challengeToken
sessions[addr.String()] = challengeToken
sessionsMutex.Unlock()

// Type - byte
Expand All @@ -36,13 +37,13 @@ func writeHandshakePacket(w io.Writer, sessionID int32) error {
return nil
}

func readRequestPacket(r io.Reader, w io.Writer, sessionID int32) (bool, error) {
func readRequestPacket(r io.Reader, w io.Writer, addr net.Addr, sessionID int32) (bool, error) {
sessionsMutex.Lock()

defer sessionsMutex.Unlock()

if _, ok := sessions[sessionID]; !ok {
return false, fmt.Errorf("query: invalid or expired session ID: %X", sessionID)
if _, ok := sessions[addr.String()]; !ok {
return false, fmt.Errorf("query: no currently active challenges for %s", addr.String())
}

// Challenge Token - int32
Expand All @@ -53,7 +54,7 @@ func readRequestPacket(r io.Reader, w io.Writer, sessionID int32) (bool, error)
return false, err
}

if sessions[sessionID] != strconv.FormatInt(int64(challengeToken), 10) {
if sessions[addr.String()] != strconv.FormatInt(int64(challengeToken), 10) {
return false, fmt.Errorf("query: received challenge token did not match stored")
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/query/socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
var (
socket net.PacketConn = nil
conf *config.Config = nil
sessions map[int32]string = make(map[int32]string)
sessions map[string]string = make(map[string]string) // Map of net.Addr.String() to challenge string
sessionsMutex *sync.Mutex = &sync.Mutex{}
)

Expand Down Expand Up @@ -76,15 +76,15 @@ func handlePacket(data []byte, addr net.Addr) {
switch packetType {
case 0x09: // Generate challenge token
{
if err = writeHandshakePacket(buf, sessionID); err != nil {
if err = writeHandshakePacket(buf, addr, sessionID); err != nil {
return
}

break
}
case 0x00: // Request
{
isFullStat, err := readRequestPacket(r, buf, sessionID)
isFullStat, err := readRequestPacket(r, buf, addr, sessionID)

if err != nil {
return
Expand Down

0 comments on commit 806601a

Please sign in to comment.