package nhttp import ( "bufio" "context" "crypto/tls" "errors" "fmt" "log" "math/rand" "net" _http "net/http" "runtime" "sync" "sync/atomic" "time" ) // ErrServerClosed is returned by the Server's Serve, ServeTLS, ListenAndServe, // and ListenAndServeTLS methods after a call to Shutdown or Close. var ErrServerClosed = errors.New("nhttp: Server closed") // contextKey is a value for use with context.WithValue. It's used as // a pointer so it fits in an interface{} without allocation. type contextKey struct { name string } func (k *contextKey) String() string { return "net/nhttp context value " + k.name } var ( // ServerContextKey is a context key. It can be used in HTTP // handlers with Context.Value to access the server that // started the handler. The associated value will be of // type *Server. ServerContextKey = &contextKey{"nhttp-server"} // LocalAddrContextKey is a context key. It can be used in // HTTP handlers with Context.Value to access the local // address the connection arrived on. // The associated value will be of type net.Addr. LocalAddrContextKey = &contextKey{"local-addr"} ) type Server struct { Addr string Handler _http.Handler // BaseContext optionally specifies a function that returns // the base context for incoming requests on this server. // The provided Listener is the specific Listener that's // about to start accepting requests. // If BaseContext is nil, the default is context.Background(). // If non-nil, it must return a non-nil context. BaseContext func(net.Listener) context.Context mu sync.Mutex disableKeepAlives int32 // accessed atomically. // IdleTimeout is the maximum amount of time to wait for the // next request when keep-alives are enabled. If IdleTimeout // is zero, the value of ReadTimeout is used. If both are // zero, there is no timeout. IdleTimeout time.Duration // ReadTimeout is the maximum duration for reading the entire // request, including the body. A zero or negative value means // there will be no timeout. // // Because ReadTimeout does not let Handlers make per-request // decisions on each request body's acceptable deadline or // upload rate, most users will prefer to use // ReadHeaderTimeout. It is valid to use them both. ReadTimeout time.Duration inShutdown atomic.Bool // true when server is in shutdown listeners map[*net.Listener]struct{} activeConn map[*conn]struct{} listenerGroup sync.WaitGroup } // conn represents the server side of an HTTP connection. type conn struct { // server is the server on which the connection arrived. // Immutable; never nil. server *Server // cancelCtx cancels the connection-level context. cancelCtx context.CancelFunc // rwc is the underlying network connection. // This is never wrapped by other types and is the value given out // to CloseNotifier callers. It is usually of type *net.TCPConn or // *tls.Conn. rwc net.Conn // remoteAddr is rwc.RemoteAddr().String(). It is not populated synchronously // inside the Listener's Accept goroutine, as some implementations block. // It is populated immediately inside the (*conn).serve goroutine. // This is the value of a Handler's (*Request).RemoteAddr. remoteAddr string // tlsState is the TLS connection state when using TLS. // nil means not TLS. tlsState *tls.ConnectionState // werr is set to the first write error to rwc. // It is set via checkConnErrorWriter{w}, where bufw writes. werr error // r is bufr's read source. It's a wrapper around rwc that provides // io.LimitedReader-style limiting (while reading request headers) // and functionality to support CloseNotifier. See *connReader docs. r *connReader // bufr reads from r. bufr *bufio.Reader // bufw writes to checkConnErrorWriter{c}, which populates werr on error. bufw *bufio.Writer // lastMethod is the method of the most recent request // on this connection, if any. lastMethod string curReq atomic.Pointer[response] // (which has a Request in it) curState atomic.Uint64 // packed (unixtime<<8|uint8(ConnState)) mu sync.Mutex } // onceCloseListener wraps a net.Listener, protecting it from // multiple Close calls. type onceCloseListener struct { net.Listener once sync.Once closeErr error } func (oc *onceCloseListener) Close() error { oc.once.Do(oc.close) return oc.closeErr } func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() } const shutdownPollIntervalMax = 500 * time.Millisecond // Shutdown gracefully shuts down the server without interrupting any // active connections. Shutdown works by first closing all open // listeners, then closing all idle connections, and then waiting // indefinitely for connections to return to idle and then shut down. // If the provided context expires before the shutdown is complete, // Shutdown returns the context's error, otherwise it returns any // error returned from closing the [Server]'s underlying Listener(s). // // When Shutdown is called, [Serve], [ListenAndServe], and // [ListenAndServeTLS] immediately return [ErrServerClosed]. Make sure the // program doesn't exit and waits instead for Shutdown to return. // // Once Shutdown has been called on a server, it may not be reused; // future calls to methods such as Serve will return ErrServerClosed. func (srv *Server) Shutdown(ctx context.Context) error { srv.inShutdown.Store(true) srv.mu.Lock() lnerr := srv.closeListenersLocked() srv.mu.Unlock() srv.listenerGroup.Wait() pollIntervalBase := time.Millisecond nextPollInterval := func() time.Duration { // Add 10% jitter. interval := pollIntervalBase + time.Duration(rand.Intn(int(pollIntervalBase/10))) // Double and clamp for next time. pollIntervalBase *= 2 if pollIntervalBase > shutdownPollIntervalMax { pollIntervalBase = shutdownPollIntervalMax } return interval } timer := time.NewTimer(nextPollInterval()) defer timer.Stop() for { if srv.closeIdleConns() { return lnerr } select { case <-ctx.Done(): return ctx.Err() case <-timer.C: timer.Reset(nextPollInterval()) } } } // closeIdleConns closes all idle connections and reports whether the // server is quiescent. func (s *Server) closeIdleConns() bool { s.mu.Lock() defer s.mu.Unlock() quiescent := true for c := range s.activeConn { st, unixSec := c.getState() // Issue 22682: treat StateNew connections as if // they're idle if we haven't read the first request's // header in over 5 seconds. if st == _http.StateNew && unixSec < time.Now().UTC().Unix()-5 { st = _http.StateIdle } if st != _http.StateIdle || unixSec == 0 { // Assume unixSec == 0 means it's a very new // connection, without state set yet. quiescent = false continue } c.rwc.Close() delete(s.activeConn, c) } return quiescent } func (s *Server) closeListenersLocked() error { var err error for ln := range s.listeners { if cerr := (*ln).Close(); cerr != nil && err == nil { err = cerr } } return err } func (srv *Server) ListenAndServe() error { if srv.shuttingDown() { return ErrServerClosed } addr := srv.Addr if addr == "" { addr = ":nhttp" } ln, err := net.Listen("tcp", addr) if err != nil { return err } return srv.Serve(ln) } // ListenAndServe listens on the TCP network address addr and then calls // Serve with handler to handle requests on incoming connections. // Accepted connections are configured to enable TCP keep-alives. // // The handler is typically nil, in which case the DefaultServeMux is used. // // ListenAndServe always returns a non-nil error. func ListenAndServe(addr string, handler _http.Handler) error { server := &Server{Addr: addr, Handler: handler} return server.ListenAndServe() } // trackListener adds or removes a net.Listener to the set of tracked // listeners. // // We store a pointer to interface in the map set, in case the // net.Listener is not comparable. This is safe because we only call // trackListener via Serve and can track+defer untrack the same // pointer to local variable there. We never need to compare a // Listener from another caller. // // It reports whether the server is still up (not Shutdown or Closed). func (s *Server) trackListener(ln *net.Listener, add bool) bool { s.mu.Lock() defer s.mu.Unlock() if s.listeners == nil { s.listeners = make(map[*net.Listener]struct{}) } if add { if s.shuttingDown() { return false } s.listeners[ln] = struct{}{} s.listenerGroup.Add(1) } else { delete(s.listeners, ln) s.listenerGroup.Done() } return true } func (srv *Server) Serve(l net.Listener) error { origListener := l l = &onceCloseListener{Listener: l} defer l.Close() if !srv.trackListener(&l, true) { return ErrServerClosed } defer srv.trackListener(&l, false) baseCtx := context.Background() if srv.BaseContext != nil { baseCtx = srv.BaseContext(origListener) if baseCtx == nil { panic("BaseContext returned a nil context") } } // How long to sleep on accept failure var tempDelay time.Duration ctx := context.WithValue(baseCtx, ServerContextKey, srv) for { rw, err := l.Accept() if err != nil { if srv.shuttingDown() { return ErrServerClosed } if ne, ok := err.(net.Error); ok && ne.Temporary() { if tempDelay == 0 { tempDelay = 5 * time.Millisecond } else { tempDelay *= 2 } if max := 1 * time.Second; tempDelay > max { tempDelay = max } log.Printf("nhttp: Accept error: %v; retrying in %v", err, tempDelay) time.Sleep(tempDelay) continue } return err } connCtx := ctx tempDelay = 0 c := &conn{ server: srv, rwc: rw, } go c.serve(connCtx) } } // ErrAbortHandler is a sentinel panic value to abort a handler. // While any panic from ServeHTTP aborts the response to the client, // panicking with ErrAbortHandler also suppresses logging of a stack // trace to the server's error log. var ErrAbortHandler = errors.New("net/nhttp: abort Handler") // Close the connection. func (c *conn) close() { c.finalFlush() c.rwc.Close() } func (c *conn) finalFlush() { if c.bufr != nil { // Steal the bufio.Reader (~4KB worth of memory) and its associated // reader for a future connection. putBufioReader(c.bufr) c.bufr = nil } if c.bufw != nil { c.bufw.Flush() // Steal the bufio.Writer (~4KB worth of memory) and its associated // writer for a future connection. putBufioWriter(c.bufw) c.bufw = nil } } func (c *conn) setState(nc net.Conn, state _http.ConnState) { srv := c.server switch state { case _http.StateNew: srv.trackConn(c, true) case _http.StateHijacked, _http.StateClosed: srv.trackConn(c, false) } if state > 0xff || state < 0 { panic("internal error") } packedState := uint64(time.Now().UTC().Unix()<<8) | uint64(state) c.curState.Store(packedState) } func (c *conn) getState() (state _http.ConnState, unixSec int64) { packedState := c.curState.Load() return _http.ConnState(packedState & 0xff), int64(packedState >> 8) } func (c *conn) serve(ctx context.Context) { c.remoteAddr = c.rwc.RemoteAddr().String() ctx = context.WithValue(ctx, LocalAddrContextKey, c.rwc.LocalAddr()) var inFlightResponse *response defer func() { if err := recover(); err != nil && err != ErrAbortHandler { const size = 64 << 10 buf := make([]byte, size) buf = buf[:runtime.Stack(buf, false)] log.Printf("nhttp: panic serving %v: %v\n%s\n", c.remoteAddr, err, buf) } if inFlightResponse != nil { inFlightResponse.cancelCtx() } if inFlightResponse != nil { inFlightResponse.conn.r.abortPendingRead() inFlightResponse.reqBody.Close() } c.close() c.setState(c.rwc, _http.StateClosed) }() // HTTP/1.x from here on. ctx, cancelCtx := context.WithCancel(ctx) c.cancelCtx = cancelCtx defer cancelCtx() c.r = &connReader{conn: c} c.bufr = newBufioReader(c.r) c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10) for { w, err := c.readRequest(ctx) if c.r.remain != c.server.initialReadLimitSize() { // If we read any bytes off the wire, we're active. c.setState(c.rwc, _http.StateActive) } if err != nil { const errorHeaders = "\r\nContent-Type: text/plain; charset=utf-8\r\nConnection: close\r\n\r\n" switch { case err == errTooLarge: // Their HTTP client may or may not be // able to read this if we're // responding to them and hanging up // while they're still writing their // request. Undefined behavior. const publicErr = "431 Request Header Fields Too Large" fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr) c.closeWriteAndWait() return case isUnsupportedTEError(err): // Respond as per RFC 7230 Section 3.3.1 which says, // A server that receives a request message with a // transfer coding it does not understand SHOULD // respond with 501 (Unimplemented). code := _http.StatusNotImplemented // We purposefully aren't echoing back the transfer-encoding's value, // so as to mitigate the risk of cross side scripting by an attacker. fmt.Fprintf(c.rwc, "HTTP/1.1 %d %s%sUnsupported transfer encoding", code, _http.StatusText(code), errorHeaders) return case isCommonNetReadError(err): return // don't reply default: if v, ok := err.(statusError); ok { fmt.Fprintf(c.rwc, "HTTP/1.1 %d %s: %s%s%d %s: %s", v.code, _http.StatusText(v.code), v.text, errorHeaders, v.code, _http.StatusText(v.code), v.text) return } publicErr := "400 Bad Request" fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr) return } } // Expect 100 Continue support req := w.req if expectsContinue(req) { if req.ProtoAtLeast(1, 1) && req.ContentLength != 0 { // Wrap the Body reader with one that replies on the connection req.Body = &expectContinueReader{readCloser: req.Body, resp: w} w.canWriteContinue.SetTrue() } } else if req.Header.Get("Expect") != "" { w.sendExpectationFailed() return } c.curReq.Store(w) if requestBodyRemains(req.Body) { registerOnHitEOF(req.Body, w.conn.r.startBackgroundRead) } else { w.conn.r.startBackgroundRead() } // HTTP cannot have multiple simultaneous active requests.[*] // Until the server replies to this request, it can't read another, // so we might as well run the handler in this goroutine. // [*] Not strictly true: HTTP pipelining. We could let them all process // in parallel even if their responses need to be serialized. // But we're not going to implement HTTP pipelining because it // was never deployed in the wild and the answer is HTTP/2. inFlightResponse = w serverHandler{c.server}.ServeHTTP(w, w.req) inFlightResponse = nil w.cancelCtx() w.finishRequest() if !w.shouldReuseConnection() { if w.requestBodyLimitHit || w.closedRequestBodyEarly() { c.closeWriteAndWait() } return } c.setState(c.rwc, _http.StateIdle) c.curReq.Store(nil) if !w.conn.server.doKeepAlives() { // We're in shutdown mode. We might've replied // to the user without "Connection: close" and // they might think they can send another // request, but such is life with HTTP/1.1. return } if d := c.server.idleTimeout(); d != 0 { c.rwc.SetReadDeadline(time.Now().UTC().Add(d)) if _, err := c.bufr.Peek(4); err != nil { return } } c.curReq.Store((*response)(nil)) if d := c.server.idleTimeout(); d != 0 { c.rwc.SetReadDeadline(time.Now().UTC().Add(d)) if _, err := c.bufr.Peek(4); err != nil { return } } c.rwc.SetReadDeadline(time.Time{}) } } func (s *Server) doKeepAlives() bool { return atomic.LoadInt32(&s.disableKeepAlives) == 0 && !s.shuttingDown() } func (s *Server) shuttingDown() bool { return s.inShutdown.Load() } func (srv *Server) initialReadLimitSize() int64 { return 1<<20 + 4096 // bufio slop } func (s *Server) trackConn(c *conn, add bool) { s.mu.Lock() defer s.mu.Unlock() if s.activeConn == nil { s.activeConn = make(map[*conn]struct{}) } if add { s.activeConn[c] = struct{}{} } else { delete(s.activeConn, c) } } func (srv *Server) idleTimeout() time.Duration { if srv.IdleTimeout != 0 { return srv.IdleTimeout } return srv.ReadTimeout } // checkConnErrorWriter writes to c.rwc and records any write errors to c.werr. // It only contains one field (and a pointer field at that), so it // fits in an interface value without an extra allocation. type checkConnErrorWriter struct { c *conn } func (w checkConnErrorWriter) Write(p []byte) (n int, err error) { n, err = w.c.rwc.Write(p) if err != nil && w.c.werr == nil { w.c.werr = err w.c.cancelCtx() } return }