Publicize nhttp.AtomicBool for use in qr2

This commit is contained in:
ppeb 2025-04-27 16:16:30 -05:00
parent ad221202f2
commit c67efb33a0
No known key found for this signature in database
GPG Key ID: CC147AD1B3D318D0
5 changed files with 27 additions and 26 deletions

View File

@ -28,11 +28,11 @@ func (noBody) Read([]byte) (int, error) { return 0, io.EOF }
func (noBody) Close() error { return nil } func (noBody) Close() error { return nil }
func (noBody) WriteTo(io.Writer) (int64, error) { return 0, nil } func (noBody) WriteTo(io.Writer) (int64, error) { return 0, nil }
type atomicBool int32 type AtomicBool int32
func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 } func (b *AtomicBool) IsSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) } func (b *AtomicBool) SetTrue() { atomic.StoreInt32((*int32)(b), 1) }
func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) } func (b *AtomicBool) SetFalse() { atomic.StoreInt32((*int32)(b), 0) }
var textprotoReaderPool sync.Pool var textprotoReaderPool sync.Pool
@ -180,34 +180,34 @@ func (cw *chunkWriter) close() {
type expectContinueReader struct { type expectContinueReader struct {
resp *response resp *response
readCloser io.ReadCloser readCloser io.ReadCloser
closed atomicBool closed AtomicBool
sawEOF atomicBool sawEOF AtomicBool
} }
func (ecr *expectContinueReader) Read(p []byte) (n int, err error) { func (ecr *expectContinueReader) Read(p []byte) (n int, err error) {
if ecr.closed.isSet() { if ecr.closed.IsSet() {
return 0, errors.New("nhttp: invalid Read on closed Body") return 0, errors.New("nhttp: invalid Read on closed Body")
} }
w := ecr.resp w := ecr.resp
if !w.wroteContinue && w.canWriteContinue.isSet() { if !w.wroteContinue && w.canWriteContinue.IsSet() {
w.wroteContinue = true w.wroteContinue = true
w.writeContinueMu.Lock() w.writeContinueMu.Lock()
if w.canWriteContinue.isSet() { if w.canWriteContinue.IsSet() {
w.conn.bufw.WriteString("HTTP/1.1 100 Continue\r\n\r\n") w.conn.bufw.WriteString("HTTP/1.1 100 Continue\r\n\r\n")
w.conn.bufw.Flush() w.conn.bufw.Flush()
w.canWriteContinue.setFalse() w.canWriteContinue.SetFalse()
} }
w.writeContinueMu.Unlock() w.writeContinueMu.Unlock()
} }
n, err = ecr.readCloser.Read(p) n, err = ecr.readCloser.Read(p)
if err == io.EOF { if err == io.EOF {
ecr.sawEOF.setTrue() ecr.sawEOF.SetTrue()
} }
return return
} }
func (ecr *expectContinueReader) Close() error { func (ecr *expectContinueReader) Close() error {
ecr.closed.setTrue() ecr.closed.SetTrue()
return ecr.readCloser.Close() return ecr.readCloser.Close()
} }
@ -288,7 +288,7 @@ func (cw *chunkWriter) writeHeader(p []byte) {
// send a Content-Length header. // send a Content-Length header.
// Further, we don't send an automatic Content-Length if they // Further, we don't send an automatic Content-Length if they
// set a Transfer-Encoding, because they're generally incompatible. // set a Transfer-Encoding, because they're generally incompatible.
if w.handlerDone.isSet() && !trailers && !hasTE && bodyAllowedForStatus(w.status) && header.Get("Content-Length") == "" && (!isHEAD || len(p) > 0) { if w.handlerDone.IsSet() && !trailers && !hasTE && bodyAllowedForStatus(w.status) && header.Get("Content-Length") == "" && (!isHEAD || len(p) > 0) {
w.contentLength = int64(len(p)) w.contentLength = int64(len(p))
setHeader.contentLength = strconv.AppendInt(cw.res.clenBuf[:0], int64(len(p)), 10) setHeader.contentLength = strconv.AppendInt(cw.res.clenBuf[:0], int64(len(p)), 10)
} }
@ -330,7 +330,7 @@ func (cw *chunkWriter) writeHeader(p []byte) {
// because we don't know if the next bytes on the wire will be // because we don't know if the next bytes on the wire will be
// the body-following-the-timer or the subsequent request. // the body-following-the-timer or the subsequent request.
// See Issue 11549. // See Issue 11549.
if ecr, ok := w.req.Body.(*expectContinueReader); ok && !ecr.sawEOF.isSet() { if ecr, ok := w.req.Body.(*expectContinueReader); ok && !ecr.sawEOF.IsSet() {
w.closeAfterReply = true w.closeAfterReply = true
} }

View File

@ -218,7 +218,7 @@ type response struct {
// These two fields together synchronize the body reader // These two fields together synchronize the body reader
// (the expectContinueReader, which wants to write 100 Continue) // (the expectContinueReader, which wants to write 100 Continue)
// against the main writer. // against the main writer.
canWriteContinue atomicBool canWriteContinue AtomicBool
writeContinueMu sync.Mutex writeContinueMu sync.Mutex
w *bufio.Writer // buffers output in chunks to chunkWriter w *bufio.Writer // buffers output in chunks to chunkWriter
@ -256,7 +256,7 @@ type response struct {
// written. // written.
trailers []string trailers []string
handlerDone atomicBool // set true when the handler exits handlerDone AtomicBool // set true when the handler exits
// Buffers for Date, Content-Length, and status code // Buffers for Date, Content-Length, and status code
dateBuf [len(TimeFormat)]byte dateBuf [len(TimeFormat)]byte

View File

@ -63,7 +63,7 @@ func (w *response) sendExpectationFailed() {
} }
func (w *response) finishRequest() { func (w *response) finishRequest() {
w.handlerDone.setTrue() w.handlerDone.SetTrue()
if !w.wroteHeader { if !w.wroteHeader {
w.WriteHeader(_http.StatusOK) w.WriteHeader(_http.StatusOK)
@ -107,9 +107,9 @@ func (w *response) WriteHeader(code int) {
// Handle informational headers // Handle informational headers
if code >= 100 && code <= 199 { if code >= 100 && code <= 199 {
// Prevent a potential race with an automatically-sent 100 Continue triggered by Request.Body.Read() // Prevent a potential race with an automatically-sent 100 Continue triggered by Request.Body.Read()
if code == 100 && w.canWriteContinue.isSet() { if code == 100 && w.canWriteContinue.IsSet() {
w.writeContinueMu.Lock() w.writeContinueMu.Lock()
w.canWriteContinue.setFalse() w.canWriteContinue.SetFalse()
w.writeContinueMu.Unlock() w.writeContinueMu.Unlock()
} }
@ -207,13 +207,13 @@ func (w *response) WriteString(data string) (n int, err error) {
// either dataB or dataS is non-zero. // either dataB or dataS is non-zero.
func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err error) { func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err error) {
if w.canWriteContinue.isSet() { if w.canWriteContinue.IsSet() {
// Body reader wants to write 100 Continue but hasn't yet. // Body reader wants to write 100 Continue but hasn't yet.
// Tell it not to. The store must be done while holding the lock // Tell it not to. The store must be done while holding the lock
// because the lock makes sure that there is not an active write // because the lock makes sure that there is not an active write
// this very moment. // this very moment.
w.writeContinueMu.Lock() w.writeContinueMu.Lock()
w.canWriteContinue.setFalse() w.canWriteContinue.SetFalse()
w.writeContinueMu.Unlock() w.writeContinueMu.Unlock()
} }

View File

@ -480,7 +480,7 @@ func (c *conn) serve(ctx context.Context) {
if req.ProtoAtLeast(1, 1) && req.ContentLength != 0 { if req.ProtoAtLeast(1, 1) && req.ContentLength != 0 {
// Wrap the Body reader with one that replies on the connection // Wrap the Body reader with one that replies on the connection
req.Body = &expectContinueReader{readCloser: req.Body, resp: w} req.Body = &expectContinueReader{readCloser: req.Body, resp: w}
w.canWriteContinue.setTrue() w.canWriteContinue.SetTrue()
} }
} else if req.Header.Get("Expect") != "" { } else if req.Header.Get("Expect") != "" {
w.sendExpectationFailed() w.sendExpectationFailed()

View File

@ -7,6 +7,7 @@ import (
"time" "time"
"wwfc/common" "wwfc/common"
"wwfc/logging" "wwfc/logging"
"wwfc/nhttp"
"github.com/logrusorgru/aurora/v3" "github.com/logrusorgru/aurora/v3"
) )
@ -29,7 +30,7 @@ const (
var ( var (
masterConn net.PacketConn masterConn net.PacketConn
inShutdown = false inShutdown nhttp.AtomicBool
waitGroup = sync.WaitGroup{} waitGroup = sync.WaitGroup{}
) )
@ -44,7 +45,7 @@ func StartServer(reload bool) {
} }
masterConn = conn masterConn = conn
inShutdown = false inShutdown.SetFalse()
if reload { if reload {
err := loadSessions() err := loadSessions()
@ -79,7 +80,7 @@ func StartServer(reload bool) {
logging.Notice("QR2", "Listening on", aurora.BrightCyan(address)) logging.Notice("QR2", "Listening on", aurora.BrightCyan(address))
for { for {
if inShutdown { if inShutdown.IsSet() {
return return
} }
@ -97,7 +98,7 @@ func StartServer(reload bool) {
} }
func Shutdown() { func Shutdown() {
inShutdown = true inShutdown.SetTrue()
masterConn.Close() masterConn.Close()
waitGroup.Wait() waitGroup.Wait()