mirror of
https://github.com/WiiLink24/wfc-server.git
synced 2026-03-21 17:44:58 -05:00
592 lines
16 KiB
Go
592 lines
16 KiB
Go
package nhttp
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"math/rand"
|
|
"net"
|
|
_http "net/http"
|
|
"runtime"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
// ErrServerClosed is returned by the Server's Serve, ServeTLS, ListenAndServe,
|
|
// and ListenAndServeTLS methods after a call to Shutdown or Close.
|
|
var ErrServerClosed = errors.New("nhttp: Server closed")
|
|
|
|
// contextKey is a value for use with context.WithValue. It's used as
|
|
// a pointer so it fits in an interface{} without allocation.
|
|
type contextKey struct {
|
|
name string
|
|
}
|
|
|
|
func (k *contextKey) String() string { return "net/nhttp context value " + k.name }
|
|
|
|
var (
|
|
// ServerContextKey is a context key. It can be used in HTTP
|
|
// handlers with Context.Value to access the server that
|
|
// started the handler. The associated value will be of
|
|
// type *Server.
|
|
ServerContextKey = &contextKey{"nhttp-server"}
|
|
|
|
// LocalAddrContextKey is a context key. It can be used in
|
|
// HTTP handlers with Context.Value to access the local
|
|
// address the connection arrived on.
|
|
// The associated value will be of type net.Addr.
|
|
LocalAddrContextKey = &contextKey{"local-addr"}
|
|
)
|
|
|
|
type Server struct {
|
|
Addr string
|
|
Handler _http.Handler
|
|
|
|
// BaseContext optionally specifies a function that returns
|
|
// the base context for incoming requests on this server.
|
|
// The provided Listener is the specific Listener that's
|
|
// about to start accepting requests.
|
|
// If BaseContext is nil, the default is context.Background().
|
|
// If non-nil, it must return a non-nil context.
|
|
BaseContext func(net.Listener) context.Context
|
|
|
|
mu sync.Mutex
|
|
|
|
disableKeepAlives int32 // accessed atomically.
|
|
|
|
// IdleTimeout is the maximum amount of time to wait for the
|
|
// next request when keep-alives are enabled. If IdleTimeout
|
|
// is zero, the value of ReadTimeout is used. If both are
|
|
// zero, there is no timeout.
|
|
IdleTimeout time.Duration
|
|
|
|
// ReadTimeout is the maximum duration for reading the entire
|
|
// request, including the body. A zero or negative value means
|
|
// there will be no timeout.
|
|
//
|
|
// Because ReadTimeout does not let Handlers make per-request
|
|
// decisions on each request body's acceptable deadline or
|
|
// upload rate, most users will prefer to use
|
|
// ReadHeaderTimeout. It is valid to use them both.
|
|
ReadTimeout time.Duration
|
|
|
|
inShutdown atomic.Bool // true when server is in shutdown
|
|
|
|
listeners map[*net.Listener]struct{}
|
|
|
|
activeConn map[*conn]struct{}
|
|
|
|
listenerGroup sync.WaitGroup
|
|
}
|
|
|
|
// conn represents the server side of an HTTP connection.
|
|
type conn struct {
|
|
// server is the server on which the connection arrived.
|
|
// Immutable; never nil.
|
|
server *Server
|
|
|
|
// cancelCtx cancels the connection-level context.
|
|
cancelCtx context.CancelFunc
|
|
|
|
// rwc is the underlying network connection.
|
|
// This is never wrapped by other types and is the value given out
|
|
// to CloseNotifier callers. It is usually of type *net.TCPConn or
|
|
// *tls.Conn.
|
|
rwc net.Conn
|
|
|
|
// remoteAddr is rwc.RemoteAddr().String(). It is not populated synchronously
|
|
// inside the Listener's Accept goroutine, as some implementations block.
|
|
// It is populated immediately inside the (*conn).serve goroutine.
|
|
// This is the value of a Handler's (*Request).RemoteAddr.
|
|
remoteAddr string
|
|
|
|
// tlsState is the TLS connection state when using TLS.
|
|
// nil means not TLS.
|
|
tlsState *tls.ConnectionState
|
|
|
|
// werr is set to the first write error to rwc.
|
|
// It is set via checkConnErrorWriter{w}, where bufw writes.
|
|
werr error
|
|
|
|
// r is bufr's read source. It's a wrapper around rwc that provides
|
|
// io.LimitedReader-style limiting (while reading request headers)
|
|
// and functionality to support CloseNotifier. See *connReader docs.
|
|
r *connReader
|
|
|
|
// bufr reads from r.
|
|
bufr *bufio.Reader
|
|
|
|
// bufw writes to checkConnErrorWriter{c}, which populates werr on error.
|
|
bufw *bufio.Writer
|
|
|
|
// lastMethod is the method of the most recent request
|
|
// on this connection, if any.
|
|
lastMethod string
|
|
|
|
curReq atomic.Pointer[response] // (which has a Request in it)
|
|
|
|
curState atomic.Uint64 // packed (unixtime<<8|uint8(ConnState))
|
|
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// onceCloseListener wraps a net.Listener, protecting it from
|
|
// multiple Close calls.
|
|
type onceCloseListener struct {
|
|
net.Listener
|
|
once sync.Once
|
|
closeErr error
|
|
}
|
|
|
|
func (oc *onceCloseListener) Close() error {
|
|
oc.once.Do(oc.close)
|
|
return oc.closeErr
|
|
}
|
|
|
|
func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() }
|
|
|
|
const shutdownPollIntervalMax = 500 * time.Millisecond
|
|
|
|
// Shutdown gracefully shuts down the server without interrupting any
|
|
// active connections. Shutdown works by first closing all open
|
|
// listeners, then closing all idle connections, and then waiting
|
|
// indefinitely for connections to return to idle and then shut down.
|
|
// If the provided context expires before the shutdown is complete,
|
|
// Shutdown returns the context's error, otherwise it returns any
|
|
// error returned from closing the [Server]'s underlying Listener(s).
|
|
//
|
|
// When Shutdown is called, [Serve], [ListenAndServe], and
|
|
// [ListenAndServeTLS] immediately return [ErrServerClosed]. Make sure the
|
|
// program doesn't exit and waits instead for Shutdown to return.
|
|
//
|
|
// Once Shutdown has been called on a server, it may not be reused;
|
|
// future calls to methods such as Serve will return ErrServerClosed.
|
|
func (srv *Server) Shutdown(ctx context.Context) error {
|
|
srv.inShutdown.Store(true)
|
|
|
|
srv.mu.Lock()
|
|
lnerr := srv.closeListenersLocked()
|
|
srv.mu.Unlock()
|
|
srv.listenerGroup.Wait()
|
|
|
|
pollIntervalBase := time.Millisecond
|
|
nextPollInterval := func() time.Duration {
|
|
// Add 10% jitter.
|
|
interval := pollIntervalBase + time.Duration(rand.Intn(int(pollIntervalBase/10)))
|
|
// Double and clamp for next time.
|
|
pollIntervalBase *= 2
|
|
if pollIntervalBase > shutdownPollIntervalMax {
|
|
pollIntervalBase = shutdownPollIntervalMax
|
|
}
|
|
return interval
|
|
}
|
|
|
|
timer := time.NewTimer(nextPollInterval())
|
|
defer timer.Stop()
|
|
for {
|
|
if srv.closeIdleConns() {
|
|
return lnerr
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-timer.C:
|
|
timer.Reset(nextPollInterval())
|
|
}
|
|
}
|
|
}
|
|
|
|
// closeIdleConns closes all idle connections and reports whether the
|
|
// server is quiescent.
|
|
func (s *Server) closeIdleConns() bool {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
quiescent := true
|
|
for c := range s.activeConn {
|
|
st, unixSec := c.getState()
|
|
// Issue 22682: treat StateNew connections as if
|
|
// they're idle if we haven't read the first request's
|
|
// header in over 5 seconds.
|
|
if st == _http.StateNew && unixSec < time.Now().UTC().Unix()-5 {
|
|
st = _http.StateIdle
|
|
}
|
|
if st != _http.StateIdle || unixSec == 0 {
|
|
// Assume unixSec == 0 means it's a very new
|
|
// connection, without state set yet.
|
|
quiescent = false
|
|
continue
|
|
}
|
|
c.rwc.Close()
|
|
delete(s.activeConn, c)
|
|
}
|
|
return quiescent
|
|
}
|
|
|
|
func (s *Server) closeListenersLocked() error {
|
|
var err error
|
|
for ln := range s.listeners {
|
|
if cerr := (*ln).Close(); cerr != nil && err == nil {
|
|
err = cerr
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (srv *Server) ListenAndServe() error {
|
|
if srv.shuttingDown() {
|
|
return ErrServerClosed
|
|
}
|
|
addr := srv.Addr
|
|
if addr == "" {
|
|
addr = ":nhttp"
|
|
}
|
|
ln, err := net.Listen("tcp", addr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return srv.Serve(ln)
|
|
}
|
|
|
|
// ListenAndServe listens on the TCP network address addr and then calls
|
|
// Serve with handler to handle requests on incoming connections.
|
|
// Accepted connections are configured to enable TCP keep-alives.
|
|
//
|
|
// The handler is typically nil, in which case the DefaultServeMux is used.
|
|
//
|
|
// ListenAndServe always returns a non-nil error.
|
|
func ListenAndServe(addr string, handler _http.Handler) error {
|
|
server := &Server{Addr: addr, Handler: handler}
|
|
return server.ListenAndServe()
|
|
}
|
|
|
|
// trackListener adds or removes a net.Listener to the set of tracked
|
|
// listeners.
|
|
//
|
|
// We store a pointer to interface in the map set, in case the
|
|
// net.Listener is not comparable. This is safe because we only call
|
|
// trackListener via Serve and can track+defer untrack the same
|
|
// pointer to local variable there. We never need to compare a
|
|
// Listener from another caller.
|
|
//
|
|
// It reports whether the server is still up (not Shutdown or Closed).
|
|
func (s *Server) trackListener(ln *net.Listener, add bool) bool {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.listeners == nil {
|
|
s.listeners = make(map[*net.Listener]struct{})
|
|
}
|
|
if add {
|
|
if s.shuttingDown() {
|
|
return false
|
|
}
|
|
s.listeners[ln] = struct{}{}
|
|
s.listenerGroup.Add(1)
|
|
} else {
|
|
delete(s.listeners, ln)
|
|
s.listenerGroup.Done()
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (srv *Server) Serve(l net.Listener) error {
|
|
origListener := l
|
|
l = &onceCloseListener{Listener: l}
|
|
defer l.Close()
|
|
|
|
if !srv.trackListener(&l, true) {
|
|
return ErrServerClosed
|
|
}
|
|
defer srv.trackListener(&l, false)
|
|
|
|
baseCtx := context.Background()
|
|
if srv.BaseContext != nil {
|
|
baseCtx = srv.BaseContext(origListener)
|
|
if baseCtx == nil {
|
|
panic("BaseContext returned a nil context")
|
|
}
|
|
}
|
|
|
|
// How long to sleep on accept failure
|
|
var tempDelay time.Duration
|
|
ctx := context.WithValue(baseCtx, ServerContextKey, srv)
|
|
for {
|
|
rw, err := l.Accept()
|
|
if err != nil {
|
|
if srv.shuttingDown() {
|
|
return ErrServerClosed
|
|
}
|
|
if ne, ok := err.(net.Error); ok && ne.Temporary() {
|
|
if tempDelay == 0 {
|
|
tempDelay = 5 * time.Millisecond
|
|
} else {
|
|
tempDelay *= 2
|
|
}
|
|
if max := 1 * time.Second; tempDelay > max {
|
|
tempDelay = max
|
|
}
|
|
log.Printf("nhttp: Accept error: %v; retrying in %v", err, tempDelay)
|
|
time.Sleep(tempDelay)
|
|
continue
|
|
}
|
|
return err
|
|
}
|
|
connCtx := ctx
|
|
tempDelay = 0
|
|
c := &conn{
|
|
server: srv,
|
|
rwc: rw,
|
|
}
|
|
|
|
go c.serve(connCtx)
|
|
}
|
|
}
|
|
|
|
// ErrAbortHandler is a sentinel panic value to abort a handler.
|
|
// While any panic from ServeHTTP aborts the response to the client,
|
|
// panicking with ErrAbortHandler also suppresses logging of a stack
|
|
// trace to the server's error log.
|
|
var ErrAbortHandler = errors.New("net/nhttp: abort Handler")
|
|
|
|
// Close the connection.
|
|
func (c *conn) close() {
|
|
c.finalFlush()
|
|
c.rwc.Close()
|
|
}
|
|
|
|
func (c *conn) finalFlush() {
|
|
if c.bufr != nil {
|
|
// Steal the bufio.Reader (~4KB worth of memory) and its associated
|
|
// reader for a future connection.
|
|
putBufioReader(c.bufr)
|
|
c.bufr = nil
|
|
}
|
|
|
|
if c.bufw != nil {
|
|
c.bufw.Flush()
|
|
// Steal the bufio.Writer (~4KB worth of memory) and its associated
|
|
// writer for a future connection.
|
|
putBufioWriter(c.bufw)
|
|
c.bufw = nil
|
|
}
|
|
}
|
|
|
|
func (c *conn) setState(nc net.Conn, state _http.ConnState) {
|
|
srv := c.server
|
|
switch state {
|
|
case _http.StateNew:
|
|
srv.trackConn(c, true)
|
|
case _http.StateHijacked, _http.StateClosed:
|
|
srv.trackConn(c, false)
|
|
}
|
|
if state > 0xff || state < 0 {
|
|
panic("internal error")
|
|
}
|
|
packedState := uint64(time.Now().UTC().Unix()<<8) | uint64(state)
|
|
c.curState.Store(packedState)
|
|
}
|
|
|
|
func (c *conn) getState() (state _http.ConnState, unixSec int64) {
|
|
packedState := c.curState.Load()
|
|
return _http.ConnState(packedState & 0xff), int64(packedState >> 8)
|
|
}
|
|
|
|
func (c *conn) serve(ctx context.Context) {
|
|
c.remoteAddr = c.rwc.RemoteAddr().String()
|
|
ctx = context.WithValue(ctx, LocalAddrContextKey, c.rwc.LocalAddr())
|
|
var inFlightResponse *response
|
|
defer func() {
|
|
if err := recover(); err != nil && err != ErrAbortHandler {
|
|
const size = 64 << 10
|
|
buf := make([]byte, size)
|
|
buf = buf[:runtime.Stack(buf, false)]
|
|
log.Printf("nhttp: panic serving %v: %v\n%s\n", c.remoteAddr, err, buf)
|
|
}
|
|
if inFlightResponse != nil {
|
|
inFlightResponse.cancelCtx()
|
|
}
|
|
|
|
if inFlightResponse != nil {
|
|
inFlightResponse.conn.r.abortPendingRead()
|
|
inFlightResponse.reqBody.Close()
|
|
}
|
|
c.close()
|
|
c.setState(c.rwc, _http.StateClosed)
|
|
}()
|
|
|
|
// HTTP/1.x from here on.
|
|
|
|
ctx, cancelCtx := context.WithCancel(ctx)
|
|
c.cancelCtx = cancelCtx
|
|
defer cancelCtx()
|
|
|
|
c.r = &connReader{conn: c}
|
|
c.bufr = newBufioReader(c.r)
|
|
c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10)
|
|
|
|
for {
|
|
w, err := c.readRequest(ctx)
|
|
if c.r.remain != c.server.initialReadLimitSize() {
|
|
// If we read any bytes off the wire, we're active.
|
|
c.setState(c.rwc, _http.StateActive)
|
|
}
|
|
if err != nil {
|
|
const errorHeaders = "\r\nContent-Type: text/plain; charset=utf-8\r\nConnection: close\r\n\r\n"
|
|
|
|
switch {
|
|
case err == errTooLarge:
|
|
// Their HTTP client may or may not be
|
|
// able to read this if we're
|
|
// responding to them and hanging up
|
|
// while they're still writing their
|
|
// request. Undefined behavior.
|
|
const publicErr = "431 Request Header Fields Too Large"
|
|
fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr)
|
|
c.closeWriteAndWait()
|
|
return
|
|
|
|
case isUnsupportedTEError(err):
|
|
// Respond as per RFC 7230 Section 3.3.1 which says,
|
|
// A server that receives a request message with a
|
|
// transfer coding it does not understand SHOULD
|
|
// respond with 501 (Unimplemented).
|
|
code := _http.StatusNotImplemented
|
|
|
|
// We purposefully aren't echoing back the transfer-encoding's value,
|
|
// so as to mitigate the risk of cross side scripting by an attacker.
|
|
fmt.Fprintf(c.rwc, "HTTP/1.1 %d %s%sUnsupported transfer encoding", code, _http.StatusText(code), errorHeaders)
|
|
return
|
|
|
|
case isCommonNetReadError(err):
|
|
return // don't reply
|
|
|
|
default:
|
|
if v, ok := err.(statusError); ok {
|
|
fmt.Fprintf(c.rwc, "HTTP/1.1 %d %s: %s%s%d %s: %s", v.code, _http.StatusText(v.code), v.text, errorHeaders, v.code, _http.StatusText(v.code), v.text)
|
|
return
|
|
}
|
|
publicErr := "400 Bad Request"
|
|
fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Expect 100 Continue support
|
|
req := w.req
|
|
if expectsContinue(req) {
|
|
if req.ProtoAtLeast(1, 1) && req.ContentLength != 0 {
|
|
// Wrap the Body reader with one that replies on the connection
|
|
req.Body = &expectContinueReader{readCloser: req.Body, resp: w}
|
|
w.canWriteContinue.SetTrue()
|
|
}
|
|
} else if req.Header.Get("Expect") != "" {
|
|
w.sendExpectationFailed()
|
|
return
|
|
}
|
|
|
|
c.curReq.Store(w)
|
|
|
|
if requestBodyRemains(req.Body) {
|
|
registerOnHitEOF(req.Body, w.conn.r.startBackgroundRead)
|
|
} else {
|
|
w.conn.r.startBackgroundRead()
|
|
}
|
|
|
|
// HTTP cannot have multiple simultaneous active requests.[*]
|
|
// Until the server replies to this request, it can't read another,
|
|
// so we might as well run the handler in this goroutine.
|
|
// [*] Not strictly true: HTTP pipelining. We could let them all process
|
|
// in parallel even if their responses need to be serialized.
|
|
// But we're not going to implement HTTP pipelining because it
|
|
// was never deployed in the wild and the answer is HTTP/2.
|
|
inFlightResponse = w
|
|
serverHandler{c.server}.ServeHTTP(w, w.req)
|
|
inFlightResponse = nil
|
|
w.cancelCtx()
|
|
w.finishRequest()
|
|
if !w.shouldReuseConnection() {
|
|
if w.requestBodyLimitHit || w.closedRequestBodyEarly() {
|
|
c.closeWriteAndWait()
|
|
}
|
|
return
|
|
}
|
|
c.setState(c.rwc, _http.StateIdle)
|
|
c.curReq.Store(nil)
|
|
|
|
if !w.conn.server.doKeepAlives() {
|
|
// We're in shutdown mode. We might've replied
|
|
// to the user without "Connection: close" and
|
|
// they might think they can send another
|
|
// request, but such is life with HTTP/1.1.
|
|
return
|
|
}
|
|
|
|
if d := c.server.idleTimeout(); d != 0 {
|
|
c.rwc.SetReadDeadline(time.Now().UTC().Add(d))
|
|
if _, err := c.bufr.Peek(4); err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
c.curReq.Store((*response)(nil))
|
|
if d := c.server.idleTimeout(); d != 0 {
|
|
c.rwc.SetReadDeadline(time.Now().UTC().Add(d))
|
|
if _, err := c.bufr.Peek(4); err != nil {
|
|
return
|
|
}
|
|
}
|
|
c.rwc.SetReadDeadline(time.Time{})
|
|
}
|
|
}
|
|
|
|
func (s *Server) doKeepAlives() bool {
|
|
return atomic.LoadInt32(&s.disableKeepAlives) == 0 && !s.shuttingDown()
|
|
}
|
|
|
|
func (s *Server) shuttingDown() bool {
|
|
return s.inShutdown.Load()
|
|
}
|
|
|
|
func (srv *Server) initialReadLimitSize() int64 {
|
|
return 1<<20 + 4096 // bufio slop
|
|
}
|
|
|
|
func (s *Server) trackConn(c *conn, add bool) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.activeConn == nil {
|
|
s.activeConn = make(map[*conn]struct{})
|
|
}
|
|
if add {
|
|
s.activeConn[c] = struct{}{}
|
|
} else {
|
|
delete(s.activeConn, c)
|
|
}
|
|
}
|
|
|
|
func (srv *Server) idleTimeout() time.Duration {
|
|
if srv.IdleTimeout != 0 {
|
|
return srv.IdleTimeout
|
|
}
|
|
return srv.ReadTimeout
|
|
}
|
|
|
|
// checkConnErrorWriter writes to c.rwc and records any write errors to c.werr.
|
|
// It only contains one field (and a pointer field at that), so it
|
|
// fits in an interface value without an extra allocation.
|
|
type checkConnErrorWriter struct {
|
|
c *conn
|
|
}
|
|
|
|
func (w checkConnErrorWriter) Write(p []byte) (n int, err error) {
|
|
n, err = w.c.rwc.Write(p)
|
|
if err != nil && w.c.werr == nil {
|
|
w.c.werr = err
|
|
w.c.cancelCtx()
|
|
}
|
|
return
|
|
}
|