From d3198757bfb5bca39f3d6fea6e5ba688d8a4619e Mon Sep 17 00:00:00 2001 From: mkwcat Date: Tue, 7 May 2024 05:02:11 -0400 Subject: [PATCH] NAS,NHTTP: Graceful shutdown of HTTP server --- main.go | 4 +- nas/main.go | 47 ++++++++------ nhttp/request.go | 5 +- nhttp/server.go | 165 ++++++++++++++++++++++++++++++++++++++++------- 4 files changed, 170 insertions(+), 51 deletions(-) diff --git a/main.go b/main.go index 1fea9d9..6c5a711 100644 --- a/main.go +++ b/main.go @@ -445,10 +445,8 @@ func (r *RPCFrontendPacket) ShutdownBackend(_ struct{}, _ *struct{}) error { rpcBusyCount.Wait() err := rpcClient.Call("RPCPacket.Shutdown", struct{}{}, nil) - if err != nil && err != rpc.ErrShutdown && !strings.Contains(err.Error(), "An existing connection was forcibly closed by the remote host.") { + if err != nil && !strings.Contains(err.Error(), "An existing connection was forcibly closed by the remote host.") { logging.Error("FRONTEND", "Failed to reload backend:", err) - } else { - err = nil } err = rpcClient.Close() diff --git a/nas/main.go b/nas/main.go index cacac39..9172cc7 100644 --- a/nas/main.go +++ b/nas/main.go @@ -2,11 +2,12 @@ package nas import ( "context" - "fmt" + "errors" "net/http" "regexp" "strconv" "strings" + "time" "wwfc/api" "wwfc/common" "wwfc/gamestats" @@ -14,33 +15,18 @@ import ( "wwfc/nhttp" "wwfc/sake" - "github.com/jackc/pgx/v4/pgxpool" "github.com/logrusorgru/aurora/v3" ) var ( - ctx = context.Background() - pool *pgxpool.Pool - serverName string + server *nhttp.Server ) func StartServer(reload bool) { // Get config config := common.GetConfig() - // Start SQL - dbString := fmt.Sprintf("postgres://%s:%s@%s/%s", config.Username, config.Password, config.DatabaseAddress, config.DatabaseName) - dbConf, err := pgxpool.ParseConfig(dbString) - if err != nil { - panic(err) - } - - pool, err = pgxpool.ConnectConfig(ctx, dbConf) - if err != nil { - panic(err) - } - serverName = config.ServerName address := *config.NASAddress + ":" + config.NASPort @@ -49,19 +35,38 @@ func StartServer(reload bool) { go startHTTPSProxy(config) } - err = CacheProfanityFile() + err := CacheProfanityFile() if err != nil { logging.Info("NAS", err) } - logging.Notice("NAS", "Starting HTTP server on", aurora.BrightCyan(address)) + server = &nhttp.Server{ + Addr: address, + Handler: http.HandlerFunc(handleRequest), + } + go func() { - panic(nhttp.ListenAndServe(address, http.HandlerFunc(handleRequest))) + logging.Notice("NAS", "Starting HTTP server on", aurora.BrightCyan(address)) + + err := server.ListenAndServe() + if err != nil && !errors.Is(err, nhttp.ErrServerClosed) { + panic(err) + } }() } func Shutdown() { - pool.Close() + if server == nil { + return + } + + ctx, release := context.WithTimeout(context.Background(), 10*time.Second) + defer release() + + err := server.Shutdown(ctx) + if err != nil { + logging.Error("NAS", "Error on HTTP shutdown:", err) + } } var regexSakeHost = regexp.MustCompile(`^([a-z\-]+\.)?sake\.gs\.`) diff --git a/nhttp/request.go b/nhttp/request.go index 5c99207..68dcfa3 100644 --- a/nhttp/request.go +++ b/nhttp/request.go @@ -5,7 +5,6 @@ import ( "context" "errors" "fmt" - "golang.org/x/net/http/httpguts" "io" "math" "net" @@ -15,6 +14,8 @@ import ( "sync" "sync/atomic" "time" + + "golang.org/x/net/http/httpguts" ) type connReader struct { @@ -117,7 +118,7 @@ func (cr *connReader) handleReadError(_ error) { // may be called from multiple goroutines. func (cr *connReader) closeNotify() { - res, _ := cr.conn.curReq.Load().(*response) + res := cr.conn.curReq.Load() if res != nil && atomic.CompareAndSwapInt32(&res.didCloseNotify, 0, 1) { res.closeNotifyCh <- true } diff --git a/nhttp/server.go b/nhttp/server.go index ddbba2b..50ddb2d 100644 --- a/nhttp/server.go +++ b/nhttp/server.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "log" + "math/rand" "net" _http "net/http" "runtime" @@ -53,8 +54,6 @@ type Server struct { // If non-nil, it must return a non-nil context. BaseContext func(net.Listener) context.Context - doneChan chan struct{} - mu sync.Mutex disableKeepAlives int32 // accessed atomically. @@ -75,10 +74,12 @@ type Server struct { // ReadHeaderTimeout. It is valid to use them both. ReadTimeout time.Duration - inShutdown atomicBool // true when server is in shutdown + inShutdown atomic.Bool // true when server is in shutdown listeners map[*net.Listener]struct{} + activeConn map[*conn]struct{} + listenerGroup sync.WaitGroup } @@ -126,11 +127,10 @@ type conn struct { // on this connection, if any. lastMethod string - curReq atomic.Value // of *response (which has a Request in it) + curReq atomic.Pointer[response] // (which has a Request in it) - curState struct{ atomic uint64 } // packed (unixtime<<8|uint8(ConnState)) + curState atomic.Uint64 // packed (unixtime<<8|uint8(ConnState)) - // mu guards hijackedv mu sync.Mutex } @@ -149,9 +149,96 @@ func (oc *onceCloseListener) Close() error { func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() } +const shutdownPollIntervalMax = 500 * time.Millisecond + +// Shutdown gracefully shuts down the server without interrupting any +// active connections. Shutdown works by first closing all open +// listeners, then closing all idle connections, and then waiting +// indefinitely for connections to return to idle and then shut down. +// If the provided context expires before the shutdown is complete, +// Shutdown returns the context's error, otherwise it returns any +// error returned from closing the [Server]'s underlying Listener(s). +// +// When Shutdown is called, [Serve], [ListenAndServe], and +// [ListenAndServeTLS] immediately return [ErrServerClosed]. Make sure the +// program doesn't exit and waits instead for Shutdown to return. +// +// Once Shutdown has been called on a server, it may not be reused; +// future calls to methods such as Serve will return ErrServerClosed. +func (srv *Server) Shutdown(ctx context.Context) error { + srv.inShutdown.Store(true) + + srv.mu.Lock() + lnerr := srv.closeListenersLocked() + srv.mu.Unlock() + srv.listenerGroup.Wait() + + pollIntervalBase := time.Millisecond + nextPollInterval := func() time.Duration { + // Add 10% jitter. + interval := pollIntervalBase + time.Duration(rand.Intn(int(pollIntervalBase/10))) + // Double and clamp for next time. + pollIntervalBase *= 2 + if pollIntervalBase > shutdownPollIntervalMax { + pollIntervalBase = shutdownPollIntervalMax + } + return interval + } + + timer := time.NewTimer(nextPollInterval()) + defer timer.Stop() + for { + if srv.closeIdleConns() { + return lnerr + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + timer.Reset(nextPollInterval()) + } + } +} + +// closeIdleConns closes all idle connections and reports whether the +// server is quiescent. +func (s *Server) closeIdleConns() bool { + s.mu.Lock() + defer s.mu.Unlock() + quiescent := true + for c := range s.activeConn { + st, unixSec := c.getState() + // Issue 22682: treat StateNew connections as if + // they're idle if we haven't read the first request's + // header in over 5 seconds. + if st == _http.StateNew && unixSec < time.Now().Unix()-5 { + st = _http.StateIdle + } + if st != _http.StateIdle || unixSec == 0 { + // Assume unixSec == 0 means it's a very new + // connection, without state set yet. + quiescent = false + continue + } + c.rwc.Close() + delete(s.activeConn, c) + } + return quiescent +} + +func (s *Server) closeListenersLocked() error { + var err error + for ln := range s.listeners { + if cerr := (*ln).Close(); cerr != nil && err == nil { + err = cerr + } + } + return err +} + func (srv *Server) ListenAndServe() error { if srv.shuttingDown() { - // return ErrServerClosed + return ErrServerClosed } addr := srv.Addr if addr == "" { @@ -229,21 +316,19 @@ func (srv *Server) Serve(l net.Listener) error { for { rw, err := l.Accept() if err != nil { - select { - case <-srv.getDoneChan(): + if srv.shuttingDown() { return ErrServerClosed - default: } - if _, ok := err.(net.Error); ok { + if ne, ok := err.(net.Error); ok && ne.Temporary() { if tempDelay == 0 { tempDelay = 5 * time.Millisecond } else { tempDelay *= 2 } - if _max := 1 * time.Second; tempDelay > _max { - tempDelay = _max + if max := 1 * time.Second; tempDelay > max { + tempDelay = max } - log.Printf("nhttp: Accept error: %v; retrying in %v\n", err, tempDelay) + log.Printf("nhttp: Accept error: %v; retrying in %v", err, tempDelay) time.Sleep(tempDelay) continue } @@ -289,6 +374,26 @@ func (c *conn) finalFlush() { } } +func (c *conn) setState(nc net.Conn, state _http.ConnState) { + srv := c.server + switch state { + case _http.StateNew: + srv.trackConn(c, true) + case _http.StateHijacked, _http.StateClosed: + srv.trackConn(c, false) + } + if state > 0xff || state < 0 { + panic("internal error") + } + packedState := uint64(time.Now().Unix()<<8) | uint64(state) + c.curState.Store(packedState) +} + +func (c *conn) getState() (state _http.ConnState, unixSec int64) { + packedState := c.curState.Load() + return _http.ConnState(packedState & 0xff), int64(packedState >> 8) +} + func (c *conn) serve(ctx context.Context) { c.remoteAddr = c.rwc.RemoteAddr().String() ctx = context.WithValue(ctx, LocalAddrContextKey, c.rwc.LocalAddr()) @@ -309,6 +414,7 @@ func (c *conn) serve(ctx context.Context) { inFlightResponse.reqBody.Close() } c.close() + c.setState(c.rwc, _http.StateClosed) }() // HTTP/1.x from here on. @@ -323,6 +429,10 @@ func (c *conn) serve(ctx context.Context) { for { w, err := c.readRequest(ctx) + if c.r.remain != c.server.initialReadLimitSize() { + // If we read any bytes off the wire, we're active. + c.setState(c.rwc, _http.StateActive) + } if err != nil { const errorHeaders = "\r\nContent-Type: text/plain; charset=utf-8\r\nConnection: close\r\n\r\n" @@ -403,6 +513,8 @@ func (c *conn) serve(ctx context.Context) { } return } + c.setState(c.rwc, _http.StateIdle) + c.curReq.Store(nil) if !w.conn.server.doKeepAlives() { // We're in shutdown mode. We might've replied @@ -430,28 +542,31 @@ func (c *conn) serve(ctx context.Context) { } } -func (srv *Server) getDoneChan() <-chan struct{} { - srv.mu.Lock() - defer srv.mu.Unlock() - if srv.doneChan == nil { - srv.doneChan = make(chan struct{}) - } - - return srv.doneChan -} - func (s *Server) doKeepAlives() bool { return atomic.LoadInt32(&s.disableKeepAlives) == 0 && !s.shuttingDown() } func (s *Server) shuttingDown() bool { - return s.inShutdown.isSet() + return s.inShutdown.Load() } func (srv *Server) initialReadLimitSize() int64 { return 1<<20 + 4096 // bufio slop } +func (s *Server) trackConn(c *conn, add bool) { + s.mu.Lock() + defer s.mu.Unlock() + if s.activeConn == nil { + s.activeConn = make(map[*conn]struct{}) + } + if add { + s.activeConn[c] = struct{}{} + } else { + delete(s.activeConn, c) + } +} + func (srv *Server) idleTimeout() time.Duration { if srv.IdleTimeout != 0 { return srv.IdleTimeout