diff options
author | Michael Stapelberg <stapelberg@debian.org> | 2014-06-19 09:22:53 +0200 |
---|---|---|
committer | Michael Stapelberg <stapelberg@debian.org> | 2014-06-19 09:22:53 +0200 |
commit | 8a39ee361feb9bf46d728ff1ba4f07ca1d9610b1 (patch) | |
tree | 4449f2036cccf162e8417cc5841a35815b3e7ac5 /src/pkg/net/http | |
parent | c8bf49ef8a92e2337b69c14b9b88396efe498600 (diff) | |
download | golang-51f2ca399fb8da86b2e7b3a0582e083fab731a98.tar.gz |
Imported Upstream version 1.3upstream/1.3
Diffstat (limited to 'src/pkg/net/http')
37 files changed, 3337 insertions, 634 deletions
diff --git a/src/pkg/net/http/cgi/host.go b/src/pkg/net/http/cgi/host.go index d27cc4dc9..ec95a972c 100644 --- a/src/pkg/net/http/cgi/host.go +++ b/src/pkg/net/http/cgi/host.go @@ -214,12 +214,17 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { internalError(err) return } + if hook := testHookStartProcess; hook != nil { + hook(cmd.Process) + } defer cmd.Wait() defer stdoutRead.Close() linebody := bufio.NewReaderSize(stdoutRead, 1024) headers := make(http.Header) statusCode := 0 + headerLines := 0 + sawBlankLine := false for { line, isPrefix, err := linebody.ReadLine() if isPrefix { @@ -236,8 +241,10 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } if len(line) == 0 { + sawBlankLine = true break } + headerLines++ parts := strings.SplitN(string(line), ":", 2) if len(parts) < 2 { h.printf("cgi: bogus header line: %s", string(line)) @@ -263,6 +270,11 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { headers.Add(header, val) } } + if headerLines == 0 || !sawBlankLine { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("cgi: no headers") + return + } if loc := headers.Get("Location"); loc != "" { if strings.HasPrefix(loc, "/") && h.PathLocationHandler != nil { @@ -274,6 +286,12 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } } + if statusCode == 0 && headers.Get("Content-Type") == "" { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("cgi: missing required Content-Type in headers") + return + } + if statusCode == 0 { statusCode = http.StatusOK } @@ -292,6 +310,13 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { _, err = io.Copy(rw, linebody) if err != nil { h.printf("cgi: copy error: %v", err) + // And kill the child CGI process so we don't hang on + // the deferred cmd.Wait above if the error was just + // the client (rw) going away. If it was a read error + // (because the child died itself), then the extra + // kill of an already-dead process is harmless (the PID + // won't be reused until the Wait above). + cmd.Process.Kill() } } @@ -348,3 +373,5 @@ func upperCaseAndUnderscore(r rune) rune { // TODO: other transformations in spec or practice? return r } + +var testHookStartProcess func(*os.Process) // nil except for some tests diff --git a/src/pkg/net/http/cgi/matryoshka_test.go b/src/pkg/net/http/cgi/matryoshka_test.go index e1a78c8f6..18c4803e7 100644 --- a/src/pkg/net/http/cgi/matryoshka_test.go +++ b/src/pkg/net/http/cgi/matryoshka_test.go @@ -9,15 +9,25 @@ package cgi import ( + "bytes" + "errors" "fmt" + "io" "net/http" + "net/http/httptest" "os" + "runtime" "testing" + "time" ) // This test is a CGI host (testing host.go) that runs its own binary // as a child process testing the other half of CGI (child.go). func TestHostingOurselves(t *testing.T) { + if runtime.GOOS == "nacl" { + t.Skip("skipping on nacl") + } + h := &Handler{ Path: os.Args[0], Root: "/test.go", @@ -51,8 +61,88 @@ func TestHostingOurselves(t *testing.T) { } } -// Test that a child handler only writing headers works. +type customWriterRecorder struct { + w io.Writer + *httptest.ResponseRecorder +} + +func (r *customWriterRecorder) Write(p []byte) (n int, err error) { + return r.w.Write(p) +} + +type limitWriter struct { + w io.Writer + n int +} + +func (w *limitWriter) Write(p []byte) (n int, err error) { + if len(p) > w.n { + p = p[:w.n] + } + if len(p) > 0 { + n, err = w.w.Write(p) + w.n -= n + } + if w.n == 0 { + err = errors.New("past write limit") + } + return +} + +// If there's an error copying the child's output to the parent, test +// that we kill the child. +func TestKillChildAfterCopyError(t *testing.T) { + if runtime.GOOS == "nacl" { + t.Skip("skipping on nacl") + } + + defer func() { testHookStartProcess = nil }() + proc := make(chan *os.Process, 1) + testHookStartProcess = func(p *os.Process) { + proc <- p + } + + h := &Handler{ + Path: os.Args[0], + Root: "/test.go", + Args: []string{"-test.run=TestBeChildCGIProcess"}, + } + req, _ := http.NewRequest("GET", "http://example.com/test.cgi?write-forever=1", nil) + rec := httptest.NewRecorder() + var out bytes.Buffer + const writeLen = 50 << 10 + rw := &customWriterRecorder{&limitWriter{&out, writeLen}, rec} + + donec := make(chan bool, 1) + go func() { + h.ServeHTTP(rw, req) + donec <- true + }() + + select { + case <-donec: + if out.Len() != writeLen || out.Bytes()[0] != 'a' { + t.Errorf("unexpected output: %q", out.Bytes()) + } + case <-time.After(5 * time.Second): + t.Errorf("timeout. ServeHTTP hung and didn't kill the child process?") + select { + case p := <-proc: + p.Kill() + t.Logf("killed process") + default: + t.Logf("didn't kill process") + } + } +} + +// Test that a child handler writing only headers works. +// golang.org/issue/7196 func TestChildOnlyHeaders(t *testing.T) { + if runtime.GOOS == "nacl" { + t.Skip("skipping on nacl") + } + h := &Handler{ Path: os.Args[0], Root: "/test.go", @@ -67,18 +157,63 @@ func TestChildOnlyHeaders(t *testing.T) { } } +// golang.org/issue/7198 +func Test500WithNoHeaders(t *testing.T) { want500Test(t, "/immediate-disconnect") } +func Test500WithNoContentType(t *testing.T) { want500Test(t, "/no-content-type") } +func Test500WithEmptyHeaders(t *testing.T) { want500Test(t, "/empty-headers") } + +func want500Test(t *testing.T, path string) { + h := &Handler{ + Path: os.Args[0], + Root: "/test.go", + Args: []string{"-test.run=TestBeChildCGIProcess"}, + } + expectedMap := map[string]string{ + "_body": "", + } + replay := runCgiTest(t, h, "GET "+path+" HTTP/1.0\nHost: example.com\n\n", expectedMap) + if replay.Code != 500 { + t.Errorf("Got code %d; want 500", replay.Code) + } +} + +type neverEnding byte + +func (b neverEnding) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = byte(b) + } + return len(p), nil +} + // Note: not actually a test. func TestBeChildCGIProcess(t *testing.T) { if os.Getenv("REQUEST_METHOD") == "" { // Not in a CGI environment; skipping test. return } + switch os.Getenv("REQUEST_URI") { + case "/immediate-disconnect": + os.Exit(0) + case "/no-content-type": + fmt.Printf("Content-Length: 6\n\nHello\n") + os.Exit(0) + case "/empty-headers": + fmt.Printf("\nHello") + os.Exit(0) + } Serve(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("X-Test-Header", "X-Test-Value") req.ParseForm() if req.FormValue("no-body") == "1" { return } + if req.FormValue("write-forever") == "1" { + io.Copy(rw, neverEnding('a')) + for { + time.Sleep(5 * time.Second) // hang forever, until killed + } + } fmt.Fprintf(rw, "test=Hello CGI-in-CGI\n") for k, vv := range req.Form { for _, v := range vv { diff --git a/src/pkg/net/http/chunked.go b/src/pkg/net/http/chunked.go index 91db01724..749f29d32 100644 --- a/src/pkg/net/http/chunked.go +++ b/src/pkg/net/http/chunked.go @@ -4,13 +4,14 @@ // The wire protocol for HTTP's "chunked" Transfer-Encoding. -// This code is duplicated in httputil/chunked.go. +// This code is duplicated in net/http and net/http/httputil. // Please make any changes in both files. package http import ( "bufio" + "bytes" "errors" "fmt" "io" @@ -57,26 +58,45 @@ func (cr *chunkedReader) beginChunk() { } } -func (cr *chunkedReader) Read(b []uint8) (n int, err error) { - if cr.err != nil { - return 0, cr.err +func (cr *chunkedReader) chunkHeaderAvailable() bool { + n := cr.r.Buffered() + if n > 0 { + peek, _ := cr.r.Peek(n) + return bytes.IndexByte(peek, '\n') >= 0 } - if cr.n == 0 { - cr.beginChunk() - if cr.err != nil { - return 0, cr.err + return false +} + +func (cr *chunkedReader) Read(b []uint8) (n int, err error) { + for cr.err == nil { + if cr.n == 0 { + if n > 0 && !cr.chunkHeaderAvailable() { + // We've read enough. Don't potentially block + // reading a new chunk header. + break + } + cr.beginChunk() + continue } - } - if uint64(len(b)) > cr.n { - b = b[0:cr.n] - } - n, cr.err = cr.r.Read(b) - cr.n -= uint64(n) - if cr.n == 0 && cr.err == nil { - // end of chunk (CRLF) - if _, cr.err = io.ReadFull(cr.r, cr.buf[:]); cr.err == nil { - if cr.buf[0] != '\r' || cr.buf[1] != '\n' { - cr.err = errors.New("malformed chunked encoding") + if len(b) == 0 { + break + } + rbuf := b + if uint64(len(rbuf)) > cr.n { + rbuf = rbuf[:cr.n] + } + var n0 int + n0, cr.err = cr.r.Read(rbuf) + n += n0 + b = b[n0:] + cr.n -= uint64(n0) + // If we're at the end of a chunk, read the next two + // bytes to verify they are "\r\n". + if cr.n == 0 && cr.err == nil { + if _, cr.err = io.ReadFull(cr.r, cr.buf[:2]); cr.err == nil { + if cr.buf[0] != '\r' || cr.buf[1] != '\n' { + cr.err = errors.New("malformed chunked encoding") + } } } } diff --git a/src/pkg/net/http/chunked_test.go b/src/pkg/net/http/chunked_test.go index 0b18c7b55..34544790a 100644 --- a/src/pkg/net/http/chunked_test.go +++ b/src/pkg/net/http/chunked_test.go @@ -2,17 +2,18 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This code is duplicated in httputil/chunked_test.go. +// This code is duplicated in net/http and net/http/httputil. // Please make any changes in both files. package http import ( + "bufio" "bytes" "fmt" "io" "io/ioutil" - "runtime" + "strings" "testing" ) @@ -41,9 +42,77 @@ func TestChunk(t *testing.T) { } } +func TestChunkReadMultiple(t *testing.T) { + // Bunch of small chunks, all read together. + { + var b bytes.Buffer + w := newChunkedWriter(&b) + w.Write([]byte("foo")) + w.Write([]byte("bar")) + w.Close() + + r := newChunkedReader(&b) + buf := make([]byte, 10) + n, err := r.Read(buf) + if n != 6 || err != io.EOF { + t.Errorf("Read = %d, %v; want 6, EOF", n, err) + } + buf = buf[:n] + if string(buf) != "foobar" { + t.Errorf("Read = %q; want %q", buf, "foobar") + } + } + + // One big chunk followed by a little chunk, but the small bufio.Reader size + // should prevent the second chunk header from being read. + { + var b bytes.Buffer + w := newChunkedWriter(&b) + // fillBufChunk is 11 bytes + 3 bytes header + 2 bytes footer = 16 bytes, + // the same as the bufio ReaderSize below (the minimum), so even + // though we're going to try to Read with a buffer larger enough to also + // receive "foo", the second chunk header won't be read yet. + const fillBufChunk = "0123456789a" + const shortChunk = "foo" + w.Write([]byte(fillBufChunk)) + w.Write([]byte(shortChunk)) + w.Close() + + r := newChunkedReader(bufio.NewReaderSize(&b, 16)) + buf := make([]byte, len(fillBufChunk)+len(shortChunk)) + n, err := r.Read(buf) + if n != len(fillBufChunk) || err != nil { + t.Errorf("Read = %d, %v; want %d, nil", n, err, len(fillBufChunk)) + } + buf = buf[:n] + if string(buf) != fillBufChunk { + t.Errorf("Read = %q; want %q", buf, fillBufChunk) + } + + n, err = r.Read(buf) + if n != len(shortChunk) || err != io.EOF { + t.Errorf("Read = %d, %v; want %d, EOF", n, err, len(shortChunk)) + } + } + + // And test that we see an EOF chunk, even though our buffer is already full: + { + r := newChunkedReader(bufio.NewReader(strings.NewReader("3\r\nfoo\r\n0\r\n"))) + buf := make([]byte, 3) + n, err := r.Read(buf) + if n != 3 || err != io.EOF { + t.Errorf("Read = %d, %v; want 3, EOF", n, err) + } + if string(buf) != "foo" { + t.Errorf("buf = %q; want foo", buf) + } + } +} + func TestChunkReaderAllocs(t *testing.T) { - // temporarily set GOMAXPROCS to 1 as we are testing memory allocations - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) + if testing.Short() { + t.Skip("skipping in short mode") + } var buf bytes.Buffer w := newChunkedWriter(&buf) a, b, c := []byte("aaaaaa"), []byte("bbbbbbbbbbbb"), []byte("cccccccccccccccccccccccc") @@ -52,26 +121,23 @@ func TestChunkReaderAllocs(t *testing.T) { w.Write(c) w.Close() - r := newChunkedReader(&buf) readBuf := make([]byte, len(a)+len(b)+len(c)+1) - - var ms runtime.MemStats - runtime.ReadMemStats(&ms) - m0 := ms.Mallocs - - n, err := io.ReadFull(r, readBuf) - - runtime.ReadMemStats(&ms) - mallocs := ms.Mallocs - m0 - if mallocs > 1 { - t.Errorf("%d mallocs; want <= 1", mallocs) - } - - if n != len(readBuf)-1 { - t.Errorf("read %d bytes; want %d", n, len(readBuf)-1) - } - if err != io.ErrUnexpectedEOF { - t.Errorf("read error = %v; want ErrUnexpectedEOF", err) + byter := bytes.NewReader(buf.Bytes()) + bufr := bufio.NewReader(byter) + mallocs := testing.AllocsPerRun(100, func() { + byter.Seek(0, 0) + bufr.Reset(byter) + r := newChunkedReader(bufr) + n, err := io.ReadFull(r, readBuf) + if n != len(readBuf)-1 { + t.Fatalf("read %d bytes; want %d", n, len(readBuf)-1) + } + if err != io.ErrUnexpectedEOF { + t.Fatalf("read error = %v; want ErrUnexpectedEOF", err) + } + }) + if mallocs > 1.5 { + t.Errorf("mallocs = %v; want 1", mallocs) } } diff --git a/src/pkg/net/http/client.go b/src/pkg/net/http/client.go index 22f2e865c..a5a3abe61 100644 --- a/src/pkg/net/http/client.go +++ b/src/pkg/net/http/client.go @@ -14,9 +14,12 @@ import ( "errors" "fmt" "io" + "io/ioutil" "log" "net/url" "strings" + "sync" + "time" ) // A Client is an HTTP client. Its zero value (DefaultClient) is a @@ -52,6 +55,20 @@ type Client struct { // If Jar is nil, cookies are not sent in requests and ignored // in responses. Jar CookieJar + + // Timeout specifies a time limit for requests made by this + // Client. The timeout includes connection time, any + // redirects, and reading the response body. The timer remains + // running after Get, Head, Post, or Do return and will + // interrupt reading of the Response.Body. + // + // A Timeout of zero means no timeout. + // + // The Client's Transport must support the CancelRequest + // method or Client will return errors when attempting to make + // a request with Get, Head, Post, or Do. Client's default + // Transport (DefaultTransport) supports CancelRequest. + Timeout time.Duration } // DefaultClient is the default Client and is used by Get, Head, and Post. @@ -74,8 +91,9 @@ type RoundTripper interface { // authentication, or cookies. // // RoundTrip should not modify the request, except for - // consuming and closing the Body. The request's URL and - // Header fields are guaranteed to be initialized. + // consuming and closing the Body, including on errors. The + // request's URL and Header fields are guaranteed to be + // initialized. RoundTrip(*Request) (*Response, error) } @@ -97,7 +115,7 @@ func (c *Client) send(req *Request) (*Response, error) { req.AddCookie(cookie) } } - resp, err := send(req, c.Transport) + resp, err := send(req, c.transport()) if err != nil { return nil, err } @@ -123,6 +141,9 @@ func (c *Client) send(req *Request) (*Response, error) { // (typically Transport) may not be able to re-use a persistent TCP // connection to the server for a subsequent "keep-alive" request. // +// The request Body, if non-nil, will be closed by the underlying +// Transport, even on errors. +// // Generally Get, Post, or PostForm will be used instead of Do. func (c *Client) Do(req *Request) (resp *Response, err error) { if req.Method == "GET" || req.Method == "HEAD" { @@ -134,22 +155,28 @@ func (c *Client) Do(req *Request) (resp *Response, err error) { return c.send(req) } +func (c *Client) transport() RoundTripper { + if c.Transport != nil { + return c.Transport + } + return DefaultTransport +} + // send issues an HTTP request. // Caller should close resp.Body when done reading from it. func send(req *Request, t RoundTripper) (resp *Response, err error) { if t == nil { - t = DefaultTransport - if t == nil { - err = errors.New("http: no Client.Transport or DefaultTransport") - return - } + req.closeBody() + return nil, errors.New("http: no Client.Transport or DefaultTransport") } if req.URL == nil { + req.closeBody() return nil, errors.New("http: nil Request.URL") } if req.RequestURI != "" { + req.closeBody() return nil, errors.New("http: Request.RequestURI can't be set in client requests.") } @@ -257,21 +284,40 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo var via []*Request if ireq.URL == nil { + ireq.closeBody() return nil, errors.New("http: nil Request.URL") } + var reqmu sync.Mutex // guards req req := ireq + + var timer *time.Timer + if c.Timeout > 0 { + type canceler interface { + CancelRequest(*Request) + } + tr, ok := c.transport().(canceler) + if !ok { + return nil, fmt.Errorf("net/http: Client Transport of type %T doesn't support CancelRequest; Timeout not supported", c.transport()) + } + timer = time.AfterFunc(c.Timeout, func() { + reqmu.Lock() + defer reqmu.Unlock() + tr.CancelRequest(req) + }) + } + urlStr := "" // next relative or absolute URL to fetch (after first request) redirectFailed := false for redirect := 0; ; redirect++ { if redirect != 0 { - req = new(Request) - req.Method = ireq.Method + nreq := new(Request) + nreq.Method = ireq.Method if ireq.Method == "POST" || ireq.Method == "PUT" { - req.Method = "GET" + nreq.Method = "GET" } - req.Header = make(Header) - req.URL, err = base.Parse(urlStr) + nreq.Header = make(Header) + nreq.URL, err = base.Parse(urlStr) if err != nil { break } @@ -279,15 +325,18 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo // Add the Referer header. lastReq := via[len(via)-1] if lastReq.URL.Scheme != "https" { - req.Header.Set("Referer", lastReq.URL.String()) + nreq.Header.Set("Referer", lastReq.URL.String()) } - err = redirectChecker(req, via) + err = redirectChecker(nreq, via) if err != nil { redirectFailed = true break } } + reqmu.Lock() + req = nreq + reqmu.Unlock() } urlStr = req.URL.String() @@ -296,6 +345,12 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo } if shouldRedirect(resp.StatusCode) { + // Read the body if small so underlying TCP connection will be re-used. + // No need to check for errors: if it fails, Transport won't reuse it anyway. + const maxBodySlurpSize = 2 << 10 + if resp.ContentLength == -1 || resp.ContentLength <= maxBodySlurpSize { + io.CopyN(ioutil.Discard, resp.Body, maxBodySlurpSize) + } resp.Body.Close() if urlStr = resp.Header.Get("Location"); urlStr == "" { err = errors.New(fmt.Sprintf("%d response missing Location header", resp.StatusCode)) @@ -305,7 +360,10 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo via = append(via, req) continue } - return + if timer != nil { + resp.Body = &cancelTimerBody{timer, resp.Body} + } + return resp, nil } method := ireq.Method @@ -349,7 +407,7 @@ func Post(url string, bodyType string, body io.Reader) (resp *Response, err erro // Caller should close resp.Body when done reading from it. // // If the provided body is also an io.Closer, it is closed after the -// body is successfully written to the server. +// request. func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Response, err error) { req, err := NewRequest("POST", url, body) if err != nil { @@ -408,3 +466,22 @@ func (c *Client) Head(url string) (resp *Response, err error) { } return c.doFollowingRedirects(req, shouldRedirectGet) } + +type cancelTimerBody struct { + t *time.Timer + rc io.ReadCloser +} + +func (b *cancelTimerBody) Read(p []byte) (n int, err error) { + n, err = b.rc.Read(p) + if err == io.EOF { + b.t.Stop() + } + return +} + +func (b *cancelTimerBody) Close() error { + err := b.rc.Close() + b.t.Stop() + return err +} diff --git a/src/pkg/net/http/client_test.go b/src/pkg/net/http/client_test.go index 997d04151..6392c1baf 100644 --- a/src/pkg/net/http/client_test.go +++ b/src/pkg/net/http/client_test.go @@ -15,14 +15,18 @@ import ( "fmt" "io" "io/ioutil" + "log" "net" . "net/http" "net/http/httptest" "net/url" + "reflect" + "sort" "strconv" "strings" "sync" "testing" + "time" ) var robotsTxtHandler = HandlerFunc(func(w ResponseWriter, r *Request) { @@ -54,6 +58,13 @@ func pedanticReadAll(r io.Reader) (b []byte, err error) { } } +type chanWriter chan string + +func (w chanWriter) Write(p []byte) (n int, err error) { + w <- string(p) + return len(p), nil +} + func TestClient(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(robotsTxtHandler) @@ -373,24 +384,6 @@ func (j *TestJar) Cookies(u *url.URL) []*Cookie { return j.perURL[u.Host] } -func TestRedirectCookiesOnRequest(t *testing.T) { - defer afterTest(t) - var ts *httptest.Server - ts = httptest.NewServer(echoCookiesRedirectHandler) - defer ts.Close() - c := &Client{} - req, _ := NewRequest("GET", ts.URL, nil) - req.AddCookie(expectedCookies[0]) - // TODO: Uncomment when an implementation of a RFC6265 cookie jar lands. - _ = c - // resp, _ := c.Do(req) - // matchReturnedCookies(t, expectedCookies, resp.Cookies()) - - req, _ = NewRequest("GET", ts.URL, nil) - // resp, _ = c.Do(req) - // matchReturnedCookies(t, expectedCookies[1:], resp.Cookies()) -} - func TestRedirectCookiesJar(t *testing.T) { defer afterTest(t) var ts *httptest.Server @@ -410,8 +403,8 @@ func TestRedirectCookiesJar(t *testing.T) { } func matchReturnedCookies(t *testing.T, expected, given []*Cookie) { - t.Logf("Received cookies: %v", given) if len(given) != len(expected) { + t.Logf("Received cookies: %v", given) t.Errorf("Expected %d cookies, got %d", len(expected), len(given)) } for _, ec := range expected { @@ -582,6 +575,8 @@ func TestClientInsecureTransport(t *testing.T) { ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello")) })) + errc := make(chanWriter, 10) // but only expecting 1 + ts.Config.ErrorLog = log.New(errc, "", 0) defer ts.Close() // TODO(bradfitz): add tests for skipping hostname checks too? @@ -603,6 +598,16 @@ func TestClientInsecureTransport(t *testing.T) { res.Body.Close() } } + + select { + case v := <-errc: + if !strings.Contains(v, "TLS handshake error") { + t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v) + } + case <-time.After(5 * time.Second): + t.Errorf("timeout waiting for logged error") + } + } func TestClientErrorWithRequestURI(t *testing.T) { @@ -653,6 +658,8 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) { defer afterTest(t) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) defer ts.Close() + errc := make(chanWriter, 10) // but only expecting 1 + ts.Config.ErrorLog = log.New(errc, "", 0) trans := newTLSTransport(t, ts) trans.TLSClientConfig.ServerName = "badserver" @@ -664,6 +671,14 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) { if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") { t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err) } + select { + case v := <-errc: + if !strings.Contains(v, "TLS handshake error") { + t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v) + } + case <-time.After(5 * time.Second): + t.Errorf("timeout waiting for logged error") + } } // Test for golang.org/issue/5829; the Transport should respect TLSClientConfig.ServerName @@ -696,6 +711,33 @@ func TestTransportUsesTLSConfigServerName(t *testing.T) { res.Body.Close() } +func TestResponseSetsTLSConnectionState(t *testing.T) { + defer afterTest(t) + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Write([]byte("Hello")) + })) + defer ts.Close() + + tr := newTLSTransport(t, ts) + tr.TLSClientConfig.CipherSuites = []uint16{tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA} + tr.Dial = func(netw, addr string) (net.Conn, error) { + return net.Dial(netw, ts.Listener.Addr().String()) + } + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + res, err := c.Get("https://example.com/") + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.TLS == nil { + t.Fatal("Response didn't set TLS Connection State.") + } + if got, want := res.TLS.CipherSuite, tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA; got != want { + t.Errorf("TLS Cipher Suite = %d; want %d", got, want) + } +} + // Verify Response.ContentLength is populated. http://golang.org/issue/4126 func TestClientHeadContentLength(t *testing.T) { defer afterTest(t) @@ -799,3 +841,198 @@ func TestBasicAuth(t *testing.T) { t.Errorf("Invalid auth %q", auth) } } + +func TestClientTimeout(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + defer afterTest(t) + sawRoot := make(chan bool, 1) + sawSlow := make(chan bool, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.URL.Path == "/" { + sawRoot <- true + Redirect(w, r, "/slow", StatusFound) + return + } + if r.URL.Path == "/slow" { + w.Write([]byte("Hello")) + w.(Flusher).Flush() + sawSlow <- true + time.Sleep(2 * time.Second) + return + } + })) + defer ts.Close() + const timeout = 500 * time.Millisecond + c := &Client{ + Timeout: timeout, + } + + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + + select { + case <-sawRoot: + // good. + default: + t.Fatal("handler never got / request") + } + + select { + case <-sawSlow: + // good. + default: + t.Fatal("handler never got /slow request") + } + + errc := make(chan error, 1) + go func() { + _, err := ioutil.ReadAll(res.Body) + errc <- err + res.Body.Close() + }() + + const failTime = timeout * 2 + select { + case err := <-errc: + if err == nil { + t.Error("expected error from ReadAll") + } + // Expected error. + case <-time.After(failTime): + t.Errorf("timeout after %v waiting for timeout of %v", failTime, timeout) + } +} + +func TestClientRedirectEatsBody(t *testing.T) { + defer afterTest(t) + saw := make(chan string, 2) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + saw <- r.RemoteAddr + if r.URL.Path == "/" { + Redirect(w, r, "/foo", StatusFound) // which includes a body + } + })) + defer ts.Close() + + res, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + _, err = ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + + var first string + select { + case first = <-saw: + default: + t.Fatal("server didn't see a request") + } + + var second string + select { + case second = <-saw: + default: + t.Fatal("server didn't see a second request") + } + + if first != second { + t.Fatal("server saw different client ports before & after the redirect") + } +} + +// eofReaderFunc is an io.Reader that runs itself, and then returns io.EOF. +type eofReaderFunc func() + +func (f eofReaderFunc) Read(p []byte) (n int, err error) { + f() + return 0, io.EOF +} + +func TestClientTrailers(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Connection", "close") + w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B") + w.Header().Add("Trailer", "Server-Trailer-C") + + var decl []string + for k := range r.Trailer { + decl = append(decl, k) + } + sort.Strings(decl) + + slurp, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Errorf("Server reading request body: %v", err) + } + if string(slurp) != "foo" { + t.Errorf("Server read request body %q; want foo", slurp) + } + if r.Trailer == nil { + io.WriteString(w, "nil Trailer") + } else { + fmt.Fprintf(w, "decl: %v, vals: %s, %s", + decl, + r.Trailer.Get("Client-Trailer-A"), + r.Trailer.Get("Client-Trailer-B")) + } + + // TODO: golang.org/issue/7759: there's no way yet for + // the server to set trailers without hijacking, so do + // that for now, just to test the client. Later, in + // Go 1.4, it should be implicit that any mutations + // to w.Header() after the initial write are the + // trailers to be sent, if and only if they were + // previously declared with w.Header().Set("Trailer", + // ..keys..) + w.(Flusher).Flush() + conn, buf, _ := w.(Hijacker).Hijack() + t := Header{} + t.Set("Server-Trailer-A", "valuea") + t.Set("Server-Trailer-C", "valuec") // skipping B + buf.WriteString("0\r\n") // eof + t.Write(buf) + buf.WriteString("\r\n") // end of trailers + buf.Flush() + conn.Close() + })) + defer ts.Close() + + var req *Request + req, _ = NewRequest("POST", ts.URL, io.MultiReader( + eofReaderFunc(func() { + req.Trailer["Client-Trailer-A"] = []string{"valuea"} + }), + strings.NewReader("foo"), + eofReaderFunc(func() { + req.Trailer["Client-Trailer-B"] = []string{"valueb"} + }), + )) + req.Trailer = Header{ + "Client-Trailer-A": nil, // to be set later + "Client-Trailer-B": nil, // to be set later + } + req.ContentLength = -1 + res, err := DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil { + t.Error(err) + } + want := Header{ + "Server-Trailer-A": []string{"valuea"}, + "Server-Trailer-B": nil, + "Server-Trailer-C": []string{"valuec"}, + } + if !reflect.DeepEqual(res.Trailer, want) { + t.Errorf("Response trailers = %#v; want %#v", res.Trailer, want) + } +} diff --git a/src/pkg/net/http/cookie.go b/src/pkg/net/http/cookie.go index 8b01c508e..dc60ba87f 100644 --- a/src/pkg/net/http/cookie.go +++ b/src/pkg/net/http/cookie.go @@ -76,11 +76,7 @@ func readSetCookies(h Header) []*Cookie { attr, val = attr[:j], attr[j+1:] } lowerAttr := strings.ToLower(attr) - parseCookieValueFn := parseCookieValue - if lowerAttr == "expires" { - parseCookieValueFn = parseCookieExpiresValue - } - val, success = parseCookieValueFn(val) + val, success = parseCookieValue(val) if !success { c.Unparsed = append(c.Unparsed, parts[i]) continue @@ -94,7 +90,6 @@ func readSetCookies(h Header) []*Cookie { continue case "domain": c.Domain = val - // TODO: Add domain parsing continue case "max-age": secs, err := strconv.Atoi(val) @@ -121,7 +116,6 @@ func readSetCookies(h Header) []*Cookie { continue case "path": c.Path = val - // TODO: Add path parsing continue } c.Unparsed = append(c.Unparsed, parts[i]) @@ -300,12 +294,23 @@ func sanitizeCookieName(n string) string { // ; US-ASCII characters excluding CTLs, // ; whitespace DQUOTE, comma, semicolon, // ; and backslash +// We loosen this as spaces and commas are common in cookie values +// but we produce a quoted cookie-value in when value starts or ends +// with a comma or space. +// See http://golang.org/issue/7243 for the discussion. func sanitizeCookieValue(v string) string { - return sanitizeOrWarn("Cookie.Value", validCookieValueByte, v) + v = sanitizeOrWarn("Cookie.Value", validCookieValueByte, v) + if len(v) == 0 { + return v + } + if v[0] == ' ' || v[0] == ',' || v[len(v)-1] == ' ' || v[len(v)-1] == ',' { + return `"` + v + `"` + } + return v } func validCookieValueByte(b byte) bool { - return 0x20 < b && b < 0x7f && b != '"' && b != ',' && b != ';' && b != '\\' + return 0x20 <= b && b < 0x7f && b != '"' && b != ';' && b != '\\' } // path-av = "Path=" path-value @@ -340,38 +345,13 @@ func sanitizeOrWarn(fieldName string, valid func(byte) bool, v string) string { return string(buf) } -func unquoteCookieValue(v string) string { - if len(v) > 1 && v[0] == '"' && v[len(v)-1] == '"' { - return v[1 : len(v)-1] - } - return v -} - -func isCookieByte(c byte) bool { - switch { - case c == 0x21, 0x23 <= c && c <= 0x2b, 0x2d <= c && c <= 0x3a, - 0x3c <= c && c <= 0x5b, 0x5d <= c && c <= 0x7e: - return true - } - return false -} - -func isCookieExpiresByte(c byte) (ok bool) { - return isCookieByte(c) || c == ',' || c == ' ' -} - func parseCookieValue(raw string) (string, bool) { - return parseCookieValueUsing(raw, isCookieByte) -} - -func parseCookieExpiresValue(raw string) (string, bool) { - return parseCookieValueUsing(raw, isCookieExpiresByte) -} - -func parseCookieValueUsing(raw string, validByte func(byte) bool) (string, bool) { - raw = unquoteCookieValue(raw) + // Strip the quotes, if present. + if len(raw) > 1 && raw[0] == '"' && raw[len(raw)-1] == '"' { + raw = raw[1 : len(raw)-1] + } for i := 0; i < len(raw); i++ { - if !validByte(raw[i]) { + if !validCookieValueByte(raw[i]) { return "", false } } diff --git a/src/pkg/net/http/cookie_test.go b/src/pkg/net/http/cookie_test.go index 11b01cc57..f78f37299 100644 --- a/src/pkg/net/http/cookie_test.go +++ b/src/pkg/net/http/cookie_test.go @@ -5,9 +5,13 @@ package http import ( + "bytes" "encoding/json" "fmt" + "log" + "os" "reflect" + "strings" "testing" "time" ) @@ -48,15 +52,61 @@ var writeSetCookiesTests = []struct { &Cookie{Name: "cookie-8", Value: "eight", Domain: "::1"}, "cookie-8=eight", }, + // The "special" cookies have values containing commas or spaces which + // are disallowed by RFC 6265 but are common in the wild. + { + &Cookie{Name: "special-1", Value: "a z"}, + `special-1=a z`, + }, + { + &Cookie{Name: "special-2", Value: " z"}, + `special-2=" z"`, + }, + { + &Cookie{Name: "special-3", Value: "a "}, + `special-3="a "`, + }, + { + &Cookie{Name: "special-4", Value: " "}, + `special-4=" "`, + }, + { + &Cookie{Name: "special-5", Value: "a,z"}, + `special-5=a,z`, + }, + { + &Cookie{Name: "special-6", Value: ",z"}, + `special-6=",z"`, + }, + { + &Cookie{Name: "special-7", Value: "a,"}, + `special-7="a,"`, + }, + { + &Cookie{Name: "special-8", Value: ","}, + `special-8=","`, + }, + { + &Cookie{Name: "empty-value", Value: ""}, + `empty-value=`, + }, } func TestWriteSetCookies(t *testing.T) { + defer log.SetOutput(os.Stderr) + var logbuf bytes.Buffer + log.SetOutput(&logbuf) + for i, tt := range writeSetCookiesTests { if g, e := tt.Cookie.String(), tt.Raw; g != e { t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, e, g) continue } } + + if got, sub := logbuf.String(), "dropping domain attribute"; !strings.Contains(got, sub) { + t.Errorf("Expected substring %q in log output. Got:\n%s", sub, got) + } } type headerOnlyResponseWriter Header @@ -166,6 +216,40 @@ var readSetCookiesTests = []struct { Raw: "ASP.NET_SessionId=foo; path=/; HttpOnly", }}, }, + // Make sure we can properly read back the Set-Cookie headers we create + // for values containing spaces or commas: + { + Header{"Set-Cookie": {`special-1=a z`}}, + []*Cookie{{Name: "special-1", Value: "a z", Raw: `special-1=a z`}}, + }, + { + Header{"Set-Cookie": {`special-2=" z"`}}, + []*Cookie{{Name: "special-2", Value: " z", Raw: `special-2=" z"`}}, + }, + { + Header{"Set-Cookie": {`special-3="a "`}}, + []*Cookie{{Name: "special-3", Value: "a ", Raw: `special-3="a "`}}, + }, + { + Header{"Set-Cookie": {`special-4=" "`}}, + []*Cookie{{Name: "special-4", Value: " ", Raw: `special-4=" "`}}, + }, + { + Header{"Set-Cookie": {`special-5=a,z`}}, + []*Cookie{{Name: "special-5", Value: "a,z", Raw: `special-5=a,z`}}, + }, + { + Header{"Set-Cookie": {`special-6=",z"`}}, + []*Cookie{{Name: "special-6", Value: ",z", Raw: `special-6=",z"`}}, + }, + { + Header{"Set-Cookie": {`special-7=a,`}}, + []*Cookie{{Name: "special-7", Value: "a,", Raw: `special-7=a,`}}, + }, + { + Header{"Set-Cookie": {`special-8=","`}}, + []*Cookie{{Name: "special-8", Value: ",", Raw: `special-8=","`}}, + }, // TODO(bradfitz): users have reported seeing this in the // wild, but do browsers handle it? RFC 6265 just says "don't @@ -244,22 +328,39 @@ func TestReadCookies(t *testing.T) { } func TestCookieSanitizeValue(t *testing.T) { + defer log.SetOutput(os.Stderr) + var logbuf bytes.Buffer + log.SetOutput(&logbuf) + tests := []struct { in, want string }{ {"foo", "foo"}, - {"foo bar", "foobar"}, + {"foo;bar", "foobar"}, + {"foo\\bar", "foobar"}, + {"foo\"bar", "foobar"}, {"\x00\x7e\x7f\x80", "\x7e"}, {`"withquotes"`, "withquotes"}, + {"a z", "a z"}, + {" z", `" z"`}, + {"a ", `"a "`}, } for _, tt := range tests { if got := sanitizeCookieValue(tt.in); got != tt.want { t.Errorf("sanitizeCookieValue(%q) = %q; want %q", tt.in, got, tt.want) } } + + if got, sub := logbuf.String(), "dropping invalid bytes"; !strings.Contains(got, sub) { + t.Errorf("Expected substring %q in log output. Got:\n%s", sub, got) + } } func TestCookieSanitizePath(t *testing.T) { + defer log.SetOutput(os.Stderr) + var logbuf bytes.Buffer + log.SetOutput(&logbuf) + tests := []struct { in, want string }{ @@ -272,4 +373,8 @@ func TestCookieSanitizePath(t *testing.T) { t.Errorf("sanitizeCookiePath(%q) = %q; want %q", tt.in, got, tt.want) } } + + if got, sub := logbuf.String(), "dropping invalid bytes"; !strings.Contains(got, sub) { + t.Errorf("Expected substring %q in log output. Got:\n%s", sub, got) + } } diff --git a/src/pkg/net/http/export_test.go b/src/pkg/net/http/export_test.go index 22b7f2796..960563b24 100644 --- a/src/pkg/net/http/export_test.go +++ b/src/pkg/net/http/export_test.go @@ -21,7 +21,7 @@ var ExportAppendTime = appendTime func (t *Transport) NumPendingRequestsForTesting() int { t.reqMu.Lock() defer t.reqMu.Unlock() - return len(t.reqConn) + return len(t.reqCanceler) } func (t *Transport) IdleConnKeysForTesting() (keys []string) { @@ -32,7 +32,7 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) { return } for key := range t.idleConn { - keys = append(keys, key) + keys = append(keys, key.String()) } return } @@ -43,11 +43,12 @@ func (t *Transport) IdleConnCountForTesting(cacheKey string) int { if t.idleConn == nil { return 0 } - conns, ok := t.idleConn[cacheKey] - if !ok { - return 0 + for k, conns := range t.idleConn { + if k.String() == cacheKey { + return len(conns) + } } - return len(conns) + return 0 } func (t *Transport) IdleConnChMapSizeForTesting() int { @@ -63,4 +64,9 @@ func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler { return &timeoutHandler{handler, f, ""} } +func ResetCachedEnvironment() { + httpProxyEnv.reset() + noProxyEnv.reset() +} + var DefaultUserAgent = defaultUserAgent diff --git a/src/pkg/net/http/fcgi/child.go b/src/pkg/net/http/fcgi/child.go index 60b794e07..a3beaa33a 100644 --- a/src/pkg/net/http/fcgi/child.go +++ b/src/pkg/net/http/fcgi/child.go @@ -16,6 +16,7 @@ import ( "net/http/cgi" "os" "strings" + "sync" "time" ) @@ -126,8 +127,10 @@ func (r *response) Close() error { } type child struct { - conn *conn - handler http.Handler + conn *conn + handler http.Handler + + mu sync.Mutex // protects requests: requests map[uint16]*request // keyed by request ID } @@ -157,7 +160,9 @@ var errCloseConn = errors.New("fcgi: connection should be closed") var emptyBody = ioutil.NopCloser(strings.NewReader("")) func (c *child) handleRecord(rec *record) error { + c.mu.Lock() req, ok := c.requests[rec.h.Id] + c.mu.Unlock() if !ok && rec.h.Type != typeBeginRequest && rec.h.Type != typeGetValues { // The spec says to ignore unknown request IDs. return nil @@ -179,7 +184,10 @@ func (c *child) handleRecord(rec *record) error { c.conn.writeEndRequest(rec.h.Id, 0, statusUnknownRole) return nil } - c.requests[rec.h.Id] = newRequest(rec.h.Id, br.flags) + req = newRequest(rec.h.Id, br.flags) + c.mu.Lock() + c.requests[rec.h.Id] = req + c.mu.Unlock() return nil case typeParams: // NOTE(eds): Technically a key-value pair can straddle the boundary @@ -220,7 +228,9 @@ func (c *child) handleRecord(rec *record) error { return nil case typeAbortRequest: println("abort") + c.mu.Lock() delete(c.requests, rec.h.Id) + c.mu.Unlock() c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete) if !req.keepConn { // connection will close upon return @@ -247,6 +257,9 @@ func (c *child) serveRequest(req *request, body io.ReadCloser) { c.handler.ServeHTTP(r, httpReq) } r.Close() + c.mu.Lock() + delete(c.requests, req.reqId) + c.mu.Unlock() c.conn.writeEndRequest(req.reqId, 0, statusRequestComplete) // Consume the entire body, so the host isn't still writing to diff --git a/src/pkg/net/http/fs.go b/src/pkg/net/http/fs.go index 8b32ca1d0..8576cf844 100644 --- a/src/pkg/net/http/fs.go +++ b/src/pkg/net/http/fs.go @@ -13,6 +13,7 @@ import ( "mime" "mime/multipart" "net/textproto" + "net/url" "os" "path" "path/filepath" @@ -52,12 +53,14 @@ type FileSystem interface { // A File is returned by a FileSystem's Open method and can be // served by the FileServer implementation. +// +// The methods should behave the same as those on an *os.File. type File interface { - Close() error - Stat() (os.FileInfo, error) + io.Closer + io.Reader Readdir(count int) ([]os.FileInfo, error) - Read([]byte) (int, error) Seek(offset int64, whence int) (int64, error) + Stat() (os.FileInfo, error) } func dirList(w ResponseWriter, f File) { @@ -73,8 +76,11 @@ func dirList(w ResponseWriter, f File) { if d.IsDir() { name += "/" } - // TODO htmlescape - fmt.Fprintf(w, "<a href=\"%s\">%s</a>\n", name, name) + // name may contain '?' or '#', which must be escaped to remain + // part of the URL path, and not indicate the start of a query + // string or fragment. + url := url.URL{Path: name} + fmt.Fprintf(w, "<a href=\"%s\">%s</a>\n", url.String(), htmlReplacer.Replace(name)) } } fmt.Fprintf(w, "</pre>\n") @@ -521,7 +527,7 @@ func (w *countingWriter) Write(p []byte) (n int, err error) { return len(p), nil } -// rangesMIMESize returns the nunber of bytes it takes to encode the +// rangesMIMESize returns the number of bytes it takes to encode the // provided ranges as a multipart response. func rangesMIMESize(ranges []httpRange, contentType string, contentSize int64) (encSize int64) { var w countingWriter diff --git a/src/pkg/net/http/fs_test.go b/src/pkg/net/http/fs_test.go index ae54edf0c..f968565f9 100644 --- a/src/pkg/net/http/fs_test.go +++ b/src/pkg/net/http/fs_test.go @@ -227,6 +227,54 @@ func TestFileServerCleans(t *testing.T) { } } +func TestFileServerEscapesNames(t *testing.T) { + defer afterTest(t) + const dirListPrefix = "<pre>\n" + const dirListSuffix = "\n</pre>\n" + tests := []struct { + name, escaped string + }{ + {`simple_name`, `<a href="simple_name">simple_name</a>`}, + {`"'<>&`, `<a href="%22%27%3C%3E&">"'<>&</a>`}, + {`?foo=bar#baz`, `<a href="%3Ffoo=bar%23baz">?foo=bar#baz</a>`}, + {`<combo>?foo`, `<a href="%3Ccombo%3E%3Ffoo"><combo>?foo</a>`}, + } + + // We put each test file in its own directory in the fakeFS so we can look at it in isolation. + fs := make(fakeFS) + for i, test := range tests { + testFile := &fakeFileInfo{basename: test.name} + fs[fmt.Sprintf("/%d", i)] = &fakeFileInfo{ + dir: true, + modtime: time.Unix(1000000000, 0).UTC(), + ents: []*fakeFileInfo{testFile}, + } + fs[fmt.Sprintf("/%d/%s", i, test.name)] = testFile + } + + ts := httptest.NewServer(FileServer(&fs)) + defer ts.Close() + for i, test := range tests { + url := fmt.Sprintf("%s/%d", ts.URL, i) + res, err := Get(url) + if err != nil { + t.Fatalf("test %q: Get: %v", test.name, err) + } + b, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("test %q: read Body: %v", test.name, err) + } + s := string(b) + if !strings.HasPrefix(s, dirListPrefix) || !strings.HasSuffix(s, dirListSuffix) { + t.Errorf("test %q: listing dir, full output is %q, want prefix %q and suffix %q", test.name, s, dirListPrefix, dirListSuffix) + } + if trimmed := strings.TrimSuffix(strings.TrimPrefix(s, dirListPrefix), dirListSuffix); trimmed != test.escaped { + t.Errorf("test %q: listing dir, filename escaped to %q, want %q", test.name, trimmed, test.escaped) + } + res.Body.Close() + } +} + func mustRemoveAll(dir string) { err := os.RemoveAll(dir) if err != nil { @@ -457,8 +505,9 @@ func (f *fakeFileInfo) Mode() os.FileMode { type fakeFile struct { io.ReadSeeker - fi *fakeFileInfo - path string // as opened + fi *fakeFileInfo + path string // as opened + entpos int } func (f *fakeFile) Close() error { return nil } @@ -468,10 +517,20 @@ func (f *fakeFile) Readdir(count int) ([]os.FileInfo, error) { return nil, os.ErrInvalid } var fis []os.FileInfo - for _, fi := range f.fi.ents { - fis = append(fis, fi) + + limit := f.entpos + count + if count <= 0 || limit > len(f.fi.ents) { + limit = len(f.fi.ents) + } + for ; f.entpos < limit; f.entpos++ { + fis = append(fis, f.fi.ents[f.entpos]) + } + + if len(fis) == 0 && count > 0 { + return fis, io.EOF + } else { + return fis, nil } - return fis, nil } type fakeFS map[string]*fakeFileInfo @@ -480,7 +539,6 @@ func (fs fakeFS) Open(name string) (File, error) { name = path.Clean(name) f, ok := fs[name] if !ok { - println("fake filesystem didn't find file", name) return nil, os.ErrNotExist } return &fakeFile{ReadSeeker: strings.NewReader(f.contents), fi: f, path: name}, nil diff --git a/src/pkg/net/http/header.go b/src/pkg/net/http/header.go index ca1ae07c2..153b94370 100644 --- a/src/pkg/net/http/header.go +++ b/src/pkg/net/http/header.go @@ -9,9 +9,12 @@ import ( "net/textproto" "sort" "strings" + "sync" "time" ) +var raceEnabled = false // set by race.go + // A Header represents the key-value pairs in an HTTP header. type Header map[string][]string @@ -114,18 +117,15 @@ func (s *headerSorter) Len() int { return len(s.kvs) } func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } -// TODO: convert this to a sync.Cache (issue 4720) -var headerSorterCache = make(chan *headerSorter, 8) +var headerSorterPool = sync.Pool{ + New: func() interface{} { return new(headerSorter) }, +} // sortedKeyValues returns h's keys sorted in the returned kvs // slice. The headerSorter used to sort is also returned, for possible // return to headerSorterCache. func (h Header) sortedKeyValues(exclude map[string]bool) (kvs []keyValues, hs *headerSorter) { - select { - case hs = <-headerSorterCache: - default: - hs = new(headerSorter) - } + hs = headerSorterPool.Get().(*headerSorter) if cap(hs.kvs) < len(h) { hs.kvs = make([]keyValues, 0, len(h)) } @@ -159,10 +159,7 @@ func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error { } } } - select { - case headerSorterCache <- sorter: - default: - } + headerSorterPool.Put(sorter) return nil } diff --git a/src/pkg/net/http/header_test.go b/src/pkg/net/http/header_test.go index 9fd9837a5..9dcd591fa 100644 --- a/src/pkg/net/http/header_test.go +++ b/src/pkg/net/http/header_test.go @@ -192,9 +192,12 @@ func BenchmarkHeaderWriteSubset(b *testing.B) { } } -func TestHeaderWriteSubsetMallocs(t *testing.T) { +func TestHeaderWriteSubsetAllocs(t *testing.T) { if testing.Short() { - t.Skip("skipping malloc count in short mode") + t.Skip("skipping alloc test in short mode") + } + if raceEnabled { + t.Skip("skipping test under race detector") } if runtime.GOMAXPROCS(0) > 1 { t.Skip("skipping; GOMAXPROCS>1") @@ -204,6 +207,6 @@ func TestHeaderWriteSubsetMallocs(t *testing.T) { testHeader.WriteSubset(&buf, nil) }) if n > 0 { - t.Errorf("mallocs = %g; want 0", n) + t.Errorf("allocs = %g; want 0", n) } } diff --git a/src/pkg/net/http/httptest/server_test.go b/src/pkg/net/http/httptest/server_test.go index 500a9f0b8..501cc8a99 100644 --- a/src/pkg/net/http/httptest/server_test.go +++ b/src/pkg/net/http/httptest/server_test.go @@ -8,6 +8,7 @@ import ( "io/ioutil" "net/http" "testing" + "time" ) func TestServer(t *testing.T) { @@ -27,3 +28,25 @@ func TestServer(t *testing.T) { t.Errorf("got %q, want hello", string(got)) } } + +func TestIssue7264(t *testing.T) { + for i := 0; i < 1000; i++ { + func() { + inHandler := make(chan bool, 1) + ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + inHandler <- true + })) + defer ts.Close() + tr := &http.Transport{ + ResponseHeaderTimeout: time.Nanosecond, + } + defer tr.CloseIdleConnections() + c := &http.Client{Transport: tr} + res, err := c.Get(ts.URL) + <-inHandler + if err == nil { + res.Body.Close() + } + }() + } +} diff --git a/src/pkg/net/http/httputil/chunked.go b/src/pkg/net/http/httputil/chunked.go index b66d40951..9632bfd19 100644 --- a/src/pkg/net/http/httputil/chunked.go +++ b/src/pkg/net/http/httputil/chunked.go @@ -4,15 +4,14 @@ // The wire protocol for HTTP's "chunked" Transfer-Encoding. -// This code is a duplicate of ../chunked.go with these edits: -// s/newChunked/NewChunked/g -// s/package http/package httputil/ +// This code is duplicated in net/http and net/http/httputil. // Please make any changes in both files. package httputil import ( "bufio" + "bytes" "errors" "fmt" "io" @@ -22,13 +21,13 @@ const maxLineLength = 4096 // assumed <= bufio.defaultBufSize var ErrLineTooLong = errors.New("header line too long") -// NewChunkedReader returns a new chunkedReader that translates the data read from r +// newChunkedReader returns a new chunkedReader that translates the data read from r // out of HTTP "chunked" format before returning it. // The chunkedReader returns io.EOF when the final 0-length chunk is read. // -// NewChunkedReader is not needed by normal applications. The http package +// newChunkedReader is not needed by normal applications. The http package // automatically decodes chunking when reading response bodies. -func NewChunkedReader(r io.Reader) io.Reader { +func newChunkedReader(r io.Reader) io.Reader { br, ok := r.(*bufio.Reader) if !ok { br = bufio.NewReader(r) @@ -59,26 +58,45 @@ func (cr *chunkedReader) beginChunk() { } } -func (cr *chunkedReader) Read(b []uint8) (n int, err error) { - if cr.err != nil { - return 0, cr.err +func (cr *chunkedReader) chunkHeaderAvailable() bool { + n := cr.r.Buffered() + if n > 0 { + peek, _ := cr.r.Peek(n) + return bytes.IndexByte(peek, '\n') >= 0 } - if cr.n == 0 { - cr.beginChunk() - if cr.err != nil { - return 0, cr.err + return false +} + +func (cr *chunkedReader) Read(b []uint8) (n int, err error) { + for cr.err == nil { + if cr.n == 0 { + if n > 0 && !cr.chunkHeaderAvailable() { + // We've read enough. Don't potentially block + // reading a new chunk header. + break + } + cr.beginChunk() + continue } - } - if uint64(len(b)) > cr.n { - b = b[0:cr.n] - } - n, cr.err = cr.r.Read(b) - cr.n -= uint64(n) - if cr.n == 0 && cr.err == nil { - // end of chunk (CRLF) - if _, cr.err = io.ReadFull(cr.r, cr.buf[:]); cr.err == nil { - if cr.buf[0] != '\r' || cr.buf[1] != '\n' { - cr.err = errors.New("malformed chunked encoding") + if len(b) == 0 { + break + } + rbuf := b + if uint64(len(rbuf)) > cr.n { + rbuf = rbuf[:cr.n] + } + var n0 int + n0, cr.err = cr.r.Read(rbuf) + n += n0 + b = b[n0:] + cr.n -= uint64(n0) + // If we're at the end of a chunk, read the next two + // bytes to verify they are "\r\n". + if cr.n == 0 && cr.err == nil { + if _, cr.err = io.ReadFull(cr.r, cr.buf[:2]); cr.err == nil { + if cr.buf[0] != '\r' || cr.buf[1] != '\n' { + cr.err = errors.New("malformed chunked encoding") + } } } } @@ -117,16 +135,16 @@ func isASCIISpace(b byte) bool { return b == ' ' || b == '\t' || b == '\n' || b == '\r' } -// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP +// newChunkedWriter returns a new chunkedWriter that translates writes into HTTP // "chunked" format before writing them to w. Closing the returned chunkedWriter // sends the final 0-length chunk that marks the end of the stream. // -// NewChunkedWriter is not needed by normal applications. The http +// newChunkedWriter is not needed by normal applications. The http // package adds chunking automatically if handlers don't set a -// Content-Length header. Using NewChunkedWriter inside a handler +// Content-Length header. Using newChunkedWriter inside a handler // would result in double chunking or chunking with a Content-Length // length, both of which are wrong. -func NewChunkedWriter(w io.Writer) io.WriteCloser { +func newChunkedWriter(w io.Writer) io.WriteCloser { return &chunkedWriter{w} } diff --git a/src/pkg/net/http/httputil/chunked_test.go b/src/pkg/net/http/httputil/chunked_test.go index a06bffad5..a7a577468 100644 --- a/src/pkg/net/http/httputil/chunked_test.go +++ b/src/pkg/net/http/httputil/chunked_test.go @@ -2,26 +2,25 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This code is a duplicate of ../chunked_test.go with these edits: -// s/newChunked/NewChunked/g -// s/package http/package httputil/ +// This code is duplicated in net/http and net/http/httputil. // Please make any changes in both files. package httputil import ( + "bufio" "bytes" "fmt" "io" "io/ioutil" - "runtime" + "strings" "testing" ) func TestChunk(t *testing.T) { var b bytes.Buffer - w := NewChunkedWriter(&b) + w := newChunkedWriter(&b) const chunk1 = "hello, " const chunk2 = "world! 0123456789abcdef" w.Write([]byte(chunk1)) @@ -32,7 +31,7 @@ func TestChunk(t *testing.T) { t.Fatalf("chunk writer wrote %q; want %q", g, e) } - r := NewChunkedReader(&b) + r := newChunkedReader(&b) data, err := ioutil.ReadAll(r) if err != nil { t.Logf(`data: "%s"`, data) @@ -43,37 +42,102 @@ func TestChunk(t *testing.T) { } } +func TestChunkReadMultiple(t *testing.T) { + // Bunch of small chunks, all read together. + { + var b bytes.Buffer + w := newChunkedWriter(&b) + w.Write([]byte("foo")) + w.Write([]byte("bar")) + w.Close() + + r := newChunkedReader(&b) + buf := make([]byte, 10) + n, err := r.Read(buf) + if n != 6 || err != io.EOF { + t.Errorf("Read = %d, %v; want 6, EOF", n, err) + } + buf = buf[:n] + if string(buf) != "foobar" { + t.Errorf("Read = %q; want %q", buf, "foobar") + } + } + + // One big chunk followed by a little chunk, but the small bufio.Reader size + // should prevent the second chunk header from being read. + { + var b bytes.Buffer + w := newChunkedWriter(&b) + // fillBufChunk is 11 bytes + 3 bytes header + 2 bytes footer = 16 bytes, + // the same as the bufio ReaderSize below (the minimum), so even + // though we're going to try to Read with a buffer larger enough to also + // receive "foo", the second chunk header won't be read yet. + const fillBufChunk = "0123456789a" + const shortChunk = "foo" + w.Write([]byte(fillBufChunk)) + w.Write([]byte(shortChunk)) + w.Close() + + r := newChunkedReader(bufio.NewReaderSize(&b, 16)) + buf := make([]byte, len(fillBufChunk)+len(shortChunk)) + n, err := r.Read(buf) + if n != len(fillBufChunk) || err != nil { + t.Errorf("Read = %d, %v; want %d, nil", n, err, len(fillBufChunk)) + } + buf = buf[:n] + if string(buf) != fillBufChunk { + t.Errorf("Read = %q; want %q", buf, fillBufChunk) + } + + n, err = r.Read(buf) + if n != len(shortChunk) || err != io.EOF { + t.Errorf("Read = %d, %v; want %d, EOF", n, err, len(shortChunk)) + } + } + + // And test that we see an EOF chunk, even though our buffer is already full: + { + r := newChunkedReader(bufio.NewReader(strings.NewReader("3\r\nfoo\r\n0\r\n"))) + buf := make([]byte, 3) + n, err := r.Read(buf) + if n != 3 || err != io.EOF { + t.Errorf("Read = %d, %v; want 3, EOF", n, err) + } + if string(buf) != "foo" { + t.Errorf("buf = %q; want foo", buf) + } + } +} + func TestChunkReaderAllocs(t *testing.T) { - // temporarily set GOMAXPROCS to 1 as we are testing memory allocations - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) + if testing.Short() { + t.Skip("skipping in short mode") + } var buf bytes.Buffer - w := NewChunkedWriter(&buf) + w := newChunkedWriter(&buf) a, b, c := []byte("aaaaaa"), []byte("bbbbbbbbbbbb"), []byte("cccccccccccccccccccccccc") w.Write(a) w.Write(b) w.Write(c) w.Close() - r := NewChunkedReader(&buf) readBuf := make([]byte, len(a)+len(b)+len(c)+1) - - var ms runtime.MemStats - runtime.ReadMemStats(&ms) - m0 := ms.Mallocs - - n, err := io.ReadFull(r, readBuf) - - runtime.ReadMemStats(&ms) - mallocs := ms.Mallocs - m0 - if mallocs > 1 { - t.Errorf("%d mallocs; want <= 1", mallocs) - } - - if n != len(readBuf)-1 { - t.Errorf("read %d bytes; want %d", n, len(readBuf)-1) - } - if err != io.ErrUnexpectedEOF { - t.Errorf("read error = %v; want ErrUnexpectedEOF", err) + byter := bytes.NewReader(buf.Bytes()) + bufr := bufio.NewReader(byter) + mallocs := testing.AllocsPerRun(100, func() { + byter.Seek(0, 0) + bufr.Reset(byter) + r := newChunkedReader(bufr) + n, err := io.ReadFull(r, readBuf) + if n != len(readBuf)-1 { + t.Fatalf("read %d bytes; want %d", n, len(readBuf)-1) + } + if err != io.ErrUnexpectedEOF { + t.Fatalf("read error = %v; want ErrUnexpectedEOF", err) + } + }) + if mallocs > 1.5 { + t.Errorf("mallocs = %v; want 1", mallocs) } } diff --git a/src/pkg/net/http/httputil/dump.go b/src/pkg/net/http/httputil/dump.go index 265499fb0..2a7a413d0 100644 --- a/src/pkg/net/http/httputil/dump.go +++ b/src/pkg/net/http/httputil/dump.go @@ -7,6 +7,7 @@ package httputil import ( "bufio" "bytes" + "errors" "fmt" "io" "io/ioutil" @@ -29,7 +30,7 @@ func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) { if err = b.Close(); err != nil { return nil, nil, err } - return ioutil.NopCloser(&buf), ioutil.NopCloser(bytes.NewBuffer(buf.Bytes())), nil + return ioutil.NopCloser(&buf), ioutil.NopCloser(bytes.NewReader(buf.Bytes())), nil } // dumpConn is a net.Conn which writes to Writer and reads from Reader @@ -106,6 +107,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) { return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil }, } + defer t.CloseIdleConnections() _, err := t.RoundTrip(reqSend) @@ -230,14 +232,31 @@ func DumpRequest(req *http.Request, body bool) (dump []byte, err error) { return } +// errNoBody is a sentinel error value used by failureToReadBody so we can detect +// that the lack of body was intentional. +var errNoBody = errors.New("sentinel error value") + +// failureToReadBody is a io.ReadCloser that just returns errNoBody on +// Read. It's swapped in when we don't actually want to consume the +// body, but need a non-nil one, and want to distinguish the error +// from reading the dummy body. +type failureToReadBody struct{} + +func (failureToReadBody) Read([]byte) (int, error) { return 0, errNoBody } +func (failureToReadBody) Close() error { return nil } + +var emptyBody = ioutil.NopCloser(strings.NewReader("")) + // DumpResponse is like DumpRequest but dumps a response. func DumpResponse(resp *http.Response, body bool) (dump []byte, err error) { var b bytes.Buffer save := resp.Body savecl := resp.ContentLength - if !body || resp.Body == nil { - resp.Body = nil - resp.ContentLength = 0 + + if !body { + resp.Body = failureToReadBody{} + } else if resp.Body == nil { + resp.Body = emptyBody } else { save, resp.Body, err = drainBody(resp.Body) if err != nil { @@ -245,11 +264,13 @@ func DumpResponse(resp *http.Response, body bool) (dump []byte, err error) { } } err = resp.Write(&b) + if err == errNoBody { + err = nil + } resp.Body = save resp.ContentLength = savecl if err != nil { - return + return nil, err } - dump = b.Bytes() - return + return b.Bytes(), nil } diff --git a/src/pkg/net/http/httputil/dump_test.go b/src/pkg/net/http/httputil/dump_test.go index 987a82048..e1ffb3935 100644 --- a/src/pkg/net/http/httputil/dump_test.go +++ b/src/pkg/net/http/httputil/dump_test.go @@ -11,6 +11,8 @@ import ( "io/ioutil" "net/http" "net/url" + "runtime" + "strings" "testing" ) @@ -112,6 +114,7 @@ var dumpTests = []dumpTest{ } func TestDumpRequest(t *testing.T) { + numg0 := runtime.NumGoroutine() for i, tt := range dumpTests { setBody := func() { if tt.Body == nil { @@ -119,7 +122,7 @@ func TestDumpRequest(t *testing.T) { } switch b := tt.Body.(type) { case []byte: - tt.Req.Body = ioutil.NopCloser(bytes.NewBuffer(b)) + tt.Req.Body = ioutil.NopCloser(bytes.NewReader(b)) case func() io.ReadCloser: tt.Req.Body = b() } @@ -155,6 +158,9 @@ func TestDumpRequest(t *testing.T) { } } } + if dg := runtime.NumGoroutine() - numg0; dg > 4 { + t.Errorf("Unexpectedly large number of new goroutines: %d new", dg) + } } func chunk(s string) string { @@ -176,3 +182,82 @@ func mustNewRequest(method, url string, body io.Reader) *http.Request { } return req } + +var dumpResTests = []struct { + res *http.Response + body bool + want string +}{ + { + res: &http.Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 50, + Header: http.Header{ + "Foo": []string{"Bar"}, + }, + Body: ioutil.NopCloser(strings.NewReader("foo")), // shouldn't be used + }, + body: false, // to verify we see 50, not empty or 3. + want: `HTTP/1.1 200 OK +Content-Length: 50 +Foo: Bar`, + }, + + { + res: &http.Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 3, + Body: ioutil.NopCloser(strings.NewReader("foo")), + }, + body: true, + want: `HTTP/1.1 200 OK +Content-Length: 3 + +foo`, + }, + + { + res: &http.Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: -1, + Body: ioutil.NopCloser(strings.NewReader("foo")), + TransferEncoding: []string{"chunked"}, + }, + body: true, + want: `HTTP/1.1 200 OK +Transfer-Encoding: chunked + +3 +foo +0`, + }, +} + +func TestDumpResponse(t *testing.T) { + for i, tt := range dumpResTests { + gotb, err := DumpResponse(tt.res, tt.body) + if err != nil { + t.Errorf("%d. DumpResponse = %v", i, err) + continue + } + got := string(gotb) + got = strings.TrimSpace(got) + got = strings.Replace(got, "\r", "", -1) + + if got != tt.want { + t.Errorf("%d.\nDumpResponse got:\n%s\n\nWant:\n%s\n", i, got, tt.want) + } + } +} diff --git a/src/pkg/net/http/httputil/httputil.go b/src/pkg/net/http/httputil/httputil.go new file mode 100644 index 000000000..74fb6c655 --- /dev/null +++ b/src/pkg/net/http/httputil/httputil.go @@ -0,0 +1,32 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package httputil provides HTTP utility functions, complementing the +// more common ones in the net/http package. +package httputil + +import "io" + +// NewChunkedReader returns a new chunkedReader that translates the data read from r +// out of HTTP "chunked" format before returning it. +// The chunkedReader returns io.EOF when the final 0-length chunk is read. +// +// NewChunkedReader is not needed by normal applications. The http package +// automatically decodes chunking when reading response bodies. +func NewChunkedReader(r io.Reader) io.Reader { + return newChunkedReader(r) +} + +// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP +// "chunked" format before writing them to w. Closing the returned chunkedWriter +// sends the final 0-length chunk that marks the end of the stream. +// +// NewChunkedWriter is not needed by normal applications. The http +// package adds chunking automatically if handlers don't set a +// Content-Length header. Using NewChunkedWriter inside a handler +// would result in double chunking or chunking with a Content-Length +// length, both of which are wrong. +func NewChunkedWriter(w io.Writer) io.WriteCloser { + return newChunkedWriter(w) +} diff --git a/src/pkg/net/http/httputil/persist.go b/src/pkg/net/http/httputil/persist.go index 507938aca..987bcc96b 100644 --- a/src/pkg/net/http/httputil/persist.go +++ b/src/pkg/net/http/httputil/persist.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package httputil provides HTTP utility functions, complementing the -// more common ones in the net/http package. package httputil import ( @@ -33,8 +31,8 @@ var errClosed = errors.New("i/o operation on closed connection") // i.e. requests can be read out of sync (but in the same order) while the // respective responses are sent. // -// ServerConn is low-level and should not be needed by most applications. -// See Server. +// ServerConn is low-level and old. Applications should instead use Server +// in the net/http package. type ServerConn struct { lk sync.Mutex // read-write protects the following fields c net.Conn @@ -47,8 +45,11 @@ type ServerConn struct { pipe textproto.Pipeline } -// NewServerConn returns a new ServerConn reading and writing c. If r is not +// NewServerConn returns a new ServerConn reading and writing c. If r is not // nil, it is the buffer to use when reading c. +// +// ServerConn is low-level and old. Applications should instead use Server +// in the net/http package. func NewServerConn(c net.Conn, r *bufio.Reader) *ServerConn { if r == nil { r = bufio.NewReader(c) @@ -223,8 +224,8 @@ func (sc *ServerConn) Write(req *http.Request, resp *http.Response) error { // supports hijacking the connection calling Hijack to // regain control of the underlying net.Conn and deal with it as desired. // -// ClientConn is low-level and should not be needed by most applications. -// See Client. +// ClientConn is low-level and old. Applications should instead use +// Client or Transport in the net/http package. type ClientConn struct { lk sync.Mutex // read-write protects the following fields c net.Conn @@ -240,6 +241,9 @@ type ClientConn struct { // NewClientConn returns a new ClientConn reading and writing c. If r is not // nil, it is the buffer to use when reading c. +// +// ClientConn is low-level and old. Applications should use Client or +// Transport in the net/http package. func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn { if r == nil { r = bufio.NewReader(c) @@ -254,6 +258,9 @@ func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn { // NewProxyClientConn works like NewClientConn but writes Requests // using Request's WriteProxy method. +// +// New code should not use NewProxyClientConn. See Client or +// Transport in the net/http package instead. func NewProxyClientConn(c net.Conn, r *bufio.Reader) *ClientConn { cc := NewClientConn(c, r) cc.writeReq = (*http.Request).WriteProxy diff --git a/src/pkg/net/http/httputil/reverseproxy.go b/src/pkg/net/http/httputil/reverseproxy.go index 1990f64db..48ada5f5f 100644 --- a/src/pkg/net/http/httputil/reverseproxy.go +++ b/src/pkg/net/http/httputil/reverseproxy.go @@ -144,6 +144,10 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } defer res.Body.Close() + for _, h := range hopHeaders { + res.Header.Del(h) + } + copyHeader(rw.Header(), res.Header) rw.WriteHeader(res.StatusCode) diff --git a/src/pkg/net/http/httputil/reverseproxy_test.go b/src/pkg/net/http/httputil/reverseproxy_test.go index 1c0444ec4..e9539b44b 100644 --- a/src/pkg/net/http/httputil/reverseproxy_test.go +++ b/src/pkg/net/http/httputil/reverseproxy_test.go @@ -16,6 +16,12 @@ import ( "time" ) +const fakeHopHeader = "X-Fake-Hop-Header-For-Test" + +func init() { + hopHeaders = append(hopHeaders, fakeHopHeader) +} + func TestReverseProxy(t *testing.T) { const backendResponse = "I am the backend" const backendStatus = 404 @@ -36,6 +42,10 @@ func TestReverseProxy(t *testing.T) { t.Errorf("backend got Host header %q, want %q", g, e) } w.Header().Set("X-Foo", "bar") + w.Header().Set("Upgrade", "foo") + w.Header().Set(fakeHopHeader, "foo") + w.Header().Add("X-Multi-Value", "foo") + w.Header().Add("X-Multi-Value", "bar") http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"}) w.WriteHeader(backendStatus) w.Write([]byte(backendResponse)) @@ -64,6 +74,12 @@ func TestReverseProxy(t *testing.T) { if g, e := res.Header.Get("X-Foo"), "bar"; g != e { t.Errorf("got X-Foo %q; expected %q", g, e) } + if c := res.Header.Get(fakeHopHeader); c != "" { + t.Errorf("got %s header value %q", fakeHopHeader, c) + } + if g, e := len(res.Header["X-Multi-Value"]), 2; g != e { + t.Errorf("got %d X-Multi-Value header values; expected %d", g, e) + } if g, e := len(res.Header["Set-Cookie"]), 1; g != e { t.Fatalf("got %d SetCookies, want %d", g, e) } diff --git a/src/pkg/net/http/proxy_test.go b/src/pkg/net/http/proxy_test.go index 449ccaeea..b6aed3792 100644 --- a/src/pkg/net/http/proxy_test.go +++ b/src/pkg/net/http/proxy_test.go @@ -35,12 +35,8 @@ var UseProxyTests = []struct { } func TestUseProxy(t *testing.T) { - oldenv := os.Getenv("NO_PROXY") - defer os.Setenv("NO_PROXY", oldenv) - - no_proxy := "foobar.com, .barbaz.net" - os.Setenv("NO_PROXY", no_proxy) - + ResetProxyEnv() + os.Setenv("NO_PROXY", "foobar.com, .barbaz.net") for _, test := range UseProxyTests { if useProxy(test.host+":80") != test.match { t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match) @@ -71,8 +67,15 @@ func TestCacheKeys(t *testing.T) { proxy = u } cm := connectMethod{proxy, tt.scheme, tt.addr} - if cm.String() != tt.key { - t.Fatalf("{%q, %q, %q} cache key %q; want %q", tt.proxy, tt.scheme, tt.addr, cm.String(), tt.key) + if got := cm.key().String(); got != tt.key { + t.Fatalf("{%q, %q, %q} cache key = %q; want %q", tt.proxy, tt.scheme, tt.addr, got, tt.key) } } } + +func ResetProxyEnv() { + for _, v := range []string{"HTTP_PROXY", "http_proxy", "NO_PROXY", "no_proxy"} { + os.Setenv(v, "") + } + ResetCachedEnvironment() +} diff --git a/src/pkg/net/http/race.go b/src/pkg/net/http/race.go new file mode 100644 index 000000000..766503967 --- /dev/null +++ b/src/pkg/net/http/race.go @@ -0,0 +1,11 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build race + +package http + +func init() { + raceEnabled = true +} diff --git a/src/pkg/net/http/request.go b/src/pkg/net/http/request.go index 57b5d0948..a67092066 100644 --- a/src/pkg/net/http/request.go +++ b/src/pkg/net/http/request.go @@ -20,6 +20,7 @@ import ( "net/url" "strconv" "strings" + "sync" ) const ( @@ -68,18 +69,31 @@ var reqWriteExcludeHeader = map[string]bool{ // A Request represents an HTTP request received by a server // or to be sent by a client. +// +// The field semantics differ slightly between client and server +// usage. In addition to the notes on the fields below, see the +// documentation for Request.Write and RoundTripper. type Request struct { - Method string // GET, POST, PUT, etc. + // Method specifies the HTTP method (GET, POST, PUT, etc.). + // For client requests an empty string means GET. + Method string - // URL is created from the URI supplied on the Request-Line - // as stored in RequestURI. + // URL specifies either the URI being requested (for server + // requests) or the URL to access (for client requests). + // + // For server requests the URL is parsed from the URI + // supplied on the Request-Line as stored in RequestURI. For + // most requests, fields other than Path and RawQuery will be + // empty. (See RFC 2616, Section 5.1.2) // - // For most requests, fields other than Path and RawQuery - // will be empty. (See RFC 2616, Section 5.1.2) + // For client requests, the URL's Host specifies the server to + // connect to, while the Request's Host field optionally + // specifies the Host header value to send in the HTTP + // request. URL *url.URL // The protocol version for incoming requests. - // Outgoing requests always use HTTP/1.1. + // Client requests always use HTTP/1.1. Proto string // "HTTP/1.0" ProtoMajor int // 1 ProtoMinor int // 0 @@ -103,15 +117,20 @@ type Request struct { // The request parser implements this by canonicalizing the // name, making the first character and any characters // following a hyphen uppercase and the rest lowercase. + // + // For client requests certain headers are automatically + // added and may override values in Header. + // + // See the documentation for the Request.Write method. Header Header // Body is the request's body. // - // For client requests, a nil body means the request has no + // For client requests a nil body means the request has no // body, such as a GET request. The HTTP Client's Transport // is responsible for calling the Close method. // - // For server requests, the Request Body is always non-nil + // For server requests the Request Body is always non-nil // but will return EOF immediately when no body is present. // The Server will close the request body. The ServeHTTP // Handler does not need to. @@ -121,7 +140,7 @@ type Request struct { // The value -1 indicates that the length is unknown. // Values >= 0 indicate that the given number of bytes may // be read from Body. - // For outgoing requests, a value of 0 means unknown if Body is not nil. + // For client requests, a value of 0 means unknown if Body is not nil. ContentLength int64 // TransferEncoding lists the transfer encodings from outermost to @@ -132,13 +151,18 @@ type Request struct { TransferEncoding []string // Close indicates whether to close the connection after - // replying to this request. + // replying to this request (for servers) or after sending + // the request (for clients). Close bool - // The host on which the URL is sought. - // Per RFC 2616, this is either the value of the Host: header - // or the host name given in the URL itself. + // For server requests Host specifies the host on which the + // URL is sought. Per RFC 2616, this is either the value of + // the "Host" header or the host name given in the URL itself. // It may be of the form "host:port". + // + // For client requests Host optionally overrides the Host + // header to send. If empty, the Request.Write method uses + // the value of URL.Host. Host string // Form contains the parsed form data, including both the URL @@ -158,12 +182,24 @@ type Request struct { // The HTTP client ignores MultipartForm and uses Body instead. MultipartForm *multipart.Form - // Trailer maps trailer keys to values. Like for Header, if the - // response has multiple trailer lines with the same key, they will be - // concatenated, delimited by commas. - // For server requests, Trailer is only populated after Body has been - // closed or fully consumed. - // Trailer support is only partially complete. + // Trailer specifies additional headers that are sent after the request + // body. + // + // For server requests the Trailer map initially contains only the + // trailer keys, with nil values. (The client declares which trailers it + // will later send.) While the handler is reading from Body, it must + // not reference Trailer. After reading from Body returns EOF, Trailer + // can be read again and will contain non-nil values, if they were sent + // by the client. + // + // For client requests Trailer must be initialized to a map containing + // the trailer keys to later send. The values may be nil or their final + // values. The ContentLength must be 0 or -1, to send a chunked request. + // After the HTTP request is sent the map values can be updated while + // the request body is read. Once the body returns EOF, the caller must + // not mutate Trailer. + // + // Few HTTP clients, servers, or proxies support HTTP trailers. Trailer Header // RemoteAddr allows HTTP servers and other software to record @@ -381,7 +417,6 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err return err } - // TODO: split long values? (If so, should share code with Conn.Write) err = req.Header.WriteSubset(w, reqWriteExcludeHeader) if err != nil { return err @@ -494,25 +529,20 @@ func parseRequestLine(line string) (method, requestURI, proto string, ok bool) { return line[:s1], line[s1+1 : s2], line[s2+1:], true } -// TODO(bradfitz): use a sync.Cache when available -var textprotoReaderCache = make(chan *textproto.Reader, 4) +var textprotoReaderPool sync.Pool func newTextprotoReader(br *bufio.Reader) *textproto.Reader { - select { - case r := <-textprotoReaderCache: - r.R = br - return r - default: - return textproto.NewReader(br) + if v := textprotoReaderPool.Get(); v != nil { + tr := v.(*textproto.Reader) + tr.R = br + return tr } + return textproto.NewReader(br) } func putTextprotoReader(r *textproto.Reader) { r.R = nil - select { - case textprotoReaderCache <- r: - default: - } + textprotoReaderPool.Put(r) } // ReadRequest reads and parses a request from b. @@ -588,32 +618,6 @@ func ReadRequest(b *bufio.Reader) (req *Request, err error) { fixPragmaCacheControl(req.Header) - // TODO: Parse specific header values: - // Accept - // Accept-Encoding - // Accept-Language - // Authorization - // Cache-Control - // Connection - // Date - // Expect - // From - // If-Match - // If-Modified-Since - // If-None-Match - // If-Range - // If-Unmodified-Since - // Max-Forwards - // Proxy-Authorization - // Referer [sic] - // TE (transfer-codings) - // Trailer - // Transfer-Encoding - // Upgrade - // User-Agent - // Via - // Warning - err = readTransfer(req, b) if err != nil { return nil, err @@ -677,6 +681,11 @@ func parsePostForm(r *Request) (vs url.Values, err error) { return } ct := r.Header.Get("Content-Type") + // RFC 2616, section 7.2.1 - empty type + // SHOULD be treated as application/octet-stream + if ct == "" { + ct = "application/octet-stream" + } ct, _, err = mime.ParseMediaType(ct) switch { case ct == "application/x-www-form-urlencoded": @@ -707,7 +716,7 @@ func parsePostForm(r *Request) (vs url.Values, err error) { // orders to call too many functions here. // Clean this up and write more tests. // request_test.go contains the start of this, - // in TestRequestMultipartCallOrder. + // in TestParseMultipartFormOrder and others. } return } @@ -727,7 +736,7 @@ func parsePostForm(r *Request) (vs url.Values, err error) { func (r *Request) ParseForm() error { var err error if r.PostForm == nil { - if r.Method == "POST" || r.Method == "PUT" { + if r.Method == "POST" || r.Method == "PUT" || r.Method == "PATCH" { r.PostForm, err = parsePostForm(r) } if r.PostForm == nil { @@ -780,9 +789,7 @@ func (r *Request) ParseMultipartForm(maxMemory int64) error { } mr, err := r.multipartReader() - if err == ErrNotMultipart { - return nil - } else if err != nil { + if err != nil { return err } @@ -860,3 +867,9 @@ func (r *Request) wantsHttp10KeepAlive() bool { func (r *Request) wantsClose() bool { return hasToken(r.Header.get("Connection"), "close") } + +func (r *Request) closeBody() { + if r.Body != nil { + r.Body.Close() + } +} diff --git a/src/pkg/net/http/request_test.go b/src/pkg/net/http/request_test.go index 89303c336..b9fa3c2bf 100644 --- a/src/pkg/net/http/request_test.go +++ b/src/pkg/net/http/request_test.go @@ -60,6 +60,37 @@ func TestPostQuery(t *testing.T) { } } +func TestPatchQuery(t *testing.T) { + req, _ := NewRequest("PATCH", "http://www.google.com/search?q=foo&q=bar&both=x&prio=1&empty=not", + strings.NewReader("z=post&both=y&prio=2&empty=")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + + if q := req.FormValue("q"); q != "foo" { + t.Errorf(`req.FormValue("q") = %q, want "foo"`, q) + } + if z := req.FormValue("z"); z != "post" { + t.Errorf(`req.FormValue("z") = %q, want "post"`, z) + } + if bq, found := req.PostForm["q"]; found { + t.Errorf(`req.PostForm["q"] = %q, want no entry in map`, bq) + } + if bz := req.PostFormValue("z"); bz != "post" { + t.Errorf(`req.PostFormValue("z") = %q, want "post"`, bz) + } + if qs := req.Form["q"]; !reflect.DeepEqual(qs, []string{"foo", "bar"}) { + t.Errorf(`req.Form["q"] = %q, want ["foo", "bar"]`, qs) + } + if both := req.Form["both"]; !reflect.DeepEqual(both, []string{"y", "x"}) { + t.Errorf(`req.Form["both"] = %q, want ["y", "x"]`, both) + } + if prio := req.FormValue("prio"); prio != "2" { + t.Errorf(`req.FormValue("prio") = %q, want "2" (from body)`, prio) + } + if empty := req.FormValue("empty"); empty != "" { + t.Errorf(`req.FormValue("empty") = %q, want "" (from body)`, empty) + } +} + type stringMap map[string][]string type parseContentTypeTest struct { shouldError bool @@ -68,8 +99,9 @@ type parseContentTypeTest struct { var parseContentTypeTests = []parseContentTypeTest{ {false, stringMap{"Content-Type": {"text/plain"}}}, - // Non-existent keys are not placed. The value nil is illegal. - {true, stringMap{}}, + // Empty content type is legal - shoult be treated as + // application/octet-stream (RFC 2616, section 7.2.1) + {false, stringMap{}}, {true, stringMap{"Content-Type": {"text/plain; boundary="}}}, {false, stringMap{"Content-Type": {"application/unknown"}}}, } @@ -79,7 +111,7 @@ func TestParseFormUnknownContentType(t *testing.T) { req := &Request{ Method: "POST", Header: Header(test.contentType), - Body: ioutil.NopCloser(bytes.NewBufferString("body")), + Body: ioutil.NopCloser(strings.NewReader("body")), } err := req.ParseForm() switch { @@ -122,7 +154,25 @@ func TestMultipartReader(t *testing.T) { req.Header = Header{"Content-Type": {"text/plain"}} multipart, err = req.MultipartReader() if multipart != nil { - t.Errorf("unexpected multipart for text/plain") + t.Error("unexpected multipart for text/plain") + } +} + +func TestParseMultipartForm(t *testing.T) { + req := &Request{ + Method: "POST", + Header: Header{"Content-Type": {`multipart/form-data; boundary="foo123"`}}, + Body: ioutil.NopCloser(new(bytes.Buffer)), + } + err := req.ParseMultipartForm(25) + if err == nil { + t.Error("expected multipart EOF, got nil") + } + + req.Header = Header{"Content-Type": {"text/plain"}} + err = req.ParseMultipartForm(25) + if err != ErrNotMultipart { + t.Error("expected ErrNotMultipart for text/plain") } } @@ -188,25 +238,72 @@ func TestMultipartRequestAuto(t *testing.T) { validateTestMultipartContents(t, req, true) } -func TestEmptyMultipartRequest(t *testing.T) { - // Test that FormValue and FormFile automatically invoke - // ParseMultipartForm and return the right values. - req, err := NewRequest("GET", "/", nil) - if err != nil { - t.Errorf("NewRequest err = %q", err) - } +func TestMissingFileMultipartRequest(t *testing.T) { + // Test that FormFile returns an error if + // the named file is missing. + req := newTestMultipartRequest(t) testMissingFile(t, req) } -func TestRequestMultipartCallOrder(t *testing.T) { +// Test that FormValue invokes ParseMultipartForm. +func TestFormValueCallsParseMultipartForm(t *testing.T) { + req, _ := NewRequest("POST", "http://www.google.com/", strings.NewReader("z=post")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + if req.Form != nil { + t.Fatal("Unexpected request Form, want nil") + } + req.FormValue("z") + if req.Form == nil { + t.Fatal("ParseMultipartForm not called by FormValue") + } +} + +// Test that FormFile invokes ParseMultipartForm. +func TestFormFileCallsParseMultipartForm(t *testing.T) { req := newTestMultipartRequest(t) - _, err := req.MultipartReader() - if err != nil { + if req.Form != nil { + t.Fatal("Unexpected request Form, want nil") + } + req.FormFile("") + if req.Form == nil { + t.Fatal("ParseMultipartForm not called by FormFile") + } +} + +// Test that ParseMultipartForm errors if called +// after MultipartReader on the same request. +func TestParseMultipartFormOrder(t *testing.T) { + req := newTestMultipartRequest(t) + if _, err := req.MultipartReader(); err != nil { t.Fatalf("MultipartReader: %v", err) } - err = req.ParseMultipartForm(1024) - if err == nil { - t.Errorf("expected an error from ParseMultipartForm after call to MultipartReader") + if err := req.ParseMultipartForm(1024); err == nil { + t.Fatal("expected an error from ParseMultipartForm after call to MultipartReader") + } +} + +// Test that MultipartReader errors if called +// after ParseMultipartForm on the same request. +func TestMultipartReaderOrder(t *testing.T) { + req := newTestMultipartRequest(t) + if err := req.ParseMultipartForm(25); err != nil { + t.Fatalf("ParseMultipartForm: %v", err) + } + defer req.MultipartForm.RemoveAll() + if _, err := req.MultipartReader(); err == nil { + t.Fatal("expected an error from MultipartReader after call to ParseMultipartForm") + } +} + +// Test that FormFile errors if called after +// MultipartReader on the same request. +func TestFormFileOrder(t *testing.T) { + req := newTestMultipartRequest(t) + if _, err := req.MultipartReader(); err != nil { + t.Fatalf("MultipartReader: %v", err) + } + if _, _, err := req.FormFile(""); err == nil { + t.Fatal("expected an error from FormFile after call to MultipartReader") } } @@ -343,7 +440,7 @@ func testMissingFile(t *testing.T, req *Request) { } func newTestMultipartRequest(t *testing.T) *Request { - b := bytes.NewBufferString(strings.Replace(message, "\n", "\r\n", -1)) + b := strings.NewReader(strings.Replace(message, "\n", "\r\n", -1)) req, err := NewRequest("POST", "/", b) if err != nil { t.Fatal("NewRequest:", err) diff --git a/src/pkg/net/http/requestwrite_test.go b/src/pkg/net/http/requestwrite_test.go index b27b1f7ce..dc0e204ca 100644 --- a/src/pkg/net/http/requestwrite_test.go +++ b/src/pkg/net/http/requestwrite_test.go @@ -310,6 +310,46 @@ var reqWriteTests = []reqWriteTest{ WantError: errors.New("http: Request.ContentLength=5 with nil Body"), }, + // Request with a 0 ContentLength and a body with 1 byte content and an error. + { + Req: Request{ + Method: "POST", + URL: mustParseURL("/"), + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 0, // as if unset by user + }, + + Body: func() io.ReadCloser { + err := errors.New("Custom reader error") + errReader := &errorReader{err} + return ioutil.NopCloser(io.MultiReader(strings.NewReader("x"), errReader)) + }, + + WantError: errors.New("Custom reader error"), + }, + + // Request with a 0 ContentLength and a body without content and an error. + { + Req: Request{ + Method: "POST", + URL: mustParseURL("/"), + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 0, // as if unset by user + }, + + Body: func() io.ReadCloser { + err := errors.New("Custom reader error") + errReader := &errorReader{err} + return ioutil.NopCloser(errReader) + }, + + WantError: errors.New("Custom reader error"), + }, + // Verify that DumpRequest preserves the HTTP version number, doesn't add a Host, // and doesn't add a User-Agent. { @@ -427,7 +467,7 @@ func TestRequestWrite(t *testing.T) { } switch b := tt.Body.(type) { case []byte: - tt.Req.Body = ioutil.NopCloser(bytes.NewBuffer(b)) + tt.Req.Body = ioutil.NopCloser(bytes.NewReader(b)) case func() io.ReadCloser: tt.Req.Body = b() } diff --git a/src/pkg/net/http/response.go b/src/pkg/net/http/response.go index 35d0ba3bb..5d2c39080 100644 --- a/src/pkg/net/http/response.go +++ b/src/pkg/net/http/response.go @@ -8,6 +8,8 @@ package http import ( "bufio" + "bytes" + "crypto/tls" "errors" "io" "net/textproto" @@ -45,7 +47,8 @@ type Response struct { // // The http Client and Transport guarantee that Body is always // non-nil, even on responses without a body or responses with - // a zero-lengthed body. + // a zero-length body. It is the caller's responsibility to + // close Body. // // The Body is automatically dechunked if the server replied // with a "chunked" Transfer-Encoding. @@ -74,6 +77,12 @@ type Response struct { // Request's Body is nil (having already been consumed). // This is only populated for Client requests. Request *Request + + // TLS contains information about the TLS connection on which the + // response was received. It is nil for unencrypted responses. + // The pointer is shared between responses and should not be + // modified. + TLS *tls.ConnectionState } // Cookies parses and returns the cookies set in the Set-Cookie headers. @@ -141,6 +150,9 @@ func ReadResponse(r *bufio.Reader, req *Request) (*Response, error) { // Parse the response headers. mimeHeader, err := tp.ReadMIMEHeader() if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } return nil, err } resp.Header = Header(mimeHeader) @@ -187,8 +199,8 @@ func (r *Response) ProtoAtLeast(major, minor int) bool { // ContentLength // Header, values for non-canonical keys will have unpredictable behavior // +// Body is closed after it is sent. func (r *Response) Write(w io.Writer) error { - // Status line text := r.Status if text == "" { @@ -201,10 +213,45 @@ func (r *Response) Write(w io.Writer) error { protoMajor, protoMinor := strconv.Itoa(r.ProtoMajor), strconv.Itoa(r.ProtoMinor) statusCode := strconv.Itoa(r.StatusCode) + " " text = strings.TrimPrefix(text, statusCode) - io.WriteString(w, "HTTP/"+protoMajor+"."+protoMinor+" "+statusCode+text+"\r\n") + if _, err := io.WriteString(w, "HTTP/"+protoMajor+"."+protoMinor+" "+statusCode+text+"\r\n"); err != nil { + return err + } + + // Clone it, so we can modify r1 as needed. + r1 := new(Response) + *r1 = *r + if r1.ContentLength == 0 && r1.Body != nil { + // Is it actually 0 length? Or just unknown? + var buf [1]byte + n, err := r1.Body.Read(buf[:]) + if err != nil && err != io.EOF { + return err + } + if n == 0 { + // Reset it to a known zero reader, in case underlying one + // is unhappy being read repeatedly. + r1.Body = eofReader + } else { + r1.ContentLength = -1 + r1.Body = struct { + io.Reader + io.Closer + }{ + io.MultiReader(bytes.NewReader(buf[:1]), r.Body), + r.Body, + } + } + } + // If we're sending a non-chunked HTTP/1.1 response without a + // content-length, the only way to do that is the old HTTP/1.0 + // way, by noting the EOF with a connection close, so we need + // to set Close. + if r1.ContentLength == -1 && !r1.Close && r1.ProtoAtLeast(1, 1) && !chunked(r1.TransferEncoding) { + r1.Close = true + } // Process Body,ContentLength,Close,Trailer - tw, err := newTransferWriter(r) + tw, err := newTransferWriter(r1) if err != nil { return err } @@ -219,8 +266,19 @@ func (r *Response) Write(w io.Writer) error { return err } + // contentLengthAlreadySent may have been already sent for + // POST/PUT requests, even if zero length. See Issue 8180. + contentLengthAlreadySent := tw.shouldSendContentLength() + if r1.ContentLength == 0 && !chunked(r1.TransferEncoding) && !contentLengthAlreadySent { + if _, err := io.WriteString(w, "Content-Length: 0\r\n"); err != nil { + return err + } + } + // End-of-header - io.WriteString(w, "\r\n") + if _, err := io.WriteString(w, "\r\n"); err != nil { + return err + } // Write body and trailer err = tw.WriteBody(w) diff --git a/src/pkg/net/http/response_test.go b/src/pkg/net/http/response_test.go index 5044306a8..4b8946f7a 100644 --- a/src/pkg/net/http/response_test.go +++ b/src/pkg/net/http/response_test.go @@ -14,6 +14,7 @@ import ( "io/ioutil" "net/url" "reflect" + "regexp" "strings" "testing" ) @@ -28,6 +29,10 @@ func dummyReq(method string) *Request { return &Request{Method: method} } +func dummyReq11(method string) *Request { + return &Request{Method: method, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1} +} + var respTests = []respTest{ // Unchunked response without Content-Length. { @@ -406,8 +411,7 @@ func TestWriteResponse(t *testing.T) { t.Errorf("#%d: %v", i, err) continue } - bout := bytes.NewBuffer(nil) - err = resp.Write(bout) + err = resp.Write(ioutil.Discard) if err != nil { t.Errorf("#%d: %v", i, err) continue @@ -506,6 +510,9 @@ func TestReadResponseCloseInMiddle(t *testing.T) { rest, err := ioutil.ReadAll(bufr) checkErr(err, "ReadAll on remainder") if e, g := "Next Request Here", string(rest); e != g { + g = regexp.MustCompile(`(xx+)`).ReplaceAllStringFunc(g, func(match string) string { + return fmt.Sprintf("x(repeated x%d)", len(match)) + }) fatalf("remainder = %q, expected %q", g, e) } } @@ -615,6 +622,15 @@ func TestResponseContentLengthShortBody(t *testing.T) { } } +func TestReadResponseUnexpectedEOF(t *testing.T) { + br := bufio.NewReader(strings.NewReader("HTTP/1.1 301 Moved Permanently\r\n" + + "Location: http://example.com")) + _, err := ReadResponse(br, nil) + if err != io.ErrUnexpectedEOF { + t.Errorf("ReadResponse = %v; want io.ErrUnexpectedEOF", err) + } +} + func TestNeedsSniff(t *testing.T) { // needsSniff returns true with an empty response. r := &response{} diff --git a/src/pkg/net/http/responsewrite_test.go b/src/pkg/net/http/responsewrite_test.go index 5c10e2161..585b13b85 100644 --- a/src/pkg/net/http/responsewrite_test.go +++ b/src/pkg/net/http/responsewrite_test.go @@ -7,6 +7,7 @@ package http import ( "bytes" "io/ioutil" + "strings" "testing" ) @@ -25,7 +26,7 @@ func TestResponseWrite(t *testing.T) { ProtoMinor: 0, Request: dummyReq("GET"), Header: Header{}, - Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), + Body: ioutil.NopCloser(strings.NewReader("abcdef")), ContentLength: 6, }, @@ -41,13 +42,113 @@ func TestResponseWrite(t *testing.T) { ProtoMinor: 0, Request: dummyReq("GET"), Header: Header{}, - Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), + Body: ioutil.NopCloser(strings.NewReader("abcdef")), ContentLength: -1, }, "HTTP/1.0 200 OK\r\n" + "\r\n" + "abcdef", }, + // HTTP/1.1 response with unknown length and Connection: close + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("GET"), + Header: Header{}, + Body: ioutil.NopCloser(strings.NewReader("abcdef")), + ContentLength: -1, + Close: true, + }, + "HTTP/1.1 200 OK\r\n" + + "Connection: close\r\n" + + "\r\n" + + "abcdef", + }, + // HTTP/1.1 response with unknown length and not setting connection: close + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq11("GET"), + Header: Header{}, + Body: ioutil.NopCloser(strings.NewReader("abcdef")), + ContentLength: -1, + Close: false, + }, + "HTTP/1.1 200 OK\r\n" + + "Connection: close\r\n" + + "\r\n" + + "abcdef", + }, + // HTTP/1.1 response with unknown length and not setting connection: close, but + // setting chunked. + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq11("GET"), + Header: Header{}, + Body: ioutil.NopCloser(strings.NewReader("abcdef")), + ContentLength: -1, + TransferEncoding: []string{"chunked"}, + Close: false, + }, + "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "6\r\nabcdef\r\n0\r\n\r\n", + }, + // HTTP/1.1 response 0 content-length, and nil body + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq11("GET"), + Header: Header{}, + Body: nil, + ContentLength: 0, + Close: false, + }, + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 0\r\n" + + "\r\n", + }, + // HTTP/1.1 response 0 content-length, and non-nil empty body + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq11("GET"), + Header: Header{}, + Body: ioutil.NopCloser(strings.NewReader("")), + ContentLength: 0, + Close: false, + }, + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 0\r\n" + + "\r\n", + }, + // HTTP/1.1 response 0 content-length, and non-nil non-empty body + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq11("GET"), + Header: Header{}, + Body: ioutil.NopCloser(strings.NewReader("foo")), + ContentLength: 0, + Close: false, + }, + "HTTP/1.1 200 OK\r\n" + + "Connection: close\r\n" + + "\r\nfoo", + }, // HTTP/1.1, chunked coding; empty trailer; close { Response{ @@ -56,7 +157,7 @@ func TestResponseWrite(t *testing.T) { ProtoMinor: 1, Request: dummyReq("GET"), Header: Header{}, - Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), + Body: ioutil.NopCloser(strings.NewReader("abcdef")), ContentLength: 6, TransferEncoding: []string{"chunked"}, Close: true, @@ -90,6 +191,22 @@ func TestResponseWrite(t *testing.T) { "Foo: Bar Baz\r\n" + "\r\n", }, + + // Want a single Content-Length header. Fixing issue 8180 where + // there were two. + { + Response{ + StatusCode: StatusOK, + ProtoMajor: 1, + ProtoMinor: 1, + Request: &Request{Method: "POST"}, + Header: Header{}, + ContentLength: 0, + TransferEncoding: nil, + Body: nil, + }, + "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n", + }, } for i := range respWriteTests { diff --git a/src/pkg/net/http/serve_test.go b/src/pkg/net/http/serve_test.go index 955112bc2..9e4d226bf 100644 --- a/src/pkg/net/http/serve_test.go +++ b/src/pkg/net/http/serve_test.go @@ -419,7 +419,7 @@ func TestServeMuxHandlerRedirects(t *testing.T) { func TestMuxRedirectLeadingSlashes(t *testing.T) { paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"} for _, path := range paths { - req, err := ReadRequest(bufio.NewReader(bytes.NewBufferString("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n"))) + req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n"))) if err != nil { t.Errorf("%s", err) } @@ -441,6 +441,9 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { } func TestServerTimeouts(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7237") + } defer afterTest(t) reqNum := 0 ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { @@ -517,6 +520,9 @@ func TestServerTimeouts(t *testing.T) { // shouldn't cause a handler to block forever on reads (next HTTP // request) that will never happen. func TestOnlyWriteTimeout(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7237") + } defer afterTest(t) var conn net.Conn var afterTimeoutErrc = make(chan error, 1) @@ -840,9 +846,14 @@ func TestHeadResponses(t *testing.T) { } func TestTLSHandshakeTimeout(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7237") + } defer afterTest(t) ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + errc := make(chanWriter, 10) // but only expecting 1 ts.Config.ReadTimeout = 250 * time.Millisecond + ts.Config.ErrorLog = log.New(errc, "", 0) ts.StartTLS() defer ts.Close() conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -857,6 +868,14 @@ func TestTLSHandshakeTimeout(t *testing.T) { t.Errorf("Read = %d, %v; want an error and no bytes", n, err) } }) + select { + case v := <-errc: + if !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") { + t.Errorf("expected a TLS handshake timeout error; got %q", v) + } + case <-time.After(5 * time.Second): + t.Errorf("timeout waiting for logged error") + } } func TestTLSServer(t *testing.T) { @@ -869,6 +888,7 @@ func TestTLSServer(t *testing.T) { } } })) + ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0) defer ts.Close() // Connect an idle TCP connection to this server before we run @@ -913,31 +933,50 @@ func TestTLSServer(t *testing.T) { } type serverExpectTest struct { - contentLength int // of request body + contentLength int // of request body + chunked bool expectation string // e.g. "100-continue" readBody bool // whether handler should read the body (if false, sends StatusUnauthorized) expectedResponse string // expected substring in first line of http response } +func expectTest(contentLength int, expectation string, readBody bool, expectedResponse string) serverExpectTest { + return serverExpectTest{ + contentLength: contentLength, + expectation: expectation, + readBody: readBody, + expectedResponse: expectedResponse, + } +} + var serverExpectTests = []serverExpectTest{ // Normal 100-continues, case-insensitive. - {100, "100-continue", true, "100 Continue"}, - {100, "100-cOntInUE", true, "100 Continue"}, + expectTest(100, "100-continue", true, "100 Continue"), + expectTest(100, "100-cOntInUE", true, "100 Continue"), // No 100-continue. - {100, "", true, "200 OK"}, + expectTest(100, "", true, "200 OK"), // 100-continue but requesting client to deny us, // so it never reads the body. - {100, "100-continue", false, "401 Unauthorized"}, + expectTest(100, "100-continue", false, "401 Unauthorized"), // Likewise without 100-continue: - {100, "", false, "401 Unauthorized"}, + expectTest(100, "", false, "401 Unauthorized"), // Non-standard expectations are failures - {0, "a-pony", false, "417 Expectation Failed"}, + expectTest(0, "a-pony", false, "417 Expectation Failed"), - // Expect-100 requested but no body - {0, "100-continue", true, "400 Bad Request"}, + // Expect-100 requested but no body (is apparently okay: Issue 7625) + expectTest(0, "100-continue", true, "200 OK"), + // Expect-100 requested but handler doesn't read the body + expectTest(0, "100-continue", false, "401 Unauthorized"), + // Expect-100 continue with no body, but a chunked body. + { + expectation: "100-continue", + readBody: true, + chunked: true, + expectedResponse: "100 Continue", + }, } // Tests that the server responds to the "Expect" request header @@ -966,21 +1005,38 @@ func TestServerExpect(t *testing.T) { // Only send the body immediately if we're acting like an HTTP client // that doesn't send 100-continue expectations. - writeBody := test.contentLength > 0 && strings.ToLower(test.expectation) != "100-continue" + writeBody := test.contentLength != 0 && strings.ToLower(test.expectation) != "100-continue" go func() { + contentLen := fmt.Sprintf("Content-Length: %d", test.contentLength) + if test.chunked { + contentLen = "Transfer-Encoding: chunked" + } _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+ "Connection: close\r\n"+ - "Content-Length: %d\r\n"+ + "%s\r\n"+ "Expect: %s\r\nHost: foo\r\n\r\n", - test.readBody, test.contentLength, test.expectation) + test.readBody, contentLen, test.expectation) if err != nil { t.Errorf("On test %#v, error writing request headers: %v", test, err) return } if writeBody { + var targ io.WriteCloser = struct { + io.Writer + io.Closer + }{ + conn, + ioutil.NopCloser(nil), + } + if test.chunked { + targ = httputil.NewChunkedWriter(conn) + } body := strings.Repeat("A", test.contentLength) - _, err = fmt.Fprint(conn, body) + _, err = fmt.Fprint(targ, body) + if err == nil { + err = targ.Close() + } if err != nil { if !test.readBody { // Server likely already hung up on us. @@ -1414,6 +1470,9 @@ func TestRequestBodyLimit(t *testing.T) { // TestClientWriteShutdown tests that if the client shuts down the write // side of their TCP connection, the server doesn't send a 400 Bad Request. func TestClientWriteShutdown(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7237") + } defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) defer ts.Close() @@ -1934,6 +1993,31 @@ func TestWriteAfterHijack(t *testing.T) { } } +func TestDoubleHijack(t *testing.T) { + req := reqBytes("GET / HTTP/1.1\nHost: golang.org") + var buf bytes.Buffer + conn := &rwTestConn{ + Reader: bytes.NewReader(req), + Writer: &buf, + closec: make(chan bool, 1), + } + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + conn, _, err := rw.(Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + _, _, err = rw.(Hijacker).Hijack() + if err == nil { + t.Errorf("got err = nil; want err != nil") + } + conn.Close() + }) + ln := &oneConnListener{conn: conn} + go Serve(ln, handler) + <-conn.closec +} + // http://code.google.com/p/go/issues/detail?id=5955 // Note that this does not test the "request too large" // exit path from the http server. This is intentional; @@ -2037,31 +2121,160 @@ func TestServerReaderFromOrder(t *testing.T) { } } -// Issue 6157 -func TestNoContentTypeOnNotModified(t *testing.T) { +// Issue 6157, Issue 6685 +func TestCodesPreventingContentTypeAndBody(t *testing.T) { + for _, code := range []int{StatusNotModified, StatusNoContent, StatusContinue} { + ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.URL.Path == "/header" { + w.Header().Set("Content-Length", "123") + } + w.WriteHeader(code) + if r.URL.Path == "/more" { + w.Write([]byte("stuff")) + } + })) + for _, req := range []string{ + "GET / HTTP/1.0", + "GET /header HTTP/1.0", + "GET /more HTTP/1.0", + "GET / HTTP/1.1", + "GET /header HTTP/1.1", + "GET /more HTTP/1.1", + } { + got := ht.rawResponse(req) + wantStatus := fmt.Sprintf("%d %s", code, StatusText(code)) + if !strings.Contains(got, wantStatus) { + t.Errorf("Code %d: Wanted %q Modified for %q: %s", code, wantStatus, req, got) + } else if strings.Contains(got, "Content-Length") { + t.Errorf("Code %d: Got a Content-Length from %q: %s", code, req, got) + } else if strings.Contains(got, "stuff") { + t.Errorf("Code %d: Response contains a body from %q: %s", code, req, got) + } + } + } +} + +func TestContentTypeOkayOn204(t *testing.T) { ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) { - if r.URL.Path == "/header" { - w.Header().Set("Content-Length", "123") + w.Header().Set("Content-Length", "123") // suppressed + w.Header().Set("Content-Type", "foo/bar") + w.WriteHeader(204) + })) + got := ht.rawResponse("GET / HTTP/1.1") + if !strings.Contains(got, "Content-Type: foo/bar") { + t.Errorf("Response = %q; want Content-Type: foo/bar", got) + } + if strings.Contains(got, "Content-Length: 123") { + t.Errorf("Response = %q; don't want a Content-Length", got) + } +} + +// Issue 6995 +// A server Handler can receive a Request, and then turn around and +// give a copy of that Request.Body out to the Transport (e.g. any +// proxy). So then two people own that Request.Body (both the server +// and the http client), and both think they can close it on failure. +// Therefore, all incoming server requests Bodies need to be thread-safe. +func TestTransportAndServerSharedBodyRace(t *testing.T) { + defer afterTest(t) + + const bodySize = 1 << 20 + + unblockBackend := make(chan bool) + backend := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + io.CopyN(rw, req.Body, bodySize/2) + <-unblockBackend + })) + defer backend.Close() + + backendRespc := make(chan *Response, 1) + proxy := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + if req.RequestURI == "/foo" { + rw.Write([]byte("bar")) + return } - w.WriteHeader(StatusNotModified) - if r.URL.Path == "/more" { - w.Write([]byte("stuff")) + req2, _ := NewRequest("POST", backend.URL, req.Body) + req2.ContentLength = bodySize + + bresp, err := DefaultClient.Do(req2) + if err != nil { + t.Errorf("Proxy outbound request: %v", err) + return + } + _, err = io.CopyN(ioutil.Discard, bresp.Body, bodySize/4) + if err != nil { + t.Errorf("Proxy copy error: %v", err) + return } + backendRespc <- bresp // to close later + + // Try to cause a race: Both the DefaultTransport and the proxy handler's Server + // will try to read/close req.Body (aka req2.Body) + DefaultTransport.(*Transport).CancelRequest(req2) + rw.Write([]byte("OK")) + })) + defer proxy.Close() + + req, _ := NewRequest("POST", proxy.URL, io.LimitReader(neverEnding('a'), bodySize)) + res, err := DefaultClient.Do(req) + if err != nil { + t.Fatalf("Original request: %v", err) + } + + // Cleanup, so we don't leak goroutines. + res.Body.Close() + close(unblockBackend) + (<-backendRespc).Body.Close() +} + +// Test that a hanging Request.Body.Read from another goroutine can't +// cause the Handler goroutine's Request.Body.Close to block. +func TestRequestBodyCloseDoesntBlock(t *testing.T) { + t.Skipf("Skipping known issue; see golang.org/issue/7121") + if testing.Short() { + t.Skip("skipping in -short mode") + } + defer afterTest(t) + + readErrCh := make(chan error, 1) + errCh := make(chan error, 2) + + server := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + go func(body io.Reader) { + _, err := body.Read(make([]byte, 100)) + readErrCh <- err + }(req.Body) + time.Sleep(500 * time.Millisecond) })) - for _, req := range []string{ - "GET / HTTP/1.0", - "GET /header HTTP/1.0", - "GET /more HTTP/1.0", - "GET / HTTP/1.1", - "GET /header HTTP/1.1", - "GET /more HTTP/1.1", - } { - got := ht.rawResponse(req) - if !strings.Contains(got, "304 Not Modified") { - t.Errorf("Non-304 Not Modified for %q: %s", req, got) - } else if strings.Contains(got, "Content-Length") { - t.Errorf("Got a Content-Length from %q: %s", req, got) + defer server.Close() + + closeConn := make(chan bool) + defer close(closeConn) + go func() { + conn, err := net.Dial("tcp", server.Listener.Addr().String()) + if err != nil { + errCh <- err + return + } + defer conn.Close() + _, err = conn.Write([]byte("POST / HTTP/1.1\r\nConnection: close\r\nHost: foo\r\nContent-Length: 100000\r\n\r\n")) + if err != nil { + errCh <- err + return } + // And now just block, making the server block on our + // 100000 bytes of body that will never arrive. + <-closeConn + }() + select { + case err := <-readErrCh: + if err == nil { + t.Error("Read was nil. Expected error.") + } + case err := <-errCh: + t.Error(err) + case <-time.After(5 * time.Second): + t.Error("timeout") } } @@ -2073,8 +2286,8 @@ func TestResponseWriterWriteStringAllocs(t *testing.T) { w.Write([]byte("Hello world")) } })) - before := testing.AllocsPerRun(25, func() { ht.rawResponse("GET / HTTP/1.0") }) - after := testing.AllocsPerRun(25, func() { ht.rawResponse("GET /s HTTP/1.0") }) + before := testing.AllocsPerRun(50, func() { ht.rawResponse("GET / HTTP/1.0") }) + after := testing.AllocsPerRun(50, func() { ht.rawResponse("GET /s HTTP/1.0") }) if int(after) >= int(before) { t.Errorf("WriteString allocs of %v >= Write allocs of %v", after, before) } @@ -2093,6 +2306,230 @@ func TestAppendTime(t *testing.T) { } } +func TestServerConnState(t *testing.T) { + defer afterTest(t) + handler := map[string]func(w ResponseWriter, r *Request){ + "/": func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "Hello.") + }, + "/close": func(w ResponseWriter, r *Request) { + w.Header().Set("Connection", "close") + fmt.Fprintf(w, "Hello.") + }, + "/hijack": func(w ResponseWriter, r *Request) { + c, _, _ := w.(Hijacker).Hijack() + c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello.")) + c.Close() + }, + "/hijack-panic": func(w ResponseWriter, r *Request) { + c, _, _ := w.(Hijacker).Hijack() + c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello.")) + c.Close() + panic("intentional panic") + }, + } + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + handler[r.URL.Path](w, r) + })) + defer ts.Close() + + var mu sync.Mutex // guard stateLog and connID + var stateLog = map[int][]ConnState{} + var connID = map[net.Conn]int{} + + ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0) + ts.Config.ConnState = func(c net.Conn, state ConnState) { + if c == nil { + t.Errorf("nil conn seen in state %s", state) + return + } + mu.Lock() + defer mu.Unlock() + id, ok := connID[c] + if !ok { + id = len(connID) + 1 + connID[c] = id + } + stateLog[id] = append(stateLog[id], state) + } + ts.Start() + + mustGet(t, ts.URL+"/") + mustGet(t, ts.URL+"/close") + + mustGet(t, ts.URL+"/") + mustGet(t, ts.URL+"/", "Connection", "close") + + mustGet(t, ts.URL+"/hijack") + mustGet(t, ts.URL+"/hijack-panic") + + // New->Closed + { + c, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + c.Close() + } + + // New->Active->Closed + { + c, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil { + t.Fatal(err) + } + c.Close() + } + + // New->Idle->Closed + { + c, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil { + t.Fatal(err) + } + res, err := ReadResponse(bufio.NewReader(c), nil) + if err != nil { + t.Fatal(err) + } + if _, err := io.Copy(ioutil.Discard, res.Body); err != nil { + t.Fatal(err) + } + c.Close() + } + + want := map[int][]ConnState{ + 1: []ConnState{StateNew, StateActive, StateIdle, StateActive, StateClosed}, + 2: []ConnState{StateNew, StateActive, StateIdle, StateActive, StateClosed}, + 3: []ConnState{StateNew, StateActive, StateHijacked}, + 4: []ConnState{StateNew, StateActive, StateHijacked}, + 5: []ConnState{StateNew, StateClosed}, + 6: []ConnState{StateNew, StateActive, StateClosed}, + 7: []ConnState{StateNew, StateActive, StateIdle, StateClosed}, + } + logString := func(m map[int][]ConnState) string { + var b bytes.Buffer + for id, l := range m { + fmt.Fprintf(&b, "Conn %d: ", id) + for _, s := range l { + fmt.Fprintf(&b, "%s ", s) + } + b.WriteString("\n") + } + return b.String() + } + + for i := 0; i < 5; i++ { + time.Sleep(time.Duration(i) * 50 * time.Millisecond) + mu.Lock() + match := reflect.DeepEqual(stateLog, want) + mu.Unlock() + if match { + return + } + } + + mu.Lock() + t.Errorf("Unexpected events.\nGot log: %s\n Want: %s\n", logString(stateLog), logString(want)) + mu.Unlock() +} + +func mustGet(t *testing.T, url string, headers ...string) { + req, err := NewRequest("GET", url, nil) + if err != nil { + t.Fatal(err) + } + for len(headers) > 0 { + req.Header.Add(headers[0], headers[1]) + headers = headers[2:] + } + res, err := DefaultClient.Do(req) + if err != nil { + t.Errorf("Error fetching %s: %v", url, err) + return + } + _, err = ioutil.ReadAll(res.Body) + defer res.Body.Close() + if err != nil { + t.Errorf("Error reading %s: %v", url, err) + } +} + +func TestServerKeepAlivesEnabled(t *testing.T) { + defer afterTest(t) + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + ts.Config.SetKeepAlivesEnabled(false) + ts.Start() + defer ts.Close() + res, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if !res.Close { + t.Errorf("Body.Close == false; want true") + } +} + +// golang.org/issue/7856 +func TestServerEmptyBodyRace(t *testing.T) { + defer afterTest(t) + var n int32 + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + atomic.AddInt32(&n, 1) + })) + defer ts.Close() + var wg sync.WaitGroup + const reqs = 20 + for i := 0; i < reqs; i++ { + wg.Add(1) + go func() { + defer wg.Done() + res, err := Get(ts.URL) + if err != nil { + t.Error(err) + return + } + defer res.Body.Close() + _, err = io.Copy(ioutil.Discard, res.Body) + if err != nil { + t.Error(err) + return + } + }() + } + wg.Wait() + if got := atomic.LoadInt32(&n); got != reqs { + t.Errorf("handler ran %d times; want %d", got, reqs) + } +} + +func TestServerConnStateNew(t *testing.T) { + sawNew := false // if the test is buggy, we'll race on this variable. + srv := &Server{ + ConnState: func(c net.Conn, state ConnState) { + if state == StateNew { + sawNew = true // testing that this write isn't racy + } + }, + Handler: HandlerFunc(func(w ResponseWriter, r *Request) {}), // irrelevant + } + srv.Serve(&oneConnListener{ + conn: &rwTestConn{ + Reader: strings.NewReader("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"), + Writer: ioutil.Discard, + }, + }) + if !sawNew { // testing that this read isn't racy + t.Error("StateNew not seen") + } +} + func BenchmarkClientServer(b *testing.B) { b.ReportAllocs() b.StopTimer() @@ -2108,6 +2545,7 @@ func BenchmarkClientServer(b *testing.B) { b.Fatal("Get:", err) } all, err := ioutil.ReadAll(res.Body) + res.Body.Close() if err != nil { b.Fatal("ReadAll:", err) } @@ -2128,41 +2566,33 @@ func BenchmarkClientServerParallel64(b *testing.B) { benchmarkClientServerParallel(b, 64) } -func benchmarkClientServerParallel(b *testing.B, conc int) { +func benchmarkClientServerParallel(b *testing.B, parallelism int) { b.ReportAllocs() - b.StopTimer() ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { fmt.Fprintf(rw, "Hello world.\n") })) defer ts.Close() - b.StartTimer() - - numProcs := runtime.GOMAXPROCS(-1) * conc - var wg sync.WaitGroup - wg.Add(numProcs) - n := int32(b.N) - for p := 0; p < numProcs; p++ { - go func() { - for atomic.AddInt32(&n, -1) >= 0 { - res, err := Get(ts.URL) - if err != nil { - b.Logf("Get: %v", err) - continue - } - all, err := ioutil.ReadAll(res.Body) - if err != nil { - b.Logf("ReadAll: %v", err) - continue - } - body := string(all) - if body != "Hello world.\n" { - panic("Got body: " + body) - } + b.ResetTimer() + b.SetParallelism(parallelism) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + res, err := Get(ts.URL) + if err != nil { + b.Logf("Get: %v", err) + continue } - wg.Done() - }() - } - wg.Wait() + all, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + b.Logf("ReadAll: %v", err) + continue + } + body := string(all) + if body != "Hello world.\n" { + panic("Got body: " + body) + } + } + }) } // A benchmark for profiling the server without the HTTP client code. @@ -2187,6 +2617,7 @@ func BenchmarkServer(b *testing.B) { log.Panicf("Get: %v", err) } all, err := ioutil.ReadAll(res.Body) + res.Body.Close() if err != nil { log.Panicf("ReadAll: %v", err) } @@ -2390,3 +2821,28 @@ Host: golang.org b.Errorf("b.N=%d but handled %d", b.N, handled) } } + +func BenchmarkServerHijack(b *testing.B) { + b.ReportAllocs() + req := reqBytes(`GET / HTTP/1.1 +Host: golang.org +`) + h := HandlerFunc(func(w ResponseWriter, r *Request) { + conn, _, err := w.(Hijacker).Hijack() + if err != nil { + panic(err) + } + conn.Close() + }) + conn := &rwTestConn{ + Writer: ioutil.Discard, + closec: make(chan bool, 1), + } + ln := &oneConnListener{conn: conn} + for i := 0; i < b.N; i++ { + conn.Reader = bytes.NewReader(req) + ln.conn = conn + Serve(ln, h) + <-conn.closec + } +} diff --git a/src/pkg/net/http/server.go b/src/pkg/net/http/server.go index 0e46863d5..eae097eb8 100644 --- a/src/pkg/net/http/server.go +++ b/src/pkg/net/http/server.go @@ -22,6 +22,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" ) @@ -138,6 +139,7 @@ func (c *conn) hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { buf = c.buf c.rwc = nil c.buf = nil + c.setState(rwc, StateHijacked) return } @@ -435,56 +437,52 @@ func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) { return c, nil } -// TODO: use a sync.Cache instead var ( - bufioReaderCache = make(chan *bufio.Reader, 4) - bufioWriterCache2k = make(chan *bufio.Writer, 4) - bufioWriterCache4k = make(chan *bufio.Writer, 4) + bufioReaderPool sync.Pool + bufioWriter2kPool sync.Pool + bufioWriter4kPool sync.Pool ) -func bufioWriterCache(size int) chan *bufio.Writer { +func bufioWriterPool(size int) *sync.Pool { switch size { case 2 << 10: - return bufioWriterCache2k + return &bufioWriter2kPool case 4 << 10: - return bufioWriterCache4k + return &bufioWriter4kPool } return nil } func newBufioReader(r io.Reader) *bufio.Reader { - select { - case p := <-bufioReaderCache: - p.Reset(r) - return p - default: - return bufio.NewReader(r) + if v := bufioReaderPool.Get(); v != nil { + br := v.(*bufio.Reader) + br.Reset(r) + return br } + return bufio.NewReader(r) } func putBufioReader(br *bufio.Reader) { br.Reset(nil) - select { - case bufioReaderCache <- br: - default: - } + bufioReaderPool.Put(br) } func newBufioWriterSize(w io.Writer, size int) *bufio.Writer { - select { - case p := <-bufioWriterCache(size): - p.Reset(w) - return p - default: - return bufio.NewWriterSize(w, size) + pool := bufioWriterPool(size) + if pool != nil { + if v := pool.Get(); v != nil { + bw := v.(*bufio.Writer) + bw.Reset(w) + return bw + } } + return bufio.NewWriterSize(w, size) } func putBufioWriter(bw *bufio.Writer) { bw.Reset(nil) - select { - case bufioWriterCache(bw.Available()) <- bw: - default: + if pool := bufioWriterPool(bw.Available()); pool != nil { + pool.Put(bw) } } @@ -500,6 +498,10 @@ func (srv *Server) maxHeaderBytes() int { return DefaultMaxHeaderBytes } +func (srv *Server) initialLimitedReaderSize() int64 { + return int64(srv.maxHeaderBytes()) + 4096 // bufio slop +} + // wrapper around io.ReaderCloser which on first read, sends an // HTTP/1.1 100 Continue header type expectContinueReader struct { @@ -570,7 +572,7 @@ func (c *conn) readRequest() (w *response, err error) { }() } - c.lr.N = int64(c.server.maxHeaderBytes()) + 4096 /* bufio slop */ + c.lr.N = c.server.initialLimitedReaderSize() var req *Request if req, err = ReadRequest(c.buf.Reader); err != nil { if c.lr.N == 0 { @@ -618,11 +620,11 @@ const maxPostHandlerReadBytes = 256 << 10 func (w *response) WriteHeader(code int) { if w.conn.hijacked() { - log.Print("http: response.WriteHeader on hijacked connection") + w.conn.server.logf("http: response.WriteHeader on hijacked connection") return } if w.wroteHeader { - log.Print("http: multiple response.WriteHeader calls") + w.conn.server.logf("http: multiple response.WriteHeader calls") return } w.wroteHeader = true @@ -637,7 +639,7 @@ func (w *response) WriteHeader(code int) { if err == nil && v >= 0 { w.contentLength = v } else { - log.Printf("http: invalid Content-Length of %q", cl) + w.conn.server.logf("http: invalid Content-Length of %q", cl) w.handlerHeader.Del("Content-Length") } } @@ -707,6 +709,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { cw.wroteHeader = true w := cw.res + keepAlivesEnabled := w.conn.server.doKeepAlives() isHEAD := w.req.Method == "HEAD" // header is written out to w.conn.buf below. Depending on the @@ -739,7 +742,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { // response header and this is our first (and last) write, set // it, even to zero. This helps HTTP/1.0 clients keep their // "keep-alive" connections alive. - // Exceptions: 304 responses never get Content-Length, and if + // Exceptions: 304/204/1xx responses never get Content-Length, and if // it was a HEAD request, we don't know the difference between // 0 actual bytes and 0 bytes because the handler noticed it // was a HEAD request and chose not to write anything. So for @@ -747,14 +750,14 @@ func (cw *chunkWriter) writeHeader(p []byte) { // write non-zero bytes. If it's actually 0 bytes and the // handler never looked at the Request.Method, we just don't // send a Content-Length header. - if w.handlerDone && w.status != StatusNotModified && header.get("Content-Length") == "" && (!isHEAD || len(p) > 0) { + if w.handlerDone && bodyAllowedForStatus(w.status) && header.get("Content-Length") == "" && (!isHEAD || len(p) > 0) { w.contentLength = int64(len(p)) setHeader.contentLength = strconv.AppendInt(cw.res.clenBuf[:0], int64(len(p)), 10) } // If this was an HTTP/1.0 request with keep-alive and we sent a // Content-Length back, we can make this a keep-alive response ... - if w.req.wantsHttp10KeepAlive() { + if w.req.wantsHttp10KeepAlive() && keepAlivesEnabled { sentLength := header.get("Content-Length") != "" if sentLength && header.get("Connection") == "keep-alive" { w.closeAfterReply = false @@ -773,7 +776,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { w.closeAfterReply = true } - if header.get("Connection") == "close" { + if header.get("Connection") == "close" || !keepAlivesEnabled { w.closeAfterReply = true } @@ -796,18 +799,16 @@ func (cw *chunkWriter) writeHeader(p []byte) { } code := w.status - if code == StatusNotModified { - // Must not have body. - // RFC 2616 section 10.3.5: "the response MUST NOT include other entity-headers" - for _, k := range []string{"Content-Type", "Content-Length", "Transfer-Encoding"} { - delHeader(k) - } - } else { + if bodyAllowedForStatus(code) { // If no content type, apply sniffing algorithm to body. _, haveType := header["Content-Type"] if !haveType { setHeader.contentType = DetectContentType(p) } + } else { + for _, k := range suppressedHeaders(code) { + delHeader(k) + } } if _, ok := header["Date"]; !ok { @@ -819,13 +820,13 @@ func (cw *chunkWriter) writeHeader(p []byte) { if hasCL && hasTE && te != "identity" { // TODO: return an error if WriteHeader gets a return parameter // For now just ignore the Content-Length. - log.Printf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d", + w.conn.server.logf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d", te, w.contentLength) delHeader("Content-Length") hasCL = false } - if w.req.Method == "HEAD" || code == StatusNotModified { + if w.req.Method == "HEAD" || !bodyAllowedForStatus(code) { // do nothing } else if code == StatusNoContent { delHeader("Transfer-Encoding") @@ -855,7 +856,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { return } - if w.closeAfterReply && !hasToken(cw.header.get("Connection"), "close") { + if w.closeAfterReply && (!keepAlivesEnabled || !hasToken(cw.header.get("Connection"), "close")) { delHeader("Connection") if w.req.ProtoAtLeast(1, 1) { setHeader.connection = "close" @@ -919,7 +920,7 @@ func (w *response) bodyAllowed() bool { if !w.wroteHeader { panic("") } - return w.status != StatusNotModified + return bodyAllowedForStatus(w.status) } // The Life Of A Write is like this: @@ -965,7 +966,7 @@ func (w *response) WriteString(data string) (n int, err error) { // either dataB or dataS is non-zero. func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err error) { if w.conn.hijacked() { - log.Print("http: response.Write on hijacked connection") + w.conn.server.logf("http: response.Write on hijacked connection") return 0, ErrHijacked } if !w.wroteHeader { @@ -1001,11 +1002,10 @@ func (w *response) finishRequest() { w.cw.close() w.conn.buf.Flush() - // Close the body, unless we're about to close the whole TCP connection - // anyway. - if !w.closeAfterReply { - w.req.Body.Close() - } + // Close the body (regardless of w.closeAfterReply) so we can + // re-use its bufio.Reader later safely. + w.req.Body.Close() + if w.req.MultipartForm != nil { w.req.MultipartForm.RemoveAll() } @@ -1084,17 +1084,25 @@ func validNPN(proto string) bool { return true } +func (c *conn) setState(nc net.Conn, state ConnState) { + if hook := c.server.ConnState; hook != nil { + hook(nc, state) + } +} + // Serve a new connection. func (c *conn) serve() { + origConn := c.rwc // copy it before it's set nil on Close or Hijack defer func() { if err := recover(); err != nil { - const size = 4096 + const size = 64 << 10 buf := make([]byte, size) buf = buf[:runtime.Stack(buf, false)] - log.Printf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf) + c.server.logf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf) } if !c.hijacked() { c.close() + c.setState(origConn, StateClosed) } }() @@ -1106,6 +1114,7 @@ func (c *conn) serve() { c.rwc.SetWriteDeadline(time.Now().Add(d)) } if err := tlsConn.Handshake(); err != nil { + c.server.logf("http: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err) return } c.tlsState = new(tls.ConnectionState) @@ -1121,6 +1130,10 @@ func (c *conn) serve() { for { w, err := c.readRequest() + if c.lr.N != c.server.initialLimitedReaderSize() { + // If we read any bytes off the wire, we're active. + c.setState(c.rwc, StateActive) + } if err != nil { if err == errTooLarge { // Their HTTP client may or may not be @@ -1143,16 +1156,10 @@ func (c *conn) serve() { // Expect 100 Continue support req := w.req if req.expectsContinue() { - if req.ProtoAtLeast(1, 1) { + if req.ProtoAtLeast(1, 1) && req.ContentLength != 0 { // Wrap the Body reader with one that replies on the connection req.Body = &expectContinueReader{readCloser: req.Body, resp: w} } - if req.ContentLength == 0 { - w.Header().Set("Connection", "close") - w.WriteHeader(StatusBadRequest) - w.finishRequest() - break - } req.Header.Del("Expect") } else if req.Header.get("Expect") != "" { w.sendExpectationFailed() @@ -1175,6 +1182,7 @@ func (c *conn) serve() { } break } + c.setState(c.rwc, StateIdle) } } @@ -1202,7 +1210,14 @@ func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { if w.wroteHeader { w.cw.flush() } - return w.conn.hijack() + // Release the bufioWriter that writes to the chunk writer, it is not + // used after a connection has been hijacked. + rwc, buf, err = w.conn.hijack() + if err == nil { + putBufioWriter(w.w) + w.w = nil + } + return rwc, buf, err } func (w *response) CloseNotify() <-chan bool { @@ -1562,6 +1577,7 @@ func Serve(l net.Listener, handler Handler) error { } // A Server defines parameters for running an HTTP server. +// The zero value for Server is a valid configuration. type Server struct { Addr string // TCP address to listen on, ":http" if empty Handler Handler // handler to invoke, http.DefaultServeMux if nil @@ -1578,6 +1594,66 @@ type Server struct { // and RemoteAddr if not already set. The connection is // automatically closed when the function returns. TLSNextProto map[string]func(*Server, *tls.Conn, Handler) + + // ConnState specifies an optional callback function that is + // called when a client connection changes state. See the + // ConnState type and associated constants for details. + ConnState func(net.Conn, ConnState) + + // ErrorLog specifies an optional logger for errors accepting + // connections and unexpected behavior from handlers. + // If nil, logging goes to os.Stderr via the log package's + // standard logger. + ErrorLog *log.Logger + + disableKeepAlives int32 // accessed atomically. +} + +// A ConnState represents the state of a client connection to a server. +// It's used by the optional Server.ConnState hook. +type ConnState int + +const ( + // StateNew represents a new connection that is expected to + // send a request immediately. Connections begin at this + // state and then transition to either StateActive or + // StateClosed. + StateNew ConnState = iota + + // StateActive represents a connection that has read 1 or more + // bytes of a request. The Server.ConnState hook for + // StateActive fires before the request has entered a handler + // and doesn't fire again until the request has been + // handled. After the request is handled, the state + // transitions to StateClosed, StateHijacked, or StateIdle. + StateActive + + // StateIdle represents a connection that has finished + // handling a request and is in the keep-alive state, waiting + // for a new request. Connections transition from StateIdle + // to either StateActive or StateClosed. + StateIdle + + // StateHijacked represents a hijacked connection. + // This is a terminal state. It does not transition to StateClosed. + StateHijacked + + // StateClosed represents a closed connection. + // This is a terminal state. Hijacked connections do not + // transition to StateClosed. + StateClosed +) + +var stateName = map[ConnState]string{ + StateNew: "new", + StateActive: "active", + StateIdle: "idle", + StateHijacked: "hijacked", + StateClosed: "closed", +} + +func (c ConnState) String() string { + return stateName[c] } // serverHandler delegates to either the server's Handler or @@ -1605,11 +1681,11 @@ func (srv *Server) ListenAndServe() error { if addr == "" { addr = ":http" } - l, e := net.Listen("tcp", addr) - if e != nil { - return e + ln, err := net.Listen("tcp", addr) + if err != nil { + return err } - return srv.Serve(l) + return srv.Serve(tcpKeepAliveListener{ln.(*net.TCPListener)}) } // Serve accepts incoming connections on the Listener l, creating a @@ -1630,7 +1706,7 @@ func (srv *Server) Serve(l net.Listener) error { if max := 1 * time.Second; tempDelay > max { tempDelay = max } - log.Printf("http: Accept error: %v; retrying in %v", e, tempDelay) + srv.logf("http: Accept error: %v; retrying in %v", e, tempDelay) time.Sleep(tempDelay) continue } @@ -1641,10 +1717,35 @@ func (srv *Server) Serve(l net.Listener) error { if err != nil { continue } + c.setState(c.rwc, StateNew) // before Serve can return go c.serve() } } +func (s *Server) doKeepAlives() bool { + return atomic.LoadInt32(&s.disableKeepAlives) == 0 +} + +// SetKeepAlivesEnabled controls whether HTTP keep-alives are enabled. +// By default, keep-alives are always enabled. Only very +// resource-constrained environments or servers in the process of +// shutting down should disable them. +func (s *Server) SetKeepAlivesEnabled(v bool) { + if v { + atomic.StoreInt32(&s.disableKeepAlives, 0) + } else { + atomic.StoreInt32(&s.disableKeepAlives, 1) + } +} + +func (s *Server) logf(format string, args ...interface{}) { + if s.ErrorLog != nil { + s.ErrorLog.Printf(format, args...) + } else { + log.Printf(format, args...) + } +} + // ListenAndServe listens on the TCP network address addr // and then calls Serve with handler to handle requests // on incoming connections. Handler is typically nil, @@ -1739,12 +1840,12 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { return err } - conn, err := net.Listen("tcp", addr) + ln, err := net.Listen("tcp", addr) if err != nil { return err } - tlsListener := tls.NewListener(conn, config) + tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config) return srv.Serve(tlsListener) } @@ -1834,6 +1935,24 @@ func (tw *timeoutWriter) WriteHeader(code int) { tw.w.WriteHeader(code) } +// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted +// connections. It's used by ListenAndServe and ListenAndServeTLS so +// dead TCP connections (e.g. closing laptop mid-download) eventually +// go away. +type tcpKeepAliveListener struct { + *net.TCPListener +} + +func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := ln.AcceptTCP() + if err != nil { + return + } + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(3 * time.Minute) + return tc, nil +} + // globalOptionsHandler responds to "OPTIONS *" requests. type globalOptionsHandler struct{} @@ -1850,17 +1969,24 @@ func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) { } } +type eofReaderWithWriteTo struct{} + +func (eofReaderWithWriteTo) WriteTo(io.Writer) (int64, error) { return 0, nil } +func (eofReaderWithWriteTo) Read([]byte) (int, error) { return 0, io.EOF } + // eofReader is a non-nil io.ReadCloser that always returns EOF. -// It embeds a *strings.Reader so it still has a WriteTo method -// and io.Copy won't need a buffer. +// It has a WriteTo method so io.Copy won't need a buffer. var eofReader = &struct { - *strings.Reader + eofReaderWithWriteTo io.Closer }{ - strings.NewReader(""), + eofReaderWithWriteTo{}, ioutil.NopCloser(nil), } +// Verify that an io.Copy from an eofReader won't require a buffer. +var _ io.WriterTo = eofReader + // initNPNRequest is an HTTP handler that initializes certain // uninitialized fields in its *Request. Such partially-initialized // Requests come from NPN protocol handlers. diff --git a/src/pkg/net/http/transfer.go b/src/pkg/net/http/transfer.go index bacd83732..7f6368652 100644 --- a/src/pkg/net/http/transfer.go +++ b/src/pkg/net/http/transfer.go @@ -12,10 +12,20 @@ import ( "io" "io/ioutil" "net/textproto" + "sort" "strconv" "strings" + "sync" ) +type errorReader struct { + err error +} + +func (r *errorReader) Read(p []byte) (n int, err error) { + return 0, r.err +} + // transferWriter inspects the fields of a user-supplied Request or Response, // sanitizes them without changing the user object and provides methods for // writing the respective header, body and trailer in wire format. @@ -52,14 +62,17 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) { if t.ContentLength == 0 { // Test to see if it's actually zero or just unset. var buf [1]byte - n, _ := io.ReadFull(t.Body, buf[:]) - if n == 1 { + n, rerr := io.ReadFull(t.Body, buf[:]) + if rerr != nil && rerr != io.EOF { + t.ContentLength = -1 + t.Body = &errorReader{rerr} + } else if n == 1 { // Oh, guess there is data in this Body Reader after all. // The ContentLength field just wasn't set. // Stich the Body back together again, re-attaching our // consumed byte. t.ContentLength = -1 - t.Body = io.MultiReader(bytes.NewBuffer(buf[:]), t.Body) + t.Body = io.MultiReader(bytes.NewReader(buf[:]), t.Body) } else { // Body is actually empty. t.Body = nil @@ -131,11 +144,10 @@ func (t *transferWriter) shouldSendContentLength() bool { return false } -func (t *transferWriter) WriteHeader(w io.Writer) (err error) { +func (t *transferWriter) WriteHeader(w io.Writer) error { if t.Close { - _, err = io.WriteString(w, "Connection: close\r\n") - if err != nil { - return + if _, err := io.WriteString(w, "Connection: close\r\n"); err != nil { + return err } } @@ -143,43 +155,44 @@ func (t *transferWriter) WriteHeader(w io.Writer) (err error) { // function of the sanitized field triple (Body, ContentLength, // TransferEncoding) if t.shouldSendContentLength() { - io.WriteString(w, "Content-Length: ") - _, err = io.WriteString(w, strconv.FormatInt(t.ContentLength, 10)+"\r\n") - if err != nil { - return + if _, err := io.WriteString(w, "Content-Length: "); err != nil { + return err + } + if _, err := io.WriteString(w, strconv.FormatInt(t.ContentLength, 10)+"\r\n"); err != nil { + return err } } else if chunked(t.TransferEncoding) { - _, err = io.WriteString(w, "Transfer-Encoding: chunked\r\n") - if err != nil { - return + if _, err := io.WriteString(w, "Transfer-Encoding: chunked\r\n"); err != nil { + return err } } // Write Trailer header if t.Trailer != nil { - // TODO: At some point, there should be a generic mechanism for - // writing long headers, using HTTP line splitting - io.WriteString(w, "Trailer: ") - needComma := false + keys := make([]string, 0, len(t.Trailer)) for k := range t.Trailer { k = CanonicalHeaderKey(k) switch k { case "Transfer-Encoding", "Trailer", "Content-Length": return &badStringError{"invalid Trailer key", k} } - if needComma { - io.WriteString(w, ",") + keys = append(keys, k) + } + if len(keys) > 0 { + sort.Strings(keys) + // TODO: could do better allocation-wise here, but trailers are rare, + // so being lazy for now. + if _, err := io.WriteString(w, "Trailer: "+strings.Join(keys, ",")+"\r\n"); err != nil { + return err } - io.WriteString(w, k) - needComma = true } - _, err = io.WriteString(w, "\r\n") } - return + return nil } -func (t *transferWriter) WriteBody(w io.Writer) (err error) { +func (t *transferWriter) WriteBody(w io.Writer) error { + var err error var ncopy int64 // Write body @@ -216,11 +229,16 @@ func (t *transferWriter) WriteBody(w io.Writer) (err error) { // TODO(petar): Place trailer writer code here. if chunked(t.TransferEncoding) { + // Write Trailer header + if t.Trailer != nil { + if err := t.Trailer.Write(w); err != nil { + return err + } + } // Last chunk, empty trailer _, err = io.WriteString(w, "\r\n") } - - return + return err } type transferReader struct { @@ -252,6 +270,22 @@ func bodyAllowedForStatus(status int) bool { return true } +var ( + suppressedHeaders304 = []string{"Content-Type", "Content-Length", "Transfer-Encoding"} + suppressedHeadersNoBody = []string{"Content-Length", "Transfer-Encoding"} +) + +func suppressedHeaders(status int) []string { + switch { + case status == 304: + // RFC 2616 section 10.3.5: "the response MUST NOT include other entity-headers" + return suppressedHeaders304 + case !bodyAllowedForStatus(status): + return suppressedHeadersNoBody + } + return nil +} + // msg is *Request or *Response. func readTransfer(msg interface{}, r *bufio.Reader) (err error) { t := &transferReader{RequestMethod: "GET"} @@ -331,17 +365,17 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { if noBodyExpected(t.RequestMethod) { t.Body = eofReader } else { - t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close} + t.Body = &body{src: newChunkedReader(r), hdr: msg, r: r, closing: t.Close} } case realLength == 0: t.Body = eofReader case realLength > 0: - t.Body = &body{Reader: io.LimitReader(r, realLength), closing: t.Close} + t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close} default: // realLength < 0, i.e. "Content-Length" not mentioned in header if t.Close { // Close semantics (i.e. HTTP/1.0) - t.Body = &body{Reader: r, closing: t.Close} + t.Body = &body{src: r, closing: t.Close} } else { // Persistent connection (i.e. HTTP/1.1) t.Body = eofReader @@ -498,7 +532,7 @@ func fixTrailer(header Header, te []string) (Header, error) { case "Transfer-Encoding", "Trailer", "Content-Length": return nil, &badStringError{"bad trailer key", key} } - trailer.Del(key) + trailer[key] = nil } if len(trailer) == 0 { return nil, nil @@ -514,11 +548,13 @@ func fixTrailer(header Header, te []string) (Header, error) { // Close ensures that the body has been fully read // and then reads the trailer if necessary. type body struct { - io.Reader + src io.Reader hdr interface{} // non-nil (Response or Request) value means read trailer r *bufio.Reader // underlying wire-format reader for the trailer closing bool // is the connection to be closed after reading body? - closed bool + + mu sync.Mutex // guards closed, and calls to Read and Close + closed bool } // ErrBodyReadAfterClose is returned when reading a Request or Response @@ -528,10 +564,17 @@ type body struct { var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed Body") func (b *body) Read(p []byte) (n int, err error) { + b.mu.Lock() + defer b.mu.Unlock() if b.closed { return 0, ErrBodyReadAfterClose } - n, err = b.Reader.Read(p) + return b.readLocked(p) +} + +// Must hold b.mu. +func (b *body) readLocked(p []byte) (n int, err error) { + n, err = b.src.Read(p) if err == io.EOF { // Chunked case. Read the trailer. @@ -543,12 +586,23 @@ func (b *body) Read(p []byte) (n int, err error) { } else { // If the server declared the Content-Length, our body is a LimitedReader // and we need to check whether this EOF arrived early. - if lr, ok := b.Reader.(*io.LimitedReader); ok && lr.N > 0 { + if lr, ok := b.src.(*io.LimitedReader); ok && lr.N > 0 { err = io.ErrUnexpectedEOF } } } + // If we can return an EOF here along with the read data, do + // so. This is optional per the io.Reader contract, but doing + // so helps the HTTP transport code recycle its connection + // earlier (since it will see this EOF itself), even if the + // client doesn't do future reads or Close. + if err == nil && n > 0 { + if lr, ok := b.src.(*io.LimitedReader); ok && lr.N == 0 { + err = io.EOF + } + } + return n, err } @@ -610,14 +664,26 @@ func (b *body) readTrailer() error { } switch rr := b.hdr.(type) { case *Request: - rr.Trailer = Header(hdr) + mergeSetHeader(&rr.Trailer, Header(hdr)) case *Response: - rr.Trailer = Header(hdr) + mergeSetHeader(&rr.Trailer, Header(hdr)) } return nil } +func mergeSetHeader(dst *Header, src Header) { + if *dst == nil { + *dst = src + return + } + for k, vv := range src { + (*dst)[k] = vv + } +} + func (b *body) Close() error { + b.mu.Lock() + defer b.mu.Unlock() if b.closed { return nil } @@ -629,12 +695,25 @@ func (b *body) Close() error { default: // Fully consume the body, which will also lead to us reading // the trailer headers after the body, if present. - _, err = io.Copy(ioutil.Discard, b) + _, err = io.Copy(ioutil.Discard, bodyLocked{b}) } b.closed = true return err } +// bodyLocked is a io.Reader reading from a *body when its mutex is +// already held. +type bodyLocked struct { + b *body +} + +func (bl bodyLocked) Read(p []byte) (n int, err error) { + if bl.b.closed { + return 0, ErrBodyReadAfterClose + } + return bl.b.readLocked(p) +} + // parseContentLength trims whitespace from s and returns -1 if no value // is set, or the value if it's >= 0. func parseContentLength(cl string) (int64, error) { diff --git a/src/pkg/net/http/transfer_test.go b/src/pkg/net/http/transfer_test.go index 8627a374c..48cd540b9 100644 --- a/src/pkg/net/http/transfer_test.go +++ b/src/pkg/net/http/transfer_test.go @@ -6,15 +6,16 @@ package http import ( "bufio" + "io" "strings" "testing" ) func TestBodyReadBadTrailer(t *testing.T) { b := &body{ - Reader: strings.NewReader("foobar"), - hdr: true, // force reading the trailer - r: bufio.NewReader(strings.NewReader("")), + src: strings.NewReader("foobar"), + hdr: true, // force reading the trailer + r: bufio.NewReader(strings.NewReader("")), } buf := make([]byte, 7) n, err := b.Read(buf[:3]) @@ -35,3 +36,29 @@ func TestBodyReadBadTrailer(t *testing.T) { t.Errorf("final Read was successful (%q), expected error from trailer read", got) } } + +func TestFinalChunkedBodyReadEOF(t *testing.T) { + res, err := ReadResponse(bufio.NewReader(strings.NewReader( + "HTTP/1.1 200 OK\r\n"+ + "Transfer-Encoding: chunked\r\n"+ + "\r\n"+ + "0a\r\n"+ + "Body here\n\r\n"+ + "09\r\n"+ + "continued\r\n"+ + "0\r\n"+ + "\r\n")), nil) + if err != nil { + t.Fatal(err) + } + want := "Body here\ncontinued" + buf := make([]byte, len(want)) + n, err := res.Body.Read(buf) + if n != len(want) || err != io.EOF { + t.Logf("body = %#v", res.Body) + t.Errorf("Read = %v, %v; want %d, EOF", n, err, len(want)) + } + if string(buf) != want { + t.Errorf("buf = %q; want %q", buf, want) + } +} diff --git a/src/pkg/net/http/transport.go b/src/pkg/net/http/transport.go index f6871afac..b1cc632a7 100644 --- a/src/pkg/net/http/transport.go +++ b/src/pkg/net/http/transport.go @@ -30,7 +30,14 @@ import ( // and caches them for reuse by subsequent calls. It uses HTTP proxies // as directed by the $HTTP_PROXY and $NO_PROXY (or $http_proxy and // $no_proxy) environment variables. -var DefaultTransport RoundTripper = &Transport{Proxy: ProxyFromEnvironment} +var DefaultTransport RoundTripper = &Transport{ + Proxy: ProxyFromEnvironment, + Dial: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).Dial, + TLSHandshakeTimeout: 10 * time.Second, +} // DefaultMaxIdleConnsPerHost is the default value of Transport's // MaxIdleConnsPerHost. @@ -40,13 +47,13 @@ const DefaultMaxIdleConnsPerHost = 2 // https, and http proxies (for either http or https with CONNECT). // Transport can also cache connections for future re-use. type Transport struct { - idleMu sync.Mutex - idleConn map[string][]*persistConn - idleConnCh map[string]chan *persistConn - reqMu sync.Mutex - reqConn map[*Request]*persistConn - altMu sync.RWMutex - altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper + idleMu sync.Mutex + idleConn map[connectMethodKey][]*persistConn + idleConnCh map[connectMethodKey]chan *persistConn + reqMu sync.Mutex + reqCanceler map[*Request]func() + altMu sync.RWMutex + altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the @@ -63,6 +70,10 @@ type Transport struct { // tls.Client. If nil, the default configuration is used. TLSClientConfig *tls.Config + // TLSHandshakeTimeout specifies the maximum amount of time waiting to + // wait for a TLS handshake. Zero means no timeout. + TLSHandshakeTimeout time.Duration + // DisableKeepAlives, if true, prevents re-use of TCP connections // between different HTTP requests. DisableKeepAlives bool @@ -98,8 +109,11 @@ type Transport struct { // An error is returned if the proxy environment is invalid. // A nil URL and nil error are returned if no proxy is defined in the // environment, or a proxy should not be used for the given request. +// +// As a special case, if req.URL.Host is "localhost" (with or without +// a port number), then a nil URL and nil error will be returned. func ProxyFromEnvironment(req *Request) (*url.URL, error) { - proxy := getenvEitherCase("HTTP_PROXY") + proxy := httpProxyEnv.Get() if proxy == "" { return nil, nil } @@ -149,9 +163,11 @@ func (tr *transportRequest) extraHeaders() Header { // and redirects), see Get, Post, and the Client type. func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) { if req.URL == nil { + req.closeBody() return nil, errors.New("http: nil Request.URL") } if req.Header == nil { + req.closeBody() return nil, errors.New("http: nil Request.Header") } if req.URL.Scheme != "http" && req.URL.Scheme != "https" { @@ -162,16 +178,19 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) { } t.altMu.RUnlock() if rt == nil { + req.closeBody() return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme} } return rt.RoundTrip(req) } if req.URL.Host == "" { + req.closeBody() return nil, errors.New("http: no Host in request URL") } treq := &transportRequest{Request: req} cm, err := t.connectMethodForRequest(treq) if err != nil { + req.closeBody() return nil, err } @@ -179,8 +198,10 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) { // host (for http or https), the http proxy, or the http proxy // pre-CONNECTed to https server. In any case, we'll be ready // to send it requests. - pconn, err := t.getConn(cm) + pconn, err := t.getConn(req, cm) if err != nil { + t.setReqCanceler(req, nil) + req.closeBody() return nil, err } @@ -218,9 +239,6 @@ func (t *Transport) CloseIdleConnections() { t.idleConn = nil t.idleConnCh = nil t.idleMu.Unlock() - if m == nil { - return - } for _, conns := range m { for _, pconn := range conns { pconn.close() @@ -232,10 +250,10 @@ func (t *Transport) CloseIdleConnections() { // connection. func (t *Transport) CancelRequest(req *Request) { t.reqMu.Lock() - pc := t.reqConn[req] + cancel := t.reqCanceler[req] t.reqMu.Unlock() - if pc != nil { - pc.conn.Close() + if cancel != nil { + cancel() } } @@ -243,24 +261,49 @@ func (t *Transport) CancelRequest(req *Request) { // Private implementation past this point. // -func getenvEitherCase(k string) string { - if v := os.Getenv(strings.ToUpper(k)); v != "" { - return v +var ( + httpProxyEnv = &envOnce{ + names: []string{"HTTP_PROXY", "http_proxy"}, } - return os.Getenv(strings.ToLower(k)) + noProxyEnv = &envOnce{ + names: []string{"NO_PROXY", "no_proxy"}, + } +) + +// envOnce looks up an environment variable (optionally by multiple +// names) once. It mitigates expensive lookups on some platforms +// (e.g. Windows). +type envOnce struct { + names []string + once sync.Once + val string +} + +func (e *envOnce) Get() string { + e.once.Do(e.init) + return e.val } -func (t *Transport) connectMethodForRequest(treq *transportRequest) (*connectMethod, error) { - cm := &connectMethod{ - targetScheme: treq.URL.Scheme, - targetAddr: canonicalAddr(treq.URL), +func (e *envOnce) init() { + for _, n := range e.names { + e.val = os.Getenv(n) + if e.val != "" { + return + } } +} + +// reset is used by tests +func (e *envOnce) reset() { + e.once = sync.Once{} + e.val = "" +} + +func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) { + cm.targetScheme = treq.URL.Scheme + cm.targetAddr = canonicalAddr(treq.URL) if t.Proxy != nil { - var err error cm.proxyURL, err = t.Proxy(treq.Request) - if err != nil { - return nil, err - } } return cm, nil } @@ -316,7 +359,7 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { } } if t.idleConn == nil { - t.idleConn = make(map[string][]*persistConn) + t.idleConn = make(map[connectMethodKey][]*persistConn) } if len(t.idleConn[key]) >= max { t.idleMu.Unlock() @@ -336,7 +379,7 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { // getIdleConnCh returns a channel to receive and return idle // persistent connection for the given connectMethod. // It may return nil, if persistent connections are not being used. -func (t *Transport) getIdleConnCh(cm *connectMethod) chan *persistConn { +func (t *Transport) getIdleConnCh(cm connectMethod) chan *persistConn { if t.DisableKeepAlives { return nil } @@ -344,7 +387,7 @@ func (t *Transport) getIdleConnCh(cm *connectMethod) chan *persistConn { t.idleMu.Lock() defer t.idleMu.Unlock() if t.idleConnCh == nil { - t.idleConnCh = make(map[string]chan *persistConn) + t.idleConnCh = make(map[connectMethodKey]chan *persistConn) } ch, ok := t.idleConnCh[key] if !ok { @@ -354,7 +397,7 @@ func (t *Transport) getIdleConnCh(cm *connectMethod) chan *persistConn { return ch } -func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { +func (t *Transport) getIdleConn(cm connectMethod) (pconn *persistConn) { key := cm.key() t.idleMu.Lock() defer t.idleMu.Unlock() @@ -373,7 +416,7 @@ func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { // 2 or more cached connections; pop last // TODO: queue? pconn = pconns[len(pconns)-1] - t.idleConn[key] = pconns[0 : len(pconns)-1] + t.idleConn[key] = pconns[:len(pconns)-1] } if !pconn.isBroken() { return @@ -381,16 +424,16 @@ func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { } } -func (t *Transport) setReqConn(r *Request, pc *persistConn) { +func (t *Transport) setReqCanceler(r *Request, fn func()) { t.reqMu.Lock() defer t.reqMu.Unlock() - if t.reqConn == nil { - t.reqConn = make(map[*Request]*persistConn) + if t.reqCanceler == nil { + t.reqCanceler = make(map[*Request]func()) } - if pc != nil { - t.reqConn[r] = pc + if fn != nil { + t.reqCanceler[r] = fn } else { - delete(t.reqConn, r) + delete(t.reqCanceler, r) } } @@ -405,7 +448,7 @@ func (t *Transport) dial(network, addr string) (c net.Conn, err error) { // specified in the connectMethod. This includes doing a proxy CONNECT // and/or setting up TLS. If this doesn't return an error, the persistConn // is ready to write requests to. -func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { +func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error) { if pc := t.getIdleConn(cm); pc != nil { return pc, nil } @@ -415,6 +458,16 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { err error } dialc := make(chan dialRes) + + handlePendingDial := func() { + if v := <-dialc; v.err == nil { + t.putIdleConn(v.pc) + } + } + + cancelc := make(chan struct{}) + t.setReqCanceler(req, func() { close(cancelc) }) + go func() { pc, err := t.dialConn(cm) dialc <- dialRes{pc, err} @@ -431,16 +484,15 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { // else's dial that they didn't use. // But our dial is still going, so give it away // when it finishes: - go func() { - if v := <-dialc; v.err == nil { - t.putIdleConn(v.pc) - } - }() + go handlePendingDial() return pc, nil + case <-cancelc: + go handlePendingDial() + return nil, errors.New("net/http: request canceled while waiting for connection") } } -func (t *Transport) dialConn(cm *connectMethod) (*persistConn, error) { +func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) { conn, err := t.dial("tcp", cm.addr()) if err != nil { if cm.proxyURL != nil { @@ -452,12 +504,13 @@ func (t *Transport) dialConn(cm *connectMethod) (*persistConn, error) { pa := cm.proxyAuth() pconn := &persistConn{ - t: t, - cacheKey: cm.key(), - conn: conn, - reqch: make(chan requestAndChan, 50), - writech: make(chan writeRequest, 50), - closech: make(chan struct{}), + t: t, + cacheKey: cm.key(), + conn: conn, + reqch: make(chan requestAndChan, 1), + writech: make(chan writeRequest, 1), + closech: make(chan struct{}), + writeErrCh: make(chan error, 1), } switch { @@ -511,19 +564,38 @@ func (t *Transport) dialConn(cm *connectMethod) (*persistConn, error) { cfg = &clone } } - conn = tls.Client(conn, cfg) - if err = conn.(*tls.Conn).Handshake(); err != nil { + plainConn := conn + tlsConn := tls.Client(plainConn, cfg) + errc := make(chan error, 2) + var timer *time.Timer // for canceling TLS handshake + if d := t.TLSHandshakeTimeout; d != 0 { + timer = time.AfterFunc(d, func() { + errc <- tlsHandshakeTimeoutError{} + }) + } + go func() { + err := tlsConn.Handshake() + if timer != nil { + timer.Stop() + } + errc <- err + }() + if err := <-errc; err != nil { + plainConn.Close() return nil, err } if !cfg.InsecureSkipVerify { - if err = conn.(*tls.Conn).VerifyHostname(cfg.ServerName); err != nil { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + plainConn.Close() return nil, err } } - pconn.conn = conn + cs := tlsConn.ConnectionState() + pconn.tlsState = &cs + pconn.conn = tlsConn } - pconn.br = bufio.NewReader(pconn.conn) + pconn.br = bufio.NewReader(noteEOFReader{pconn.conn, &pconn.sawEOF}) pconn.bw = bufio.NewWriter(pconn.conn) go pconn.readLoop() go pconn.writeLoop() @@ -550,7 +622,7 @@ func useProxy(addr string) bool { } } - no_proxy := getenvEitherCase("NO_PROXY") + no_proxy := noProxyEnv.Get() if no_proxy == "*" { return false } @@ -590,8 +662,8 @@ func useProxy(addr string) bool { // // Cache key form Description // ----------------- ------------------------- -// ||http|foo.com http directly to server, no proxy -// ||https|foo.com https directly to server, no proxy +// |http|foo.com http directly to server, no proxy +// |https|foo.com https directly to server, no proxy // http://proxy.com|https|foo.com http to proxy, then CONNECT to foo.com // http://proxy.com|http http to proxy, http to anywhere after that // @@ -603,20 +675,20 @@ type connectMethod struct { targetAddr string // Not used if proxy + http targetScheme (4th example in table) } -func (ck *connectMethod) key() string { - return ck.String() // TODO: use a struct type instead -} - -func (ck *connectMethod) String() string { +func (cm *connectMethod) key() connectMethodKey { proxyStr := "" - targetAddr := ck.targetAddr - if ck.proxyURL != nil { - proxyStr = ck.proxyURL.String() - if ck.targetScheme == "http" { + targetAddr := cm.targetAddr + if cm.proxyURL != nil { + proxyStr = cm.proxyURL.String() + if cm.targetScheme == "http" { targetAddr = "" } } - return strings.Join([]string{proxyStr, ck.targetScheme, targetAddr}, "|") + return connectMethodKey{ + proxy: proxyStr, + scheme: cm.targetScheme, + addr: targetAddr, + } } // addr returns the first hop "host:port" to which we need to TCP connect. @@ -637,22 +709,41 @@ func (cm *connectMethod) tlsHost() string { return h } +// connectMethodKey is the map key version of connectMethod, with a +// stringified proxy URL (or the empty string) instead of a pointer to +// a URL. +type connectMethodKey struct { + proxy, scheme, addr string +} + +func (k connectMethodKey) String() string { + // Only used by tests. + return fmt.Sprintf("%s|%s|%s", k.proxy, k.scheme, k.addr) +} + // persistConn wraps a connection, usually a persistent one // (but may be used for non-keep-alive requests as well) type persistConn struct { t *Transport - cacheKey string // its connectMethod.String() + cacheKey connectMethodKey conn net.Conn - closed bool // whether conn has been closed + tlsState *tls.ConnectionState br *bufio.Reader // from conn + sawEOF bool // whether we've seen EOF from conn; owned by readLoop bw *bufio.Writer // to conn reqch chan requestAndChan // written by roundTrip; read by readLoop writech chan writeRequest // written by roundTrip; read by writeLoop - closech chan struct{} // broadcast close when readLoop (TCP connection) closes + closech chan struct{} // closed when conn closed isProxy bool + // writeErrCh passes the request write error (usually nil) + // from the writeLoop goroutine to the readLoop which passes + // it off to the res.Body reader, which then uses it to decide + // whether or not a connection can be reused. Issue 7569. + writeErrCh chan error - lk sync.Mutex // guards following 3 fields + lk sync.Mutex // guards following fields numExpectedResponses int + closed bool // whether conn has been closed broken bool // an error has happened on this connection; marked broken so it's not reused. // mutateHeaderFunc is an optional func to modify extra // headers on each outbound request before it's written. (the @@ -660,6 +751,7 @@ type persistConn struct { mutateHeaderFunc func(Header) } +// isBroken reports whether this connection is in a known broken state. func (pc *persistConn) isBroken() bool { pc.lk.Lock() b := pc.broken @@ -667,6 +759,10 @@ func (pc *persistConn) isBroken() bool { return b } +func (pc *persistConn) cancelRequest() { + pc.conn.Close() +} + var remoteSideClosedFunc func(error) bool // or nil to use default func remoteSideClosed(err error) bool { @@ -680,7 +776,6 @@ func remoteSideClosed(err error) bool { } func (pc *persistConn) readLoop() { - defer close(pc.closech) alive := true for alive { @@ -688,12 +783,14 @@ func (pc *persistConn) readLoop() { pc.lk.Lock() if pc.numExpectedResponses == 0 { - pc.closeLocked() - pc.lk.Unlock() - if len(pb) > 0 { - log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v", - string(pb), err) + if !pc.closed { + pc.closeLocked() + if len(pb) > 0 { + log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v", + string(pb), err) + } } + pc.lk.Unlock() return } pc.lk.Unlock() @@ -712,6 +809,11 @@ func (pc *persistConn) readLoop() { resp, err = ReadResponse(pc.br, rc.req) } } + + if resp != nil { + resp.TLS = pc.tlsState + } + hasBody := resp != nil && rc.req.Method != "HEAD" && resp.ContentLength != 0 if err != nil { @@ -721,13 +823,7 @@ func (pc *persistConn) readLoop() { resp.Header.Del("Content-Encoding") resp.Header.Del("Content-Length") resp.ContentLength = -1 - gzReader, zerr := gzip.NewReader(resp.Body) - if zerr != nil { - pc.close() - err = zerr - } else { - resp.Body = &readerAndCloser{gzReader, resp.Body} - } + resp.Body = &gzipReader{body: resp.Body} } resp.Body = &bodyEOFSignal{body: resp.Body} } @@ -750,24 +846,18 @@ func (pc *persistConn) readLoop() { return nil } resp.Body.(*bodyEOFSignal).fn = func(err error) { - alive1 := alive - if err != nil { - alive1 = false - } - if alive1 && !pc.t.putIdleConn(pc) { - alive1 = false - } - if !alive1 || pc.isBroken() { - pc.close() - } - waitForBodyRead <- alive1 + waitForBodyRead <- alive && + err == nil && + !pc.sawEOF && + pc.wroteRequest() && + pc.t.putIdleConn(pc) } } if alive && !hasBody { - if !pc.t.putIdleConn(pc) { - alive = false - } + alive = !pc.sawEOF && + pc.wroteRequest() && + pc.t.putIdleConn(pc) } rc.ch <- responseAndError{resp, err} @@ -775,10 +865,14 @@ func (pc *persistConn) readLoop() { // Wait for the just-returned response body to be fully consumed // before we race and peek on the underlying bufio reader. if waitForBodyRead != nil { - alive = <-waitForBodyRead + select { + case alive = <-waitForBodyRead: + case <-pc.closech: + alive = false + } } - pc.t.setReqConn(rc.req, nil) + pc.t.setReqCanceler(rc.req, nil) if !alive { pc.close() @@ -800,14 +894,44 @@ func (pc *persistConn) writeLoop() { } if err != nil { pc.markBroken() + wr.req.Request.closeBody() } - wr.ch <- err + pc.writeErrCh <- err // to the body reader, which might recycle us + wr.ch <- err // to the roundTrip function case <-pc.closech: return } } } +// wroteRequest is a check before recycling a connection that the previous write +// (from writeLoop above) happened and was successful. +func (pc *persistConn) wroteRequest() bool { + select { + case err := <-pc.writeErrCh: + // Common case: the write happened well before the response, so + // avoid creating a timer. + return err == nil + default: + // Rare case: the request was written in writeLoop above but + // before it could send to pc.writeErrCh, the reader read it + // all, processed it, and called us here. In this case, give the + // write goroutine a bit of time to finish its send. + // + // Less rare case: We also get here in the legitimate case of + // Issue 7569, where the writer is still writing (or stalled), + // but the server has already replied. In this case, we don't + // want to wait too long, and we want to return false so this + // connection isn't re-used. + select { + case err := <-pc.writeErrCh: + return err == nil + case <-time.After(50 * time.Millisecond): + return false + } + } +} + type responseAndError struct { res *Response err error @@ -832,8 +956,20 @@ type writeRequest struct { ch chan<- error } +type httpError struct { + err string + timeout bool +} + +func (e *httpError) Error() string { return e.err } +func (e *httpError) Timeout() bool { return e.timeout } +func (e *httpError) Temporary() bool { return true } + +var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true} +var errClosed error = &httpError{err: "net/http: transport closed before response was received"} + func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { - pc.t.setReqConn(req.Request, pc) + pc.t.setReqCanceler(req.Request, pc.cancelRequest) pc.lk.Lock() pc.numExpectedResponses++ headerFn := pc.mutateHeaderFunc @@ -902,11 +1038,11 @@ WaitResponse: pconnDeadCh = nil // avoid spinning failTicker = time.After(100 * time.Millisecond) // arbitrary time to wait for resc case <-failTicker: - re = responseAndError{err: errors.New("net/http: transport closed before response was received")} + re = responseAndError{err: errClosed} break WaitResponse case <-respHeaderTimer: pc.close() - re = responseAndError{err: errors.New("net/http: timeout awaiting response headers")} + re = responseAndError{err: errTimeout} break WaitResponse case re = <-resc: break WaitResponse @@ -918,7 +1054,7 @@ WaitResponse: pc.lk.Unlock() if re.err != nil { - pc.t.setReqConn(req.Request, nil) + pc.t.setReqCanceler(req.Request, nil) } return re.res, re.err } @@ -943,6 +1079,7 @@ func (pc *persistConn) closeLocked() { if !pc.closed { pc.conn.Close() pc.closed = true + close(pc.closech) } pc.mutateHeaderFunc = nil } @@ -1025,7 +1162,47 @@ func (es *bodyEOFSignal) condfn(err error) { es.fn = nil } +// gzipReader wraps a response body so it can lazily +// call gzip.NewReader on the first call to Read +type gzipReader struct { + body io.ReadCloser // underlying Response.Body + zr io.Reader // lazily-initialized gzip reader +} + +func (gz *gzipReader) Read(p []byte) (n int, err error) { + if gz.zr == nil { + gz.zr, err = gzip.NewReader(gz.body) + if err != nil { + return 0, err + } + } + return gz.zr.Read(p) +} + +func (gz *gzipReader) Close() error { + return gz.body.Close() +} + type readerAndCloser struct { io.Reader io.Closer } + +type tlsHandshakeTimeoutError struct{} + +func (tlsHandshakeTimeoutError) Timeout() bool { return true } +func (tlsHandshakeTimeoutError) Temporary() bool { return true } +func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" } + +type noteEOFReader struct { + r io.Reader + sawEOF *bool +} + +func (nr noteEOFReader) Read(p []byte) (n int, err error) { + n, err = nr.r.Read(p) + if err == io.EOF { + *nr.sawEOF = true + } + return +} diff --git a/src/pkg/net/http/transport_test.go b/src/pkg/net/http/transport_test.go index e4df30a98..964ca0fca 100644 --- a/src/pkg/net/http/transport_test.go +++ b/src/pkg/net/http/transport_test.go @@ -11,9 +11,12 @@ import ( "bytes" "compress/gzip" "crypto/rand" + "crypto/tls" + "errors" "fmt" "io" "io/ioutil" + "log" "net" "net/http" . "net/http" @@ -54,21 +57,21 @@ func (c *testCloseConn) Close() error { // been closed. type testConnSet struct { t *testing.T + mu sync.Mutex // guards closed and list closed map[net.Conn]bool list []net.Conn // in order created - mutex sync.Mutex } func (tcs *testConnSet) insert(c net.Conn) { - tcs.mutex.Lock() - defer tcs.mutex.Unlock() + tcs.mu.Lock() + defer tcs.mu.Unlock() tcs.closed[c] = false tcs.list = append(tcs.list, c) } func (tcs *testConnSet) remove(c net.Conn) { - tcs.mutex.Lock() - defer tcs.mutex.Unlock() + tcs.mu.Lock() + defer tcs.mu.Unlock() tcs.closed[c] = true } @@ -91,11 +94,19 @@ func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, e } func (tcs *testConnSet) check(t *testing.T) { - tcs.mutex.Lock() - defer tcs.mutex.Unlock() - - for i, c := range tcs.list { - if !tcs.closed[c] { + tcs.mu.Lock() + defer tcs.mu.Unlock() + for i := 4; i >= 0; i-- { + for i, c := range tcs.list { + if tcs.closed[c] { + continue + } + if i != 0 { + tcs.mu.Unlock() + time.Sleep(50 * time.Millisecond) + tcs.mu.Lock() + continue + } t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list)) } } @@ -271,6 +282,58 @@ func TestTransportIdleCacheKeys(t *testing.T) { } } +// Tests that the HTTP transport re-uses connections when a client +// reads to the end of a response Body without closing it. +func TestTransportReadToEndReusesConn(t *testing.T) { + defer afterTest(t) + const msg = "foobar" + + var addrSeen map[string]int + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + addrSeen[r.RemoteAddr]++ + if r.URL.Path == "/chunked/" { + w.WriteHeader(200) + w.(http.Flusher).Flush() + } else { + w.Header().Set("Content-Type", strconv.Itoa(len(msg))) + w.WriteHeader(200) + } + w.Write([]byte(msg)) + })) + defer ts.Close() + + buf := make([]byte, len(msg)) + + for pi, path := range []string{"/content-length/", "/chunked/"} { + wantLen := []int{len(msg), -1}[pi] + addrSeen = make(map[string]int) + for i := 0; i < 3; i++ { + res, err := http.Get(ts.URL + path) + if err != nil { + t.Errorf("Get %s: %v", path, err) + continue + } + // We want to close this body eventually (before the + // defer afterTest at top runs), but not before the + // len(addrSeen) check at the bottom of this test, + // since Closing this early in the loop would risk + // making connections be re-used for the wrong reason. + defer res.Body.Close() + + if res.ContentLength != int64(wantLen) { + t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen) + } + n, err := res.Body.Read(buf) + if n != len(msg) || err != io.EOF { + t.Errorf("%s Read = %v, %v; want %d, EOF", path, n, err, len(msg)) + } + } + if len(addrSeen) != 1 { + t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen)) + } + } +} + func TestTransportMaxPerHostIdleConns(t *testing.T) { defer afterTest(t) resch := make(chan string) @@ -295,10 +358,11 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { resp, err := c.Get(ts.URL) if err != nil { t.Error(err) + return } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatalf("ReadAll: %v", err) + if _, err := ioutil.ReadAll(resp.Body); err != nil { + t.Errorf("ReadAll: %v", err) + return } donech <- true } @@ -739,8 +803,38 @@ func TestTransportGzipRecursive(t *testing.T) { } } +// golang.org/issue/7750: request fails when server replies with +// a short gzip body +func TestTransportGzipShort(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Encoding", "gzip") + w.Write([]byte{0x1f, 0x8b}) + })) + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + _, err = ioutil.ReadAll(res.Body) + if err == nil { + t.Fatal("Expect an error from reading a body.") + } + if err != io.ErrUnexpectedEOF { + t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err) + } +} + // tests that persistent goroutine connections shut down when no longer desired. func TestTransportPersistConnLeak(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7237") + } defer afterTest(t) gotReqCh := make(chan bool) unblockCh := make(chan bool) @@ -798,8 +892,8 @@ func TestTransportPersistConnLeak(t *testing.T) { // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. // Previously we were leaking one per numReq. - t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) if int(growth) > 5 { + t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) t.Error("too many new goroutines") } } @@ -807,6 +901,9 @@ func TestTransportPersistConnLeak(t *testing.T) { // golang.org/issue/4531: Transport leaks goroutines when // request.ContentLength is explicitly short func TestTransportPersistConnLeakShortBody(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7237") + } defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { })) @@ -1014,6 +1111,9 @@ func TestTransportConcurrency(t *testing.T) { } func TestIssue4191_InfiniteGetTimeout(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7237") + } defer afterTest(t) const debug = false mux := NewServeMux() @@ -1075,6 +1175,9 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { } func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7237") + } defer afterTest(t) const debug = false mux := NewServeMux() @@ -1147,9 +1250,13 @@ func TestTransportResponseHeaderTimeout(t *testing.T) { if testing.Short() { t.Skip("skipping timeout test in -short mode") } + inHandler := make(chan bool, 1) mux := NewServeMux() - mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {}) + mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) { + inHandler <- true + }) mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) { + inHandler <- true time.Sleep(2 * time.Second) }) ts := httptest.NewServer(mux) @@ -1172,7 +1279,27 @@ func TestTransportResponseHeaderTimeout(t *testing.T) { } for i, tt := range tests { res, err := c.Get(ts.URL + tt.path) + select { + case <-inHandler: + case <-time.After(5 * time.Second): + t.Errorf("never entered handler for test index %d, %s", i, tt.path) + continue + } if err != nil { + uerr, ok := err.(*url.Error) + if !ok { + t.Errorf("error is not an url.Error; got: %#v", err) + continue + } + nerr, ok := uerr.Err.(net.Error) + if !ok { + t.Errorf("error does not satisfy net.Error interface; got: %#v", err) + continue + } + if !nerr.Timeout() { + t.Errorf("want timeout error; got: %q", nerr) + continue + } if strings.Contains(err.Error(), tt.wantErr) { continue } @@ -1243,6 +1370,60 @@ func TestTransportCancelRequest(t *testing.T) { } } +func TestTransportCancelRequestInDial(t *testing.T) { + defer afterTest(t) + if testing.Short() { + t.Skip("skipping test in -short mode") + } + var logbuf bytes.Buffer + eventLog := log.New(&logbuf, "", 0) + + unblockDial := make(chan bool) + defer close(unblockDial) + + inDial := make(chan bool) + tr := &Transport{ + Dial: func(network, addr string) (net.Conn, error) { + eventLog.Println("dial: blocking") + inDial <- true + <-unblockDial + return nil, errors.New("nope") + }, + } + cl := &Client{Transport: tr} + gotres := make(chan bool) + req, _ := NewRequest("GET", "http://something.no-network.tld/", nil) + go func() { + _, err := cl.Do(req) + eventLog.Printf("Get = %v", err) + gotres <- true + }() + + select { + case <-inDial: + case <-time.After(5 * time.Second): + t.Fatal("timeout; never saw blocking dial") + } + + eventLog.Printf("canceling") + tr.CancelRequest(req) + + select { + case <-gotres: + case <-time.After(5 * time.Second): + panic("hang. events are: " + logbuf.String()) + } + + got := logbuf.String() + want := `dial: blocking +canceling +Get = Get http://something.no-network.tld/: net/http: request canceled while waiting for connection +` + if got != want { + t.Errorf("Got events:\n%s\nWant:\n%s", got, want) + } +} + // golang.org/issue/3672 -- Client can't close HTTP stream // Calling Close on a Response.Body used to just read until EOF. // Now it actually closes the TCP connection. @@ -1283,7 +1464,7 @@ func TestTransportCloseResponseBody(t *testing.T) { t.Fatal(err) } if !bytes.Equal(buf, want) { - t.Errorf("read %q; want %q", buf, want) + t.Fatalf("read %q; want %q", buf, want) } didClose := make(chan error, 1) go func() { @@ -1372,8 +1553,10 @@ func TestTransportSocketLateBinding(t *testing.T) { dialGate := make(chan bool, 1) tr := &Transport{ Dial: func(n, addr string) (net.Conn, error) { - <-dialGate - return net.Dial(n, addr) + if <-dialGate { + return net.Dial(n, addr) + } + return nil, errors.New("manually closed") }, DisableKeepAlives: false, } @@ -1408,7 +1591,7 @@ func TestTransportSocketLateBinding(t *testing.T) { t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr) } barRes.Body.Close() - dialGate <- true + dialGate <- false } // Issue 2184 @@ -1559,13 +1742,11 @@ var proxyFromEnvTests = []proxyFromEnvTest{ } func TestProxyFromEnvironment(t *testing.T) { - os.Setenv("HTTP_PROXY", "") - os.Setenv("http_proxy", "") - os.Setenv("NO_PROXY", "") - os.Setenv("no_proxy", "") + ResetProxyEnv() for _, tt := range proxyFromEnvTests { os.Setenv("HTTP_PROXY", tt.env) os.Setenv("NO_PROXY", tt.noenv) + ResetCachedEnvironment() reqURL := tt.req if reqURL == "" { reqURL = "http://example.com" @@ -1643,6 +1824,308 @@ func TestTransportClosesRequestBody(t *testing.T) { } } +func TestTransportTLSHandshakeTimeout(t *testing.T) { + defer afterTest(t) + if testing.Short() { + t.Skip("skipping in short mode") + } + ln := newLocalListener(t) + defer ln.Close() + testdonec := make(chan struct{}) + defer close(testdonec) + + go func() { + c, err := ln.Accept() + if err != nil { + t.Error(err) + return + } + <-testdonec + c.Close() + }() + + getdonec := make(chan struct{}) + go func() { + defer close(getdonec) + tr := &Transport{ + Dial: func(_, _ string) (net.Conn, error) { + return net.Dial("tcp", ln.Addr().String()) + }, + TLSHandshakeTimeout: 250 * time.Millisecond, + } + cl := &Client{Transport: tr} + _, err := cl.Get("https://dummy.tld/") + if err == nil { + t.Error("expected error") + return + } + ue, ok := err.(*url.Error) + if !ok { + t.Errorf("expected url.Error; got %#v", err) + return + } + ne, ok := ue.Err.(net.Error) + if !ok { + t.Errorf("expected net.Error; got %#v", err) + return + } + if !ne.Timeout() { + t.Errorf("expected timeout error; got %v", err) + } + if !strings.Contains(err.Error(), "handshake timeout") { + t.Errorf("expected 'handshake timeout' in error; got %v", err) + } + }() + select { + case <-getdonec: + case <-time.After(5 * time.Second): + t.Error("test timeout; TLS handshake hung?") + } +} + +// Trying to repro golang.org/issue/3514 +func TestTLSServerClosesConnection(t *testing.T) { + defer afterTest(t) + if runtime.GOOS == "windows" { + t.Skip("skipping flaky test on Windows; golang.org/issue/7634") + } + closedc := make(chan bool, 1) + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if strings.Contains(r.URL.Path, "/keep-alive-then-die") { + conn, _, _ := w.(Hijacker).Hijack() + conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) + conn.Close() + closedc <- true + return + } + fmt.Fprintf(w, "hello") + })) + defer ts.Close() + tr := &Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + } + defer tr.CloseIdleConnections() + client := &Client{Transport: tr} + + var nSuccess = 0 + var errs []error + const trials = 20 + for i := 0; i < trials; i++ { + tr.CloseIdleConnections() + res, err := client.Get(ts.URL + "/keep-alive-then-die") + if err != nil { + t.Fatal(err) + } + <-closedc + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(slurp) != "foo" { + t.Errorf("Got %q, want foo", slurp) + } + + // Now try again and see if we successfully + // pick a new connection. + res, err = client.Get(ts.URL + "/") + if err != nil { + errs = append(errs, err) + continue + } + slurp, err = ioutil.ReadAll(res.Body) + if err != nil { + errs = append(errs, err) + continue + } + nSuccess++ + } + if nSuccess > 0 { + t.Logf("successes = %d of %d", nSuccess, trials) + } else { + t.Errorf("All runs failed:") + } + for _, err := range errs { + t.Logf(" err: %v", err) + } +} + +// byteFromChanReader is an io.Reader that reads a single byte at a +// time from the channel. When the channel is closed, the reader +// returns io.EOF. +type byteFromChanReader chan byte + +func (c byteFromChanReader) Read(p []byte) (n int, err error) { + if len(p) == 0 { + return + } + b, ok := <-c + if !ok { + return 0, io.EOF + } + p[0] = b + return 1, nil +} + +// Verifies that the Transport doesn't reuse a connection in the case +// where the server replies before the request has been fully +// written. We still honor that reply (see TestIssue3595), but don't +// send future requests on the connection because it's then in a +// questionable state. +// golang.org/issue/7569 +func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { + defer afterTest(t) + var sconn struct { + sync.Mutex + c net.Conn + } + var getOkay bool + closeConn := func() { + sconn.Lock() + defer sconn.Unlock() + if sconn.c != nil { + sconn.c.Close() + sconn.c = nil + if !getOkay { + t.Logf("Closed server connection") + } + } + } + defer closeConn() + + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method == "GET" { + io.WriteString(w, "bar") + return + } + conn, _, _ := w.(Hijacker).Hijack() + sconn.Lock() + sconn.c = conn + sconn.Unlock() + conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive + go io.Copy(ioutil.Discard, conn) + })) + defer ts.Close() + tr := &Transport{} + defer tr.CloseIdleConnections() + client := &Client{Transport: tr} + + const bodySize = 256 << 10 + finalBit := make(byteFromChanReader, 1) + req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit)) + req.ContentLength = bodySize + res, err := client.Do(req) + if err := wantBody(res, err, "foo"); err != nil { + t.Errorf("POST response: %v", err) + } + donec := make(chan bool) + go func() { + defer close(donec) + res, err = client.Get(ts.URL) + if err := wantBody(res, err, "bar"); err != nil { + t.Errorf("GET response: %v", err) + return + } + getOkay = true // suppress test noise + }() + time.AfterFunc(5*time.Second, closeConn) + select { + case <-donec: + finalBit <- 'x' // unblock the writeloop of the first Post + close(finalBit) + case <-time.After(7 * time.Second): + t.Fatal("timeout waiting for GET request to finish") + } +} + +type errorReader struct { + err error +} + +func (e errorReader) Read(p []byte) (int, error) { return 0, e.err } + +type closerFunc func() error + +func (f closerFunc) Close() error { return f() } + +// Issue 6981 +func TestTransportClosesBodyOnError(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7782") + } + defer afterTest(t) + readBody := make(chan error, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + _, err := ioutil.ReadAll(r.Body) + readBody <- err + })) + defer ts.Close() + fakeErr := errors.New("fake error") + didClose := make(chan bool, 1) + req, _ := NewRequest("POST", ts.URL, struct { + io.Reader + io.Closer + }{ + io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), errorReader{fakeErr}), + closerFunc(func() error { + select { + case didClose <- true: + default: + } + return nil + }), + }) + res, err := DefaultClient.Do(req) + if res != nil { + defer res.Body.Close() + } + if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) { + t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error()) + } + select { + case err := <-readBody: + if err == nil { + t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'") + } + case <-time.After(5 * time.Second): + t.Error("timeout waiting for server handler to complete") + } + select { + case <-didClose: + default: + t.Errorf("didn't see Body.Close") + } +} + +func wantBody(res *http.Response, err error, want string) error { + if err != nil { + return err + } + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("error reading body: %v", err) + } + if string(slurp) != want { + return fmt.Errorf("body = %q; want %q", slurp, want) + } + if err := res.Body.Close(); err != nil { + return fmt.Errorf("body Close = %v", err) + } + return nil +} + +func newLocalListener(t *testing.T) net.Listener { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + ln, err = net.Listen("tcp6", "[::1]:0") + } + if err != nil { + t.Fatal(err) + } + return ln +} + type countCloseReader struct { n *int io.Reader |