diff --git a/natneg/main.go b/natneg/main.go index 9688c96..cff7856 100644 --- a/natneg/main.go +++ b/natneg/main.go @@ -3,8 +3,10 @@ package natneg import ( "bytes" "encoding/binary" + "encoding/gob" "fmt" "net" + "os" "sync" "time" "wwfc/common" @@ -60,7 +62,7 @@ type NATNEGSession struct { Open bool Version byte Cookie uint32 - Mutex sync.RWMutex + mutex sync.RWMutex Clients map[byte]*NATNEGClient } @@ -80,6 +82,9 @@ var ( sessions = map[uint32]*NATNEGSession{} mutex = sync.RWMutex{} natnegConn net.PacketConn + + inShutdown = false + waitGroup = sync.WaitGroup{} ) func StartServer(reload bool) { @@ -93,28 +98,90 @@ func StartServer(reload bool) { } natnegConn = conn + inShutdown = false + + if reload { + // Load state + file, err := os.Open("state/natneg_sessions.gob") + if err != nil { + panic(err) + } + + decoder := gob.NewDecoder(file) + + err = decoder.Decode(&sessions) + file.Close() + + if err != nil { + panic(err) + } + + for _, session := range sessions { + cur := session + time.AfterFunc(30*time.Second, func() { + closeSession("NATNEG:"+fmt.Sprintf("%08x", cur.Cookie), cur) + }) + } + + logging.Notice("NATNEG", "Loaded", aurora.Cyan(len(sessions)), "sessions") + } + + waitGroup.Add(1) go func() { + defer waitGroup.Done() + // Close the listener when the application closes. defer conn.Close() logging.Notice("NATNEG", "Listening on", aurora.BrightCyan(address)) for { + if inShutdown { + return + } + buffer := make([]byte, 1024) size, addr, err := conn.ReadFrom(buffer) if err != nil { continue } + waitGroup.Add(1) + go handleConnection(conn, addr, buffer[:size]) } }() } func Shutdown() { + inShutdown = true + natnegConn.Close() + waitGroup.Wait() + + // Save state + mutex.Lock() + defer mutex.Unlock() + + file, err := os.OpenFile("state/natneg_sessions.gob", os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + panic(err) + } + + encoder := gob.NewEncoder(file) + + err = encoder.Encode(sessions) + file.Close() + + if err != nil { + panic(err) + } + + logging.Notice("NATNEG", "Saved", aurora.Cyan(len(sessions)), "sessions") } func handleConnection(conn net.PacketConn, addr net.Addr, buffer []byte) { + defer waitGroup.Done() + // Validate the packet magic if len(buffer) < 12 || !bytes.Equal(buffer[:6], []byte{0xfd, 0xfc, 0x1e, 0x66, 0x6a, 0xb2}) { logging.Error("NATNEG:"+addr.String(), "Invalid packet header") @@ -145,37 +212,14 @@ func handleConnection(conn net.PacketConn, addr net.Addr, buffer []byte) { Open: true, Version: version, Cookie: cookie, - Mutex: sync.RWMutex{}, + mutex: sync.RWMutex{}, Clients: map[byte]*NATNEGClient{}, } sessions[cookie] = session // Session has TTL of 30 seconds time.AfterFunc(30*time.Second, func() { - session.Open = false - - mutex.Lock() - delete(sessions, cookie) - mutex.Unlock() - - session.Mutex.Lock() - defer session.Mutex.Unlock() - - // Disconnect each client - for _, client := range session.Clients { - if client.ConnectingIndex == client.Index { - continue - } - - logging.Info(moduleName, "Disconnecting client", aurora.Cyan(client.Index)) - // Send report ack, which will cause the client to cancel - reportAck := createPacketHeader(version, NNReportReply, session.Cookie) - reportAck = append(reportAck, 0x00, client.Index, 0x00) - reportAck = append(reportAck, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00) - conn.WriteTo(reportAck, addr) - } - - logging.Info(moduleName, "Deleted session") + closeSession(moduleName, session) }) } mutex.Unlock() @@ -185,8 +229,8 @@ func handleConnection(conn net.PacketConn, addr net.Addr, buffer []byte) { return } - session.Mutex.Lock() - defer session.Mutex.Unlock() + session.mutex.Lock() + defer session.mutex.Unlock() } switch command { @@ -249,6 +293,43 @@ func handleConnection(conn net.PacketConn, addr net.Addr, buffer []byte) { } } +func closeSession(moduleName string, session *NATNEGSession) { + mutex.Lock() + if inShutdown { + mutex.Unlock() + return + } + + session.Open = false + delete(sessions, session.Cookie) + mutex.Unlock() + + session.mutex.Lock() + defer session.mutex.Unlock() + + // Disconnect each client + for _, client := range session.Clients { + if client.ConnectingIndex == client.Index { + continue + } + + logging.Info("NATNEG", "Disconnecting client", aurora.Cyan(client.Index)) + // Send report ack, which will cause the client to cancel + reportAck := createPacketHeader(session.Version, NNReportReply, session.Cookie) + reportAck = append(reportAck, 0x00, client.Index, 0x00) + reportAck = append(reportAck, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00) + + addr, err := net.ResolveUDPAddr("udp", client.NegotiateIP) + if err != nil { + panic(err) + } + + natnegConn.WriteTo(reportAck, addr) + } + + logging.Info("NATNEG", "Deleted session") +} + func getPortTypeName(portType byte) string { switch portType { default: