Main: Enforce UUID of reload state

This commit is contained in:
mkwcat 2024-05-08 02:53:03 -04:00
parent b952c9e386
commit 2047badd9f
No known key found for this signature in database
GPG Key ID: 7A505679CE9E7AA9
2 changed files with 127 additions and 21 deletions

View File

@ -62,14 +62,30 @@ func Ready() error {
}
// Shutdown will notify the frontend that the backend is shutting down
func Shutdown() error {
func Shutdown() (string, error) {
if rpcFrontend == nil {
ConnectFrontend()
}
err := rpcFrontend.Call("RPCFrontendPacket.ShutdownBackend", struct{}{}, nil)
var stateUuid string
err := rpcFrontend.Call("RPCFrontendPacket.ShutdownBackend", struct{}{}, &stateUuid)
if err != nil {
logging.Error("COMMON", "Failed to notify frontend that backend is shutting down:", err)
}
return err
return stateUuid, err
}
// VerifyState will verify the state UUID with the frontend
func VerifyState(stateUuid string) (bool, error) {
if rpcFrontend == nil {
ConnectFrontend()
}
valid := false
err := rpcFrontend.Call("RPCFrontendPacket.VerifyState", stateUuid, &valid)
if err != nil {
logging.Error("COMMON", "Failed to verify state UUID with frontend:", err)
}
return valid, err
}

126
main.go
View File

@ -2,6 +2,7 @@ package main
import (
"errors"
"io"
"net"
"net/rpc"
"os"
@ -40,7 +41,7 @@ func main() {
// Start the backend instead of the frontend if the first argument is "backend"
if len(args) > 0 && args[0] == "backend" {
backendMain(len(args) > 1 && args[1] == "reload")
backendMain(len(args) > 1 && args[1] == "noreload")
} else {
frontendMain(len(args) > 0 && args[0] == "frontend")
}
@ -54,7 +55,7 @@ type RPCPacket struct {
}
// backendMain starts all the servers and creates an RPC server to communicate with the frontend
func backendMain(reload bool) {
func backendMain(noreload bool) {
sigExit := make(chan os.Signal, 1)
signal.Notify(sigExit, syscall.SIGINT, syscall.SIGTERM)
@ -73,6 +74,16 @@ func backendMain(reload bool) {
common.ConnectFrontend()
uuid := ""
if !noreload {
uuid = loadUuidFile()
}
reload, err := common.VerifyState(uuid)
if err != nil {
panic(err)
}
wg := &sync.WaitGroup{}
actions := []func(bool){nas.StartServer, gpcm.StartServer, qr2.StartServer, gpsp.StartServer, serverbrowser.StartServer, sake.StartServer, natneg.StartServer, api.StartServer, gamestats.StartServer}
wg.Add(len(actions))
@ -105,12 +116,30 @@ func backendMain(reload bool) {
// Wait for a signal to shutdown
<-sigExit
err = common.Shutdown()
stateUuid, err := common.Shutdown()
if err != nil {
panic(err)
}
(&RPCPacket{}).Shutdown(struct{}{}, &struct{}{})
(&RPCPacket{}).Shutdown(stateUuid, &struct{}{})
}
func loadUuidFile() string {
stateFile, err := os.Open("state/uuid.txt")
if err != nil {
logging.Error("BACKEND", "Failed to open state file:", err)
return ""
}
defer stateFile.Close()
uuid, err := io.ReadAll(stateFile)
if err != nil {
logging.Error("BACKEND", "Failed to read state file:", err)
return ""
}
return string(uuid)
}
// RPCPacket.NewConnection is called by the frontend to notify the backend of a new connection
@ -162,7 +191,7 @@ func (r *RPCPacket) CloseConnection(args RPCPacket, _ *struct{}) error {
}
// RPCPacket.Shutdown is called by the frontend to shutdown the backend
func (r *RPCPacket) Shutdown(_ struct{}, _ *struct{}) error {
func (r *RPCPacket) Shutdown(stateUuid string, _ *struct{}) error {
wg := &sync.WaitGroup{}
actions := []func(){nas.Shutdown, gpcm.Shutdown, qr2.Shutdown, gpsp.Shutdown, serverbrowser.Shutdown, sake.Shutdown, natneg.Shutdown, api.Shutdown, gamestats.Shutdown}
wg.Add(len(actions))
@ -175,6 +204,21 @@ func (r *RPCPacket) Shutdown(_ struct{}, _ *struct{}) error {
wg.Wait()
stateFile, err := os.OpenFile("state/uuid.txt", os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
panic(err)
}
_, err = stateFile.WriteString(stateUuid)
if err != nil {
panic(err)
}
err = stateFile.Close()
if err != nil {
panic(err)
}
os.Exit(0)
return nil
}
@ -197,8 +241,9 @@ var (
rpcMutex sync.Mutex
rpcBusyCount sync.WaitGroup
backendReady = make(chan struct{})
frontendUuid string
connections = map[string]map[uint64]net.Conn{}
connections = map[string]map[uint64]*net.Conn{}
)
// frontendMain starts the backend process and communicates with it using RPC
@ -231,7 +276,7 @@ func frontendMain(noBackend bool) {
}
for _, server := range servers {
connections[server.rpcName] = map[uint64]net.Conn{}
connections[server.rpcName] = map[uint64]*net.Conn{}
go frontendListen(server)
}
@ -309,6 +354,7 @@ func waitForBackend() {
rpcMutex.Unlock()
logging.Notice("FRONTEND", "Connected to backend")
return
}
}
@ -354,7 +400,8 @@ func handleConnection(server serverInfo, conn net.Conn, index uint64) {
rpcMutex.Lock()
rpcBusyCount.Add(1)
connections[server.rpcName][index] = conn
pConn := &conn
connections[server.rpcName][index] = pConn
rpcMutex.Unlock()
err := rpcClient.Call("RPCPacket.NewConnection", RPCPacket{Server: server.rpcName, Index: index, Address: conn.RemoteAddr().String(), Data: []byte{}}, nil)
@ -400,6 +447,11 @@ func handleConnection(server serverInfo, conn net.Conn, index uint64) {
}
rpcMutex.Lock()
if connections[server.rpcName][index] != pConn {
rpcMutex.Unlock()
return
}
rpcBusyCount.Add(1)
delete(connections[server.rpcName], index)
rpcMutex.Unlock()
@ -416,19 +468,23 @@ func handleConnection(server serverInfo, conn net.Conn, index uint64) {
}
}
var ErrBadIndex = errors.New("incorrect connection index")
var (
ErrBadIndex = errors.New("incorrect connection index")
ErrorBusy = errors.New("backend is busy")
)
// RPCFrontendPacket.SendPacket is called by the backend to send a packet to a connection
func (r *RPCFrontendPacket) SendPacket(args RPCFrontendPacket, _ *struct{}) error {
rpcMutex.Lock()
defer rpcMutex.Unlock()
conn, ok := connections[args.Server][args.Index]
if !ok {
conn := connections[args.Server][args.Index]
if conn == nil {
return ErrBadIndex
}
_, err := conn.Write(args.Data)
_, err := (*conn).Write(args.Data)
return err
}
@ -437,20 +493,22 @@ func (r *RPCFrontendPacket) CloseConnection(args RPCFrontendPacket, _ *struct{})
rpcMutex.Lock()
defer rpcMutex.Unlock()
conn, ok := connections[args.Server][args.Index]
if !ok {
conn := connections[args.Server][args.Index]
if conn == nil {
return ErrBadIndex
}
delete(connections[args.Server], args.Index)
return conn.Close()
return (*conn).Close()
}
// RPCFrontendPacket.ReloadBackend is called by an external program to reload the backend
func (r *RPCFrontendPacket) ReloadBackend(_ struct{}, _ *struct{}) error {
r.ShutdownBackend(struct{}{}, &struct{}{})
var stateUid string
r.ShutdownBackend(struct{}{}, &stateUid)
err := rpcClient.Call("RPCPacket.Shutdown", struct{}{}, nil)
err := rpcClient.Call("RPCPacket.Shutdown", stateUid, nil)
if err != nil && !strings.Contains(err.Error(), "An existing connection was forcibly closed by the remote host.") {
logging.Error("FRONTEND", "Failed to reload backend:", err)
}
@ -467,7 +525,7 @@ func (r *RPCFrontendPacket) ReloadBackend(_ struct{}, _ *struct{}) error {
}
// RPCFrontendPacket.ShutdownBackend is called by an external program to shutdown the backend
func (r *RPCFrontendPacket) ShutdownBackend(_ struct{}, _ *struct{}) error {
func (r *RPCFrontendPacket) ShutdownBackend(_ struct{}, uuid *string) error {
logging.Notice("FRONTEND", "Shutting down backend")
// Lock indefinitely
@ -477,11 +535,43 @@ func (r *RPCFrontendPacket) ShutdownBackend(_ struct{}, _ *struct{}) error {
go waitForBackend()
frontendUuid = common.RandomString(32)
*uuid = frontendUuid
return nil
}
// RPCFrontendPacket.VerifyState is called by the backend to verify the state UUID
func (r *RPCFrontendPacket) VerifyState(uuid string, reload *bool) error {
if rpcMutex.TryLock() {
rpcMutex.Unlock()
logging.Error("FRONTEND", "Failed to verify UUID, backend is active")
*reload = false
return ErrorBusy
}
if uuid != frontendUuid {
logging.Notice("FRONTEND", "VerifyState: Resetting all connections")
// Close all connections
for _, server := range connections {
for index, conn := range server {
(*conn).Close()
delete(server, index)
}
}
*reload = false
return nil
}
*reload = true
return nil
}
// RPCFrontendPacket.Ready is called by the backend to indicate it is ready to accept connections
func (r *RPCFrontendPacket) Ready(_ struct{}, _ *struct{}) error {
close(backendReady)
return nil
}