diff options
author | Michael Stapelberg <stapelberg@debian.org> | 2013-03-04 21:27:36 +0100 |
---|---|---|
committer | Michael Stapelberg <michael@stapelberg.de> | 2013-03-04 21:27:36 +0100 |
commit | 04b08da9af0c450d645ab7389d1467308cfc2db8 (patch) | |
tree | db247935fa4f2f94408edc3acd5d0d4f997aa0d8 /src/pkg/net/http/server.go | |
parent | 917c5fb8ec48e22459d77e3849e6d388f93d3260 (diff) | |
download | golang-upstream/1.1_hg20130304.tar.gz |
Imported Upstream version 1.1~hg20130304upstream/1.1_hg20130304
Diffstat (limited to 'src/pkg/net/http/server.go')
-rw-r--r-- | src/pkg/net/http/server.go | 873 |
1 files changed, 626 insertions, 247 deletions
diff --git a/src/pkg/net/http/server.go b/src/pkg/net/http/server.go index 0572b4ae3..b6ab78228 100644 --- a/src/pkg/net/http/server.go +++ b/src/pkg/net/http/server.go @@ -11,7 +11,6 @@ package http import ( "bufio" - "bytes" "crypto/tls" "errors" "fmt" @@ -21,7 +20,7 @@ import ( "net" "net/url" "path" - "runtime/debug" + "runtime" "strconv" "strings" "sync" @@ -94,30 +93,188 @@ type Hijacker interface { Hijack() (net.Conn, *bufio.ReadWriter, error) } +// The CloseNotifier interface is implemented by ResponseWriters which +// allow detecting when the underlying connection has gone away. +// +// This mechanism can be used to cancel long operations on the server +// if the client has disconnected before the response is ready. +type CloseNotifier interface { + // CloseNotify returns a channel that receives a single value + // when the client connection has gone away. + CloseNotify() <-chan bool +} + // A conn represents the server side of an HTTP connection. type conn struct { remoteAddr string // network address of remote side server *Server // the Server on which the connection arrived rwc net.Conn // i/o connection - lr *io.LimitedReader // io.LimitReader(rwc) - buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->rwc - hijacked bool // connection has been hijacked by handler + sr switchReader // where the LimitReader reads from; usually the rwc + lr *io.LimitedReader // io.LimitReader(sr) + buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->sr->rwc tlsState *tls.ConnectionState // or nil when not using TLS - body []byte + + mu sync.Mutex // guards the following + clientGone bool // if client has disconnected mid-request + closeNotifyc chan bool // made lazily + hijackedv bool // connection has been hijacked by handler +} + +func (c *conn) hijacked() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.hijackedv +} + +func (c *conn) hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.hijackedv { + return nil, nil, ErrHijacked + } + if c.closeNotifyc != nil { + return nil, nil, errors.New("http: Hijack is incompatible with use of CloseNotifier") + } + c.hijackedv = true + rwc = c.rwc + buf = c.buf + c.rwc = nil + c.buf = nil + return +} + +func (c *conn) closeNotify() <-chan bool { + c.mu.Lock() + defer c.mu.Unlock() + if c.closeNotifyc == nil { + c.closeNotifyc = make(chan bool) + if c.hijackedv { + // to obey the function signature, even though + // it'll never receive a value. + return c.closeNotifyc + } + pr, pw := io.Pipe() + + readSource := c.sr.r + c.sr.Lock() + c.sr.r = pr + c.sr.Unlock() + go func() { + _, err := io.Copy(pw, readSource) + if err == nil { + err = io.EOF + } + pw.CloseWithError(err) + c.noteClientGone() + }() + } + return c.closeNotifyc +} + +func (c *conn) noteClientGone() { + c.mu.Lock() + defer c.mu.Unlock() + if c.closeNotifyc != nil && !c.clientGone { + c.closeNotifyc <- true + } + c.clientGone = true +} + +type switchReader struct { + sync.Mutex + r io.Reader +} + +func (sr *switchReader) Read(p []byte) (n int, err error) { + sr.Lock() + r := sr.r + sr.Unlock() + return r.Read(p) +} + +// This should be >= 512 bytes for DetectContentType, +// but otherwise it's somewhat arbitrary. +const bufferBeforeChunkingSize = 2048 + +// chunkWriter writes to a response's conn buffer, and is the writer +// wrapped by the response.bufw buffered writer. +// +// chunkWriter also is responsible for finalizing the Header, including +// conditionally setting the Content-Type and setting a Content-Length +// in cases where the handler's final output is smaller than the buffer +// size. It also conditionally adds chunk headers, when in chunking mode. +// +// See the comment above (*response).Write for the entire write flow. +type chunkWriter struct { + res *response + header Header // a deep copy of r.Header, once WriteHeader is called + wroteHeader bool // whether the header's been sent + + // set by the writeHeader method: + chunking bool // using chunked transfer encoding for reply body +} + +var crlf = []byte("\r\n") + +func (cw *chunkWriter) Write(p []byte) (n int, err error) { + if !cw.wroteHeader { + cw.writeHeader(p) + } + if cw.chunking { + _, err = fmt.Fprintf(cw.res.conn.buf, "%x\r\n", len(p)) + if err != nil { + cw.res.conn.rwc.Close() + return + } + } + n, err = cw.res.conn.buf.Write(p) + if cw.chunking && err == nil { + _, err = cw.res.conn.buf.Write(crlf) + } + if err != nil { + cw.res.conn.rwc.Close() + } + return +} + +func (cw *chunkWriter) flush() { + if !cw.wroteHeader { + cw.writeHeader(nil) + } + cw.res.conn.buf.Flush() +} + +func (cw *chunkWriter) close() { + if !cw.wroteHeader { + cw.writeHeader(nil) + } + if cw.chunking { + // zero EOF chunk, trailer key/value pairs (currently + // unsupported in Go's server), followed by a blank + // line. + io.WriteString(cw.res.conn.buf, "0\r\n\r\n") + } } // A response represents the server side of an HTTP response. type response struct { conn *conn req *Request // request for this response - chunking bool // using chunked transfer encoding for reply body - wroteHeader bool // reply header has been written + wroteHeader bool // reply header has been (logically) written wroteContinue bool // 100 Continue response was written - header Header // reply header parameters - written int64 // number of bytes written in body - contentLength int64 // explicitly-declared Content-Length; or -1 - status int // status code passed to WriteHeader - needSniff bool // need to sniff to find Content-Type + + w *bufio.Writer // buffers output in chunks to chunkWriter + cw *chunkWriter + + // handlerHeader is the Header that Handlers get access to, + // which may be retained and mutated even after WriteHeader. + // handlerHeader is copied into cw.header at WriteHeader + // time, and privately mutated thereafter. + handlerHeader Header + + written int64 // number of bytes written in body + contentLength int64 // explicitly-declared Content-Length; or -1 + status int // status code passed to WriteHeader // close connection after this reply. set on request and // updated after response from handler if there's a @@ -127,12 +284,14 @@ type response struct { // requestBodyLimitHit is set by requestTooLarge when // maxBytesReader hits its max size. It is checked in - // WriteHeader, to make sure we don't consume the the + // WriteHeader, to make sure we don't consume the // remaining request body to try to advance to the next HTTP - // request. Instead, when this is set, we stop doing + // request. Instead, when this is set, we stop reading // subsequent requests on this connection and stop reading // input from it. requestBodyLimitHit bool + + handlerDone bool // set true when the handler exits } // requestTooLarge is called by maxBytesReader when too much input has @@ -145,42 +304,68 @@ func (w *response) requestTooLarge() { } } +// needsSniff returns whether a Content-Type still needs to be sniffed. +func (w *response) needsSniff() bool { + return !w.cw.wroteHeader && w.handlerHeader.Get("Content-Type") == "" && w.written < sniffLen +} + type writerOnly struct { io.Writer } func (w *response) ReadFrom(src io.Reader) (n int64, err error) { - // Call WriteHeader before checking w.chunking if it hasn't - // been called yet, since WriteHeader is what sets w.chunking. if !w.wroteHeader { w.WriteHeader(StatusOK) } - if !w.chunking && w.bodyAllowed() && !w.needSniff { - w.Flush() + + if w.needsSniff() { + n0, err := io.Copy(writerOnly{w}, io.LimitReader(src, sniffLen)) + n += n0 + if err != nil { + return n, err + } + } + + w.w.Flush() // get rid of any previous writes + w.cw.flush() // make sure Header is written; flush data to rwc + + // Now that cw has been flushed, its chunking field is guaranteed initialized. + if !w.cw.chunking && w.bodyAllowed() { if rf, ok := w.conn.rwc.(io.ReaderFrom); ok { - n, err = rf.ReadFrom(src) - w.written += n - return + n0, err := rf.ReadFrom(src) + n += n0 + w.written += n0 + return n, err } } + // Fall back to default io.Copy implementation. // Use wrapper to hide w.ReadFrom from io.Copy. - return io.Copy(writerOnly{w}, src) + n0, err := io.Copy(writerOnly{w}, src) + n += n0 + return n, err } // noLimit is an effective infinite upper bound for io.LimitedReader const noLimit int64 = (1 << 63) - 1 +// debugServerConnections controls whether all server connections are wrapped +// with a verbose logging wrapper. +const debugServerConnections = false + // Create new connection from rwc. func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) { c = new(conn) c.remoteAddr = rwc.RemoteAddr().String() c.server = srv c.rwc = rwc - c.body = make([]byte, sniffLen) - c.lr = io.LimitReader(rwc, noLimit).(*io.LimitedReader) + if debugServerConnections { + c.rwc = newLoggingConn("server", c.rwc) + } + c.sr = switchReader{r: c.rwc} + c.lr = io.LimitReader(&c.sr, noLimit).(*io.LimitedReader) br := bufio.NewReader(c.lr) - bw := bufio.NewWriter(rwc) + bw := bufio.NewWriter(c.rwc) c.buf = bufio.NewReadWriter(br, bw) return c, nil } @@ -207,9 +392,9 @@ type expectContinueReader struct { func (ecr *expectContinueReader) Read(p []byte) (n int, err error) { if ecr.closed { - return 0, errors.New("http: Read after Close on request Body") + return 0, ErrBodyReadAfterClose } - if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked { + if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked() { ecr.resp.wroteContinue = true io.WriteString(ecr.resp.conn.buf, "HTTP/1.1 100 Continue\r\n\r\n") ecr.resp.conn.buf.Flush() @@ -232,9 +417,19 @@ var errTooLarge = errors.New("http: request too large") // Read next request from connection. func (c *conn) readRequest() (w *response, err error) { - if c.hijacked { + if c.hijacked() { return nil, ErrHijacked } + + if d := c.server.ReadTimeout; d != 0 { + c.rwc.SetReadDeadline(time.Now().Add(d)) + } + if d := c.server.WriteTimeout; d != 0 { + defer func() { + c.rwc.SetWriteDeadline(time.Now().Add(d)) + }() + } + c.lr.N = int64(c.server.maxHeaderBytes()) + 4096 /* bufio slop */ var req *Request if req, err = ReadRequest(c.buf.Reader); err != nil { @@ -248,17 +443,20 @@ func (c *conn) readRequest() (w *response, err error) { req.RemoteAddr = c.remoteAddr req.TLS = c.tlsState - w = new(response) - w.conn = c - w.req = req - w.header = make(Header) - w.contentLength = -1 - c.body = c.body[:0] + w = &response{ + conn: c, + req: req, + handlerHeader: make(Header), + contentLength: -1, + cw: new(chunkWriter), + } + w.cw.res = w + w.w = bufio.NewWriterSize(w.cw, bufferBeforeChunkingSize) return w, nil } func (w *response) Header() Header { - return w.header + return w.handlerHeader } // maxPostHandlerReadBytes is the max number of Request.Body bytes not @@ -273,7 +471,7 @@ func (w *response) Header() Header { const maxPostHandlerReadBytes = 256 << 10 func (w *response) WriteHeader(code int) { - if w.conn.hijacked { + if w.conn.hijacked() { log.Print("http: response.WriteHeader on hijacked connection") return } @@ -284,31 +482,68 @@ func (w *response) WriteHeader(code int) { w.wroteHeader = true w.status = code - // Check for a explicit (and valid) Content-Length header. - var hasCL bool - var contentLength int64 - if clenStr := w.header.Get("Content-Length"); clenStr != "" { - var err error - contentLength, err = strconv.ParseInt(clenStr, 10, 64) - if err == nil { - hasCL = true + w.cw.header = w.handlerHeader.clone() + + if cl := w.cw.header.get("Content-Length"); cl != "" { + v, err := strconv.ParseInt(cl, 10, 64) + if err == nil && v >= 0 { + w.contentLength = v } else { - log.Printf("http: invalid Content-Length of %q sent", clenStr) - w.header.Del("Content-Length") + log.Printf("http: invalid Content-Length of %q", cl) + w.cw.header.Del("Content-Length") + } + } +} + +// writeHeader finalizes the header sent to the client and writes it +// to cw.res.conn.buf. +// +// p is not written by writeHeader, but is the first chunk of the body +// that will be written. It is sniffed for a Content-Type if none is +// set explicitly. It's also used to set the Content-Length, if the +// total body size was small and the handler has already finished +// running. +func (cw *chunkWriter) writeHeader(p []byte) { + if cw.wroteHeader { + return + } + cw.wroteHeader = true + + w := cw.res + code := w.status + done := w.handlerDone + + // If the handler is done but never sent a Content-Length + // response header and this is our first (and last) write, set + // it, even to zero. This helps HTTP/1.0 clients keep their + // "keep-alive" connections alive. + if done && cw.header.get("Content-Length") == "" && w.req.Method != "HEAD" { + w.contentLength = int64(len(p)) + cw.header.Set("Content-Length", strconv.Itoa(len(p))) + } + + // If this was an HTTP/1.0 request with keep-alive and we sent a + // Content-Length back, we can make this a keep-alive response ... + if w.req.wantsHttp10KeepAlive() { + sentLength := cw.header.get("Content-Length") != "" + if sentLength && cw.header.get("Connection") == "keep-alive" { + w.closeAfterReply = false } } + // Check for a explicit (and valid) Content-Length header. + hasCL := w.contentLength != -1 + if w.req.wantsHttp10KeepAlive() && (w.req.Method == "HEAD" || hasCL) { - _, connectionHeaderSet := w.header["Connection"] + _, connectionHeaderSet := cw.header["Connection"] if !connectionHeaderSet { - w.header.Set("Connection", "keep-alive") + cw.header.Set("Connection", "keep-alive") } - } else if !w.req.ProtoAtLeast(1, 1) { - // Client did not ask to keep connection alive. + } else if !w.req.ProtoAtLeast(1, 1) || w.req.wantsClose() { w.closeAfterReply = true } - if w.header.Get("Connection") == "close" { + if cw.header.get("Connection") == "close" { w.closeAfterReply = true } @@ -322,7 +557,7 @@ func (w *response) WriteHeader(code int) { n, _ := io.CopyN(ioutil.Discard, w.req.Body, maxPostHandlerReadBytes+1) if n >= maxPostHandlerReadBytes { w.requestTooLarge() - w.header.Set("Connection", "close") + cw.header.Set("Connection", "close") } else { w.req.Body.Close() } @@ -332,64 +567,67 @@ func (w *response) WriteHeader(code int) { if code == StatusNotModified { // Must not have body. for _, header := range []string{"Content-Type", "Content-Length", "Transfer-Encoding"} { - if w.header.Get(header) != "" { - // TODO: return an error if WriteHeader gets a return parameter - // or set a flag on w to make future Writes() write an error page? - // for now just log and drop the header. - log.Printf("http: StatusNotModified response with header %q defined", header) - w.header.Del(header) + // RFC 2616 section 10.3.5: "the response MUST NOT include other entity-headers" + if cw.header.get(header) != "" { + cw.header.Del(header) } } } else { // If no content type, apply sniffing algorithm to body. - if w.header.Get("Content-Type") == "" && w.req.Method != "HEAD" { - w.needSniff = true + if cw.header.get("Content-Type") == "" && w.req.Method != "HEAD" { + cw.header.Set("Content-Type", DetectContentType(p)) } } - if _, ok := w.header["Date"]; !ok { - w.Header().Set("Date", time.Now().UTC().Format(TimeFormat)) + if _, ok := cw.header["Date"]; !ok { + cw.header.Set("Date", time.Now().UTC().Format(TimeFormat)) } - te := w.header.Get("Transfer-Encoding") + te := cw.header.get("Transfer-Encoding") hasTE := te != "" if hasCL && hasTE && te != "identity" { // TODO: return an error if WriteHeader gets a return parameter // For now just ignore the Content-Length. log.Printf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d", - te, contentLength) - w.header.Del("Content-Length") + te, w.contentLength) + cw.header.Del("Content-Length") hasCL = false } if w.req.Method == "HEAD" || code == StatusNotModified { // do nothing + } else if code == StatusNoContent { + cw.header.Del("Transfer-Encoding") } else if hasCL { - w.contentLength = contentLength - w.header.Del("Transfer-Encoding") + cw.header.Del("Transfer-Encoding") } else if w.req.ProtoAtLeast(1, 1) { // HTTP/1.1 or greater: use chunked transfer encoding // to avoid closing the connection at EOF. // TODO: this blows away any custom or stacked Transfer-Encoding they // might have set. Deal with that as need arises once we have a valid // use case. - w.chunking = true - w.header.Set("Transfer-Encoding", "chunked") + cw.chunking = true + cw.header.Set("Transfer-Encoding", "chunked") } else { // HTTP version < 1.1: cannot do chunked transfer // encoding and we don't know the Content-Length so // signal EOF by closing connection. w.closeAfterReply = true - w.header.Del("Transfer-Encoding") // in case already set + cw.header.Del("Transfer-Encoding") // in case already set } // Cannot use Content-Length with non-identity Transfer-Encoding. - if w.chunking { - w.header.Del("Content-Length") + if cw.chunking { + cw.header.Del("Content-Length") } if !w.req.ProtoAtLeast(1, 0) { return } + + if w.closeAfterReply && !hasToken(cw.header.get("Connection"), "close") { + cw.header.Set("Connection", "close") + } + proto := "HTTP/1.0" if w.req.ProtoAtLeast(1, 1) { proto = "HTTP/1.1" @@ -400,37 +638,8 @@ func (w *response) WriteHeader(code int) { text = "status code " + codestring } io.WriteString(w.conn.buf, proto+" "+codestring+" "+text+"\r\n") - w.header.Write(w.conn.buf) - - // If we need to sniff the body, leave the header open. - // Otherwise, end it here. - if !w.needSniff { - io.WriteString(w.conn.buf, "\r\n") - } -} - -// sniff uses the first block of written data, -// stored in w.conn.body, to decide the Content-Type -// for the HTTP body. -func (w *response) sniff() { - if !w.needSniff { - return - } - w.needSniff = false - - data := w.conn.body - fmt.Fprintf(w.conn.buf, "Content-Type: %s\r\n\r\n", DetectContentType(data)) - - if len(data) == 0 { - return - } - if w.chunking { - fmt.Fprintf(w.conn.buf, "%x\r\n", len(data)) - } - _, err := w.conn.buf.Write(data) - if w.chunking && err == nil { - io.WriteString(w.conn.buf, "\r\n") - } + cw.header.Write(w.conn.buf) + w.conn.buf.Write(crlf) } // bodyAllowed returns true if a Write is allowed for this response type. @@ -442,8 +651,40 @@ func (w *response) bodyAllowed() bool { return w.status != StatusNotModified && w.req.Method != "HEAD" } +// The Life Of A Write is like this: +// +// Handler starts. No header has been sent. The handler can either +// write a header, or just start writing. Writing before sending a header +// sends an implicity empty 200 OK header. +// +// If the handler didn't declare a Content-Length up front, we either +// go into chunking mode or, if the handler finishes running before +// the chunking buffer size, we compute a Content-Length and send that +// in the header instead. +// +// Likewise, if the handler didn't set a Content-Type, we sniff that +// from the initial chunk of output. +// +// The Writers are wired together like: +// +// 1. *response (the ResponseWriter) -> +// 2. (*response).w, a *bufio.Writer of bufferBeforeChunkingSize bytes +// 3. chunkWriter.Writer (whose writeHeader finalizes Content-Length/Type) +// and which writes the chunk headers, if needed. +// 4. conn.buf, a bufio.Writer of default (4kB) bytes +// 5. the rwc, the net.Conn. +// +// TODO(bradfitz): short-circuit some of the buffering when the +// initial header contains both a Content-Type and Content-Length. +// Also short-circuit in (1) when the header's been sent and not in +// chunking mode, writing directly to (4) instead, if (2) has no +// buffered data. More generally, we could short-circuit from (1) to +// (3) even in chunking mode if the write size from (1) is over some +// threshold and nothing is in (2). The answer might be mostly making +// bufferBeforeChunkingSize smaller and having bufio's fast-paths deal +// with this instead. func (w *response) Write(data []byte) (n int, err error) { - if w.conn.hijacked { + if w.conn.hijacked() { log.Print("http: response.Write on hijacked connection") return 0, ErrHijacked } @@ -461,73 +702,20 @@ func (w *response) Write(data []byte) (n int, err error) { if w.contentLength != -1 && w.written > w.contentLength { return 0, ErrContentLength } - - var m int - if w.needSniff { - // We need to sniff the beginning of the output to - // determine the content type. Accumulate the - // initial writes in w.conn.body. - // Cap m so that append won't allocate. - m = cap(w.conn.body) - len(w.conn.body) - if m > len(data) { - m = len(data) - } - w.conn.body = append(w.conn.body, data[:m]...) - data = data[m:] - if len(data) == 0 { - // Copied everything into the buffer. - // Wait for next write. - return m, nil - } - - // Filled the buffer; more data remains. - // Sniff the content (flushes the buffer) - // and then proceed with the remainder - // of the data as a normal Write. - // Calling sniff clears needSniff. - w.sniff() - } - - // TODO(rsc): if chunking happened after the buffering, - // then there would be fewer chunk headers. - // On the other hand, it would make hijacking more difficult. - if w.chunking { - fmt.Fprintf(w.conn.buf, "%x\r\n", len(data)) // TODO(rsc): use strconv not fmt - } - n, err = w.conn.buf.Write(data) - if err == nil && w.chunking { - if n != len(data) { - err = io.ErrShortWrite - } - if err == nil { - io.WriteString(w.conn.buf, "\r\n") - } - } - - return m + n, err + return w.w.Write(data) } func (w *response) finishRequest() { - // If this was an HTTP/1.0 request with keep-alive and we sent a Content-Length - // back, we can make this a keep-alive response ... - if w.req.wantsHttp10KeepAlive() { - sentLength := w.header.Get("Content-Length") != "" - if sentLength && w.header.Get("Connection") == "keep-alive" { - w.closeAfterReply = false - } - } + w.handlerDone = true + if !w.wroteHeader { w.WriteHeader(StatusOK) } - if w.needSniff { - w.sniff() - } - if w.chunking { - io.WriteString(w.conn.buf, "0\r\n") - // trailer key/value pairs, followed by blank line - io.WriteString(w.conn.buf, "\r\n") - } + + w.w.Flush() + w.cw.close() w.conn.buf.Flush() + // Close the body, unless we're about to close the whole TCP connection // anyway. if !w.closeAfterReply { @@ -537,7 +725,7 @@ func (w *response) finishRequest() { w.req.MultipartForm.RemoveAll() } - if w.contentLength != -1 && w.contentLength != w.written { + if w.contentLength != -1 && w.bodyAllowed() && w.contentLength != w.written { // Did not write enough. Avoid getting out of sync. w.closeAfterReply = true } @@ -547,66 +735,114 @@ func (w *response) Flush() { if !w.wroteHeader { w.WriteHeader(StatusOK) } - w.sniff() - w.conn.buf.Flush() + w.w.Flush() + w.cw.flush() } -// Close the connection. -func (c *conn) close() { +func (c *conn) finalFlush() { if c.buf != nil { c.buf.Flush() c.buf = nil } +} + +// Close the connection. +func (c *conn) close() { + c.finalFlush() if c.rwc != nil { c.rwc.Close() c.rwc = nil } } +// rstAvoidanceDelay is the amount of time we sleep after closing the +// write side of a TCP connection before closing the entire socket. +// By sleeping, we increase the chances that the client sees our FIN +// and processes its final data before they process the subsequent RST +// from closing a connection with known unread data. +// This RST seems to occur mostly on BSD systems. (And Windows?) +// This timeout is somewhat arbitrary (~latency around the planet). +const rstAvoidanceDelay = 500 * time.Millisecond + +// closeWrite flushes any outstanding data and sends a FIN packet (if +// client is connected via TCP), signalling that we're done. We then +// pause for a bit, hoping the client processes it before `any +// subsequent RST. +// +// See http://golang.org/issue/3595 +func (c *conn) closeWriteAndWait() { + c.finalFlush() + if tcp, ok := c.rwc.(*net.TCPConn); ok { + tcp.CloseWrite() + } + time.Sleep(rstAvoidanceDelay) +} + +// validNPN returns whether the proto is not a blacklisted Next +// Protocol Negotiation protocol. Empty and built-in protocol types +// are blacklisted and can't be overridden with alternate +// implementations. +func validNPN(proto string) bool { + switch proto { + case "", "http/1.1", "http/1.0": + return false + } + return true +} + // Serve a new connection. func (c *conn) serve() { defer func() { - err := recover() - if err == nil { - return + if err := recover(); err != nil { + const size = 4096 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + log.Printf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf) } - - var buf bytes.Buffer - fmt.Fprintf(&buf, "http: panic serving %v: %v\n", c.remoteAddr, err) - buf.Write(debug.Stack()) - log.Print(buf.String()) - - if c.rwc != nil { // may be nil if connection hijacked - c.rwc.Close() + if !c.hijacked() { + c.close() } }() if tlsConn, ok := c.rwc.(*tls.Conn); ok { + if d := c.server.ReadTimeout; d != 0 { + c.rwc.SetReadDeadline(time.Now().Add(d)) + } + if d := c.server.WriteTimeout; d != 0 { + c.rwc.SetWriteDeadline(time.Now().Add(d)) + } if err := tlsConn.Handshake(); err != nil { - c.close() return } c.tlsState = new(tls.ConnectionState) *c.tlsState = tlsConn.ConnectionState() + if proto := c.tlsState.NegotiatedProtocol; validNPN(proto) { + if fn := c.server.TLSNextProto[proto]; fn != nil { + h := initNPNRequest{tlsConn, serverHandler{c.server}} + fn(c.server, tlsConn, h) + } + return + } } for { w, err := c.readRequest() if err != nil { - msg := "400 Bad Request" if err == errTooLarge { // Their HTTP client may or may not be // able to read this if we're // responding to them and hanging up // while they're still writing their // request. Undefined behavior. - msg = "413 Request Entity Too Large" + io.WriteString(c.rwc, "HTTP/1.1 413 Request Entity Too Large\r\n\r\n") + c.closeWriteAndWait() + break } else if err == io.EOF { break // Don't reply } else if neterr, ok := err.(net.Error); ok && neterr.Timeout() { break // Don't reply } - fmt.Fprintf(c.rwc, "HTTP/1.1 %s\r\n\r\n", msg) + io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\n\r\n") break } @@ -624,59 +860,59 @@ func (c *conn) serve() { break } req.Header.Del("Expect") - } else if req.Header.Get("Expect") != "" { - // TODO(bradfitz): let ServeHTTP handlers handle - // requests with non-standard expectation[s]? Seems - // theoretical at best, and doesn't fit into the - // current ServeHTTP model anyway. We'd need to - // make the ResponseWriter an optional - // "ExpectReplier" interface or something. - // - // For now we'll just obey RFC 2616 14.20 which says - // "If a server receives a request containing an - // Expect field that includes an expectation- - // extension that it does not support, it MUST - // respond with a 417 (Expectation Failed) status." - w.Header().Set("Connection", "close") - w.WriteHeader(StatusExpectationFailed) - w.finishRequest() + } else if req.Header.get("Expect") != "" { + w.sendExpectationFailed() break } - handler := c.server.Handler - if handler == nil { - handler = DefaultServeMux - } - // HTTP cannot have multiple simultaneous active requests.[*] // Until the server replies to this request, it can't read another, // so we might as well run the handler in this goroutine. // [*] Not strictly true: HTTP pipelining. We could let them all process // in parallel even if their responses need to be serialized. - handler.ServeHTTP(w, w.req) - if c.hijacked { + serverHandler{c.server}.ServeHTTP(w, w.req) + if c.hijacked() { return } w.finishRequest() if w.closeAfterReply { + if w.requestBodyLimitHit { + c.closeWriteAndWait() + } break } } - c.close() +} + +func (w *response) sendExpectationFailed() { + // TODO(bradfitz): let ServeHTTP handlers handle + // requests with non-standard expectation[s]? Seems + // theoretical at best, and doesn't fit into the + // current ServeHTTP model anyway. We'd need to + // make the ResponseWriter an optional + // "ExpectReplier" interface or something. + // + // For now we'll just obey RFC 2616 14.20 which says + // "If a server receives a request containing an + // Expect field that includes an expectation- + // extension that it does not support, it MUST + // respond with a 417 (Expectation Failed) status." + w.Header().Set("Connection", "close") + w.WriteHeader(StatusExpectationFailed) + w.finishRequest() } // Hijack implements the Hijacker.Hijack method. Our response is both a ResponseWriter // and a Hijacker. func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { - if w.conn.hijacked { - return nil, nil, ErrHijacked + if w.wroteHeader { + w.cw.flush() } - w.conn.hijacked = true - rwc = w.conn.rwc - buf = w.conn.buf - w.conn.rwc = nil - w.conn.buf = nil - return + return w.conn.hijack() +} + +func (w *response) CloseNotify() <-chan bool { + return w.conn.closeNotify() } // The HandlerFunc type is an adapter to allow the use of @@ -817,13 +1053,13 @@ func RedirectHandler(url string, code int) Handler { // patterns and calls the handler for the pattern that // most closely matches the URL. // -// Patterns named fixed, rooted paths, like "/favicon.ico", +// Patterns name fixed, rooted paths, like "/favicon.ico", // or rooted subtrees, like "/images/" (note the trailing slash). // Longer patterns take precedence over shorter ones, so that // if there are handlers registered for both "/images/" // and "/images/thumbnails/", the latter handler will be // called for paths beginning "/images/thumbnails/" and the -// former will receiver requests for any other paths in the +// former will receive requests for any other paths in the // "/images/" subtree. // // Patterns may optionally begin with a host name, restricting matches to @@ -836,13 +1072,15 @@ func RedirectHandler(url string, code int) Handler { // redirecting any request containing . or .. elements to an // equivalent .- and ..-free URL. type ServeMux struct { - mu sync.RWMutex - m map[string]muxEntry + mu sync.RWMutex + m map[string]muxEntry + hosts bool // whether any patterns contain hostnames } type muxEntry struct { explicit bool h Handler + pattern string } // NewServeMux allocates and returns a new ServeMux. @@ -883,8 +1121,7 @@ func cleanPath(p string) string { // Find a handler on a handler map given a path string // Most-specific (longest) pattern wins -func (mux *ServeMux) match(path string) Handler { - var h Handler +func (mux *ServeMux) match(path string) (h Handler, pattern string) { var n = 0 for k, v := range mux.m { if !pathMatch(k, path) { @@ -893,37 +1130,64 @@ func (mux *ServeMux) match(path string) Handler { if h == nil || len(k) > n { n = len(k) h = v.h + pattern = v.pattern + } + } + return +} + +// Handler returns the handler to use for the given request, +// consulting r.Method, r.Host, and r.URL.Path. It always returns +// a non-nil handler. If the path is not in its canonical form, the +// handler will be an internally-generated handler that redirects +// to the canonical path. +// +// Handler also returns the registered pattern that matches the +// request or, in the case of internally-generated redirects, +// the pattern that will match after following the redirect. +// +// If there is no registered handler that applies to the request, +// Handler returns a ``page not found'' handler and an empty pattern. +func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) { + if r.Method != "CONNECT" { + if p := cleanPath(r.URL.Path); p != r.URL.Path { + _, pattern = mux.handler(r.Host, p) + return RedirectHandler(p, StatusMovedPermanently), pattern } } - return h + + return mux.handler(r.Host, r.URL.Path) } -// handler returns the handler to use for the request r. -func (mux *ServeMux) handler(r *Request) Handler { +// handler is the main implementation of Handler. +// The path is known to be in canonical form, except for CONNECT methods. +func (mux *ServeMux) handler(host, path string) (h Handler, pattern string) { mux.mu.RLock() defer mux.mu.RUnlock() // Host-specific pattern takes precedence over generic ones - h := mux.match(r.Host + r.URL.Path) + if mux.hosts { + h, pattern = mux.match(host + path) + } if h == nil { - h = mux.match(r.URL.Path) + h, pattern = mux.match(path) } if h == nil { - h = NotFoundHandler() + h, pattern = NotFoundHandler(), "" } - return h + return } // ServeHTTP dispatches the request to the handler whose // pattern most closely matches the request URL. func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) { - // Clean path to canonical form and redirect. - if p := cleanPath(r.URL.Path); p != r.URL.Path { - w.Header().Set("Location", p) - w.WriteHeader(StatusMovedPermanently) + if r.RequestURI == "*" { + w.Header().Set("Connection", "close") + w.WriteHeader(StatusBadRequest) return } - mux.handler(r).ServeHTTP(w, r) + h, _ := mux.Handler(r) + h.ServeHTTP(w, r) } // Handle registers the handler for the given pattern. @@ -942,14 +1206,26 @@ func (mux *ServeMux) Handle(pattern string, handler Handler) { panic("http: multiple registrations for " + pattern) } - mux.m[pattern] = muxEntry{explicit: true, h: handler} + mux.m[pattern] = muxEntry{explicit: true, h: handler, pattern: pattern} + + if pattern[0] != '/' { + mux.hosts = true + } // Helpful behavior: // If pattern is /tree/, insert an implicit permanent redirect for /tree. // It can be overridden by an explicit registration. n := len(pattern) if n > 0 && pattern[n-1] == '/' && !mux.m[pattern[0:n-1]].explicit { - mux.m[pattern[0:n-1]] = muxEntry{h: RedirectHandler(pattern, StatusMovedPermanently)} + // If pattern contains a host name, strip it and use remaining + // path for redirect. + path := pattern + if pattern[0] != '/' { + // In pattern, at least the last character is a '/', so + // strings.Index can't be -1. + path = pattern[strings.Index(pattern, "/"):] + } + mux.m[pattern[0:n-1]] = muxEntry{h: RedirectHandler(path, StatusMovedPermanently), pattern: pattern} } } @@ -971,7 +1247,7 @@ func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { } // Serve accepts incoming HTTP connections on the listener l, -// creating a new service thread for each. The service threads +// creating a new service goroutine for each. The service goroutines // read requests and then call handler to reply to them. // Handler is typically nil, in which case the DefaultServeMux is used. func Serve(l net.Listener, handler Handler) error { @@ -987,6 +1263,32 @@ type Server struct { WriteTimeout time.Duration // maximum duration before timing out write of the response MaxHeaderBytes int // maximum size of request headers, DefaultMaxHeaderBytes if 0 TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS + + // TLSNextProto optionally specifies a function to take over + // ownership of the provided TLS connection when an NPN + // protocol upgrade has occured. The map key is the protocol + // name negotiated. The Handler argument should be used to + // handle HTTP requests and will initialize the Request's TLS + // and RemoteAddr if not already set. The connection is + // automatically closed when the function returns. + TLSNextProto map[string]func(*Server, *tls.Conn, Handler) +} + +// serverHandler delegates to either the server's Handler or +// DefaultServeMux and also handles "OPTIONS *" requests. +type serverHandler struct { + srv *Server +} + +func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) { + handler := sh.srv.Handler + if handler == nil { + handler = DefaultServeMux + } + if req.RequestURI == "*" && req.Method == "OPTIONS" { + handler = globalOptionsHandler{} + } + handler.ServeHTTP(rw, req) } // ListenAndServe listens on the TCP network address srv.Addr and then @@ -1005,7 +1307,7 @@ func (srv *Server) ListenAndServe() error { } // Serve accepts incoming connections on the Listener l, creating a -// new service thread for each. The service threads read requests and +// new service goroutine for each. The service goroutines read requests and // then call srv.Handler to reply to them. func (srv *Server) Serve(l net.Listener) error { defer l.Close() @@ -1029,12 +1331,6 @@ func (srv *Server) Serve(l net.Listener) error { return e } tempDelay = 0 - if srv.ReadTimeout != 0 { - rw.SetReadDeadline(time.Now().Add(srv.ReadTimeout)) - } - if srv.WriteTimeout != 0 { - rw.SetWriteDeadline(time.Now().Add(srv.WriteTimeout)) - } c, err := srv.newConn(rw) if err != nil { continue @@ -1150,7 +1446,7 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { // TimeoutHandler returns a Handler that runs h with the given time limit. // // The new Handler calls h.ServeHTTP to handle each request, but if a -// call runs for more than ns nanoseconds, the handler responds with +// call runs for longer than its time limit, the handler responds with // a 503 Service Unavailable error and the given message in its body. // (If msg is empty, a suitable default message will be sent.) // After such a timeout, writes by h to its ResponseWriter will return @@ -1180,7 +1476,7 @@ func (h *timeoutHandler) errorBody() string { } func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { - done := make(chan bool) + done := make(chan bool, 1) tw := &timeoutWriter{w: w} go func() { h.handler.ServeHTTP(tw, r) @@ -1232,3 +1528,86 @@ func (tw *timeoutWriter) WriteHeader(code int) { tw.mu.Unlock() tw.w.WriteHeader(code) } + +// globalOptionsHandler responds to "OPTIONS *" requests. +type globalOptionsHandler struct{} + +func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", "0") + if r.ContentLength != 0 { + // Read up to 4KB of OPTIONS body (as mentioned in the + // spec as being reserved for future use), but anything + // over that is considered a waste of server resources + // (or an attack) and we abort and close the connection, + // courtesy of MaxBytesReader's EOF behavior. + mb := MaxBytesReader(w, r.Body, 4<<10) + io.Copy(ioutil.Discard, mb) + } +} + +// eofReader is a non-nil io.ReadCloser that always returns EOF. +var eofReader = ioutil.NopCloser(strings.NewReader("")) + +// initNPNRequest is an HTTP handler that initializes certain +// uninitialized fields in its *Request. Such partially-initialized +// Requests come from NPN protocol handlers. +type initNPNRequest struct { + c *tls.Conn + h serverHandler +} + +func (h initNPNRequest) ServeHTTP(rw ResponseWriter, req *Request) { + if req.TLS == nil { + req.TLS = &tls.ConnectionState{} + *req.TLS = h.c.ConnectionState() + } + if req.Body == nil { + req.Body = eofReader + } + if req.RemoteAddr == "" { + req.RemoteAddr = h.c.RemoteAddr().String() + } + h.h.ServeHTTP(rw, req) +} + +// loggingConn is used for debugging. +type loggingConn struct { + name string + net.Conn +} + +var ( + uniqNameMu sync.Mutex + uniqNameNext = make(map[string]int) +) + +func newLoggingConn(baseName string, c net.Conn) net.Conn { + uniqNameMu.Lock() + defer uniqNameMu.Unlock() + uniqNameNext[baseName]++ + return &loggingConn{ + name: fmt.Sprintf("%s-%d", baseName, uniqNameNext[baseName]), + Conn: c, + } +} + +func (c *loggingConn) Write(p []byte) (n int, err error) { + log.Printf("%s.Write(%d) = ....", c.name, len(p)) + n, err = c.Conn.Write(p) + log.Printf("%s.Write(%d) = %d, %v", c.name, len(p), n, err) + return +} + +func (c *loggingConn) Read(p []byte) (n int, err error) { + log.Printf("%s.Read(%d) = ....", c.name, len(p)) + n, err = c.Conn.Read(p) + log.Printf("%s.Read(%d) = %d, %v", c.name, len(p), n, err) + return +} + +func (c *loggingConn) Close() (err error) { + log.Printf("%s.Close() = ...", c.name) + err = c.Conn.Close() + log.Printf("%s.Close() = %v", c.name, err) + return +} |