wfc-server/nhttp/server.go

592 lines
16 KiB
Go

package nhttp
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
"log"
"math/rand"
"net"
_http "net/http"
"runtime"
"sync"
"sync/atomic"
"time"
)
// ErrServerClosed is returned by the Server's Serve, ServeTLS, ListenAndServe,
// and ListenAndServeTLS methods after a call to Shutdown or Close.
var ErrServerClosed = errors.New("nhttp: Server closed")
// contextKey is a value for use with context.WithValue. It's used as
// a pointer so it fits in an interface{} without allocation.
type contextKey struct {
name string
}
func (k *contextKey) String() string { return "net/nhttp context value " + k.name }
var (
// ServerContextKey is a context key. It can be used in HTTP
// handlers with Context.Value to access the server that
// started the handler. The associated value will be of
// type *Server.
ServerContextKey = &contextKey{"nhttp-server"}
// LocalAddrContextKey is a context key. It can be used in
// HTTP handlers with Context.Value to access the local
// address the connection arrived on.
// The associated value will be of type net.Addr.
LocalAddrContextKey = &contextKey{"local-addr"}
)
type Server struct {
Addr string
Handler _http.Handler
// BaseContext optionally specifies a function that returns
// the base context for incoming requests on this server.
// The provided Listener is the specific Listener that's
// about to start accepting requests.
// If BaseContext is nil, the default is context.Background().
// If non-nil, it must return a non-nil context.
BaseContext func(net.Listener) context.Context
mu sync.Mutex
disableKeepAlives int32 // accessed atomically.
// IdleTimeout is the maximum amount of time to wait for the
// next request when keep-alives are enabled. If IdleTimeout
// is zero, the value of ReadTimeout is used. If both are
// zero, there is no timeout.
IdleTimeout time.Duration
// ReadTimeout is the maximum duration for reading the entire
// request, including the body. A zero or negative value means
// there will be no timeout.
//
// Because ReadTimeout does not let Handlers make per-request
// decisions on each request body's acceptable deadline or
// upload rate, most users will prefer to use
// ReadHeaderTimeout. It is valid to use them both.
ReadTimeout time.Duration
inShutdown atomic.Bool // true when server is in shutdown
listeners map[*net.Listener]struct{}
activeConn map[*conn]struct{}
listenerGroup sync.WaitGroup
}
// conn represents the server side of an HTTP connection.
type conn struct {
// server is the server on which the connection arrived.
// Immutable; never nil.
server *Server
// cancelCtx cancels the connection-level context.
cancelCtx context.CancelFunc
// rwc is the underlying network connection.
// This is never wrapped by other types and is the value given out
// to CloseNotifier callers. It is usually of type *net.TCPConn or
// *tls.Conn.
rwc net.Conn
// remoteAddr is rwc.RemoteAddr().String(). It is not populated synchronously
// inside the Listener's Accept goroutine, as some implementations block.
// It is populated immediately inside the (*conn).serve goroutine.
// This is the value of a Handler's (*Request).RemoteAddr.
remoteAddr string
// tlsState is the TLS connection state when using TLS.
// nil means not TLS.
tlsState *tls.ConnectionState
// werr is set to the first write error to rwc.
// It is set via checkConnErrorWriter{w}, where bufw writes.
werr error
// r is bufr's read source. It's a wrapper around rwc that provides
// io.LimitedReader-style limiting (while reading request headers)
// and functionality to support CloseNotifier. See *connReader docs.
r *connReader
// bufr reads from r.
bufr *bufio.Reader
// bufw writes to checkConnErrorWriter{c}, which populates werr on error.
bufw *bufio.Writer
// lastMethod is the method of the most recent request
// on this connection, if any.
lastMethod string
curReq atomic.Pointer[response] // (which has a Request in it)
curState atomic.Uint64 // packed (unixtime<<8|uint8(ConnState))
mu sync.Mutex
}
// onceCloseListener wraps a net.Listener, protecting it from
// multiple Close calls.
type onceCloseListener struct {
net.Listener
once sync.Once
closeErr error
}
func (oc *onceCloseListener) Close() error {
oc.once.Do(oc.close)
return oc.closeErr
}
func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() }
const shutdownPollIntervalMax = 500 * time.Millisecond
// Shutdown gracefully shuts down the server without interrupting any
// active connections. Shutdown works by first closing all open
// listeners, then closing all idle connections, and then waiting
// indefinitely for connections to return to idle and then shut down.
// If the provided context expires before the shutdown is complete,
// Shutdown returns the context's error, otherwise it returns any
// error returned from closing the [Server]'s underlying Listener(s).
//
// When Shutdown is called, [Serve], [ListenAndServe], and
// [ListenAndServeTLS] immediately return [ErrServerClosed]. Make sure the
// program doesn't exit and waits instead for Shutdown to return.
//
// Once Shutdown has been called on a server, it may not be reused;
// future calls to methods such as Serve will return ErrServerClosed.
func (srv *Server) Shutdown(ctx context.Context) error {
srv.inShutdown.Store(true)
srv.mu.Lock()
lnerr := srv.closeListenersLocked()
srv.mu.Unlock()
srv.listenerGroup.Wait()
pollIntervalBase := time.Millisecond
nextPollInterval := func() time.Duration {
// Add 10% jitter.
interval := pollIntervalBase + time.Duration(rand.Intn(int(pollIntervalBase/10)))
// Double and clamp for next time.
pollIntervalBase *= 2
if pollIntervalBase > shutdownPollIntervalMax {
pollIntervalBase = shutdownPollIntervalMax
}
return interval
}
timer := time.NewTimer(nextPollInterval())
defer timer.Stop()
for {
if srv.closeIdleConns() {
return lnerr
}
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
timer.Reset(nextPollInterval())
}
}
}
// closeIdleConns closes all idle connections and reports whether the
// server is quiescent.
func (s *Server) closeIdleConns() bool {
s.mu.Lock()
defer s.mu.Unlock()
quiescent := true
for c := range s.activeConn {
st, unixSec := c.getState()
// Issue 22682: treat StateNew connections as if
// they're idle if we haven't read the first request's
// header in over 5 seconds.
if st == _http.StateNew && unixSec < time.Now().UTC().Unix()-5 {
st = _http.StateIdle
}
if st != _http.StateIdle || unixSec == 0 {
// Assume unixSec == 0 means it's a very new
// connection, without state set yet.
quiescent = false
continue
}
c.rwc.Close()
delete(s.activeConn, c)
}
return quiescent
}
func (s *Server) closeListenersLocked() error {
var err error
for ln := range s.listeners {
if cerr := (*ln).Close(); cerr != nil && err == nil {
err = cerr
}
}
return err
}
func (srv *Server) ListenAndServe() error {
if srv.shuttingDown() {
return ErrServerClosed
}
addr := srv.Addr
if addr == "" {
addr = ":nhttp"
}
ln, err := net.Listen("tcp", addr)
if err != nil {
return err
}
return srv.Serve(ln)
}
// ListenAndServe listens on the TCP network address addr and then calls
// Serve with handler to handle requests on incoming connections.
// Accepted connections are configured to enable TCP keep-alives.
//
// The handler is typically nil, in which case the DefaultServeMux is used.
//
// ListenAndServe always returns a non-nil error.
func ListenAndServe(addr string, handler _http.Handler) error {
server := &Server{Addr: addr, Handler: handler}
return server.ListenAndServe()
}
// trackListener adds or removes a net.Listener to the set of tracked
// listeners.
//
// We store a pointer to interface in the map set, in case the
// net.Listener is not comparable. This is safe because we only call
// trackListener via Serve and can track+defer untrack the same
// pointer to local variable there. We never need to compare a
// Listener from another caller.
//
// It reports whether the server is still up (not Shutdown or Closed).
func (s *Server) trackListener(ln *net.Listener, add bool) bool {
s.mu.Lock()
defer s.mu.Unlock()
if s.listeners == nil {
s.listeners = make(map[*net.Listener]struct{})
}
if add {
if s.shuttingDown() {
return false
}
s.listeners[ln] = struct{}{}
s.listenerGroup.Add(1)
} else {
delete(s.listeners, ln)
s.listenerGroup.Done()
}
return true
}
func (srv *Server) Serve(l net.Listener) error {
origListener := l
l = &onceCloseListener{Listener: l}
defer l.Close()
if !srv.trackListener(&l, true) {
return ErrServerClosed
}
defer srv.trackListener(&l, false)
baseCtx := context.Background()
if srv.BaseContext != nil {
baseCtx = srv.BaseContext(origListener)
if baseCtx == nil {
panic("BaseContext returned a nil context")
}
}
// How long to sleep on accept failure
var tempDelay time.Duration
ctx := context.WithValue(baseCtx, ServerContextKey, srv)
for {
rw, err := l.Accept()
if err != nil {
if srv.shuttingDown() {
return ErrServerClosed
}
if ne, ok := err.(net.Error); ok && ne.Temporary() {
if tempDelay == 0 {
tempDelay = 5 * time.Millisecond
} else {
tempDelay *= 2
}
if max := 1 * time.Second; tempDelay > max {
tempDelay = max
}
log.Printf("nhttp: Accept error: %v; retrying in %v", err, tempDelay)
time.Sleep(tempDelay)
continue
}
return err
}
connCtx := ctx
tempDelay = 0
c := &conn{
server: srv,
rwc: rw,
}
go c.serve(connCtx)
}
}
// ErrAbortHandler is a sentinel panic value to abort a handler.
// While any panic from ServeHTTP aborts the response to the client,
// panicking with ErrAbortHandler also suppresses logging of a stack
// trace to the server's error log.
var ErrAbortHandler = errors.New("net/nhttp: abort Handler")
// Close the connection.
func (c *conn) close() {
c.finalFlush()
c.rwc.Close()
}
func (c *conn) finalFlush() {
if c.bufr != nil {
// Steal the bufio.Reader (~4KB worth of memory) and its associated
// reader for a future connection.
putBufioReader(c.bufr)
c.bufr = nil
}
if c.bufw != nil {
c.bufw.Flush()
// Steal the bufio.Writer (~4KB worth of memory) and its associated
// writer for a future connection.
putBufioWriter(c.bufw)
c.bufw = nil
}
}
func (c *conn) setState(nc net.Conn, state _http.ConnState) {
srv := c.server
switch state {
case _http.StateNew:
srv.trackConn(c, true)
case _http.StateHijacked, _http.StateClosed:
srv.trackConn(c, false)
}
if state > 0xff || state < 0 {
panic("internal error")
}
packedState := uint64(time.Now().UTC().Unix()<<8) | uint64(state)
c.curState.Store(packedState)
}
func (c *conn) getState() (state _http.ConnState, unixSec int64) {
packedState := c.curState.Load()
return _http.ConnState(packedState & 0xff), int64(packedState >> 8)
}
func (c *conn) serve(ctx context.Context) {
c.remoteAddr = c.rwc.RemoteAddr().String()
ctx = context.WithValue(ctx, LocalAddrContextKey, c.rwc.LocalAddr())
var inFlightResponse *response
defer func() {
if err := recover(); err != nil && err != ErrAbortHandler {
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
log.Printf("nhttp: panic serving %v: %v\n%s\n", c.remoteAddr, err, buf)
}
if inFlightResponse != nil {
inFlightResponse.cancelCtx()
}
if inFlightResponse != nil {
inFlightResponse.conn.r.abortPendingRead()
inFlightResponse.reqBody.Close()
}
c.close()
c.setState(c.rwc, _http.StateClosed)
}()
// HTTP/1.x from here on.
ctx, cancelCtx := context.WithCancel(ctx)
c.cancelCtx = cancelCtx
defer cancelCtx()
c.r = &connReader{conn: c}
c.bufr = newBufioReader(c.r)
c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10)
for {
w, err := c.readRequest(ctx)
if c.r.remain != c.server.initialReadLimitSize() {
// If we read any bytes off the wire, we're active.
c.setState(c.rwc, _http.StateActive)
}
if err != nil {
const errorHeaders = "\r\nContent-Type: text/plain; charset=utf-8\r\nConnection: close\r\n\r\n"
switch {
case err == errTooLarge:
// Their HTTP client may or may not be
// able to read this if we're
// responding to them and hanging up
// while they're still writing their
// request. Undefined behavior.
const publicErr = "431 Request Header Fields Too Large"
fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr)
c.closeWriteAndWait()
return
case isUnsupportedTEError(err):
// Respond as per RFC 7230 Section 3.3.1 which says,
// A server that receives a request message with a
// transfer coding it does not understand SHOULD
// respond with 501 (Unimplemented).
code := _http.StatusNotImplemented
// We purposefully aren't echoing back the transfer-encoding's value,
// so as to mitigate the risk of cross side scripting by an attacker.
fmt.Fprintf(c.rwc, "HTTP/1.1 %d %s%sUnsupported transfer encoding", code, _http.StatusText(code), errorHeaders)
return
case isCommonNetReadError(err):
return // don't reply
default:
if v, ok := err.(statusError); ok {
fmt.Fprintf(c.rwc, "HTTP/1.1 %d %s: %s%s%d %s: %s", v.code, _http.StatusText(v.code), v.text, errorHeaders, v.code, _http.StatusText(v.code), v.text)
return
}
publicErr := "400 Bad Request"
fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr)
return
}
}
// Expect 100 Continue support
req := w.req
if expectsContinue(req) {
if req.ProtoAtLeast(1, 1) && req.ContentLength != 0 {
// Wrap the Body reader with one that replies on the connection
req.Body = &expectContinueReader{readCloser: req.Body, resp: w}
w.canWriteContinue.SetTrue()
}
} else if req.Header.Get("Expect") != "" {
w.sendExpectationFailed()
return
}
c.curReq.Store(w)
if requestBodyRemains(req.Body) {
registerOnHitEOF(req.Body, w.conn.r.startBackgroundRead)
} else {
w.conn.r.startBackgroundRead()
}
// HTTP cannot have multiple simultaneous active requests.[*]
// Until the server replies to this request, it can't read another,
// so we might as well run the handler in this goroutine.
// [*] Not strictly true: HTTP pipelining. We could let them all process
// in parallel even if their responses need to be serialized.
// But we're not going to implement HTTP pipelining because it
// was never deployed in the wild and the answer is HTTP/2.
inFlightResponse = w
serverHandler{c.server}.ServeHTTP(w, w.req)
inFlightResponse = nil
w.cancelCtx()
w.finishRequest()
if !w.shouldReuseConnection() {
if w.requestBodyLimitHit || w.closedRequestBodyEarly() {
c.closeWriteAndWait()
}
return
}
c.setState(c.rwc, _http.StateIdle)
c.curReq.Store(nil)
if !w.conn.server.doKeepAlives() {
// We're in shutdown mode. We might've replied
// to the user without "Connection: close" and
// they might think they can send another
// request, but such is life with HTTP/1.1.
return
}
if d := c.server.idleTimeout(); d != 0 {
c.rwc.SetReadDeadline(time.Now().UTC().Add(d))
if _, err := c.bufr.Peek(4); err != nil {
return
}
}
c.curReq.Store((*response)(nil))
if d := c.server.idleTimeout(); d != 0 {
c.rwc.SetReadDeadline(time.Now().UTC().Add(d))
if _, err := c.bufr.Peek(4); err != nil {
return
}
}
c.rwc.SetReadDeadline(time.Time{})
}
}
func (s *Server) doKeepAlives() bool {
return atomic.LoadInt32(&s.disableKeepAlives) == 0 && !s.shuttingDown()
}
func (s *Server) shuttingDown() bool {
return s.inShutdown.Load()
}
func (srv *Server) initialReadLimitSize() int64 {
return 1<<20 + 4096 // bufio slop
}
func (s *Server) trackConn(c *conn, add bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.activeConn == nil {
s.activeConn = make(map[*conn]struct{})
}
if add {
s.activeConn[c] = struct{}{}
} else {
delete(s.activeConn, c)
}
}
func (srv *Server) idleTimeout() time.Duration {
if srv.IdleTimeout != 0 {
return srv.IdleTimeout
}
return srv.ReadTimeout
}
// checkConnErrorWriter writes to c.rwc and records any write errors to c.werr.
// It only contains one field (and a pointer field at that), so it
// fits in an interface value without an extra allocation.
type checkConnErrorWriter struct {
c *conn
}
func (w checkConnErrorWriter) Write(p []byte) (n int, err error) {
n, err = w.c.rwc.Write(p)
if err != nil && w.c.werr == nil {
w.c.werr = err
w.c.cancelCtx()
}
return
}