diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..187ac4e --- /dev/null +++ b/go.mod @@ -0,0 +1,10 @@ +module wwfc + +go 1.19 + +require ( + github.com/logrusorgru/aurora/v3 v3.0.0 + golang.org/x/net v0.10.0 +) + +require golang.org/x/text v0.9.0 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..871e4ec --- /dev/null +++ b/go.sum @@ -0,0 +1,6 @@ +github.com/logrusorgru/aurora/v3 v3.0.0 h1:R6zcoZZbvVcGMvDCKo45A9U/lzYyzl5NfYIvznmDfE4= +github.com/logrusorgru/aurora/v3 v3.0.0/go.mod h1:vsR12bk5grlLvLXAYrBsb5Oc/N+LxAlxggSjiwMnCUc= +golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= diff --git a/nhttp/ascii.go b/nhttp/ascii.go new file mode 100644 index 0000000..677fabe --- /dev/null +++ b/nhttp/ascii.go @@ -0,0 +1,138 @@ +package nhttp + +import ( + "bufio" + "bytes" + "net/textproto" + "time" +) + +var ( + singleCRLF = []byte("\r\n") + doubleCRLF = []byte("\r\n\r\n") +) + +var ( + crlf = []byte("\r\n") + colonSpace = []byte(": ") +) + +func seeUpcomingDoubleCRLF(r *bufio.Reader) bool { + for peekSize := 4; ; peekSize++ { + // This loop stops when Peek returns an error, + // which it does when r's buffer has been filled. + buf, err := r.Peek(peekSize) + if bytes.HasSuffix(buf, doubleCRLF) { + return true + } + if err != nil { + break + } + } + return false +} + +func numLeadingCRorLF(v []byte) (n int) { + for _, b := range v { + if b == '\r' || b == '\n' { + n++ + continue + } + break + } + return + +} + +// EqualFold is strings.EqualFold, ASCII only. It reports whether s and t +// are equal, ASCII-case-insensitively. +func EqualFold(s, t string) bool { + if len(s) != len(t) { + return false + } + for i := 0; i < len(s); i++ { + if lower(s[i]) != lower(t[i]) { + return false + } + } + return true +} + +// lower returns the ASCII lowercase version of b. +func lower(b byte) byte { + if 'A' <= b && b <= 'Z' { + return b + ('a' - 'A') + } + return b +} + +// CanonicalHeaderKey returns the canonical format of the +// header key s. The canonicalization converts the first +// letter and any letter following a hyphen to upper case; +// the rest are converted to lowercase. For example, the +// canonical key for "accept-encoding" is "Accept-Encoding". +// If s contains a space or invalid header field bytes, it is +// returned without modifications. +func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } + +// hasToken reports whether token appears with v, ASCII +// case-insensitive, with space or comma boundaries. +// token must be all lowercase. +// v may contain mixed cased. +func hasToken(v, token string) bool { + if len(token) > len(v) || token == "" { + return false + } + if v == token { + return true + } + for sp := 0; sp <= len(v)-len(token); sp++ { + // Check that first character is good. + // The token is ASCII, so checking only a single byte + // is sufficient. We skip this potential starting + // position if both the first byte and its potential + // ASCII uppercase equivalent (b|0x20) don't match. + // False positives ('^' => '~') are caught by EqualFold. + if b := v[sp]; b != token[0] && b|0x20 != token[0] { + continue + } + // Check that start pos is on a valid token boundary. + if sp > 0 && !isTokenBoundary(v[sp-1]) { + continue + } + // Check that end pos is on a valid token boundary. + if endPos := sp + len(token); endPos != len(v) && !isTokenBoundary(v[endPos]) { + continue + } + if EqualFold(v[sp:sp+len(token)], token) { + return true + } + } + return false +} + +func isTokenBoundary(b byte) bool { + return b == ' ' || b == ',' || b == '\t' +} + +// appendTime is a non-allocating version of []byte(t.UTC().Format(TimeFormat)) +func appendTime(b []byte, t time.Time) []byte { + const days = "SunMonTueWedThuFriSat" + const months = "JanFebMarAprMayJunJulAugSepOctNovDec" + + t = t.UTC() + yy, mm, dd := t.Date() + hh, mn, ss := t.Clock() + day := days[3*t.Weekday():] + mon := months[3*(mm-1):] + + return append(b, + day[0], day[1], day[2], ',', ' ', + byte('0'+dd/10), byte('0'+dd%10), ' ', + mon[0], mon[1], mon[2], ' ', + byte('0'+yy/1000), byte('0'+(yy/100)%10), byte('0'+(yy/10)%10), byte('0'+yy%10), ' ', + byte('0'+hh/10), byte('0'+hh%10), ':', + byte('0'+mn/10), byte('0'+mn%10), ':', + byte('0'+ss/10), byte('0'+ss%10), ' ', + 'G', 'M', 'T') +} diff --git a/nhttp/buffer.go b/nhttp/buffer.go new file mode 100644 index 0000000..b8ae130 --- /dev/null +++ b/nhttp/buffer.go @@ -0,0 +1,583 @@ +package nhttp + +import ( + "bufio" + "errors" + "fmt" + "golang.org/x/net/http/httpguts" + "io" + "log" + _http "net/http" + "net/textproto" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +// NoBody is an io.ReadCloser with no bytes. Read always returns EOF +// and Close always returns nil. It can be used in an outgoing client +// request to explicitly signal that a request has zero bytes. +// An alternative, however, is to simply set Request.Body to nil. +var NoBody = noBody{} + +type noBody struct{} + +func (noBody) Read([]byte) (int, error) { return 0, io.EOF } +func (noBody) Close() error { return nil } +func (noBody) WriteTo(io.Writer) (int64, error) { return 0, nil } + +type atomicBool int32 + +func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 } +func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) } +func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) } + +var textprotoReaderPool sync.Pool + +func newTextprotoReader(br *bufio.Reader) *textproto.Reader { + if v := textprotoReaderPool.Get(); v != nil { + tr := v.(*textproto.Reader) + tr.R = br + return tr + } + return textproto.NewReader(br) +} + +func putTextprotoReader(r *textproto.Reader) { + r.R = nil + textprotoReaderPool.Put(r) +} + +// foreachHeaderElement splits v according to the "#rule" construction +// in RFC 7230 section 7 and calls fn for each non-empty element. +func foreachHeaderElement(v string, fn func(string)) { + v = textproto.TrimString(v) + if v == "" { + return + } + if !strings.Contains(v, ",") { + fn(v) + return + } + for _, f := range strings.Split(v, ",") { + if f = textproto.TrimString(f); f != "" { + fn(f) + } + } +} + +var ( + bufioReaderPool sync.Pool + bufioWriter2kPool sync.Pool + bufioWriter4kPool sync.Pool +) + +var copyBufPool = sync.Pool{ + New: func() any { + b := make([]byte, 32*1024) + return &b + }, +} + +func bufioWriterPool(size int) *sync.Pool { + switch size { + case 2 << 10: + return &bufioWriter2kPool + case 4 << 10: + return &bufioWriter4kPool + } + return nil +} + +func newBufioWriterSize(w io.Writer, size int) *bufio.Writer { + pool := bufioWriterPool(size) + if pool != nil { + if v := pool.Get(); v != nil { + bw := v.(*bufio.Writer) + bw.Reset(w) + return bw + } + } + return bufio.NewWriterSize(w, size) +} + +func newBufioReader(r io.Reader) *bufio.Reader { + if v := bufioReaderPool.Get(); v != nil { + br := v.(*bufio.Reader) + br.Reset(r) + return br + } + // Note: if this reader size is ever changed, update + // TestHandlerBodyClose's assumptions. + return bufio.NewReader(r) +} + +func putBufioReader(br *bufio.Reader) { + br.Reset(nil) + bufioReaderPool.Put(br) +} + +func putBufioWriter(bw *bufio.Writer) { + bw.Reset(nil) + if pool := bufioWriterPool(bw.Available()); pool != nil { + pool.Put(bw) + } +} + +func (cw *chunkWriter) Write(p []byte) (n int, err error) { + if !cw.wroteHeader { + cw.writeHeader(nil) + } + if cw.res.req.Method == "HEAD" { + // Eat writes. + return len(p), nil + } + if cw.chunking { + _, err = fmt.Fprintf(cw.res.conn.bufw, "%x\r\n", len(p)) + if err != nil { + cw.res.conn.rwc.Close() + return + } + } + n, err = cw.res.conn.bufw.Write(p) + if cw.chunking && err == nil { + _, err = cw.res.conn.bufw.Write(crlf) + } + if err != nil { + cw.res.conn.rwc.Close() + } + return +} + +func (cw *chunkWriter) flush() { + if !cw.wroteHeader { + cw.writeHeader(nil) + } + cw.res.conn.bufw.Flush() +} + +func (cw *chunkWriter) close() { + if !cw.wroteHeader { + cw.writeHeader(nil) + } + if cw.chunking { + bw := cw.res.conn.bufw // conn's bufio writer + // zero chunk to mark EOF + bw.WriteString("0\r\n") + if trailers := cw.res.finalTrailers(); trailers != nil { + trailers.Write(bw) // the writer handles noting errors + } + // final blank line after the trailers (whether + // present or not) + bw.WriteString("\r\n") + } +} + +// wrapper around io.ReadCloser which on first read, sends an +// HTTP/1.1 100 Continue header +type expectContinueReader struct { + resp *response + readCloser io.ReadCloser + closed atomicBool + sawEOF atomicBool +} + +func (ecr *expectContinueReader) Read(p []byte) (n int, err error) { + if ecr.closed.isSet() { + return 0, errors.New("nhttp: invalid Read on closed Body") + } + w := ecr.resp + if !w.wroteContinue && w.canWriteContinue.isSet() { + w.wroteContinue = true + w.writeContinueMu.Lock() + if w.canWriteContinue.isSet() { + w.conn.bufw.WriteString("HTTP/1.1 100 Continue\r\n\r\n") + w.conn.bufw.Flush() + w.canWriteContinue.setFalse() + } + w.writeContinueMu.Unlock() + } + n, err = ecr.readCloser.Read(p) + if err == io.EOF { + ecr.sawEOF.setTrue() + } + return +} + +func (ecr *expectContinueReader) Close() error { + ecr.closed.setTrue() + return ecr.readCloser.Close() +} + +// writeHeader finalizes the header sent to the client and writes it +// to cw.res.conn.bufw. +// +// p is not written by writeHeader, but is the first chunk of the body +// that will be written. It is sniffed for a Content-Type if none is +// set explicitly. It's also used to set the Content-Length, if the +// total body size was small and the handler has already finished +// running. +func (cw *chunkWriter) writeHeader(p []byte) { + if cw.wroteHeader { + return + } + cw.wroteHeader = true + + w := cw.res + keepAlivesEnabled := w.conn.server.doKeepAlives() + isHEAD := w.req.Method == "HEAD" + + // header is written out to w.conn.buf below. Depending on the + // state of the handler, we either own the map or not. If we + // don't own it, the exclude map is created lazily for + // WriteSubset to remove headers. The setHeader struct holds + // headers we need to add. + header := cw.header + owned := header != nil + if !owned { + header = w.handlerHeader + } + var excludeHeader map[string]bool + delHeader := func(key string) { + if owned { + header.Del(key) + return + } + if _, ok := header[key]; !ok { + return + } + if excludeHeader == nil { + excludeHeader = make(map[string]bool) + } + excludeHeader[key] = true + } + var setHeader extraHeader + + // Don't write out the fake "Trailer:foo" keys. See TrailerPrefix. + trailers := false + for k := range cw.header { + if strings.HasPrefix(k, TrailerPrefix) { + if excludeHeader == nil { + excludeHeader = make(map[string]bool) + } + excludeHeader[k] = true + trailers = true + } + } + for _, v := range cw.header["Trailer"] { + trailers = true + foreachHeaderElement(v, cw.res.declareTrailer) + } + + te := header.Get("Transfer-Encoding") + hasTE := te != "" + + // If the handler is done but never sent a Content-Length + // response header and this is our first (and last) write, set + // it, even to zero. This helps HTTP/1.0 clients keep their + // "keep-alive" connections alive. + // Exceptions: 304/204/1xx responses never get Content-Length, and if + // it was a HEAD request, we don't know the difference between + // 0 actual bytes and 0 bytes because the handler noticed it + // was a HEAD request and chose not to write anything. So for + // HEAD, the handler should either write the Content-Length or + // write non-zero bytes. If it's actually 0 bytes and the + // handler never looked at the Request.Method, we just don't + // send a Content-Length header. + // Further, we don't send an automatic Content-Length if they + // set a Transfer-Encoding, because they're generally incompatible. + if w.handlerDone.isSet() && !trailers && !hasTE && bodyAllowedForStatus(w.status) && header.Get("Content-Length") == "" && (!isHEAD || len(p) > 0) { + w.contentLength = int64(len(p)) + setHeader.contentLength = strconv.AppendInt(cw.res.clenBuf[:0], int64(len(p)), 10) + } + + // If this was an HTTP/1.0 request with keep-alive and we sent a + // Content-Length back, we can make this a keep-alive response ... + if w.wants10KeepAlive && keepAlivesEnabled { + sentLength := header.Get("Content-Length") != "" + if sentLength && header.Get("Connection") == "keep-alive" { + w.closeAfterReply = false + } + } + + // Check for an explicit (and valid) Content-Length header. + hasCL := w.contentLength != -1 + + if w.wants10KeepAlive && (isHEAD || hasCL || !bodyAllowedForStatus(w.status)) { + _, connectionHeaderSet := header["Connection"] + if !connectionHeaderSet { + setHeader.connection = "keep-alive" + } + } else if !w.req.ProtoAtLeast(1, 1) || w.wantsClose { + w.closeAfterReply = true + } + + if header.Get("Connection") == "close" || !keepAlivesEnabled { + w.closeAfterReply = true + } + + // If the client wanted a 100-continue but we never sent it to + // them (or, more strictly: we never finished reading their + // request body), don't reuse this connection because it's now + // in an unknown state: we might be sending this response at + // the same time the client is now sending its request body + // after a timeout. (Some HTTP clients send Expect: + // 100-continue but knowing that some servers don't support + // it, the clients set a timer and send the body later anyway) + // If we haven't seen EOF, we can't skip over the unread body + // because we don't know if the next bytes on the wire will be + // the body-following-the-timer or the subsequent request. + // See Issue 11549. + if ecr, ok := w.req.Body.(*expectContinueReader); ok && !ecr.sawEOF.isSet() { + w.closeAfterReply = true + } + + // Per RFC 2616, we should consume the request body before + // replying, if the handler hasn't already done so. But we + // don't want to do an unbounded amount of reading here for + // DoS reasons, so we only try up to a threshold. + // TODO(bradfitz): where does RFC 2616 say that? See Issue 15527 + // about HTTP/1.x Handlers concurrently reading and writing, like + // HTTP/2 handlers can do. Maybe this code should be relaxed? + if w.req.ContentLength != 0 && !w.closeAfterReply { + var discard, tooBig bool + + switch bdy := w.req.Body.(type) { + case *expectContinueReader: + if bdy.resp.wroteContinue { + discard = true + } + case *body: + bdy.mu.Lock() + switch { + case bdy.closed: + if !bdy.sawEOF { + // Body was closed in handler with non-EOF error. + w.closeAfterReply = true + } + case bdy.unreadDataSizeLocked() >= 256<<10: + tooBig = true + default: + discard = true + } + bdy.mu.Unlock() + default: + discard = true + } + + if discard { + _, err := io.CopyN(io.Discard, w.reqBody, 256<<10+1) + switch err { + case nil: + // There must be even more data left over. + tooBig = true + case _http.ErrBodyReadAfterClose: + // Body was already consumed and closed. + case io.EOF: + // The remaining body was just consumed, close it. + err = w.reqBody.Close() + if err != nil { + w.closeAfterReply = true + } + default: + // Some other kind of error occurred, like a read timeout, or + // corrupt chunked encoding. In any case, whatever remains + // on the wire must not be parsed as another HTTP request. + w.closeAfterReply = true + } + } + + if tooBig { + w.requestTooLarge() + delHeader("Connection") + setHeader.connection = "close" + } + } + + code := w.status + if bodyAllowedForStatus(code) { + // If no content type, apply sniffing algorithm to body. + _, haveType := header["Content-Type"] + + // If the Content-Encoding was set and is non-blank, + // we shouldn't sniff the body. See Issue 31753. + ce := header.Get("Content-Encoding") + hasCE := len(ce) > 0 + if !hasCE && !haveType && !hasTE && len(p) > 0 { + setHeader.contentType = _http.DetectContentType(p) + } + } else { + for _, k := range suppressedHeaders(code) { + delHeader(k) + } + } + + //if !header.has("Date") { + setHeader.date = appendTime(cw.res.dateBuf[:0], time.Now()) + //} + + if hasCL && hasTE && te != "identity" { + // TODO: return an error if WriteHeader gets a return parameter + // For now just ignore the Content-Length. + log.Printf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d\n", + te, w.contentLength) + delHeader("Content-Length") + hasCL = false + } + + if w.req.Method == "HEAD" || !bodyAllowedForStatus(code) || code == _http.StatusNoContent { + // Response has no body. + delHeader("Transfer-Encoding") + } else if hasCL { + // Content-Length has been provided, so no chunking is to be done. + delHeader("Transfer-Encoding") + } else if w.req.ProtoAtLeast(1, 1) { + // HTTP/1.1 or greater: Transfer-Encoding has been set to identity, and no + // content-length has been provided. The connection must be closed after the + // reply is written, and no chunking is to be done. This is the setup + // recommended in the Server-Sent Events candidate recommendation 11, + // section 8. + if hasTE && te == "identity" { + cw.chunking = false + w.closeAfterReply = true + delHeader("Transfer-Encoding") + } else { + // HTTP/1.1 or greater: use chunked transfer encoding + // to avoid closing the connection at EOF. + cw.chunking = true + setHeader.transferEncoding = "chunked" + if hasTE && te == "chunked" { + // We will send the chunked Transfer-Encoding header later. + delHeader("Transfer-Encoding") + } + } + } else { + // HTTP version < 1.1: cannot do chunked transfer + // encoding and we don't know the Content-Length so + // signal EOF by closing connection. + w.closeAfterReply = true + delHeader("Transfer-Encoding") // in case already set + } + + // Cannot use Content-Length with non-identity Transfer-Encoding. + if cw.chunking { + delHeader("Content-Length") + } + if !w.req.ProtoAtLeast(1, 0) { + return + } + + // Only override the Connection header if it is not a successful + // protocol switch response and if KeepAlives are not enabled. + // See https://golang.org/issue/36381. + delConnectionHeader := w.closeAfterReply && + (!keepAlivesEnabled || !hasToken(cw.header.Get("Connection"), "close")) && + !isProtocolSwitchResponse(w.status, header) + if delConnectionHeader { + delHeader("Connection") + if w.req.ProtoAtLeast(1, 1) { + setHeader.connection = "close" + } + } + + writeStatusLine(w.conn.bufw, w.req.ProtoAtLeast(1, 1), code, w.statusBuf[:]) + cw.header.WriteSubset(w.conn.bufw, excludeHeader) + setHeader.Write(w.conn.bufw) + w.conn.bufw.Write(crlf) +} + +// extraHeader is the set of headers sometimes added by chunkWriter.writeHeader. +// This type is used to avoid extra allocations from cloning and/or populating +// the response Header map and all its 1-element slices. +type extraHeader struct { + contentType string + connection string + transferEncoding string + date []byte // written if not nil + contentLength []byte // written if not nil +} + +// Sorted the same as extraHeader.Write's loop. +var extraHeaderKeys = [][]byte{ + []byte("Content-Type"), + []byte("Connection"), + []byte("Transfer-Encoding"), +} + +var ( + headerContentLength = []byte("Content-Length: ") + headerDate = []byte("Date: ") +) + +// Write writes the headers described in h to w. +// +// This method has a value receiver, despite the somewhat large size +// of h, because it prevents an allocation. The escape analysis isn't +// smart enough to realize this function doesn't mutate h. +func (h extraHeader) Write(w *bufio.Writer) { + if h.date != nil { + w.Write(headerDate) + w.Write(h.date) + w.Write(crlf) + } + if h.contentLength != nil { + w.Write(headerContentLength) + w.Write(h.contentLength) + w.Write(crlf) + } + for i, v := range []string{h.contentType, h.connection, h.transferEncoding} { + if v != "" { + w.Write(extraHeaderKeys[i]) + w.Write(colonSpace) + w.WriteString(v) + w.Write(crlf) + } + } +} + +// declareTrailer is called for each Trailer header when the +// response header is written. It notes that a header will need to be +// written in the trailers at the end of the response. +func (w *response) declareTrailer(k string) { + k = CanonicalHeaderKey(k) + if !httpguts.ValidTrailerHeader(k) { + // Forbidden by RFC 7230, section 4.1.2 + return + } + w.trailers = append(w.trailers, k) +} + +// unreadDataSizeLocked returns the number of bytes of unread input. +// It returns -1 if unknown. +// b.mu must be held. +func (b *body) unreadDataSizeLocked() int64 { + if lr, ok := b.src.(*io.LimitedReader); ok { + return lr.N + } + return -1 +} + +// requestTooLarge is called by maxBytesReader when too much input has +// been read from the client. +func (w *response) requestTooLarge() { + w.closeAfterReply = true + w.requestBodyLimitHit = true + if !w.wroteHeader { + w.Header().Set("Connection", "close") + } +} + +// isProtocolSwitchResponse reports whether the response code and +// response header indicate a successful protocol upgrade response. +func isProtocolSwitchResponse(code int, h _http.Header) bool { + return code == _http.StatusSwitchingProtocols && isProtocolSwitchHeader(h) +} + +// isProtocolSwitchHeader reports whether the request or response header +// is for a protocol switch. +func isProtocolSwitchHeader(h _http.Header) bool { + return h.Get("Upgrade") != "" && + httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") +} diff --git a/nhttp/chunked.go b/nhttp/chunked.go new file mode 100644 index 0000000..493bf3a --- /dev/null +++ b/nhttp/chunked.go @@ -0,0 +1,262 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The wire protocol for HTTP's "chunked" Transfer-Encoding. + +// Package internal contains HTTP internals shared by net/nhttp and +// net/nhttp/httputil. +package nhttp + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" +) + +const maxLineLength = 4096 // assumed <= bufio.defaultBufSize + +var ErrLineTooLong = errors.New("header line too long") + +// NewChunkedReader returns a new chunkedReader that translates the data read from r +// out of HTTP "chunked" format before returning it. +// The chunkedReader returns io.EOF when the final 0-length chunk is read. +// +// NewChunkedReader is not needed by normal applications. The nhttp package +// automatically decodes chunking when reading response bodies. +func NewChunkedReader(r io.Reader) io.Reader { + br, ok := r.(*bufio.Reader) + if !ok { + br = bufio.NewReader(r) + } + return &chunkedReader{r: br} +} + +type chunkedReader struct { + r *bufio.Reader + n uint64 // unread bytes in chunk + err error + buf [2]byte + checkEnd bool // whether need to check for \r\n chunk footer +} + +func (cr *chunkedReader) beginChunk() { + // chunk-size CRLF + var line []byte + line, cr.err = readChunkLine(cr.r) + if cr.err != nil { + return + } + cr.n, cr.err = parseHexUint(line) + if cr.err != nil { + return + } + if cr.n == 0 { + cr.err = io.EOF + } +} + +func (cr *chunkedReader) chunkHeaderAvailable() bool { + n := cr.r.Buffered() + if n > 0 { + peek, _ := cr.r.Peek(n) + return bytes.IndexByte(peek, '\n') >= 0 + } + return false +} + +func (cr *chunkedReader) Read(b []uint8) (n int, err error) { + for cr.err == nil { + if cr.checkEnd { + if n > 0 && cr.r.Buffered() < 2 { + // We have some data. Return early (per the io.Reader + // contract) instead of potentially blocking while + // reading more. + break + } + if _, cr.err = io.ReadFull(cr.r, cr.buf[:2]); cr.err == nil { + if string(cr.buf[:]) != "\r\n" { + cr.err = errors.New("malformed chunked encoding") + break + } + } else { + if cr.err == io.EOF { + cr.err = io.ErrUnexpectedEOF + } + break + } + cr.checkEnd = false + } + if cr.n == 0 { + if n > 0 && !cr.chunkHeaderAvailable() { + // We've read enough. Don't potentially block + // reading a new chunk header. + break + } + cr.beginChunk() + continue + } + if len(b) == 0 { + break + } + rbuf := b + if uint64(len(rbuf)) > cr.n { + rbuf = rbuf[:cr.n] + } + var n0 int + n0, cr.err = cr.r.Read(rbuf) + n += n0 + b = b[n0:] + cr.n -= uint64(n0) + // If we're at the end of a chunk, read the next two + // bytes to verify they are "\r\n". + if cr.n == 0 && cr.err == nil { + cr.checkEnd = true + } else if cr.err == io.EOF { + cr.err = io.ErrUnexpectedEOF + } + } + return n, cr.err +} + +// Read a line of bytes (up to \n) from b. +// Give up if the line exceeds maxLineLength. +// The returned bytes are owned by the bufio.Reader +// so they are only valid until the next bufio read. +func readChunkLine(b *bufio.Reader) ([]byte, error) { + p, err := b.ReadSlice('\n') + if err != nil { + // We always know when EOF is coming. + // If the caller asked for a line, there should be a line. + if err == io.EOF { + err = io.ErrUnexpectedEOF + } else if err == bufio.ErrBufferFull { + err = ErrLineTooLong + } + return nil, err + } + if len(p) >= maxLineLength { + return nil, ErrLineTooLong + } + p = trimTrailingWhitespace(p) + p, err = removeChunkExtension(p) + if err != nil { + return nil, err + } + return p, nil +} + +func trimTrailingWhitespace(b []byte) []byte { + for len(b) > 0 && isASCIISpace(b[len(b)-1]) { + b = b[:len(b)-1] + } + return b +} + +func isASCIISpace(b byte) bool { + return b == ' ' || b == '\t' || b == '\n' || b == '\r' +} + +var semi = []byte(";") + +// removeChunkExtension removes any chunk-extension from p. +// For example, +// +// "0" => "0" +// "0;token" => "0" +// "0;token=val" => "0" +// `0;token="quoted string"` => "0" +func removeChunkExtension(p []byte) ([]byte, error) { + p, _, _ = bytes.Cut(p, semi) + // TODO: care about exact syntax of chunk extensions? We're + // ignoring and stripping them anyway. For now just never + // return an error. + return p, nil +} + +// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP +// "chunked" format before writing them to w. Closing the returned chunkedWriter +// sends the final 0-length chunk that marks the end of the stream but does +// not send the final CRLF that appears after trailers; trailers and the last +// CRLF must be written separately. +// +// NewChunkedWriter is not needed by normal applications. The nhttp +// package adds chunking automatically if handlers don't set a +// Content-Length header. Using newChunkedWriter inside a handler +// would result in double chunking or chunking with a Content-Length +// length, both of which are wrong. +func NewChunkedWriter(w io.Writer) io.WriteCloser { + return &chunkedWriter{w} +} + +// Writing to chunkedWriter translates to writing in HTTP chunked Transfer +// Encoding wire format to the underlying Wire chunkedWriter. +type chunkedWriter struct { + Wire io.Writer +} + +// Write the contents of data as one chunk to Wire. +// NOTE: Note that the corresponding chunk-writing procedure in Conn.Write has +// a bug since it does not check for success of io.WriteString +func (cw *chunkedWriter) Write(data []byte) (n int, err error) { + + // Don't send 0-length data. It looks like EOF for chunked encoding. + if len(data) == 0 { + return 0, nil + } + + if _, err = fmt.Fprintf(cw.Wire, "%x\r\n", len(data)); err != nil { + return 0, err + } + if n, err = cw.Wire.Write(data); err != nil { + return + } + if n != len(data) { + err = io.ErrShortWrite + return + } + if _, err = io.WriteString(cw.Wire, "\r\n"); err != nil { + return + } + if bw, ok := cw.Wire.(*FlushAfterChunkWriter); ok { + err = bw.Flush() + } + return +} + +func (cw *chunkedWriter) Close() error { + _, err := io.WriteString(cw.Wire, "0\r\n") + return err +} + +// FlushAfterChunkWriter signals from the caller of NewChunkedWriter +// that each chunk should be followed by a flush. It is used by the +// nhttp.Transport code to keep the buffering behavior for headers and +// trailers, but flush out chunks aggressively in the middle for +// request bodies which may be generated slowly. See Issue 6574. +type FlushAfterChunkWriter struct { + *bufio.Writer +} + +func parseHexUint(v []byte) (n uint64, err error) { + for i, b := range v { + switch { + case '0' <= b && b <= '9': + b = b - '0' + case 'a' <= b && b <= 'f': + b = b - 'a' + 10 + case 'A' <= b && b <= 'F': + b = b - 'A' + 10 + default: + return 0, errors.New("invalid byte in chunk length") + } + if i == 16 { + return 0, errors.New("nhttp chunk length too large") + } + n <<= 4 + n |= uint64(b) + } + return +} diff --git a/nhttp/errors.go b/nhttp/errors.go new file mode 100644 index 0000000..d038847 --- /dev/null +++ b/nhttp/errors.go @@ -0,0 +1,49 @@ +package nhttp + +import ( + "io" + "net" + _http "net/http" +) + +// unsupportedTEError reports unsupported transfer-encodings. +type unsupportedTEError struct { + err string +} + +func (uste *unsupportedTEError) Error() string { + return uste.err +} + +// isUnsupportedTEError checks if the error is of type +// unsupportedTEError. It is usually invoked with a non-nil err. +func isUnsupportedTEError(err error) bool { + _, ok := err.(*unsupportedTEError) + return ok +} + +// isCommonNetReadError reports whether err is a common error +// encountered during reading a request off the network when the +// client has gone away or had its read fail somehow. This is used to +// determine which logs are interesting enough to log about. +func isCommonNetReadError(err error) bool { + if err == io.EOF { + return true + } + if neterr, ok := err.(net.Error); ok && neterr.Timeout() { + return true + } + if oe, ok := err.(*net.OpError); ok && oe.Op == "read" { + return true + } + return false +} + +// statusError is an error used to respond to a request with an HTTP status. +// The text should be plain text without user info or other embedded errors. +type statusError struct { + code int + text string +} + +func (e statusError) Error() string { return _http.StatusText(e.code) + ": " + e.text } diff --git a/nhttp/handler.go b/nhttp/handler.go new file mode 100644 index 0000000..620f22d --- /dev/null +++ b/nhttp/handler.go @@ -0,0 +1,36 @@ +package nhttp + +import ( + "context" + "log" + _http "net/http" + "strings" + "sync/atomic" +) + +// serverHandler delegates to either the server's Handler or +// DefaultServeMux and also handles "OPTIONS *" requests. +type serverHandler struct { + srv *Server +} + +func (sh serverHandler) ServeHTTP(rw _http.ResponseWriter, req *_http.Request) { + handler := sh.srv.Handler + if handler == nil { + handler = _http.DefaultServeMux + } + + if req.URL != nil && strings.Contains(req.URL.RawQuery, ";") { + var allowQuerySemicolonsInUse int32 + req = req.WithContext(context.WithValue(req.Context(), &contextKey{"silence-semicolons"}, func() { + atomic.StoreInt32(&allowQuerySemicolonsInUse, 1) + })) + defer func() { + if atomic.LoadInt32(&allowQuerySemicolonsInUse) == 0 { + log.Printf("nhttp: URL query contains semicolon, which is no longer a supported separator; parts of the query may be stripped when parsed; see golang.org/issue/25192\n") + } + }() + } + + handler.ServeHTTP(rw, req) +} diff --git a/nhttp/request.go b/nhttp/request.go new file mode 100644 index 0000000..ce26ca7 --- /dev/null +++ b/nhttp/request.go @@ -0,0 +1,506 @@ +package nhttp + +import ( + "bufio" + "context" + "errors" + "fmt" + "golang.org/x/net/http/httpguts" + "io" + "math" + "net" + _http "net/http" + "net/url" + "strings" + "sync" + "sync/atomic" + "time" +) + +type connReader struct { + conn *conn + + mu sync.Mutex // guards following + hasByte bool + byteBuf [1]byte + cond *sync.Cond + inRead bool + aborted bool // set true before conn.rwc deadline is set to past + remain int64 // bytes remaining +} + +// aLongTimeAgo is a non-zero time, far in the past, used for +// immediate cancellation of network operations. +var aLongTimeAgo = time.Unix(1, 0) + +func (cr *connReader) lock() { + cr.mu.Lock() + if cr.cond == nil { + cr.cond = sync.NewCond(&cr.mu) + } +} + +func (cr *connReader) unlock() { cr.mu.Unlock() } + +func (cr *connReader) abortPendingRead() { + cr.lock() + defer cr.unlock() + if !cr.inRead { + return + } + cr.aborted = true + cr.conn.rwc.SetReadDeadline(aLongTimeAgo) + for cr.inRead { + cr.cond.Wait() + } + cr.conn.rwc.SetReadDeadline(time.Time{}) +} + +func (cr *connReader) startBackgroundRead() { + cr.lock() + defer cr.unlock() + if cr.inRead { + panic("invalid concurrent Body.Read call") + } + if cr.hasByte { + return + } + cr.inRead = true + cr.conn.rwc.SetReadDeadline(time.Time{}) + go cr.backgroundRead() +} + +func (cr *connReader) backgroundRead() { + n, err := cr.conn.rwc.Read(cr.byteBuf[:]) + cr.lock() + if n == 1 { + cr.hasByte = true + // We were past the end of the previous request's body already + // (since we wouldn't be in a background read otherwise), so + // this is a pipelined HTTP request. Prior to Go 1.11 we used to + // send on the CloseNotify channel and cancel the context here, + // but the behavior was documented as only "may", and we only + // did that because that's how CloseNotify accidentally behaved + // in very early Go releases prior to context support. Once we + // added context support, people used a Handler's + // Request.Context() and passed it along. Having that context + // cancel on pipelined HTTP requests caused problems. + // Fortunately, almost nothing uses HTTP/1.x pipelining. + // Unfortunately, apt-get does, or sometimes does. + // New Go 1.11 behavior: don't fire CloseNotify or cancel + // contexts on pipelined requests. Shouldn't affect people, but + // fixes cases like Issue 23921. This does mean that a client + // closing their TCP connection after sending a pipelined + // request won't cancel the context, but we'll catch that on any + // write failure (in checkConnErrorWriter.Write). + // If the server never writes, yes, there are still contrived + // server & client behaviors where this fails to ever cancel the + // context, but that's kinda why HTTP/1.x pipelining died + // anyway. + } + if ne, ok := err.(net.Error); ok && cr.aborted && ne.Timeout() { + // Ignore this error. It's the expected error from + // another goroutine calling abortPendingRead. + } else if err != nil { + cr.handleReadError(err) + } + cr.aborted = false + cr.inRead = false + cr.unlock() + cr.cond.Broadcast() +} + +func (cr *connReader) handleReadError(_ error) { + cr.conn.cancelCtx() + cr.closeNotify() +} + +// may be called from multiple goroutines. +func (cr *connReader) closeNotify() { + res, _ := cr.conn.curReq.Load().(*response) + if res != nil && atomic.CompareAndSwapInt32(&res.didCloseNotify, 0, 1) { + res.closeNotifyCh <- true + } +} + +func (cr *connReader) Read(p []byte) (n int, err error) { + cr.lock() + + if cr.hitReadLimit() { + cr.unlock() + return 0, io.EOF + } + if len(p) == 0 { + cr.unlock() + return 0, nil + } + if int64(len(p)) > cr.remain { + p = p[:cr.remain] + } + if cr.hasByte { + p[0] = cr.byteBuf[0] + cr.hasByte = false + cr.unlock() + return 1, nil + } + cr.inRead = true + cr.unlock() + n, err = cr.conn.rwc.Read(p) + + cr.lock() + cr.inRead = false + if err != nil { + cr.handleReadError(err) + } + cr.remain -= int64(n) + cr.unlock() + + cr.cond.Broadcast() + return n, err +} + +func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain } +func (cr *connReader) setInfiniteReadLimit() { cr.remain = math.MaxInt64 } +func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 } + +// chunkWriter writes to a response's conn buffer, and is the writer +// wrapped by the response.w buffered writer. +// +// chunkWriter also is responsible for finalizing the Header, including +// conditionally setting the Content-Type and setting a Content-Length +// in cases where the handler's final output is smaller than the buffer +// size. It also conditionally adds chunk headers, when in chunking mode. +// +// See the comment above (*response).Write for the entire write flow. +type chunkWriter struct { + res *response + + // header is either nil or a deep clone of res.handlerHeader + // at the time of res.writeHeader, if res.writeHeader is + // called and extra buffering is being done to calculate + // Content-Type and/or Content-Length. + header _http.Header + + // wroteHeader tells whether the header's been written to "the + // wire" (or rather: w.conn.buf). this is unlike + // (*response).wroteHeader, which tells only whether it was + // logically written. + wroteHeader bool + + // set by the writeHeader method: + chunking bool // using chunked transfer encoding for reply body +} + +// TimeFormat is the time format to use when generating times in HTTP +// headers. It is like time.RFC1123 but hard-codes GMT as the time +// zone. The time being formatted must be in UTC for Format to +// generate the correct format. +// +// For parsing this time format, see ParseTime. +const TimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT" + +// A response represents the server side of an HTTP response. +type response struct { + conn *conn + req *_http.Request // request for this response + reqBody io.ReadCloser + cancelCtx context.CancelFunc // when ServeHTTP exits + wroteHeader bool // a non-1xx header has been (logically) written + wroteContinue bool // 100 Continue response was written + wants10KeepAlive bool // HTTP/1.0 w/ Connection "keep-alive" + wantsClose bool // HTTP request has Connection "close" + + // canWriteContinue is a boolean value accessed as an atomic int32 + // that says whether or not a 100 Continue header can be written + // to the connection. + // writeContinueMu must be held while writing the header. + // These two fields together synchronize the body reader + // (the expectContinueReader, which wants to write 100 Continue) + // against the main writer. + canWriteContinue atomicBool + writeContinueMu sync.Mutex + + w *bufio.Writer // buffers output in chunks to chunkWriter + cw chunkWriter + + // handlerHeader is the Header that Handlers get access to, + // which may be retained and mutated even after WriteHeader. + // handlerHeader is copied into cw.header at WriteHeader + // time, and privately mutated thereafter. + handlerHeader _http.Header + calledHeader bool // handler accessed handlerHeader via Header + + written int64 // number of bytes written in body + contentLength int64 // explicitly-declared Content-Length; or -1 + status int // status code passed to WriteHeader + + // close connection after this reply. set on request and + // updated after response from handler if there's a + // "Connection: keep-alive" response header and a + // Content-Length. + closeAfterReply bool + + // requestBodyLimitHit is set by requestTooLarge when + // maxBytesReader hits its max size. It is checked in + // WriteHeader, to make sure we don't consume the + // remaining request body to try to advance to the next HTTP + // request. Instead, when this is set, we stop reading + // subsequent requests on this connection and stop reading + // input from it. + requestBodyLimitHit bool + + // trailers are the headers to be sent after the handler + // finishes writing the body. This field is initialized from + // the Trailer response header when the response header is + // written. + trailers []string + + handlerDone atomicBool // set true when the handler exits + + // Buffers for Date, Content-Length, and status code + dateBuf [len(TimeFormat)]byte + clenBuf [10]byte + statusBuf [3]byte + + // closeNotifyCh is the channel returned by CloseNotify. + // TODO(bradfitz): this is currently (for Go 1.8) always + // non-nil. Make this lazily-created again as it used to be? + closeNotifyCh chan bool + didCloseNotify int32 // atomic (only 0->1 winner should send) +} + +// http1ServerSupportsRequest reports whether Go's HTTP/1.x server +// supports the given request. +func http1ServerSupportsRequest(req *_http.Request) bool { + if req.ProtoMajor == 1 { + return true + } + // Accept "PRI * HTTP/2.0" upgrade requests, so Handlers can + // wire up their own HTTP/2 upgrades. + if req.ProtoMajor == 2 && req.ProtoMinor == 0 && + req.Method == "PRI" && req.RequestURI == "*" { + return true + } + // Reject HTTP/0.x, and all other HTTP/2+ requests (which + // aren't encoded in ASCII anyway). + return false +} + +var errTooLarge = errors.New("nhttp: request too large") + +// Read next request from connection. +func (c *conn) readRequest(ctx context.Context) (w *response, err error) { + var ( + wholeReqDeadline time.Time // or zero if none + hdrDeadline time.Time // or zero if none + ) + + c.r.setReadLimit(c.server.initialReadLimitSize()) + if c.lastMethod == "POST" { + // RFC 7230 section 3 tolerance for old buggy clients. + peek, _ := c.bufr.Peek(4) // ReadRequest will get err below + c.bufr.Discard(numLeadingCRorLF(peek)) + } + req, err := readRequest(c.bufr) + if err != nil { + if c.r.hitReadLimit() { + return nil, errTooLarge + } + return nil, err + } + + if !http1ServerSupportsRequest(req) { + return nil, errors.New("unsupported protocol version") + } + + c.lastMethod = req.Method + c.r.setInfiniteReadLimit() + + hosts, haveHost := req.Header["Host"] + if req.ProtoAtLeast(1, 1) && (!haveHost || len(hosts) == 0) && req.Method != "CONNECT" { + return nil, errors.New("missing required Host header") + } + if len(hosts) == 1 && !httpguts.ValidHostHeader(hosts[0]) { + return nil, errors.New("malformed Host header") + } + for k, vv := range req.Header { + if !httpguts.ValidHeaderFieldName(k) { + return nil, errors.New("invalid header name") + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + return nil, errors.New("invalid header value") + } + } + } + delete(req.Header, "Host") + + ctx, cancelCtx := context.WithCancel(ctx) + req = req.WithContext(ctx) + req.RemoteAddr = c.remoteAddr + req.TLS = c.tlsState + if body, ok := req.Body.(*body); ok { + body.doEarlyClose = true + } + + // Adjust the read deadline if necessary. + if !hdrDeadline.Equal(wholeReqDeadline) { + c.rwc.SetReadDeadline(wholeReqDeadline) + } + + w = &response{ + conn: c, + cancelCtx: cancelCtx, + req: req, + reqBody: req.Body, + handlerHeader: make(_http.Header), + contentLength: -1, + closeNotifyCh: make(chan bool, 1), + + // We populate these ahead of time so we're not + // reading from req.Header after their Handler starts + // and maybe mutates it (Issue 14940) + wants10KeepAlive: wantsHttp10KeepAlive(req), + wantsClose: wantsClose(req), + } + + w.cw.res = w + w.w = newBufioWriterSize(&w.cw, 2048) + return w, nil +} + +func wantsHttp10KeepAlive(r *_http.Request) bool { + if r.ProtoMajor != 1 || r.ProtoMinor != 0 { + return false + } + return hasToken(r.Header.Get("Connection"), "keep-alive") +} + +func wantsClose(r *_http.Request) bool { + if r.Close { + return true + } + return hasToken(r.Header.Get("Connection"), "close") +} + +func readRequest(b *bufio.Reader) (req *_http.Request, err error) { + tp := newTextprotoReader(b) + req = new(_http.Request) + + // First line: GET /index.html HTTP/1.0 + var s string + if s, err = tp.ReadLine(); err != nil { + return nil, err + } + defer func() { + putTextprotoReader(tp) + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() + + var ok bool + req.Method, req.RequestURI, req.Proto, ok = parseRequestLine(s) + if !ok { + return nil, errors.New("malformed HTTP request") + } + + rawurl := req.RequestURI + if req.ProtoMajor, req.ProtoMinor, ok = _http.ParseHTTPVersion(req.Proto); !ok { + return nil, errors.New("malformed HTTP version") + } + + // CONNECT requests are used two different ways, and neither uses a full URL: + // The standard use is to tunnel HTTPS through an HTTP proxy. + // It looks like "CONNECT www.google.com:443 HTTP/1.1", and the parameter is + // just the authority section of a URL. This information should go in req.URL.Host. + // + // The net/rpc package also uses CONNECT, but there the parameter is a path + // that starts with a slash. It can be parsed with the regular URL parser, + // and the path will end up in req.URL.Path, where it needs to be in order for + // RPC to work. + justAuthority := req.Method == "CONNECT" && !strings.HasPrefix(rawurl, "/") + if justAuthority { + rawurl = "http://" + rawurl + } + + if req.URL, err = url.ParseRequestURI(rawurl); err != nil { + return nil, err + } + + if justAuthority { + // Strip the bogus "nhttp://" back off. + req.URL.Scheme = "" + } + + // Subsequent lines: Key: value. + mimeHeader, err := tp.ReadMIMEHeader() + if err != nil { + return nil, err + } + req.Header = _http.Header(mimeHeader) + + // RFC 7230, section 5.3: Must treat + // GET /index.html HTTP/1.1 + // Host: www.google.com + // and + // GET http://www.google.com/index.html HTTP/1.1 + // Host: doesntmatter + // the same. In the second case, any Host line is ignored. + req.Host = req.URL.Host + if req.Host == "" { + req.Host = req.Header.Get("Host") + } + + fixPragmaCacheControl(req.Header) + + req.Close = shouldClose(req.ProtoMajor, req.ProtoMinor, req.Header, false) + + err = readTransfer(req, b) + if err != nil { + return nil, err + } + + return req, nil +} + +// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts. +func parseRequestLine(line string) (method, requestURI, proto string, ok bool) { + method, rest, ok1 := strings.Cut(line, " ") + requestURI, proto, ok2 := strings.Cut(rest, " ") + if !ok1 || !ok2 { + return "", "", "", false + } + return method, requestURI, proto, true +} + +func expectsContinue(r *_http.Request) bool { + return hasToken(r.Header.Get("Expect"), "100-continue") +} + +// requestBodyRemains reports whether future calls to Read +// on rc might yield more data. +func requestBodyRemains(rc io.ReadCloser) bool { + if rc == NoBody { + return false + } + switch v := rc.(type) { + case *expectContinueReader: + return requestBodyRemains(v.readCloser) + case *body: + return v.bodyRemains() + default: + panic("unexpected type " + fmt.Sprintf("%T", rc)) + } +} + +func registerOnHitEOF(rc io.ReadCloser, fn func()) { + switch v := rc.(type) { + case *expectContinueReader: + registerOnHitEOF(v.readCloser, fn) + case *body: + v.registerOnHitEOF(fn) + default: + panic("unexpected type " + fmt.Sprintf("%T", rc)) + } +} diff --git a/nhttp/response.go b/nhttp/response.go new file mode 100644 index 0000000..80afb08 --- /dev/null +++ b/nhttp/response.go @@ -0,0 +1,301 @@ +package nhttp + +import ( + "bufio" + "fmt" + "log" + "net" + _http "net/http" + "path" + "runtime" + "strconv" + "strings" + "time" +) + +const TrailerPrefix = "Trailer:" + +func fixPragmaCacheControl(header _http.Header) { + if hp, ok := header["Pragma"]; ok && len(hp) > 0 && hp[0] == "no-cache" { + if _, presentcc := header["Cache-Control"]; !presentcc { + header["Cache-Control"] = []string{"no-cache"} + } + } +} + +func (w *response) finalTrailers() _http.Header { + var t _http.Header + for k, vv := range w.handlerHeader { + if strings.HasPrefix(k, TrailerPrefix) { + if t == nil { + t = make(_http.Header) + } + t[strings.TrimPrefix(k, TrailerPrefix)] = vv + } + } + for _, k := range w.trailers { + if t == nil { + t = make(_http.Header) + } + for _, v := range w.handlerHeader[k] { + t.Add(k, v) + } + } + return t +} + +func (w *response) sendExpectationFailed() { + // TODO(bradfitz): let ServeHTTP handlers handle + // requests with non-standard expectation[s]? Seems + // theoretical at best, and doesn't fit into the + // current ServeHTTP model anyway. We'd need to + // make the ResponseWriter an optional + // "ExpectReplier" interface or something. + // + // For now we'll just obey RFC 7231 5.1.1 which says + // "A server that receives an Expect field-value other + // than 100-continue MAY respond with a 417 (Expectation + // Failed) status code to indicate that the unexpected + // expectation cannot be met." + w.Header().Set("Connection", "close") + w.WriteHeader(_http.StatusExpectationFailed) + w.finishRequest() +} + +func (w *response) finishRequest() { + w.handlerDone.setTrue() + + if !w.wroteHeader { + w.WriteHeader(_http.StatusOK) + } + + w.w.Flush() + putBufioWriter(w.w) + w.cw.close() + w.conn.bufw.Flush() + + w.conn.r.abortPendingRead() + + // Close the body (regardless of w.closeAfterReply) so we can + // re-use its bufio.Reader later safely. + w.reqBody.Close() + + if w.req.MultipartForm != nil { + w.req.MultipartForm.RemoveAll() + } +} + +func (w *response) Header() _http.Header { + if w.cw.header == nil && w.wroteHeader && !w.cw.wroteHeader { + // Accessing the header between logically writing it + // and physically writing it means we need to allocate + // a clone to snapshot the logically written state. + w.cw.header = w.handlerHeader.Clone() + } + w.calledHeader = true + return w.handlerHeader +} + +func (w *response) WriteHeader(code int) { + if w.wroteHeader { + caller := relevantCaller() + log.Printf("nhttp: superfluous response.WriteHeader call from %s (%s:%d)\n", caller.Function, path.Base(caller.File), caller.Line) + return + } + checkWriteHeaderCode(code) + + // Handle informational headers + if code >= 100 && code <= 199 { + // Prevent a potential race with an automatically-sent 100 Continue triggered by Request.Body.Read() + if code == 100 && w.canWriteContinue.isSet() { + w.writeContinueMu.Lock() + w.canWriteContinue.setFalse() + w.writeContinueMu.Unlock() + } + + writeStatusLine(w.conn.bufw, w.req.ProtoAtLeast(1, 1), code, w.statusBuf[:]) + + // Per RFC 8297 we must not clear the current header map + w.handlerHeader.WriteSubset(w.conn.bufw, map[string]bool{"Content-Length": true, "Transfer-Encoding": true}) + w.conn.bufw.Write(crlf) + w.conn.bufw.Flush() + + return + } + + w.wroteHeader = true + w.status = code + + if w.calledHeader && w.cw.header == nil { + w.cw.header = w.handlerHeader.Clone() + } + + if cl := w.handlerHeader.Get("Content-Length"); cl != "" { + v, err := strconv.ParseInt(cl, 10, 64) + if err == nil && v >= 0 { + w.contentLength = v + } else { + log.Printf("nhttp: invalid Content-Length of %q\n", cl) + w.handlerHeader.Del("Content-Length") + } + } +} + +// relevantCaller searches the call stack for the first function outside of net/nhttp. +// The purpose of this function is to provide more helpful error messages. +func relevantCaller() runtime.Frame { + pc := make([]uintptr, 16) + n := runtime.Callers(1, pc) + frames := runtime.CallersFrames(pc[:n]) + var frame runtime.Frame + for { + frame, more := frames.Next() + if !strings.HasPrefix(frame.Function, "net/nhttp.") { + return frame + } + if !more { + break + } + } + return frame +} + +func checkWriteHeaderCode(code int) { + // Issue 22880: require valid WriteHeader status codes. + // For now we only enforce that it's three digits. + // In the future we might block things over 599 (600 and above aren't defined + // at https://httpwg.org/specs/rfc7231.html#status.codes). + // But for now any three digits. + // + // We used to send "HTTP/1.1 000 0" on the wire in responses but there's + // no equivalent bogus thing we can realistically send in HTTP/2, + // so we'll consistently panic instead and help people find their bugs + // early. (We can't return an error from WriteHeader even if we wanted to.) + if code < 100 || code > 999 { + panic(fmt.Sprintf("invalid WriteHeader code %v", code)) + } +} + +// writeStatusLine writes an HTTP/1.x Status-Line (RFC 7230 Section 3.1.2) +// to bw. is11 is whether the HTTP request is HTTP/1.1. false means HTTP/1.0. +// code is the response status code. +// scratch is an optional scratch buffer. If it has at least capacity 3, it's used. +func writeStatusLine(bw *bufio.Writer, is11 bool, code int, scratch []byte) { + if is11 { + bw.WriteString("HTTP/1.1 ") + } else { + bw.WriteString("HTTP/1.0 ") + } + if text := _http.StatusText(code); text != "" { + bw.Write(strconv.AppendInt(scratch[:0], int64(code), 10)) + bw.WriteByte(' ') + bw.WriteString(text) + bw.WriteString("\r\n") + } else { + // don't worry about performance + fmt.Fprintf(bw, "%03d status code %d\r\n", code, code) + } +} + +func (w *response) Write(data []byte) (n int, err error) { + return w.write(len(data), data, "") +} + +func (w *response) WriteString(data string) (n int, err error) { + return w.write(len(data), nil, data) +} + +// either dataB or dataS is non-zero. +func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err error) { + if w.canWriteContinue.isSet() { + // Body reader wants to write 100 Continue but hasn't yet. + // Tell it not to. The store must be done while holding the lock + // because the lock makes sure that there is not an active write + // this very moment. + w.writeContinueMu.Lock() + w.canWriteContinue.setFalse() + w.writeContinueMu.Unlock() + } + + if !w.wroteHeader { + w.WriteHeader(_http.StatusOK) + } + if lenData == 0 { + return 0, nil + } + /*if !w.bodyAllowed() { + return 0, _http.ErrBodyNotAllowed + }*/ + + w.written += int64(lenData) // ignoring errors, for errorKludge + if w.contentLength != -1 && w.written > w.contentLength { + return 0, _http.ErrContentLength + } + if dataB != nil { + return w.w.Write(dataB) + } else { + return w.w.WriteString(dataS) + } +} + +// shouldReuseConnection reports whether the underlying TCP connection can be reused. +// It must only be called after the handler is done executing. +func (w *response) shouldReuseConnection() bool { + if w.closeAfterReply { + // The request or something set while executing the + // handler indicated we shouldn't reuse this + // connection. + return false + } + + if w.req.Method != "HEAD" && w.contentLength != -1 && w.contentLength != w.written { + // Did not write enough. Avoid getting out of sync. + return false + } + + // There was some error writing to the underlying connection + // during the request, so don't re-use this conn. + if w.conn.werr != nil { + return false + } + + if w.closedRequestBodyEarly() { + return false + } + + return true +} + +func (w *response) closedRequestBodyEarly() bool { + body, ok := w.req.Body.(*body) + return ok && body.didEarlyClose() +} + +// rstAvoidanceDelay is the amount of time we sleep after closing the +// write side of a TCP connection before closing the entire socket. +// By sleeping, we increase the chances that the client sees our FIN +// and processes its final data before they process the subsequent RST +// from closing a connection with known unread data. +// This RST seems to occur mostly on BSD systems. (And Windows?) +// This timeout is somewhat arbitrary (~latency around the planet). +const rstAvoidanceDelay = 500 * time.Millisecond + +type closeWriter interface { + CloseWrite() error +} + +var _ closeWriter = (*net.TCPConn)(nil) + +// closeWrite flushes any outstanding data and sends a FIN packet (if +// client is connected via TCP), signaling that we're done. We then +// pause for a bit, hoping the client processes it before any +// subsequent RST. +// +// See https://golang.org/issue/3595 +func (c *conn) closeWriteAndWait() { + c.finalFlush() + if tcp, ok := c.rwc.(closeWriter); ok { + tcp.CloseWrite() + } + time.Sleep(rstAvoidanceDelay) +} diff --git a/nhttp/server.go b/nhttp/server.go new file mode 100644 index 0000000..657bb92 --- /dev/null +++ b/nhttp/server.go @@ -0,0 +1,476 @@ +package nhttp + +import ( + "bufio" + "context" + "crypto/tls" + "errors" + "fmt" + "log" + "net" + _http "net/http" + "runtime" + "sync" + "sync/atomic" + "time" +) + +// ErrServerClosed is returned by the Server's Serve, ServeTLS, ListenAndServe, +// and ListenAndServeTLS methods after a call to Shutdown or Close. +var ErrServerClosed = errors.New("nhttp: Server closed") + +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. +type contextKey struct { + name string +} + +func (k *contextKey) String() string { return "net/nhttp context value " + k.name } + +var ( + // ServerContextKey is a context key. It can be used in HTTP + // handlers with Context.Value to access the server that + // started the handler. The associated value will be of + // type *Server. + ServerContextKey = &contextKey{"nhttp-server"} + + // LocalAddrContextKey is a context key. It can be used in + // HTTP handlers with Context.Value to access the local + // address the connection arrived on. + // The associated value will be of type net.Addr. + LocalAddrContextKey = &contextKey{"local-addr"} +) + +type Server struct { + Addr string + Handler _http.Handler + + // BaseContext optionally specifies a function that returns + // the base context for incoming requests on this server. + // The provided Listener is the specific Listener that's + // about to start accepting requests. + // If BaseContext is nil, the default is context.Background(). + // If non-nil, it must return a non-nil context. + BaseContext func(net.Listener) context.Context + + doneChan chan struct{} + + mu sync.Mutex + + disableKeepAlives int32 // accessed atomically. + + // IdleTimeout is the maximum amount of time to wait for the + // next request when keep-alives are enabled. If IdleTimeout + // is zero, the value of ReadTimeout is used. If both are + // zero, there is no timeout. + IdleTimeout time.Duration + + // ReadTimeout is the maximum duration for reading the entire + // request, including the body. A zero or negative value means + // there will be no timeout. + // + // Because ReadTimeout does not let Handlers make per-request + // decisions on each request body's acceptable deadline or + // upload rate, most users will prefer to use + // ReadHeaderTimeout. It is valid to use them both. + ReadTimeout time.Duration + + inShutdown atomicBool // true when server is in shutdown + + listeners map[*net.Listener]struct{} + + listenerGroup sync.WaitGroup +} + +// conn represents the server side of an HTTP connection. +type conn struct { + // server is the server on which the connection arrived. + // Immutable; never nil. + server *Server + + // cancelCtx cancels the connection-level context. + cancelCtx context.CancelFunc + + // rwc is the underlying network connection. + // This is never wrapped by other types and is the value given out + // to CloseNotifier callers. It is usually of type *net.TCPConn or + // *tls.Conn. + rwc net.Conn + + // remoteAddr is rwc.RemoteAddr().String(). It is not populated synchronously + // inside the Listener's Accept goroutine, as some implementations block. + // It is populated immediately inside the (*conn).serve goroutine. + // This is the value of a Handler's (*Request).RemoteAddr. + remoteAddr string + + // tlsState is the TLS connection state when using TLS. + // nil means not TLS. + tlsState *tls.ConnectionState + + // werr is set to the first write error to rwc. + // It is set via checkConnErrorWriter{w}, where bufw writes. + werr error + + // r is bufr's read source. It's a wrapper around rwc that provides + // io.LimitedReader-style limiting (while reading request headers) + // and functionality to support CloseNotifier. See *connReader docs. + r *connReader + + // bufr reads from r. + bufr *bufio.Reader + + // bufw writes to checkConnErrorWriter{c}, which populates werr on error. + bufw *bufio.Writer + + // lastMethod is the method of the most recent request + // on this connection, if any. + lastMethod string + + curReq atomic.Value // of *response (which has a Request in it) + + curState struct{ atomic uint64 } // packed (unixtime<<8|uint8(ConnState)) + + // mu guards hijackedv + mu sync.Mutex +} + +// onceCloseListener wraps a net.Listener, protecting it from +// multiple Close calls. +type onceCloseListener struct { + net.Listener + once sync.Once + closeErr error +} + +func (oc *onceCloseListener) Close() error { + oc.once.Do(oc.close) + return oc.closeErr +} + +func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() } + +func (srv *Server) ListenAndServe() error { + if srv.shuttingDown() { + // return ErrServerClosed + } + addr := srv.Addr + if addr == "" { + addr = ":nhttp" + } + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + return srv.Serve(ln) +} + +// ListenAndServe listens on the TCP network address addr and then calls +// Serve with handler to handle requests on incoming connections. +// Accepted connections are configured to enable TCP keep-alives. +// +// The handler is typically nil, in which case the DefaultServeMux is used. +// +// ListenAndServe always returns a non-nil error. +func ListenAndServe(addr string, handler _http.Handler) error { + server := &Server{Addr: addr, Handler: handler} + return server.ListenAndServe() +} + +// trackListener adds or removes a net.Listener to the set of tracked +// listeners. +// +// We store a pointer to interface in the map set, in case the +// net.Listener is not comparable. This is safe because we only call +// trackListener via Serve and can track+defer untrack the same +// pointer to local variable there. We never need to compare a +// Listener from another caller. +// +// It reports whether the server is still up (not Shutdown or Closed). +func (s *Server) trackListener(ln *net.Listener, add bool) bool { + s.mu.Lock() + defer s.mu.Unlock() + if s.listeners == nil { + s.listeners = make(map[*net.Listener]struct{}) + } + if add { + if s.shuttingDown() { + return false + } + s.listeners[ln] = struct{}{} + s.listenerGroup.Add(1) + } else { + delete(s.listeners, ln) + s.listenerGroup.Done() + } + return true +} + +func (srv *Server) Serve(l net.Listener) error { + origListener := l + l = &onceCloseListener{Listener: l} + defer l.Close() + + if !srv.trackListener(&l, true) { + return ErrServerClosed + } + defer srv.trackListener(&l, false) + + baseCtx := context.Background() + if srv.BaseContext != nil { + baseCtx = srv.BaseContext(origListener) + if baseCtx == nil { + panic("BaseContext returned a nil context") + } + } + + // How long to sleep on accept failure + var tempDelay time.Duration + ctx := context.WithValue(baseCtx, ServerContextKey, srv) + for { + rw, err := l.Accept() + if err != nil { + select { + case <-srv.getDoneChan(): + return ErrServerClosed + default: + } + if ne, ok := err.(net.Error); ok && ne.Temporary() { + if tempDelay == 0 { + tempDelay = 5 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 1 * time.Second; tempDelay > max { + tempDelay = max + } + log.Printf("nhttp: Accept error: %v; retrying in %v\n", err, tempDelay) + time.Sleep(tempDelay) + continue + } + return err + } + connCtx := ctx + tempDelay = 0 + c := &conn{ + server: srv, + rwc: rw, + } + + go c.serve(connCtx) + } +} + +// ErrAbortHandler is a sentinel panic value to abort a handler. +// While any panic from ServeHTTP aborts the response to the client, +// panicking with ErrAbortHandler also suppresses logging of a stack +// trace to the server's error log. +var ErrAbortHandler = errors.New("net/nhttp: abort Handler") + +// Close the connection. +func (c *conn) close() { + c.finalFlush() + c.rwc.Close() +} + +func (c *conn) finalFlush() { + if c.bufr != nil { + // Steal the bufio.Reader (~4KB worth of memory) and its associated + // reader for a future connection. + putBufioReader(c.bufr) + c.bufr = nil + } + + if c.bufw != nil { + c.bufw.Flush() + // Steal the bufio.Writer (~4KB worth of memory) and its associated + // writer for a future connection. + putBufioWriter(c.bufw) + c.bufw = nil + } +} + +func (c *conn) serve(ctx context.Context) { + c.remoteAddr = c.rwc.RemoteAddr().String() + ctx = context.WithValue(ctx, LocalAddrContextKey, c.rwc.LocalAddr()) + var inFlightResponse *response + defer func() { + if err := recover(); err != nil && err != ErrAbortHandler { + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + log.Printf("nhttp: panic serving %v: %v\n%s\n", c.remoteAddr, err, buf) + } + if inFlightResponse != nil { + inFlightResponse.cancelCtx() + } + + if inFlightResponse != nil { + inFlightResponse.conn.r.abortPendingRead() + inFlightResponse.reqBody.Close() + } + c.close() + }() + + // HTTP/1.x from here on. + + ctx, cancelCtx := context.WithCancel(ctx) + c.cancelCtx = cancelCtx + defer cancelCtx() + + c.r = &connReader{conn: c} + c.bufr = newBufioReader(c.r) + c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10) + + for { + w, err := c.readRequest(ctx) + if err != nil { + const errorHeaders = "\r\nContent-Type: text/plain; charset=utf-8\r\nConnection: close\r\n\r\n" + + switch { + case err == errTooLarge: + // Their HTTP client may or may not be + // able to read this if we're + // responding to them and hanging up + // while they're still writing their + // request. Undefined behavior. + const publicErr = "431 Request Header Fields Too Large" + fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr) + c.closeWriteAndWait() + return + + case isUnsupportedTEError(err): + // Respond as per RFC 7230 Section 3.3.1 which says, + // A server that receives a request message with a + // transfer coding it does not understand SHOULD + // respond with 501 (Unimplemented). + code := _http.StatusNotImplemented + + // We purposefully aren't echoing back the transfer-encoding's value, + // so as to mitigate the risk of cross side scripting by an attacker. + fmt.Fprintf(c.rwc, "HTTP/1.1 %d %s%sUnsupported transfer encoding", code, _http.StatusText(code), errorHeaders) + return + + case isCommonNetReadError(err): + return // don't reply + + default: + if v, ok := err.(statusError); ok { + fmt.Fprintf(c.rwc, "HTTP/1.1 %d %s: %s%s%d %s: %s", v.code, _http.StatusText(v.code), v.text, errorHeaders, v.code, _http.StatusText(v.code), v.text) + return + } + publicErr := "400 Bad Request" + fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr) + return + } + } + + // Expect 100 Continue support + req := w.req + if expectsContinue(req) { + if req.ProtoAtLeast(1, 1) && req.ContentLength != 0 { + // Wrap the Body reader with one that replies on the connection + req.Body = &expectContinueReader{readCloser: req.Body, resp: w} + w.canWriteContinue.setTrue() + } + } else if req.Header.Get("Expect") != "" { + w.sendExpectationFailed() + return + } + + c.curReq.Store(w) + + if requestBodyRemains(req.Body) { + registerOnHitEOF(req.Body, w.conn.r.startBackgroundRead) + } else { + w.conn.r.startBackgroundRead() + } + + // HTTP cannot have multiple simultaneous active requests.[*] + // Until the server replies to this request, it can't read another, + // so we might as well run the handler in this goroutine. + // [*] Not strictly true: HTTP pipelining. We could let them all process + // in parallel even if their responses need to be serialized. + // But we're not going to implement HTTP pipelining because it + // was never deployed in the wild and the answer is HTTP/2. + inFlightResponse = w + serverHandler{c.server}.ServeHTTP(w, w.req) + inFlightResponse = nil + w.cancelCtx() + w.finishRequest() + if !w.shouldReuseConnection() { + if w.requestBodyLimitHit || w.closedRequestBodyEarly() { + c.closeWriteAndWait() + } + return + } + + if !w.conn.server.doKeepAlives() { + // We're in shutdown mode. We might've replied + // to the user without "Connection: close" and + // they might think they can send another + // request, but such is life with HTTP/1.1. + return + } + + if d := c.server.idleTimeout(); d != 0 { + c.rwc.SetReadDeadline(time.Now().Add(d)) + if _, err := c.bufr.Peek(4); err != nil { + return + } + } + + c.curReq.Store((*response)(nil)) + if d := c.server.idleTimeout(); d != 0 { + c.rwc.SetReadDeadline(time.Now().Add(d)) + if _, err := c.bufr.Peek(4); err != nil { + return + } + } + c.rwc.SetReadDeadline(time.Time{}) + } +} + +func (srv *Server) getDoneChan() <-chan struct{} { + srv.mu.Lock() + defer srv.mu.Unlock() + if srv.doneChan == nil { + srv.doneChan = make(chan struct{}) + } + + return srv.doneChan +} + +func (s *Server) doKeepAlives() bool { + return atomic.LoadInt32(&s.disableKeepAlives) == 0 && !s.shuttingDown() +} + +func (s *Server) shuttingDown() bool { + return s.inShutdown.isSet() +} + +func (srv *Server) initialReadLimitSize() int64 { + return 1<<20 + 4096 // bufio slop +} + +func (srv *Server) idleTimeout() time.Duration { + if srv.IdleTimeout != 0 { + return srv.IdleTimeout + } + return srv.ReadTimeout +} + +// checkConnErrorWriter writes to c.rwc and records any write errors to c.werr. +// It only contains one field (and a pointer field at that), so it +// fits in an interface value without an extra allocation. +type checkConnErrorWriter struct { + c *conn +} + +func (w checkConnErrorWriter) Write(p []byte) (n int, err error) { + n, err = w.c.rwc.Write(p) + if err != nil && w.c.werr == nil { + w.c.werr = err + w.c.cancelCtx() + } + return +} diff --git a/nhttp/transfer.go b/nhttp/transfer.go new file mode 100644 index 0000000..ae7d2e3 --- /dev/null +++ b/nhttp/transfer.go @@ -0,0 +1,580 @@ +package nhttp + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "golang.org/x/net/http/httpguts" + "io" + _http "net/http" + "net/textproto" + "strconv" + "sync" +) + +type transferReader struct { + // Input + Header _http.Header + StatusCode int + RequestMethod string + ProtoMajor int + ProtoMinor int + // Output + Body io.ReadCloser + ContentLength int64 + Chunked bool + Close bool + Trailer _http.Header +} + +// bodyLocked is an io.Reader reading from a *body when its mutex is +// already held. +type bodyLocked struct { + b *body +} + +func (bl bodyLocked) Read(p []byte) (n int, err error) { + if bl.b.closed { + return 0, errors.New("nhttp: invalid Read on closed Body") + } + return bl.b.readLocked(p) +} + +// body turns a Reader into a ReadCloser. +// Close ensures that the body has been fully read +// and then reads the trailer if necessary. +type body struct { + src io.Reader + hdr any // non-nil (Response or Request) value means read trailer + r *bufio.Reader // underlying wire-format reader for the trailer + closing bool // is the connection to be closed after reading body? + doEarlyClose bool // whether Close should stop early + + mu sync.Mutex // guards following, and calls to Read and Close + sawEOF bool + closed bool + earlyClose bool // Close called and we didn't read to the end of src + onHitEOF func() // if non-nil, func to call when EOF is Read +} + +func (b *body) Read(p []byte) (n int, err error) { + b.mu.Lock() + defer b.mu.Unlock() + if b.closed { + return 0, errors.New("nhttp: invalid Read on closed Body") + } + return b.readLocked(p) +} + +// Must hold b.mu. +func (b *body) readLocked(p []byte) (n int, err error) { + if b.sawEOF { + return 0, io.EOF + } + n, err = b.src.Read(p) + + if err == io.EOF { + b.sawEOF = true + // Chunked case. Read the trailer. + if b.hdr != nil { + if e := b.readTrailer(); e != nil { + err = e + // Something went wrong in the trailer, we must not allow any + // further reads of any kind to succeed from body, nor any + // subsequent requests on the server connection. See + // golang.org/issue/12027 + b.sawEOF = false + b.closed = true + } + b.hdr = nil + } else { + // If the server declared the Content-Length, our body is a LimitedReader + // and we need to check whether this EOF arrived early. + if lr, ok := b.src.(*io.LimitedReader); ok && lr.N > 0 { + err = io.ErrUnexpectedEOF + } + } + } + + // If we can return an EOF here along with the read data, do + // so. This is optional per the io.Reader contract, but doing + // so helps the HTTP transport code recycle its connection + // earlier (since it will see this EOF itself), even if the + // client doesn't do future reads or Close. + if err == nil && n > 0 { + if lr, ok := b.src.(*io.LimitedReader); ok && lr.N == 0 { + err = io.EOF + b.sawEOF = true + } + } + + if b.sawEOF && b.onHitEOF != nil { + b.onHitEOF() + } + + return n, err +} + +func (b *body) Close() error { + b.mu.Lock() + defer b.mu.Unlock() + if b.closed { + return nil + } + var err error + switch { + case b.sawEOF: + // Already saw EOF, so no need going to look for it. + case b.hdr == nil && b.closing: + // no trailer and closing the connection next. + // no point in reading to EOF. + case b.doEarlyClose: + // Read up to maxPostHandlerReadBytes bytes of the body, looking + // for EOF (and trailers), so we can re-use this connection. + if lr, ok := b.src.(*io.LimitedReader); ok && lr.N > 256<<10 { + // There was a declared Content-Length, and we have more bytes remaining + // than our maxPostHandlerReadBytes tolerance. So, give up. + b.earlyClose = true + } else { + var n int64 + // Consume the body, or, which will also lead to us reading + // the trailer headers after the body, if present. + n, err = io.CopyN(io.Discard, bodyLocked{b}, 256<<10) + if err == io.EOF { + err = nil + } + if n == 256<<10 { + b.earlyClose = true + } + } + default: + // Fully consume the body, which will also lead to us reading + // the trailer headers after the body, if present. + _, err = io.Copy(io.Discard, bodyLocked{b}) + } + b.closed = true + return err +} + +func (b *body) didEarlyClose() bool { + b.mu.Lock() + defer b.mu.Unlock() + return b.earlyClose +} + +// bodyRemains reports whether future Read calls might +// yield data. +func (b *body) bodyRemains() bool { + b.mu.Lock() + defer b.mu.Unlock() + return !b.sawEOF +} + +func (b *body) registerOnHitEOF(fn func()) { + b.mu.Lock() + defer b.mu.Unlock() + b.onHitEOF = fn +} + +// Determine whether to hang up after sending a request and body, or +// receiving a response and body +// 'header' is the request headers +func shouldClose(major, minor int, header _http.Header, removeCloseHeader bool) bool { + if major < 1 { + return true + } + + conv := header["Connection"] + hasClose := httpguts.HeaderValuesContainsToken(conv, "close") + if major == 1 && minor == 0 { + return hasClose || !httpguts.HeaderValuesContainsToken(conv, "keep-alive") + } + + if hasClose && removeCloseHeader { + header.Del("Connection") + } + + return hasClose +} + +// msg is *Request or *Response. +func readTransfer(msg any, r *bufio.Reader) (err error) { + t := &transferReader{RequestMethod: "GET"} + + // Unify input + isResponse := false + switch rr := msg.(type) { + case *_http.Response: + t.Header = rr.Header + t.StatusCode = rr.StatusCode + t.ProtoMajor = rr.ProtoMajor + t.ProtoMinor = rr.ProtoMinor + t.Close = shouldClose(t.ProtoMajor, t.ProtoMinor, t.Header, true) + isResponse = true + if rr.Request != nil { + t.RequestMethod = rr.Request.Method + } + case *_http.Request: + t.Header = rr.Header + t.RequestMethod = rr.Method + t.ProtoMajor = rr.ProtoMajor + t.ProtoMinor = rr.ProtoMinor + // Transfer semantics for Requests are exactly like those for + // Responses with status code 200, responding to a GET method + t.StatusCode = 200 + t.Close = rr.Close + default: + panic("unexpected type") + } + + // Default to HTTP/1.1 + if t.ProtoMajor == 0 && t.ProtoMinor == 0 { + t.ProtoMajor, t.ProtoMinor = 1, 1 + } + + // Transfer-Encoding: chunked, and overriding Content-Length. + if err := t.parseTransferEncoding(); err != nil { + return err + } + + realLength, err := fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.Chunked) + if err != nil { + return err + } + if isResponse && t.RequestMethod == "HEAD" { + if n, err := parseContentLength(t.Header.Get("Content-Length")); err != nil { + return err + } else { + t.ContentLength = n + } + } else { + t.ContentLength = realLength + } + + // Trailer + t.Trailer, err = fixTrailer(t.Header, t.Chunked) + if err != nil { + return err + } + + // If there is no Content-Length or chunked Transfer-Encoding on a *Response + // and the status is not 1xx, 204 or 304, then the body is unbounded. + // See RFC 7230, section 3.3. + switch msg.(type) { + case *_http.Response: + if realLength == -1 && !t.Chunked && bodyAllowedForStatus(t.StatusCode) { + // Unbounded body. + t.Close = true + } + } + + // Prepare body reader. ContentLength < 0 means chunked encoding + // or close connection when finished, since multipart is not supported yet + switch { + case t.Chunked: + if t.RequestMethod == "HEAD" || !bodyAllowedForStatus(t.StatusCode) { + t.Body = NoBody + } else { + t.Body = &body{src: NewChunkedReader(r), hdr: msg, r: r, closing: t.Close} + } + case realLength == 0: + t.Body = NoBody + case realLength > 0: + t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close} + default: + // realLength < 0, i.e. "Content-Length" not mentioned in header + if t.Close { + // Close semantics (i.e. HTTP/1.0) + t.Body = &body{src: r, closing: t.Close} + } else { + // Persistent connection (i.e. HTTP/1.1) + t.Body = NoBody + } + } + + // Unify output + switch rr := msg.(type) { + case *_http.Request: + rr.Body = t.Body + rr.ContentLength = t.ContentLength + if t.Chunked { + rr.TransferEncoding = []string{"chunked"} + } + rr.Close = t.Close + rr.Trailer = t.Trailer + case *_http.Response: + rr.Body = t.Body + rr.ContentLength = t.ContentLength + if t.Chunked { + rr.TransferEncoding = []string{"chunked"} + } + rr.Close = t.Close + rr.Trailer = t.Trailer + } + + return nil +} + +// parseTransferEncoding sets t.Chunked based on the Transfer-Encoding header. +func (t *transferReader) parseTransferEncoding() error { + raw, present := t.Header["Transfer-Encoding"] + if !present { + return nil + } + delete(t.Header, "Transfer-Encoding") + + // Issue 12785; ignore Transfer-Encoding on HTTP/1.0 requests. + if !t.protoAtLeast(1, 1) { + return nil + } + + // Like nginx, we only support a single Transfer-Encoding header field, and + // only if set to "chunked". This is one of the most security sensitive + // surfaces in HTTP/1.1 due to the risk of request smuggling, so we keep it + // strict and simple. + if len(raw) != 1 { + return errors.New(fmt.Sprintf("too many transfer encodings: %q", raw)) + } + if !EqualFold(raw[0], "chunked") { + return errors.New(fmt.Sprintf("unsupported transfer encoding: %q", raw[0])) + } + + // RFC 7230 3.3.2 says "A sender MUST NOT send a Content-Length header field + // in any message that contains a Transfer-Encoding header field." + // + // but also: "If a message is received with both a Transfer-Encoding and a + // Content-Length header field, the Transfer-Encoding overrides the + // Content-Length. Such a message might indicate an attempt to perform + // request smuggling (Section 9.5) or response splitting (Section 9.4) and + // ought to be handled as an error. A sender MUST remove the received + // Content-Length field prior to forwarding such a message downstream." + // + // Reportedly, these appear in the wild. + delete(t.Header, "Content-Length") + + t.Chunked = true + return nil +} + +func (t *transferReader) protoAtLeast(m, n int) bool { + return t.ProtoMajor > m || (t.ProtoMajor == m && t.ProtoMinor >= n) +} + +// Determine the expected body length, using RFC 7230 Section 3.3. This +// function is not a method, because ultimately it should be shared by +// ReadResponse and ReadRequest. +func fixLength(isResponse bool, status int, requestMethod string, header _http.Header, chunked bool) (int64, error) { + isRequest := !isResponse + contentLens := header["Content-Length"] + + // Hardening against HTTP request smuggling + if len(contentLens) > 1 { + // Per RFC 7230 Section 3.3.2, prevent multiple + // Content-Length headers if they differ in value. + // If there are dups of the value, remove the dups. + // See Issue 16490. + first := textproto.TrimString(contentLens[0]) + for _, ct := range contentLens[1:] { + if first != textproto.TrimString(ct) { + return 0, fmt.Errorf("nhttp: message cannot contain multiple Content-Length headers; got %q", contentLens) + } + } + + // deduplicate Content-Length + header.Del("Content-Length") + header.Add("Content-Length", first) + + contentLens = header["Content-Length"] + } + + // Logic based on response type or status + if requestMethod == "HEAD" { + // For HTTP requests, as part of hardening against request + // smuggling (RFC 7230), don't allow a Content-Length header for + // methods which don't permit bodies. As an exception, allow + // exactly one Content-Length header if its value is "0". + if isRequest && len(contentLens) > 0 && !(len(contentLens) == 1 && contentLens[0] == "0") { + return 0, fmt.Errorf("nhttp: method cannot contain a Content-Length; got %q", contentLens) + } + return 0, nil + } + if status/100 == 1 { + return 0, nil + } + switch status { + case 204, 304: + return 0, nil + } + + // Logic based on Transfer-Encoding + if chunked { + return -1, nil + } + + // Logic based on Content-Length + var cl string + if len(contentLens) == 1 { + cl = textproto.TrimString(contentLens[0]) + } + if cl != "" { + n, err := parseContentLength(cl) + if err != nil { + return -1, err + } + return n, nil + } + header.Del("Content-Length") + + if isRequest { + // RFC 7230 neither explicitly permits nor forbids an + // entity-body on a GET request so we permit one if + // declared, but we default to 0 here (not -1 below) + // if there's no mention of a body. + // Likewise, all other request methods are assumed to have + // no body if neither Transfer-Encoding chunked nor a + // Content-Length are set. + return 0, nil + } + + // Body-EOF logic based on other methods (like closing, or chunked coding) + return -1, nil +} + +func parseContentLength(cl string) (int64, error) { + cl = textproto.TrimString(cl) + if cl == "" { + return -1, nil + } + n, err := strconv.ParseUint(cl, 10, 63) + if err != nil { + return 0, errors.New("bad Content-Length") + } + return int64(n), nil + +} + +// Parse the trailer header +func fixTrailer(header _http.Header, chunked bool) (_http.Header, error) { + vv, ok := header["Trailer"] + if !ok { + return nil, nil + } + if !chunked { + // Trailer and no chunking: + // this is an invalid use case for trailer header. + // Nevertheless, no error will be returned and we + // let users decide if this is a valid HTTP message. + // The Trailer header will be kept in Response.Header + // but not populate Response.Trailer. + // See issue #27197. + return nil, nil + } + header.Del("Trailer") + + trailer := make(_http.Header) + var err error + for _, v := range vv { + foreachHeaderElement(v, func(key string) { + key = CanonicalHeaderKey(key) + switch key { + case "Transfer-Encoding", "Trailer", "Content-Length": + if err == nil { + err = errors.New("bad trailer key") + return + } + } + trailer[key] = nil + }) + } + if err != nil { + return nil, err + } + if len(trailer) == 0 { + return nil, nil + } + return trailer, nil +} + +func bodyAllowedForStatus(status int) bool { + switch { + case status >= 100 && status <= 199: + return false + case status == 204: + return false + case status == 304: + return false + } + return true +} + +var errTrailerEOF = errors.New("nhttp: unexpected EOF reading trailer") + +func (b *body) readTrailer() error { + // The common case, since nobody uses trailers. + buf, err := b.r.Peek(2) + if bytes.Equal(buf, singleCRLF) { + b.r.Discard(2) + return nil + } + if len(buf) < 2 { + return errTrailerEOF + } + if err != nil { + return err + } + + // Make sure there's a header terminator coming up, to prevent + // a DoS with an unbounded size Trailer. It's not easy to + // slip in a LimitReader here, as textproto.NewReader requires + // a concrete *bufio.Reader. Also, we can't get all the way + // back up to our conn's LimitedReader that *might* be backing + // this bufio.Reader. Instead, a hack: we iteratively Peek up + // to the bufio.Reader's max size, looking for a double CRLF. + // This limits the trailer to the underlying buffer size, typically 4kB. + if !seeUpcomingDoubleCRLF(b.r) { + return errors.New("nhttp: suspiciously long trailer after chunked body") + } + + hdr, err := textproto.NewReader(b.r).ReadMIMEHeader() + if err != nil { + if err == io.EOF { + return errTrailerEOF + } + return err + } + switch rr := b.hdr.(type) { + case *_http.Request: + mergeSetHeader(&rr.Trailer, _http.Header(hdr)) + case *_http.Response: + mergeSetHeader(&rr.Trailer, _http.Header(hdr)) + } + return nil +} + +func mergeSetHeader(dst *_http.Header, src _http.Header) { + if *dst == nil { + *dst = src + return + } + for k, vv := range src { + (*dst)[k] = vv + } +} + +var ( + suppressedHeaders304 = []string{"Content-Type", "Content-Length", "Transfer-Encoding"} + suppressedHeadersNoBody = []string{"Content-Length", "Transfer-Encoding"} + excludedHeadersNoBody = map[string]bool{"Content-Length": true, "Transfer-Encoding": true} +) + +func suppressedHeaders(status int) []string { + switch { + case status == 304: + // RFC 7232 section 4.1 + return suppressedHeaders304 + case !bodyAllowedForStatus(status): + return suppressedHeadersNoBody + } + return nil +}