mirror of
https://github.com/WiiLink24/wfc-server.git
synced 2026-03-21 17:44:58 -05:00
NAS,NHTTP: Graceful shutdown of HTTP server
This commit is contained in:
parent
94c03d5693
commit
d3198757bf
4
main.go
4
main.go
|
|
@ -445,10 +445,8 @@ func (r *RPCFrontendPacket) ShutdownBackend(_ struct{}, _ *struct{}) error {
|
|||
rpcBusyCount.Wait()
|
||||
|
||||
err := rpcClient.Call("RPCPacket.Shutdown", struct{}{}, nil)
|
||||
if err != nil && err != rpc.ErrShutdown && !strings.Contains(err.Error(), "An existing connection was forcibly closed by the remote host.") {
|
||||
if err != nil && !strings.Contains(err.Error(), "An existing connection was forcibly closed by the remote host.") {
|
||||
logging.Error("FRONTEND", "Failed to reload backend:", err)
|
||||
} else {
|
||||
err = nil
|
||||
}
|
||||
|
||||
err = rpcClient.Close()
|
||||
|
|
|
|||
47
nas/main.go
47
nas/main.go
|
|
@ -2,11 +2,12 @@ package nas
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"errors"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"wwfc/api"
|
||||
"wwfc/common"
|
||||
"wwfc/gamestats"
|
||||
|
|
@ -14,33 +15,18 @@ import (
|
|||
"wwfc/nhttp"
|
||||
"wwfc/sake"
|
||||
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
"github.com/logrusorgru/aurora/v3"
|
||||
)
|
||||
|
||||
var (
|
||||
ctx = context.Background()
|
||||
pool *pgxpool.Pool
|
||||
|
||||
serverName string
|
||||
server *nhttp.Server
|
||||
)
|
||||
|
||||
func StartServer(reload bool) {
|
||||
// Get config
|
||||
config := common.GetConfig()
|
||||
|
||||
// Start SQL
|
||||
dbString := fmt.Sprintf("postgres://%s:%s@%s/%s", config.Username, config.Password, config.DatabaseAddress, config.DatabaseName)
|
||||
dbConf, err := pgxpool.ParseConfig(dbString)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
pool, err = pgxpool.ConnectConfig(ctx, dbConf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
serverName = config.ServerName
|
||||
|
||||
address := *config.NASAddress + ":" + config.NASPort
|
||||
|
|
@ -49,19 +35,38 @@ func StartServer(reload bool) {
|
|||
go startHTTPSProxy(config)
|
||||
}
|
||||
|
||||
err = CacheProfanityFile()
|
||||
err := CacheProfanityFile()
|
||||
if err != nil {
|
||||
logging.Info("NAS", err)
|
||||
}
|
||||
|
||||
logging.Notice("NAS", "Starting HTTP server on", aurora.BrightCyan(address))
|
||||
server = &nhttp.Server{
|
||||
Addr: address,
|
||||
Handler: http.HandlerFunc(handleRequest),
|
||||
}
|
||||
|
||||
go func() {
|
||||
panic(nhttp.ListenAndServe(address, http.HandlerFunc(handleRequest)))
|
||||
logging.Notice("NAS", "Starting HTTP server on", aurora.BrightCyan(address))
|
||||
|
||||
err := server.ListenAndServe()
|
||||
if err != nil && !errors.Is(err, nhttp.ErrServerClosed) {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func Shutdown() {
|
||||
pool.Close()
|
||||
if server == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, release := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer release()
|
||||
|
||||
err := server.Shutdown(ctx)
|
||||
if err != nil {
|
||||
logging.Error("NAS", "Error on HTTP shutdown:", err)
|
||||
}
|
||||
}
|
||||
|
||||
var regexSakeHost = regexp.MustCompile(`^([a-z\-]+\.)?sake\.gs\.`)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"golang.org/x/net/http/httpguts"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
|
|
@ -15,6 +14,8 @@ import (
|
|||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/http/httpguts"
|
||||
)
|
||||
|
||||
type connReader struct {
|
||||
|
|
@ -117,7 +118,7 @@ func (cr *connReader) handleReadError(_ error) {
|
|||
|
||||
// may be called from multiple goroutines.
|
||||
func (cr *connReader) closeNotify() {
|
||||
res, _ := cr.conn.curReq.Load().(*response)
|
||||
res := cr.conn.curReq.Load()
|
||||
if res != nil && atomic.CompareAndSwapInt32(&res.didCloseNotify, 0, 1) {
|
||||
res.closeNotifyCh <- true
|
||||
}
|
||||
|
|
|
|||
165
nhttp/server.go
165
nhttp/server.go
|
|
@ -7,6 +7,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
_http "net/http"
|
||||
"runtime"
|
||||
|
|
@ -53,8 +54,6 @@ type Server struct {
|
|||
// If non-nil, it must return a non-nil context.
|
||||
BaseContext func(net.Listener) context.Context
|
||||
|
||||
doneChan chan struct{}
|
||||
|
||||
mu sync.Mutex
|
||||
|
||||
disableKeepAlives int32 // accessed atomically.
|
||||
|
|
@ -75,10 +74,12 @@ type Server struct {
|
|||
// ReadHeaderTimeout. It is valid to use them both.
|
||||
ReadTimeout time.Duration
|
||||
|
||||
inShutdown atomicBool // true when server is in shutdown
|
||||
inShutdown atomic.Bool // true when server is in shutdown
|
||||
|
||||
listeners map[*net.Listener]struct{}
|
||||
|
||||
activeConn map[*conn]struct{}
|
||||
|
||||
listenerGroup sync.WaitGroup
|
||||
}
|
||||
|
||||
|
|
@ -126,11 +127,10 @@ type conn struct {
|
|||
// on this connection, if any.
|
||||
lastMethod string
|
||||
|
||||
curReq atomic.Value // of *response (which has a Request in it)
|
||||
curReq atomic.Pointer[response] // (which has a Request in it)
|
||||
|
||||
curState struct{ atomic uint64 } // packed (unixtime<<8|uint8(ConnState))
|
||||
curState atomic.Uint64 // packed (unixtime<<8|uint8(ConnState))
|
||||
|
||||
// mu guards hijackedv
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
|
|
@ -149,9 +149,96 @@ func (oc *onceCloseListener) Close() error {
|
|||
|
||||
func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() }
|
||||
|
||||
const shutdownPollIntervalMax = 500 * time.Millisecond
|
||||
|
||||
// Shutdown gracefully shuts down the server without interrupting any
|
||||
// active connections. Shutdown works by first closing all open
|
||||
// listeners, then closing all idle connections, and then waiting
|
||||
// indefinitely for connections to return to idle and then shut down.
|
||||
// If the provided context expires before the shutdown is complete,
|
||||
// Shutdown returns the context's error, otherwise it returns any
|
||||
// error returned from closing the [Server]'s underlying Listener(s).
|
||||
//
|
||||
// When Shutdown is called, [Serve], [ListenAndServe], and
|
||||
// [ListenAndServeTLS] immediately return [ErrServerClosed]. Make sure the
|
||||
// program doesn't exit and waits instead for Shutdown to return.
|
||||
//
|
||||
// Once Shutdown has been called on a server, it may not be reused;
|
||||
// future calls to methods such as Serve will return ErrServerClosed.
|
||||
func (srv *Server) Shutdown(ctx context.Context) error {
|
||||
srv.inShutdown.Store(true)
|
||||
|
||||
srv.mu.Lock()
|
||||
lnerr := srv.closeListenersLocked()
|
||||
srv.mu.Unlock()
|
||||
srv.listenerGroup.Wait()
|
||||
|
||||
pollIntervalBase := time.Millisecond
|
||||
nextPollInterval := func() time.Duration {
|
||||
// Add 10% jitter.
|
||||
interval := pollIntervalBase + time.Duration(rand.Intn(int(pollIntervalBase/10)))
|
||||
// Double and clamp for next time.
|
||||
pollIntervalBase *= 2
|
||||
if pollIntervalBase > shutdownPollIntervalMax {
|
||||
pollIntervalBase = shutdownPollIntervalMax
|
||||
}
|
||||
return interval
|
||||
}
|
||||
|
||||
timer := time.NewTimer(nextPollInterval())
|
||||
defer timer.Stop()
|
||||
for {
|
||||
if srv.closeIdleConns() {
|
||||
return lnerr
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
timer.Reset(nextPollInterval())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// closeIdleConns closes all idle connections and reports whether the
|
||||
// server is quiescent.
|
||||
func (s *Server) closeIdleConns() bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
quiescent := true
|
||||
for c := range s.activeConn {
|
||||
st, unixSec := c.getState()
|
||||
// Issue 22682: treat StateNew connections as if
|
||||
// they're idle if we haven't read the first request's
|
||||
// header in over 5 seconds.
|
||||
if st == _http.StateNew && unixSec < time.Now().Unix()-5 {
|
||||
st = _http.StateIdle
|
||||
}
|
||||
if st != _http.StateIdle || unixSec == 0 {
|
||||
// Assume unixSec == 0 means it's a very new
|
||||
// connection, without state set yet.
|
||||
quiescent = false
|
||||
continue
|
||||
}
|
||||
c.rwc.Close()
|
||||
delete(s.activeConn, c)
|
||||
}
|
||||
return quiescent
|
||||
}
|
||||
|
||||
func (s *Server) closeListenersLocked() error {
|
||||
var err error
|
||||
for ln := range s.listeners {
|
||||
if cerr := (*ln).Close(); cerr != nil && err == nil {
|
||||
err = cerr
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (srv *Server) ListenAndServe() error {
|
||||
if srv.shuttingDown() {
|
||||
// return ErrServerClosed
|
||||
return ErrServerClosed
|
||||
}
|
||||
addr := srv.Addr
|
||||
if addr == "" {
|
||||
|
|
@ -229,21 +316,19 @@ func (srv *Server) Serve(l net.Listener) error {
|
|||
for {
|
||||
rw, err := l.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-srv.getDoneChan():
|
||||
if srv.shuttingDown() {
|
||||
return ErrServerClosed
|
||||
default:
|
||||
}
|
||||
if _, ok := err.(net.Error); ok {
|
||||
if ne, ok := err.(net.Error); ok && ne.Temporary() {
|
||||
if tempDelay == 0 {
|
||||
tempDelay = 5 * time.Millisecond
|
||||
} else {
|
||||
tempDelay *= 2
|
||||
}
|
||||
if _max := 1 * time.Second; tempDelay > _max {
|
||||
tempDelay = _max
|
||||
if max := 1 * time.Second; tempDelay > max {
|
||||
tempDelay = max
|
||||
}
|
||||
log.Printf("nhttp: Accept error: %v; retrying in %v\n", err, tempDelay)
|
||||
log.Printf("nhttp: Accept error: %v; retrying in %v", err, tempDelay)
|
||||
time.Sleep(tempDelay)
|
||||
continue
|
||||
}
|
||||
|
|
@ -289,6 +374,26 @@ func (c *conn) finalFlush() {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *conn) setState(nc net.Conn, state _http.ConnState) {
|
||||
srv := c.server
|
||||
switch state {
|
||||
case _http.StateNew:
|
||||
srv.trackConn(c, true)
|
||||
case _http.StateHijacked, _http.StateClosed:
|
||||
srv.trackConn(c, false)
|
||||
}
|
||||
if state > 0xff || state < 0 {
|
||||
panic("internal error")
|
||||
}
|
||||
packedState := uint64(time.Now().Unix()<<8) | uint64(state)
|
||||
c.curState.Store(packedState)
|
||||
}
|
||||
|
||||
func (c *conn) getState() (state _http.ConnState, unixSec int64) {
|
||||
packedState := c.curState.Load()
|
||||
return _http.ConnState(packedState & 0xff), int64(packedState >> 8)
|
||||
}
|
||||
|
||||
func (c *conn) serve(ctx context.Context) {
|
||||
c.remoteAddr = c.rwc.RemoteAddr().String()
|
||||
ctx = context.WithValue(ctx, LocalAddrContextKey, c.rwc.LocalAddr())
|
||||
|
|
@ -309,6 +414,7 @@ func (c *conn) serve(ctx context.Context) {
|
|||
inFlightResponse.reqBody.Close()
|
||||
}
|
||||
c.close()
|
||||
c.setState(c.rwc, _http.StateClosed)
|
||||
}()
|
||||
|
||||
// HTTP/1.x from here on.
|
||||
|
|
@ -323,6 +429,10 @@ func (c *conn) serve(ctx context.Context) {
|
|||
|
||||
for {
|
||||
w, err := c.readRequest(ctx)
|
||||
if c.r.remain != c.server.initialReadLimitSize() {
|
||||
// If we read any bytes off the wire, we're active.
|
||||
c.setState(c.rwc, _http.StateActive)
|
||||
}
|
||||
if err != nil {
|
||||
const errorHeaders = "\r\nContent-Type: text/plain; charset=utf-8\r\nConnection: close\r\n\r\n"
|
||||
|
||||
|
|
@ -403,6 +513,8 @@ func (c *conn) serve(ctx context.Context) {
|
|||
}
|
||||
return
|
||||
}
|
||||
c.setState(c.rwc, _http.StateIdle)
|
||||
c.curReq.Store(nil)
|
||||
|
||||
if !w.conn.server.doKeepAlives() {
|
||||
// We're in shutdown mode. We might've replied
|
||||
|
|
@ -430,28 +542,31 @@ func (c *conn) serve(ctx context.Context) {
|
|||
}
|
||||
}
|
||||
|
||||
func (srv *Server) getDoneChan() <-chan struct{} {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
if srv.doneChan == nil {
|
||||
srv.doneChan = make(chan struct{})
|
||||
}
|
||||
|
||||
return srv.doneChan
|
||||
}
|
||||
|
||||
func (s *Server) doKeepAlives() bool {
|
||||
return atomic.LoadInt32(&s.disableKeepAlives) == 0 && !s.shuttingDown()
|
||||
}
|
||||
|
||||
func (s *Server) shuttingDown() bool {
|
||||
return s.inShutdown.isSet()
|
||||
return s.inShutdown.Load()
|
||||
}
|
||||
|
||||
func (srv *Server) initialReadLimitSize() int64 {
|
||||
return 1<<20 + 4096 // bufio slop
|
||||
}
|
||||
|
||||
func (s *Server) trackConn(c *conn, add bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.activeConn == nil {
|
||||
s.activeConn = make(map[*conn]struct{})
|
||||
}
|
||||
if add {
|
||||
s.activeConn[c] = struct{}{}
|
||||
} else {
|
||||
delete(s.activeConn, c)
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) idleTimeout() time.Duration {
|
||||
if srv.IdleTimeout != 0 {
|
||||
return srv.IdleTimeout
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user