diff --git a/common/strings.go b/common/strings.go index eba2dcc..0d41d97 100644 --- a/common/strings.go +++ b/common/strings.go @@ -2,6 +2,7 @@ package common import ( "bytes" + "errors" "math/rand" ) @@ -25,9 +26,14 @@ func RandomHexString(n int) string { return string(b) } -func GetString(buf []byte) string { +func GetString(buf []byte) (string, error) { nullTerminator := bytes.IndexByte(buf, 0) - return string(buf[:nullTerminator]) + + if nullTerminator == -1 { + return "", errors.New("buf is not null-terminated") + } + + return string(buf[:nullTerminator]), nil } // IsUppercaseAlphanumeric checks if the given string is composed exclusively of uppercase alphanumeric characters. diff --git a/natneg/main.go b/natneg/main.go index f14b25c..e644d5c 100644 --- a/natneg/main.go +++ b/natneg/main.go @@ -100,7 +100,7 @@ func StartServer() { func handleConnection(conn net.PacketConn, addr net.Addr, buffer []byte) { // Validate the packet magic - if !bytes.Equal(buffer[:6], []byte{0xfd, 0xfc, 0x1e, 0x66, 0x6a, 0xb2}) { + if len(buffer) < 12 || !bytes.Equal(buffer[:6], []byte{0xfd, 0xfc, 0x1e, 0x66, 0x6a, 0xb2}) { logging.Error("NATNEG:"+addr.String(), "Invalid packet header") return } @@ -232,12 +232,21 @@ func getPortTypeName(portType byte) string { } func (session *NATNEGSession) handleInit(conn net.PacketConn, addr net.Addr, buffer []byte, moduleName string, version byte) { + if len(buffer) < 10 { + logging.Error(moduleName, "Invalid packet size") + return + } + portType := buffer[0] clientIndex := buffer[1] useGamePort := buffer[2] localIPBytes := buffer[3:7] localPort := binary.BigEndian.Uint16(buffer[7:9]) - gameName := common.GetString(buffer[9:]) + gameName, err := common.GetString(buffer[9:]) + if err != nil { + logging.Error(moduleName, "Invalid gameName") + return + } expectedSize := 9 + len(gameName) + 1 if len(buffer) != expectedSize { diff --git a/serverbrowser/server.go b/serverbrowser/server.go index 0df413c..e81b5c1 100644 --- a/serverbrowser/server.go +++ b/serverbrowser/server.go @@ -2,6 +2,7 @@ package serverbrowser import ( "encoding/binary" + "errors" "fmt" "github.com/logrusorgru/aurora/v3" "net" @@ -39,29 +40,79 @@ const ( LimitResultCountOption = 1 << 7 // 0x80 / 128 ) -func popString(buffer []byte, index int) (string, int) { - str := common.GetString(buffer[index:]) - return str, index + len(str) + 1 +var ( + IndexOutOfBoundsError = errors.New("index is out of bounds") +) + +func popString(buffer []byte, index int) (string, int, error) { + if index < 0 || index >= len(buffer) { + return "", 0, IndexOutOfBoundsError + } + + str, err := common.GetString(buffer[index:]) + + if err != nil { + return "", 0, err + } + + return str, index + len(str) + 1, nil } -func popBytes(buffer []byte, index int, size int) ([]byte, int) { - return buffer[index : index+size], index + size +func popBytes(buffer []byte, index int, size int) ([]byte, int, error) { + bufferLen := len(buffer) + + if index < 0 || index >= bufferLen { + return nil, 0, IndexOutOfBoundsError + } + if size < 0 || index+size > bufferLen { + return nil, 0, IndexOutOfBoundsError + } + + return buffer[index : index+size], index + size, nil } -func popUint32(buffer []byte, index int) (uint32, int) { - return binary.BigEndian.Uint32(buffer[index:]), index + 4 +func popUint32(buffer []byte, index int) (uint32, int, error) { + if index < 0 || index+4 > len(buffer) { + return 0, 0, IndexOutOfBoundsError + } + + return binary.BigEndian.Uint32(buffer[index:]), index + 4, nil } var regexSelfLookup = regexp.MustCompile(`^dwc_pid ?= ?(\d{1,10})$`) func handleServerListRequest(conn net.Conn, buffer []byte) { index := 9 - queryGame, index := popString(buffer, index) - gameName, index := popString(buffer, index) - challenge, index := popBytes(buffer, index, 8) - filter, index := popString(buffer, index) - fields, index := popString(buffer, index) - options, index := popUint32(buffer, index) + queryGame, index, err := popString(buffer, index) + if err != nil { + logging.Error(ModuleName, "Invalid queryGame") + return + } + gameName, index, err := popString(buffer, index) + if err != nil { + logging.Error(ModuleName, "Invalid gameName") + return + } + challenge, index, err := popBytes(buffer, index, 8) + if err != nil { + logging.Error(ModuleName, "Invalid challenge") + return + } + filter, index, err := popString(buffer, index) + if err != nil { + logging.Error(ModuleName, "Invalid filter") + return + } + fields, index, err := popString(buffer, index) + if err != nil { + logging.Error(ModuleName, "Invalid fields") + return + } + options, index, err := popUint32(buffer, index) + if err != nil { + logging.Error(ModuleName, "Invalid options") + return + } logging.Info(ModuleName, "queryGame:", aurora.Cyan(queryGame).String(), "- gameName:", aurora.Cyan(gameName).String(), "- filter:", aurora.Cyan(filter).String(), "- fields:", aurora.Cyan(fields).String())