NAS,NHTTP: Graceful shutdown of HTTP server

This commit is contained in:
mkwcat 2024-05-07 05:02:11 -04:00
parent 94c03d5693
commit d3198757bf
No known key found for this signature in database
GPG Key ID: 7A505679CE9E7AA9
4 changed files with 170 additions and 51 deletions

View File

@ -445,10 +445,8 @@ func (r *RPCFrontendPacket) ShutdownBackend(_ struct{}, _ *struct{}) error {
rpcBusyCount.Wait()
err := rpcClient.Call("RPCPacket.Shutdown", struct{}{}, nil)
if err != nil && err != rpc.ErrShutdown && !strings.Contains(err.Error(), "An existing connection was forcibly closed by the remote host.") {
if err != nil && !strings.Contains(err.Error(), "An existing connection was forcibly closed by the remote host.") {
logging.Error("FRONTEND", "Failed to reload backend:", err)
} else {
err = nil
}
err = rpcClient.Close()

View File

@ -2,11 +2,12 @@ package nas
import (
"context"
"fmt"
"errors"
"net/http"
"regexp"
"strconv"
"strings"
"time"
"wwfc/api"
"wwfc/common"
"wwfc/gamestats"
@ -14,33 +15,18 @@ import (
"wwfc/nhttp"
"wwfc/sake"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/logrusorgru/aurora/v3"
)
var (
ctx = context.Background()
pool *pgxpool.Pool
serverName string
server *nhttp.Server
)
func StartServer(reload bool) {
// Get config
config := common.GetConfig()
// Start SQL
dbString := fmt.Sprintf("postgres://%s:%s@%s/%s", config.Username, config.Password, config.DatabaseAddress, config.DatabaseName)
dbConf, err := pgxpool.ParseConfig(dbString)
if err != nil {
panic(err)
}
pool, err = pgxpool.ConnectConfig(ctx, dbConf)
if err != nil {
panic(err)
}
serverName = config.ServerName
address := *config.NASAddress + ":" + config.NASPort
@ -49,19 +35,38 @@ func StartServer(reload bool) {
go startHTTPSProxy(config)
}
err = CacheProfanityFile()
err := CacheProfanityFile()
if err != nil {
logging.Info("NAS", err)
}
logging.Notice("NAS", "Starting HTTP server on", aurora.BrightCyan(address))
server = &nhttp.Server{
Addr: address,
Handler: http.HandlerFunc(handleRequest),
}
go func() {
panic(nhttp.ListenAndServe(address, http.HandlerFunc(handleRequest)))
logging.Notice("NAS", "Starting HTTP server on", aurora.BrightCyan(address))
err := server.ListenAndServe()
if err != nil && !errors.Is(err, nhttp.ErrServerClosed) {
panic(err)
}
}()
}
func Shutdown() {
pool.Close()
if server == nil {
return
}
ctx, release := context.WithTimeout(context.Background(), 10*time.Second)
defer release()
err := server.Shutdown(ctx)
if err != nil {
logging.Error("NAS", "Error on HTTP shutdown:", err)
}
}
var regexSakeHost = regexp.MustCompile(`^([a-z\-]+\.)?sake\.gs\.`)

View File

@ -5,7 +5,6 @@ import (
"context"
"errors"
"fmt"
"golang.org/x/net/http/httpguts"
"io"
"math"
"net"
@ -15,6 +14,8 @@ import (
"sync"
"sync/atomic"
"time"
"golang.org/x/net/http/httpguts"
)
type connReader struct {
@ -117,7 +118,7 @@ func (cr *connReader) handleReadError(_ error) {
// may be called from multiple goroutines.
func (cr *connReader) closeNotify() {
res, _ := cr.conn.curReq.Load().(*response)
res := cr.conn.curReq.Load()
if res != nil && atomic.CompareAndSwapInt32(&res.didCloseNotify, 0, 1) {
res.closeNotifyCh <- true
}

View File

@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"log"
"math/rand"
"net"
_http "net/http"
"runtime"
@ -53,8 +54,6 @@ type Server struct {
// If non-nil, it must return a non-nil context.
BaseContext func(net.Listener) context.Context
doneChan chan struct{}
mu sync.Mutex
disableKeepAlives int32 // accessed atomically.
@ -75,10 +74,12 @@ type Server struct {
// ReadHeaderTimeout. It is valid to use them both.
ReadTimeout time.Duration
inShutdown atomicBool // true when server is in shutdown
inShutdown atomic.Bool // true when server is in shutdown
listeners map[*net.Listener]struct{}
activeConn map[*conn]struct{}
listenerGroup sync.WaitGroup
}
@ -126,11 +127,10 @@ type conn struct {
// on this connection, if any.
lastMethod string
curReq atomic.Value // of *response (which has a Request in it)
curReq atomic.Pointer[response] // (which has a Request in it)
curState struct{ atomic uint64 } // packed (unixtime<<8|uint8(ConnState))
curState atomic.Uint64 // packed (unixtime<<8|uint8(ConnState))
// mu guards hijackedv
mu sync.Mutex
}
@ -149,9 +149,96 @@ func (oc *onceCloseListener) Close() error {
func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() }
const shutdownPollIntervalMax = 500 * time.Millisecond
// Shutdown gracefully shuts down the server without interrupting any
// active connections. Shutdown works by first closing all open
// listeners, then closing all idle connections, and then waiting
// indefinitely for connections to return to idle and then shut down.
// If the provided context expires before the shutdown is complete,
// Shutdown returns the context's error, otherwise it returns any
// error returned from closing the [Server]'s underlying Listener(s).
//
// When Shutdown is called, [Serve], [ListenAndServe], and
// [ListenAndServeTLS] immediately return [ErrServerClosed]. Make sure the
// program doesn't exit and waits instead for Shutdown to return.
//
// Once Shutdown has been called on a server, it may not be reused;
// future calls to methods such as Serve will return ErrServerClosed.
func (srv *Server) Shutdown(ctx context.Context) error {
srv.inShutdown.Store(true)
srv.mu.Lock()
lnerr := srv.closeListenersLocked()
srv.mu.Unlock()
srv.listenerGroup.Wait()
pollIntervalBase := time.Millisecond
nextPollInterval := func() time.Duration {
// Add 10% jitter.
interval := pollIntervalBase + time.Duration(rand.Intn(int(pollIntervalBase/10)))
// Double and clamp for next time.
pollIntervalBase *= 2
if pollIntervalBase > shutdownPollIntervalMax {
pollIntervalBase = shutdownPollIntervalMax
}
return interval
}
timer := time.NewTimer(nextPollInterval())
defer timer.Stop()
for {
if srv.closeIdleConns() {
return lnerr
}
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
timer.Reset(nextPollInterval())
}
}
}
// closeIdleConns closes all idle connections and reports whether the
// server is quiescent.
func (s *Server) closeIdleConns() bool {
s.mu.Lock()
defer s.mu.Unlock()
quiescent := true
for c := range s.activeConn {
st, unixSec := c.getState()
// Issue 22682: treat StateNew connections as if
// they're idle if we haven't read the first request's
// header in over 5 seconds.
if st == _http.StateNew && unixSec < time.Now().Unix()-5 {
st = _http.StateIdle
}
if st != _http.StateIdle || unixSec == 0 {
// Assume unixSec == 0 means it's a very new
// connection, without state set yet.
quiescent = false
continue
}
c.rwc.Close()
delete(s.activeConn, c)
}
return quiescent
}
func (s *Server) closeListenersLocked() error {
var err error
for ln := range s.listeners {
if cerr := (*ln).Close(); cerr != nil && err == nil {
err = cerr
}
}
return err
}
func (srv *Server) ListenAndServe() error {
if srv.shuttingDown() {
// return ErrServerClosed
return ErrServerClosed
}
addr := srv.Addr
if addr == "" {
@ -229,21 +316,19 @@ func (srv *Server) Serve(l net.Listener) error {
for {
rw, err := l.Accept()
if err != nil {
select {
case <-srv.getDoneChan():
if srv.shuttingDown() {
return ErrServerClosed
default:
}
if _, ok := err.(net.Error); ok {
if ne, ok := err.(net.Error); ok && ne.Temporary() {
if tempDelay == 0 {
tempDelay = 5 * time.Millisecond
} else {
tempDelay *= 2
}
if _max := 1 * time.Second; tempDelay > _max {
tempDelay = _max
if max := 1 * time.Second; tempDelay > max {
tempDelay = max
}
log.Printf("nhttp: Accept error: %v; retrying in %v\n", err, tempDelay)
log.Printf("nhttp: Accept error: %v; retrying in %v", err, tempDelay)
time.Sleep(tempDelay)
continue
}
@ -289,6 +374,26 @@ func (c *conn) finalFlush() {
}
}
func (c *conn) setState(nc net.Conn, state _http.ConnState) {
srv := c.server
switch state {
case _http.StateNew:
srv.trackConn(c, true)
case _http.StateHijacked, _http.StateClosed:
srv.trackConn(c, false)
}
if state > 0xff || state < 0 {
panic("internal error")
}
packedState := uint64(time.Now().Unix()<<8) | uint64(state)
c.curState.Store(packedState)
}
func (c *conn) getState() (state _http.ConnState, unixSec int64) {
packedState := c.curState.Load()
return _http.ConnState(packedState & 0xff), int64(packedState >> 8)
}
func (c *conn) serve(ctx context.Context) {
c.remoteAddr = c.rwc.RemoteAddr().String()
ctx = context.WithValue(ctx, LocalAddrContextKey, c.rwc.LocalAddr())
@ -309,6 +414,7 @@ func (c *conn) serve(ctx context.Context) {
inFlightResponse.reqBody.Close()
}
c.close()
c.setState(c.rwc, _http.StateClosed)
}()
// HTTP/1.x from here on.
@ -323,6 +429,10 @@ func (c *conn) serve(ctx context.Context) {
for {
w, err := c.readRequest(ctx)
if c.r.remain != c.server.initialReadLimitSize() {
// If we read any bytes off the wire, we're active.
c.setState(c.rwc, _http.StateActive)
}
if err != nil {
const errorHeaders = "\r\nContent-Type: text/plain; charset=utf-8\r\nConnection: close\r\n\r\n"
@ -403,6 +513,8 @@ func (c *conn) serve(ctx context.Context) {
}
return
}
c.setState(c.rwc, _http.StateIdle)
c.curReq.Store(nil)
if !w.conn.server.doKeepAlives() {
// We're in shutdown mode. We might've replied
@ -430,28 +542,31 @@ func (c *conn) serve(ctx context.Context) {
}
}
func (srv *Server) getDoneChan() <-chan struct{} {
srv.mu.Lock()
defer srv.mu.Unlock()
if srv.doneChan == nil {
srv.doneChan = make(chan struct{})
}
return srv.doneChan
}
func (s *Server) doKeepAlives() bool {
return atomic.LoadInt32(&s.disableKeepAlives) == 0 && !s.shuttingDown()
}
func (s *Server) shuttingDown() bool {
return s.inShutdown.isSet()
return s.inShutdown.Load()
}
func (srv *Server) initialReadLimitSize() int64 {
return 1<<20 + 4096 // bufio slop
}
func (s *Server) trackConn(c *conn, add bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.activeConn == nil {
s.activeConn = make(map[*conn]struct{})
}
if add {
s.activeConn[c] = struct{}{}
} else {
delete(s.activeConn, c)
}
}
func (srv *Server) idleTimeout() time.Duration {
if srv.IdleTimeout != 0 {
return srv.IdleTimeout