summaryrefslogtreecommitdiff
path: root/src/pkg/net/http
diff options
context:
space:
mode:
authorMichael Stapelberg <stapelberg@debian.org>2014-06-19 09:22:53 +0200
committerMichael Stapelberg <stapelberg@debian.org>2014-06-19 09:22:53 +0200
commit8a39ee361feb9bf46d728ff1ba4f07ca1d9610b1 (patch)
tree4449f2036cccf162e8417cc5841a35815b3e7ac5 /src/pkg/net/http
parentc8bf49ef8a92e2337b69c14b9b88396efe498600 (diff)
downloadgolang-51f2ca399fb8da86b2e7b3a0582e083fab731a98.tar.gz
Imported Upstream version 1.3upstream/1.3
Diffstat (limited to 'src/pkg/net/http')
-rw-r--r--src/pkg/net/http/cgi/host.go27
-rw-r--r--src/pkg/net/http/cgi/matryoshka_test.go137
-rw-r--r--src/pkg/net/http/chunked.go58
-rw-r--r--src/pkg/net/http/chunked_test.go112
-rw-r--r--src/pkg/net/http/client.go111
-rw-r--r--src/pkg/net/http/client_test.go275
-rw-r--r--src/pkg/net/http/cookie.go58
-rw-r--r--src/pkg/net/http/cookie_test.go107
-rw-r--r--src/pkg/net/http/export_test.go18
-rw-r--r--src/pkg/net/http/fcgi/child.go19
-rw-r--r--src/pkg/net/http/fs.go18
-rw-r--r--src/pkg/net/http/fs_test.go70
-rw-r--r--src/pkg/net/http/header.go19
-rw-r--r--src/pkg/net/http/header_test.go9
-rw-r--r--src/pkg/net/http/httptest/server_test.go23
-rw-r--r--src/pkg/net/http/httputil/chunked.go74
-rw-r--r--src/pkg/net/http/httputil/chunked_test.go120
-rw-r--r--src/pkg/net/http/httputil/dump.go35
-rw-r--r--src/pkg/net/http/httputil/dump_test.go87
-rw-r--r--src/pkg/net/http/httputil/httputil.go32
-rw-r--r--src/pkg/net/http/httputil/persist.go21
-rw-r--r--src/pkg/net/http/httputil/reverseproxy.go4
-rw-r--r--src/pkg/net/http/httputil/reverseproxy_test.go16
-rw-r--r--src/pkg/net/http/proxy_test.go19
-rw-r--r--src/pkg/net/http/race.go11
-rw-r--r--src/pkg/net/http/request.go139
-rw-r--r--src/pkg/net/http/request_test.go133
-rw-r--r--src/pkg/net/http/requestwrite_test.go42
-rw-r--r--src/pkg/net/http/response.go68
-rw-r--r--src/pkg/net/http/response_test.go20
-rw-r--r--src/pkg/net/http/responsewrite_test.go123
-rw-r--r--src/pkg/net/http/serve_test.go586
-rw-r--r--src/pkg/net/http/server.go270
-rw-r--r--src/pkg/net/http/transfer.go155
-rw-r--r--src/pkg/net/http/transfer_test.go33
-rw-r--r--src/pkg/net/http/transport.go393
-rw-r--r--src/pkg/net/http/transport_test.go529
37 files changed, 3337 insertions, 634 deletions
diff --git a/src/pkg/net/http/cgi/host.go b/src/pkg/net/http/cgi/host.go
index d27cc4dc9..ec95a972c 100644
--- a/src/pkg/net/http/cgi/host.go
+++ b/src/pkg/net/http/cgi/host.go
@@ -214,12 +214,17 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
internalError(err)
return
}
+ if hook := testHookStartProcess; hook != nil {
+ hook(cmd.Process)
+ }
defer cmd.Wait()
defer stdoutRead.Close()
linebody := bufio.NewReaderSize(stdoutRead, 1024)
headers := make(http.Header)
statusCode := 0
+ headerLines := 0
+ sawBlankLine := false
for {
line, isPrefix, err := linebody.ReadLine()
if isPrefix {
@@ -236,8 +241,10 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
if len(line) == 0 {
+ sawBlankLine = true
break
}
+ headerLines++
parts := strings.SplitN(string(line), ":", 2)
if len(parts) < 2 {
h.printf("cgi: bogus header line: %s", string(line))
@@ -263,6 +270,11 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
headers.Add(header, val)
}
}
+ if headerLines == 0 || !sawBlankLine {
+ rw.WriteHeader(http.StatusInternalServerError)
+ h.printf("cgi: no headers")
+ return
+ }
if loc := headers.Get("Location"); loc != "" {
if strings.HasPrefix(loc, "/") && h.PathLocationHandler != nil {
@@ -274,6 +286,12 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
}
+ if statusCode == 0 && headers.Get("Content-Type") == "" {
+ rw.WriteHeader(http.StatusInternalServerError)
+ h.printf("cgi: missing required Content-Type in headers")
+ return
+ }
+
if statusCode == 0 {
statusCode = http.StatusOK
}
@@ -292,6 +310,13 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
_, err = io.Copy(rw, linebody)
if err != nil {
h.printf("cgi: copy error: %v", err)
+ // And kill the child CGI process so we don't hang on
+ // the deferred cmd.Wait above if the error was just
+ // the client (rw) going away. If it was a read error
+ // (because the child died itself), then the extra
+ // kill of an already-dead process is harmless (the PID
+ // won't be reused until the Wait above).
+ cmd.Process.Kill()
}
}
@@ -348,3 +373,5 @@ func upperCaseAndUnderscore(r rune) rune {
// TODO: other transformations in spec or practice?
return r
}
+
+var testHookStartProcess func(*os.Process) // nil except for some tests
diff --git a/src/pkg/net/http/cgi/matryoshka_test.go b/src/pkg/net/http/cgi/matryoshka_test.go
index e1a78c8f6..18c4803e7 100644
--- a/src/pkg/net/http/cgi/matryoshka_test.go
+++ b/src/pkg/net/http/cgi/matryoshka_test.go
@@ -9,15 +9,25 @@
package cgi
import (
+ "bytes"
+ "errors"
"fmt"
+ "io"
"net/http"
+ "net/http/httptest"
"os"
+ "runtime"
"testing"
+ "time"
)
// This test is a CGI host (testing host.go) that runs its own binary
// as a child process testing the other half of CGI (child.go).
func TestHostingOurselves(t *testing.T) {
+ if runtime.GOOS == "nacl" {
+ t.Skip("skipping on nacl")
+ }
+
h := &Handler{
Path: os.Args[0],
Root: "/test.go",
@@ -51,8 +61,88 @@ func TestHostingOurselves(t *testing.T) {
}
}
-// Test that a child handler only writing headers works.
+type customWriterRecorder struct {
+ w io.Writer
+ *httptest.ResponseRecorder
+}
+
+func (r *customWriterRecorder) Write(p []byte) (n int, err error) {
+ return r.w.Write(p)
+}
+
+type limitWriter struct {
+ w io.Writer
+ n int
+}
+
+func (w *limitWriter) Write(p []byte) (n int, err error) {
+ if len(p) > w.n {
+ p = p[:w.n]
+ }
+ if len(p) > 0 {
+ n, err = w.w.Write(p)
+ w.n -= n
+ }
+ if w.n == 0 {
+ err = errors.New("past write limit")
+ }
+ return
+}
+
+// If there's an error copying the child's output to the parent, test
+// that we kill the child.
+func TestKillChildAfterCopyError(t *testing.T) {
+ if runtime.GOOS == "nacl" {
+ t.Skip("skipping on nacl")
+ }
+
+ defer func() { testHookStartProcess = nil }()
+ proc := make(chan *os.Process, 1)
+ testHookStartProcess = func(p *os.Process) {
+ proc <- p
+ }
+
+ h := &Handler{
+ Path: os.Args[0],
+ Root: "/test.go",
+ Args: []string{"-test.run=TestBeChildCGIProcess"},
+ }
+ req, _ := http.NewRequest("GET", "http://example.com/test.cgi?write-forever=1", nil)
+ rec := httptest.NewRecorder()
+ var out bytes.Buffer
+ const writeLen = 50 << 10
+ rw := &customWriterRecorder{&limitWriter{&out, writeLen}, rec}
+
+ donec := make(chan bool, 1)
+ go func() {
+ h.ServeHTTP(rw, req)
+ donec <- true
+ }()
+
+ select {
+ case <-donec:
+ if out.Len() != writeLen || out.Bytes()[0] != 'a' {
+ t.Errorf("unexpected output: %q", out.Bytes())
+ }
+ case <-time.After(5 * time.Second):
+ t.Errorf("timeout. ServeHTTP hung and didn't kill the child process?")
+ select {
+ case p := <-proc:
+ p.Kill()
+ t.Logf("killed process")
+ default:
+ t.Logf("didn't kill process")
+ }
+ }
+}
+
+// Test that a child handler writing only headers works.
+// golang.org/issue/7196
func TestChildOnlyHeaders(t *testing.T) {
+ if runtime.GOOS == "nacl" {
+ t.Skip("skipping on nacl")
+ }
+
h := &Handler{
Path: os.Args[0],
Root: "/test.go",
@@ -67,18 +157,63 @@ func TestChildOnlyHeaders(t *testing.T) {
}
}
+// golang.org/issue/7198
+func Test500WithNoHeaders(t *testing.T) { want500Test(t, "/immediate-disconnect") }
+func Test500WithNoContentType(t *testing.T) { want500Test(t, "/no-content-type") }
+func Test500WithEmptyHeaders(t *testing.T) { want500Test(t, "/empty-headers") }
+
+func want500Test(t *testing.T, path string) {
+ h := &Handler{
+ Path: os.Args[0],
+ Root: "/test.go",
+ Args: []string{"-test.run=TestBeChildCGIProcess"},
+ }
+ expectedMap := map[string]string{
+ "_body": "",
+ }
+ replay := runCgiTest(t, h, "GET "+path+" HTTP/1.0\nHost: example.com\n\n", expectedMap)
+ if replay.Code != 500 {
+ t.Errorf("Got code %d; want 500", replay.Code)
+ }
+}
+
+type neverEnding byte
+
+func (b neverEnding) Read(p []byte) (n int, err error) {
+ for i := range p {
+ p[i] = byte(b)
+ }
+ return len(p), nil
+}
+
// Note: not actually a test.
func TestBeChildCGIProcess(t *testing.T) {
if os.Getenv("REQUEST_METHOD") == "" {
// Not in a CGI environment; skipping test.
return
}
+ switch os.Getenv("REQUEST_URI") {
+ case "/immediate-disconnect":
+ os.Exit(0)
+ case "/no-content-type":
+ fmt.Printf("Content-Length: 6\n\nHello\n")
+ os.Exit(0)
+ case "/empty-headers":
+ fmt.Printf("\nHello")
+ os.Exit(0)
+ }
Serve(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("X-Test-Header", "X-Test-Value")
req.ParseForm()
if req.FormValue("no-body") == "1" {
return
}
+ if req.FormValue("write-forever") == "1" {
+ io.Copy(rw, neverEnding('a'))
+ for {
+ time.Sleep(5 * time.Second) // hang forever, until killed
+ }
+ }
fmt.Fprintf(rw, "test=Hello CGI-in-CGI\n")
for k, vv := range req.Form {
for _, v := range vv {
diff --git a/src/pkg/net/http/chunked.go b/src/pkg/net/http/chunked.go
index 91db01724..749f29d32 100644
--- a/src/pkg/net/http/chunked.go
+++ b/src/pkg/net/http/chunked.go
@@ -4,13 +4,14 @@
// The wire protocol for HTTP's "chunked" Transfer-Encoding.
-// This code is duplicated in httputil/chunked.go.
+// This code is duplicated in net/http and net/http/httputil.
// Please make any changes in both files.
package http
import (
"bufio"
+ "bytes"
"errors"
"fmt"
"io"
@@ -57,26 +58,45 @@ func (cr *chunkedReader) beginChunk() {
}
}
-func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
- if cr.err != nil {
- return 0, cr.err
+func (cr *chunkedReader) chunkHeaderAvailable() bool {
+ n := cr.r.Buffered()
+ if n > 0 {
+ peek, _ := cr.r.Peek(n)
+ return bytes.IndexByte(peek, '\n') >= 0
}
- if cr.n == 0 {
- cr.beginChunk()
- if cr.err != nil {
- return 0, cr.err
+ return false
+}
+
+func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
+ for cr.err == nil {
+ if cr.n == 0 {
+ if n > 0 && !cr.chunkHeaderAvailable() {
+ // We've read enough. Don't potentially block
+ // reading a new chunk header.
+ break
+ }
+ cr.beginChunk()
+ continue
}
- }
- if uint64(len(b)) > cr.n {
- b = b[0:cr.n]
- }
- n, cr.err = cr.r.Read(b)
- cr.n -= uint64(n)
- if cr.n == 0 && cr.err == nil {
- // end of chunk (CRLF)
- if _, cr.err = io.ReadFull(cr.r, cr.buf[:]); cr.err == nil {
- if cr.buf[0] != '\r' || cr.buf[1] != '\n' {
- cr.err = errors.New("malformed chunked encoding")
+ if len(b) == 0 {
+ break
+ }
+ rbuf := b
+ if uint64(len(rbuf)) > cr.n {
+ rbuf = rbuf[:cr.n]
+ }
+ var n0 int
+ n0, cr.err = cr.r.Read(rbuf)
+ n += n0
+ b = b[n0:]
+ cr.n -= uint64(n0)
+ // If we're at the end of a chunk, read the next two
+ // bytes to verify they are "\r\n".
+ if cr.n == 0 && cr.err == nil {
+ if _, cr.err = io.ReadFull(cr.r, cr.buf[:2]); cr.err == nil {
+ if cr.buf[0] != '\r' || cr.buf[1] != '\n' {
+ cr.err = errors.New("malformed chunked encoding")
+ }
}
}
}
diff --git a/src/pkg/net/http/chunked_test.go b/src/pkg/net/http/chunked_test.go
index 0b18c7b55..34544790a 100644
--- a/src/pkg/net/http/chunked_test.go
+++ b/src/pkg/net/http/chunked_test.go
@@ -2,17 +2,18 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// This code is duplicated in httputil/chunked_test.go.
+// This code is duplicated in net/http and net/http/httputil.
// Please make any changes in both files.
package http
import (
+ "bufio"
"bytes"
"fmt"
"io"
"io/ioutil"
- "runtime"
+ "strings"
"testing"
)
@@ -41,9 +42,77 @@ func TestChunk(t *testing.T) {
}
}
+func TestChunkReadMultiple(t *testing.T) {
+ // Bunch of small chunks, all read together.
+ {
+ var b bytes.Buffer
+ w := newChunkedWriter(&b)
+ w.Write([]byte("foo"))
+ w.Write([]byte("bar"))
+ w.Close()
+
+ r := newChunkedReader(&b)
+ buf := make([]byte, 10)
+ n, err := r.Read(buf)
+ if n != 6 || err != io.EOF {
+ t.Errorf("Read = %d, %v; want 6, EOF", n, err)
+ }
+ buf = buf[:n]
+ if string(buf) != "foobar" {
+ t.Errorf("Read = %q; want %q", buf, "foobar")
+ }
+ }
+
+ // One big chunk followed by a little chunk, but the small bufio.Reader size
+ // should prevent the second chunk header from being read.
+ {
+ var b bytes.Buffer
+ w := newChunkedWriter(&b)
+ // fillBufChunk is 11 bytes + 3 bytes header + 2 bytes footer = 16 bytes,
+ // the same as the bufio ReaderSize below (the minimum), so even
+ // though we're going to try to Read with a buffer larger enough to also
+ // receive "foo", the second chunk header won't be read yet.
+ const fillBufChunk = "0123456789a"
+ const shortChunk = "foo"
+ w.Write([]byte(fillBufChunk))
+ w.Write([]byte(shortChunk))
+ w.Close()
+
+ r := newChunkedReader(bufio.NewReaderSize(&b, 16))
+ buf := make([]byte, len(fillBufChunk)+len(shortChunk))
+ n, err := r.Read(buf)
+ if n != len(fillBufChunk) || err != nil {
+ t.Errorf("Read = %d, %v; want %d, nil", n, err, len(fillBufChunk))
+ }
+ buf = buf[:n]
+ if string(buf) != fillBufChunk {
+ t.Errorf("Read = %q; want %q", buf, fillBufChunk)
+ }
+
+ n, err = r.Read(buf)
+ if n != len(shortChunk) || err != io.EOF {
+ t.Errorf("Read = %d, %v; want %d, EOF", n, err, len(shortChunk))
+ }
+ }
+
+ // And test that we see an EOF chunk, even though our buffer is already full:
+ {
+ r := newChunkedReader(bufio.NewReader(strings.NewReader("3\r\nfoo\r\n0\r\n")))
+ buf := make([]byte, 3)
+ n, err := r.Read(buf)
+ if n != 3 || err != io.EOF {
+ t.Errorf("Read = %d, %v; want 3, EOF", n, err)
+ }
+ if string(buf) != "foo" {
+ t.Errorf("buf = %q; want foo", buf)
+ }
+ }
+}
+
func TestChunkReaderAllocs(t *testing.T) {
- // temporarily set GOMAXPROCS to 1 as we are testing memory allocations
- defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1))
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
var buf bytes.Buffer
w := newChunkedWriter(&buf)
a, b, c := []byte("aaaaaa"), []byte("bbbbbbbbbbbb"), []byte("cccccccccccccccccccccccc")
@@ -52,26 +121,23 @@ func TestChunkReaderAllocs(t *testing.T) {
w.Write(c)
w.Close()
- r := newChunkedReader(&buf)
readBuf := make([]byte, len(a)+len(b)+len(c)+1)
-
- var ms runtime.MemStats
- runtime.ReadMemStats(&ms)
- m0 := ms.Mallocs
-
- n, err := io.ReadFull(r, readBuf)
-
- runtime.ReadMemStats(&ms)
- mallocs := ms.Mallocs - m0
- if mallocs > 1 {
- t.Errorf("%d mallocs; want <= 1", mallocs)
- }
-
- if n != len(readBuf)-1 {
- t.Errorf("read %d bytes; want %d", n, len(readBuf)-1)
- }
- if err != io.ErrUnexpectedEOF {
- t.Errorf("read error = %v; want ErrUnexpectedEOF", err)
+ byter := bytes.NewReader(buf.Bytes())
+ bufr := bufio.NewReader(byter)
+ mallocs := testing.AllocsPerRun(100, func() {
+ byter.Seek(0, 0)
+ bufr.Reset(byter)
+ r := newChunkedReader(bufr)
+ n, err := io.ReadFull(r, readBuf)
+ if n != len(readBuf)-1 {
+ t.Fatalf("read %d bytes; want %d", n, len(readBuf)-1)
+ }
+ if err != io.ErrUnexpectedEOF {
+ t.Fatalf("read error = %v; want ErrUnexpectedEOF", err)
+ }
+ })
+ if mallocs > 1.5 {
+ t.Errorf("mallocs = %v; want 1", mallocs)
}
}
diff --git a/src/pkg/net/http/client.go b/src/pkg/net/http/client.go
index 22f2e865c..a5a3abe61 100644
--- a/src/pkg/net/http/client.go
+++ b/src/pkg/net/http/client.go
@@ -14,9 +14,12 @@ import (
"errors"
"fmt"
"io"
+ "io/ioutil"
"log"
"net/url"
"strings"
+ "sync"
+ "time"
)
// A Client is an HTTP client. Its zero value (DefaultClient) is a
@@ -52,6 +55,20 @@ type Client struct {
// If Jar is nil, cookies are not sent in requests and ignored
// in responses.
Jar CookieJar
+
+ // Timeout specifies a time limit for requests made by this
+ // Client. The timeout includes connection time, any
+ // redirects, and reading the response body. The timer remains
+ // running after Get, Head, Post, or Do return and will
+ // interrupt reading of the Response.Body.
+ //
+ // A Timeout of zero means no timeout.
+ //
+ // The Client's Transport must support the CancelRequest
+ // method or Client will return errors when attempting to make
+ // a request with Get, Head, Post, or Do. Client's default
+ // Transport (DefaultTransport) supports CancelRequest.
+ Timeout time.Duration
}
// DefaultClient is the default Client and is used by Get, Head, and Post.
@@ -74,8 +91,9 @@ type RoundTripper interface {
// authentication, or cookies.
//
// RoundTrip should not modify the request, except for
- // consuming and closing the Body. The request's URL and
- // Header fields are guaranteed to be initialized.
+ // consuming and closing the Body, including on errors. The
+ // request's URL and Header fields are guaranteed to be
+ // initialized.
RoundTrip(*Request) (*Response, error)
}
@@ -97,7 +115,7 @@ func (c *Client) send(req *Request) (*Response, error) {
req.AddCookie(cookie)
}
}
- resp, err := send(req, c.Transport)
+ resp, err := send(req, c.transport())
if err != nil {
return nil, err
}
@@ -123,6 +141,9 @@ func (c *Client) send(req *Request) (*Response, error) {
// (typically Transport) may not be able to re-use a persistent TCP
// connection to the server for a subsequent "keep-alive" request.
//
+// The request Body, if non-nil, will be closed by the underlying
+// Transport, even on errors.
+//
// Generally Get, Post, or PostForm will be used instead of Do.
func (c *Client) Do(req *Request) (resp *Response, err error) {
if req.Method == "GET" || req.Method == "HEAD" {
@@ -134,22 +155,28 @@ func (c *Client) Do(req *Request) (resp *Response, err error) {
return c.send(req)
}
+func (c *Client) transport() RoundTripper {
+ if c.Transport != nil {
+ return c.Transport
+ }
+ return DefaultTransport
+}
+
// send issues an HTTP request.
// Caller should close resp.Body when done reading from it.
func send(req *Request, t RoundTripper) (resp *Response, err error) {
if t == nil {
- t = DefaultTransport
- if t == nil {
- err = errors.New("http: no Client.Transport or DefaultTransport")
- return
- }
+ req.closeBody()
+ return nil, errors.New("http: no Client.Transport or DefaultTransport")
}
if req.URL == nil {
+ req.closeBody()
return nil, errors.New("http: nil Request.URL")
}
if req.RequestURI != "" {
+ req.closeBody()
return nil, errors.New("http: Request.RequestURI can't be set in client requests.")
}
@@ -257,21 +284,40 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
var via []*Request
if ireq.URL == nil {
+ ireq.closeBody()
return nil, errors.New("http: nil Request.URL")
}
+ var reqmu sync.Mutex // guards req
req := ireq
+
+ var timer *time.Timer
+ if c.Timeout > 0 {
+ type canceler interface {
+ CancelRequest(*Request)
+ }
+ tr, ok := c.transport().(canceler)
+ if !ok {
+ return nil, fmt.Errorf("net/http: Client Transport of type %T doesn't support CancelRequest; Timeout not supported", c.transport())
+ }
+ timer = time.AfterFunc(c.Timeout, func() {
+ reqmu.Lock()
+ defer reqmu.Unlock()
+ tr.CancelRequest(req)
+ })
+ }
+
urlStr := "" // next relative or absolute URL to fetch (after first request)
redirectFailed := false
for redirect := 0; ; redirect++ {
if redirect != 0 {
- req = new(Request)
- req.Method = ireq.Method
+ nreq := new(Request)
+ nreq.Method = ireq.Method
if ireq.Method == "POST" || ireq.Method == "PUT" {
- req.Method = "GET"
+ nreq.Method = "GET"
}
- req.Header = make(Header)
- req.URL, err = base.Parse(urlStr)
+ nreq.Header = make(Header)
+ nreq.URL, err = base.Parse(urlStr)
if err != nil {
break
}
@@ -279,15 +325,18 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
// Add the Referer header.
lastReq := via[len(via)-1]
if lastReq.URL.Scheme != "https" {
- req.Header.Set("Referer", lastReq.URL.String())
+ nreq.Header.Set("Referer", lastReq.URL.String())
}
- err = redirectChecker(req, via)
+ err = redirectChecker(nreq, via)
if err != nil {
redirectFailed = true
break
}
}
+ reqmu.Lock()
+ req = nreq
+ reqmu.Unlock()
}
urlStr = req.URL.String()
@@ -296,6 +345,12 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
}
if shouldRedirect(resp.StatusCode) {
+ // Read the body if small so underlying TCP connection will be re-used.
+ // No need to check for errors: if it fails, Transport won't reuse it anyway.
+ const maxBodySlurpSize = 2 << 10
+ if resp.ContentLength == -1 || resp.ContentLength <= maxBodySlurpSize {
+ io.CopyN(ioutil.Discard, resp.Body, maxBodySlurpSize)
+ }
resp.Body.Close()
if urlStr = resp.Header.Get("Location"); urlStr == "" {
err = errors.New(fmt.Sprintf("%d response missing Location header", resp.StatusCode))
@@ -305,7 +360,10 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
via = append(via, req)
continue
}
- return
+ if timer != nil {
+ resp.Body = &cancelTimerBody{timer, resp.Body}
+ }
+ return resp, nil
}
method := ireq.Method
@@ -349,7 +407,7 @@ func Post(url string, bodyType string, body io.Reader) (resp *Response, err erro
// Caller should close resp.Body when done reading from it.
//
// If the provided body is also an io.Closer, it is closed after the
-// body is successfully written to the server.
+// request.
func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Response, err error) {
req, err := NewRequest("POST", url, body)
if err != nil {
@@ -408,3 +466,22 @@ func (c *Client) Head(url string) (resp *Response, err error) {
}
return c.doFollowingRedirects(req, shouldRedirectGet)
}
+
+type cancelTimerBody struct {
+ t *time.Timer
+ rc io.ReadCloser
+}
+
+func (b *cancelTimerBody) Read(p []byte) (n int, err error) {
+ n, err = b.rc.Read(p)
+ if err == io.EOF {
+ b.t.Stop()
+ }
+ return
+}
+
+func (b *cancelTimerBody) Close() error {
+ err := b.rc.Close()
+ b.t.Stop()
+ return err
+}
diff --git a/src/pkg/net/http/client_test.go b/src/pkg/net/http/client_test.go
index 997d04151..6392c1baf 100644
--- a/src/pkg/net/http/client_test.go
+++ b/src/pkg/net/http/client_test.go
@@ -15,14 +15,18 @@ import (
"fmt"
"io"
"io/ioutil"
+ "log"
"net"
. "net/http"
"net/http/httptest"
"net/url"
+ "reflect"
+ "sort"
"strconv"
"strings"
"sync"
"testing"
+ "time"
)
var robotsTxtHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -54,6 +58,13 @@ func pedanticReadAll(r io.Reader) (b []byte, err error) {
}
}
+type chanWriter chan string
+
+func (w chanWriter) Write(p []byte) (n int, err error) {
+ w <- string(p)
+ return len(p), nil
+}
+
func TestClient(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(robotsTxtHandler)
@@ -373,24 +384,6 @@ func (j *TestJar) Cookies(u *url.URL) []*Cookie {
return j.perURL[u.Host]
}
-func TestRedirectCookiesOnRequest(t *testing.T) {
- defer afterTest(t)
- var ts *httptest.Server
- ts = httptest.NewServer(echoCookiesRedirectHandler)
- defer ts.Close()
- c := &Client{}
- req, _ := NewRequest("GET", ts.URL, nil)
- req.AddCookie(expectedCookies[0])
- // TODO: Uncomment when an implementation of a RFC6265 cookie jar lands.
- _ = c
- // resp, _ := c.Do(req)
- // matchReturnedCookies(t, expectedCookies, resp.Cookies())
-
- req, _ = NewRequest("GET", ts.URL, nil)
- // resp, _ = c.Do(req)
- // matchReturnedCookies(t, expectedCookies[1:], resp.Cookies())
-}
-
func TestRedirectCookiesJar(t *testing.T) {
defer afterTest(t)
var ts *httptest.Server
@@ -410,8 +403,8 @@ func TestRedirectCookiesJar(t *testing.T) {
}
func matchReturnedCookies(t *testing.T, expected, given []*Cookie) {
- t.Logf("Received cookies: %v", given)
if len(given) != len(expected) {
+ t.Logf("Received cookies: %v", given)
t.Errorf("Expected %d cookies, got %d", len(expected), len(given))
}
for _, ec := range expected {
@@ -582,6 +575,8 @@ func TestClientInsecureTransport(t *testing.T) {
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Write([]byte("Hello"))
}))
+ errc := make(chanWriter, 10) // but only expecting 1
+ ts.Config.ErrorLog = log.New(errc, "", 0)
defer ts.Close()
// TODO(bradfitz): add tests for skipping hostname checks too?
@@ -603,6 +598,16 @@ func TestClientInsecureTransport(t *testing.T) {
res.Body.Close()
}
}
+
+ select {
+ case v := <-errc:
+ if !strings.Contains(v, "TLS handshake error") {
+ t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v)
+ }
+ case <-time.After(5 * time.Second):
+ t.Errorf("timeout waiting for logged error")
+ }
+
}
func TestClientErrorWithRequestURI(t *testing.T) {
@@ -653,6 +658,8 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) {
defer afterTest(t)
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
defer ts.Close()
+ errc := make(chanWriter, 10) // but only expecting 1
+ ts.Config.ErrorLog = log.New(errc, "", 0)
trans := newTLSTransport(t, ts)
trans.TLSClientConfig.ServerName = "badserver"
@@ -664,6 +671,14 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) {
if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") {
t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err)
}
+ select {
+ case v := <-errc:
+ if !strings.Contains(v, "TLS handshake error") {
+ t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v)
+ }
+ case <-time.After(5 * time.Second):
+ t.Errorf("timeout waiting for logged error")
+ }
}
// Test for golang.org/issue/5829; the Transport should respect TLSClientConfig.ServerName
@@ -696,6 +711,33 @@ func TestTransportUsesTLSConfigServerName(t *testing.T) {
res.Body.Close()
}
+func TestResponseSetsTLSConnectionState(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Write([]byte("Hello"))
+ }))
+ defer ts.Close()
+
+ tr := newTLSTransport(t, ts)
+ tr.TLSClientConfig.CipherSuites = []uint16{tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA}
+ tr.Dial = func(netw, addr string) (net.Conn, error) {
+ return net.Dial(netw, ts.Listener.Addr().String())
+ }
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+ res, err := c.Get("https://example.com/")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.TLS == nil {
+ t.Fatal("Response didn't set TLS Connection State.")
+ }
+ if got, want := res.TLS.CipherSuite, tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA; got != want {
+ t.Errorf("TLS Cipher Suite = %d; want %d", got, want)
+ }
+}
+
// Verify Response.ContentLength is populated. http://golang.org/issue/4126
func TestClientHeadContentLength(t *testing.T) {
defer afterTest(t)
@@ -799,3 +841,198 @@ func TestBasicAuth(t *testing.T) {
t.Errorf("Invalid auth %q", auth)
}
}
+
+func TestClientTimeout(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ defer afterTest(t)
+ sawRoot := make(chan bool, 1)
+ sawSlow := make(chan bool, 1)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.URL.Path == "/" {
+ sawRoot <- true
+ Redirect(w, r, "/slow", StatusFound)
+ return
+ }
+ if r.URL.Path == "/slow" {
+ w.Write([]byte("Hello"))
+ w.(Flusher).Flush()
+ sawSlow <- true
+ time.Sleep(2 * time.Second)
+ return
+ }
+ }))
+ defer ts.Close()
+ const timeout = 500 * time.Millisecond
+ c := &Client{
+ Timeout: timeout,
+ }
+
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ select {
+ case <-sawRoot:
+ // good.
+ default:
+ t.Fatal("handler never got / request")
+ }
+
+ select {
+ case <-sawSlow:
+ // good.
+ default:
+ t.Fatal("handler never got /slow request")
+ }
+
+ errc := make(chan error, 1)
+ go func() {
+ _, err := ioutil.ReadAll(res.Body)
+ errc <- err
+ res.Body.Close()
+ }()
+
+ const failTime = timeout * 2
+ select {
+ case err := <-errc:
+ if err == nil {
+ t.Error("expected error from ReadAll")
+ }
+ // Expected error.
+ case <-time.After(failTime):
+ t.Errorf("timeout after %v waiting for timeout of %v", failTime, timeout)
+ }
+}
+
+func TestClientRedirectEatsBody(t *testing.T) {
+ defer afterTest(t)
+ saw := make(chan string, 2)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ saw <- r.RemoteAddr
+ if r.URL.Path == "/" {
+ Redirect(w, r, "/foo", StatusFound) // which includes a body
+ }
+ }))
+ defer ts.Close()
+
+ res, err := Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+
+ var first string
+ select {
+ case first = <-saw:
+ default:
+ t.Fatal("server didn't see a request")
+ }
+
+ var second string
+ select {
+ case second = <-saw:
+ default:
+ t.Fatal("server didn't see a second request")
+ }
+
+ if first != second {
+ t.Fatal("server saw different client ports before & after the redirect")
+ }
+}
+
+// eofReaderFunc is an io.Reader that runs itself, and then returns io.EOF.
+type eofReaderFunc func()
+
+func (f eofReaderFunc) Read(p []byte) (n int, err error) {
+ f()
+ return 0, io.EOF
+}
+
+func TestClientTrailers(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Connection", "close")
+ w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
+ w.Header().Add("Trailer", "Server-Trailer-C")
+
+ var decl []string
+ for k := range r.Trailer {
+ decl = append(decl, k)
+ }
+ sort.Strings(decl)
+
+ slurp, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ t.Errorf("Server reading request body: %v", err)
+ }
+ if string(slurp) != "foo" {
+ t.Errorf("Server read request body %q; want foo", slurp)
+ }
+ if r.Trailer == nil {
+ io.WriteString(w, "nil Trailer")
+ } else {
+ fmt.Fprintf(w, "decl: %v, vals: %s, %s",
+ decl,
+ r.Trailer.Get("Client-Trailer-A"),
+ r.Trailer.Get("Client-Trailer-B"))
+ }
+
+ // TODO: golang.org/issue/7759: there's no way yet for
+ // the server to set trailers without hijacking, so do
+ // that for now, just to test the client. Later, in
+ // Go 1.4, it should be implicit that any mutations
+ // to w.Header() after the initial write are the
+ // trailers to be sent, if and only if they were
+ // previously declared with w.Header().Set("Trailer",
+ // ..keys..)
+ w.(Flusher).Flush()
+ conn, buf, _ := w.(Hijacker).Hijack()
+ t := Header{}
+ t.Set("Server-Trailer-A", "valuea")
+ t.Set("Server-Trailer-C", "valuec") // skipping B
+ buf.WriteString("0\r\n") // eof
+ t.Write(buf)
+ buf.WriteString("\r\n") // end of trailers
+ buf.Flush()
+ conn.Close()
+ }))
+ defer ts.Close()
+
+ var req *Request
+ req, _ = NewRequest("POST", ts.URL, io.MultiReader(
+ eofReaderFunc(func() {
+ req.Trailer["Client-Trailer-A"] = []string{"valuea"}
+ }),
+ strings.NewReader("foo"),
+ eofReaderFunc(func() {
+ req.Trailer["Client-Trailer-B"] = []string{"valueb"}
+ }),
+ ))
+ req.Trailer = Header{
+ "Client-Trailer-A": nil, // to be set later
+ "Client-Trailer-B": nil, // to be set later
+ }
+ req.ContentLength = -1
+ res, err := DefaultClient.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
+ t.Error(err)
+ }
+ want := Header{
+ "Server-Trailer-A": []string{"valuea"},
+ "Server-Trailer-B": nil,
+ "Server-Trailer-C": []string{"valuec"},
+ }
+ if !reflect.DeepEqual(res.Trailer, want) {
+ t.Errorf("Response trailers = %#v; want %#v", res.Trailer, want)
+ }
+}
diff --git a/src/pkg/net/http/cookie.go b/src/pkg/net/http/cookie.go
index 8b01c508e..dc60ba87f 100644
--- a/src/pkg/net/http/cookie.go
+++ b/src/pkg/net/http/cookie.go
@@ -76,11 +76,7 @@ func readSetCookies(h Header) []*Cookie {
attr, val = attr[:j], attr[j+1:]
}
lowerAttr := strings.ToLower(attr)
- parseCookieValueFn := parseCookieValue
- if lowerAttr == "expires" {
- parseCookieValueFn = parseCookieExpiresValue
- }
- val, success = parseCookieValueFn(val)
+ val, success = parseCookieValue(val)
if !success {
c.Unparsed = append(c.Unparsed, parts[i])
continue
@@ -94,7 +90,6 @@ func readSetCookies(h Header) []*Cookie {
continue
case "domain":
c.Domain = val
- // TODO: Add domain parsing
continue
case "max-age":
secs, err := strconv.Atoi(val)
@@ -121,7 +116,6 @@ func readSetCookies(h Header) []*Cookie {
continue
case "path":
c.Path = val
- // TODO: Add path parsing
continue
}
c.Unparsed = append(c.Unparsed, parts[i])
@@ -300,12 +294,23 @@ func sanitizeCookieName(n string) string {
// ; US-ASCII characters excluding CTLs,
// ; whitespace DQUOTE, comma, semicolon,
// ; and backslash
+// We loosen this as spaces and commas are common in cookie values
+// but we produce a quoted cookie-value in when value starts or ends
+// with a comma or space.
+// See http://golang.org/issue/7243 for the discussion.
func sanitizeCookieValue(v string) string {
- return sanitizeOrWarn("Cookie.Value", validCookieValueByte, v)
+ v = sanitizeOrWarn("Cookie.Value", validCookieValueByte, v)
+ if len(v) == 0 {
+ return v
+ }
+ if v[0] == ' ' || v[0] == ',' || v[len(v)-1] == ' ' || v[len(v)-1] == ',' {
+ return `"` + v + `"`
+ }
+ return v
}
func validCookieValueByte(b byte) bool {
- return 0x20 < b && b < 0x7f && b != '"' && b != ',' && b != ';' && b != '\\'
+ return 0x20 <= b && b < 0x7f && b != '"' && b != ';' && b != '\\'
}
// path-av = "Path=" path-value
@@ -340,38 +345,13 @@ func sanitizeOrWarn(fieldName string, valid func(byte) bool, v string) string {
return string(buf)
}
-func unquoteCookieValue(v string) string {
- if len(v) > 1 && v[0] == '"' && v[len(v)-1] == '"' {
- return v[1 : len(v)-1]
- }
- return v
-}
-
-func isCookieByte(c byte) bool {
- switch {
- case c == 0x21, 0x23 <= c && c <= 0x2b, 0x2d <= c && c <= 0x3a,
- 0x3c <= c && c <= 0x5b, 0x5d <= c && c <= 0x7e:
- return true
- }
- return false
-}
-
-func isCookieExpiresByte(c byte) (ok bool) {
- return isCookieByte(c) || c == ',' || c == ' '
-}
-
func parseCookieValue(raw string) (string, bool) {
- return parseCookieValueUsing(raw, isCookieByte)
-}
-
-func parseCookieExpiresValue(raw string) (string, bool) {
- return parseCookieValueUsing(raw, isCookieExpiresByte)
-}
-
-func parseCookieValueUsing(raw string, validByte func(byte) bool) (string, bool) {
- raw = unquoteCookieValue(raw)
+ // Strip the quotes, if present.
+ if len(raw) > 1 && raw[0] == '"' && raw[len(raw)-1] == '"' {
+ raw = raw[1 : len(raw)-1]
+ }
for i := 0; i < len(raw); i++ {
- if !validByte(raw[i]) {
+ if !validCookieValueByte(raw[i]) {
return "", false
}
}
diff --git a/src/pkg/net/http/cookie_test.go b/src/pkg/net/http/cookie_test.go
index 11b01cc57..f78f37299 100644
--- a/src/pkg/net/http/cookie_test.go
+++ b/src/pkg/net/http/cookie_test.go
@@ -5,9 +5,13 @@
package http
import (
+ "bytes"
"encoding/json"
"fmt"
+ "log"
+ "os"
"reflect"
+ "strings"
"testing"
"time"
)
@@ -48,15 +52,61 @@ var writeSetCookiesTests = []struct {
&Cookie{Name: "cookie-8", Value: "eight", Domain: "::1"},
"cookie-8=eight",
},
+ // The "special" cookies have values containing commas or spaces which
+ // are disallowed by RFC 6265 but are common in the wild.
+ {
+ &Cookie{Name: "special-1", Value: "a z"},
+ `special-1=a z`,
+ },
+ {
+ &Cookie{Name: "special-2", Value: " z"},
+ `special-2=" z"`,
+ },
+ {
+ &Cookie{Name: "special-3", Value: "a "},
+ `special-3="a "`,
+ },
+ {
+ &Cookie{Name: "special-4", Value: " "},
+ `special-4=" "`,
+ },
+ {
+ &Cookie{Name: "special-5", Value: "a,z"},
+ `special-5=a,z`,
+ },
+ {
+ &Cookie{Name: "special-6", Value: ",z"},
+ `special-6=",z"`,
+ },
+ {
+ &Cookie{Name: "special-7", Value: "a,"},
+ `special-7="a,"`,
+ },
+ {
+ &Cookie{Name: "special-8", Value: ","},
+ `special-8=","`,
+ },
+ {
+ &Cookie{Name: "empty-value", Value: ""},
+ `empty-value=`,
+ },
}
func TestWriteSetCookies(t *testing.T) {
+ defer log.SetOutput(os.Stderr)
+ var logbuf bytes.Buffer
+ log.SetOutput(&logbuf)
+
for i, tt := range writeSetCookiesTests {
if g, e := tt.Cookie.String(), tt.Raw; g != e {
t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, e, g)
continue
}
}
+
+ if got, sub := logbuf.String(), "dropping domain attribute"; !strings.Contains(got, sub) {
+ t.Errorf("Expected substring %q in log output. Got:\n%s", sub, got)
+ }
}
type headerOnlyResponseWriter Header
@@ -166,6 +216,40 @@ var readSetCookiesTests = []struct {
Raw: "ASP.NET_SessionId=foo; path=/; HttpOnly",
}},
},
+ // Make sure we can properly read back the Set-Cookie headers we create
+ // for values containing spaces or commas:
+ {
+ Header{"Set-Cookie": {`special-1=a z`}},
+ []*Cookie{{Name: "special-1", Value: "a z", Raw: `special-1=a z`}},
+ },
+ {
+ Header{"Set-Cookie": {`special-2=" z"`}},
+ []*Cookie{{Name: "special-2", Value: " z", Raw: `special-2=" z"`}},
+ },
+ {
+ Header{"Set-Cookie": {`special-3="a "`}},
+ []*Cookie{{Name: "special-3", Value: "a ", Raw: `special-3="a "`}},
+ },
+ {
+ Header{"Set-Cookie": {`special-4=" "`}},
+ []*Cookie{{Name: "special-4", Value: " ", Raw: `special-4=" "`}},
+ },
+ {
+ Header{"Set-Cookie": {`special-5=a,z`}},
+ []*Cookie{{Name: "special-5", Value: "a,z", Raw: `special-5=a,z`}},
+ },
+ {
+ Header{"Set-Cookie": {`special-6=",z"`}},
+ []*Cookie{{Name: "special-6", Value: ",z", Raw: `special-6=",z"`}},
+ },
+ {
+ Header{"Set-Cookie": {`special-7=a,`}},
+ []*Cookie{{Name: "special-7", Value: "a,", Raw: `special-7=a,`}},
+ },
+ {
+ Header{"Set-Cookie": {`special-8=","`}},
+ []*Cookie{{Name: "special-8", Value: ",", Raw: `special-8=","`}},
+ },
// TODO(bradfitz): users have reported seeing this in the
// wild, but do browsers handle it? RFC 6265 just says "don't
@@ -244,22 +328,39 @@ func TestReadCookies(t *testing.T) {
}
func TestCookieSanitizeValue(t *testing.T) {
+ defer log.SetOutput(os.Stderr)
+ var logbuf bytes.Buffer
+ log.SetOutput(&logbuf)
+
tests := []struct {
in, want string
}{
{"foo", "foo"},
- {"foo bar", "foobar"},
+ {"foo;bar", "foobar"},
+ {"foo\\bar", "foobar"},
+ {"foo\"bar", "foobar"},
{"\x00\x7e\x7f\x80", "\x7e"},
{`"withquotes"`, "withquotes"},
+ {"a z", "a z"},
+ {" z", `" z"`},
+ {"a ", `"a "`},
}
for _, tt := range tests {
if got := sanitizeCookieValue(tt.in); got != tt.want {
t.Errorf("sanitizeCookieValue(%q) = %q; want %q", tt.in, got, tt.want)
}
}
+
+ if got, sub := logbuf.String(), "dropping invalid bytes"; !strings.Contains(got, sub) {
+ t.Errorf("Expected substring %q in log output. Got:\n%s", sub, got)
+ }
}
func TestCookieSanitizePath(t *testing.T) {
+ defer log.SetOutput(os.Stderr)
+ var logbuf bytes.Buffer
+ log.SetOutput(&logbuf)
+
tests := []struct {
in, want string
}{
@@ -272,4 +373,8 @@ func TestCookieSanitizePath(t *testing.T) {
t.Errorf("sanitizeCookiePath(%q) = %q; want %q", tt.in, got, tt.want)
}
}
+
+ if got, sub := logbuf.String(), "dropping invalid bytes"; !strings.Contains(got, sub) {
+ t.Errorf("Expected substring %q in log output. Got:\n%s", sub, got)
+ }
}
diff --git a/src/pkg/net/http/export_test.go b/src/pkg/net/http/export_test.go
index 22b7f2796..960563b24 100644
--- a/src/pkg/net/http/export_test.go
+++ b/src/pkg/net/http/export_test.go
@@ -21,7 +21,7 @@ var ExportAppendTime = appendTime
func (t *Transport) NumPendingRequestsForTesting() int {
t.reqMu.Lock()
defer t.reqMu.Unlock()
- return len(t.reqConn)
+ return len(t.reqCanceler)
}
func (t *Transport) IdleConnKeysForTesting() (keys []string) {
@@ -32,7 +32,7 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) {
return
}
for key := range t.idleConn {
- keys = append(keys, key)
+ keys = append(keys, key.String())
}
return
}
@@ -43,11 +43,12 @@ func (t *Transport) IdleConnCountForTesting(cacheKey string) int {
if t.idleConn == nil {
return 0
}
- conns, ok := t.idleConn[cacheKey]
- if !ok {
- return 0
+ for k, conns := range t.idleConn {
+ if k.String() == cacheKey {
+ return len(conns)
+ }
}
- return len(conns)
+ return 0
}
func (t *Transport) IdleConnChMapSizeForTesting() int {
@@ -63,4 +64,9 @@ func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler {
return &timeoutHandler{handler, f, ""}
}
+func ResetCachedEnvironment() {
+ httpProxyEnv.reset()
+ noProxyEnv.reset()
+}
+
var DefaultUserAgent = defaultUserAgent
diff --git a/src/pkg/net/http/fcgi/child.go b/src/pkg/net/http/fcgi/child.go
index 60b794e07..a3beaa33a 100644
--- a/src/pkg/net/http/fcgi/child.go
+++ b/src/pkg/net/http/fcgi/child.go
@@ -16,6 +16,7 @@ import (
"net/http/cgi"
"os"
"strings"
+ "sync"
"time"
)
@@ -126,8 +127,10 @@ func (r *response) Close() error {
}
type child struct {
- conn *conn
- handler http.Handler
+ conn *conn
+ handler http.Handler
+
+ mu sync.Mutex // protects requests:
requests map[uint16]*request // keyed by request ID
}
@@ -157,7 +160,9 @@ var errCloseConn = errors.New("fcgi: connection should be closed")
var emptyBody = ioutil.NopCloser(strings.NewReader(""))
func (c *child) handleRecord(rec *record) error {
+ c.mu.Lock()
req, ok := c.requests[rec.h.Id]
+ c.mu.Unlock()
if !ok && rec.h.Type != typeBeginRequest && rec.h.Type != typeGetValues {
// The spec says to ignore unknown request IDs.
return nil
@@ -179,7 +184,10 @@ func (c *child) handleRecord(rec *record) error {
c.conn.writeEndRequest(rec.h.Id, 0, statusUnknownRole)
return nil
}
- c.requests[rec.h.Id] = newRequest(rec.h.Id, br.flags)
+ req = newRequest(rec.h.Id, br.flags)
+ c.mu.Lock()
+ c.requests[rec.h.Id] = req
+ c.mu.Unlock()
return nil
case typeParams:
// NOTE(eds): Technically a key-value pair can straddle the boundary
@@ -220,7 +228,9 @@ func (c *child) handleRecord(rec *record) error {
return nil
case typeAbortRequest:
println("abort")
+ c.mu.Lock()
delete(c.requests, rec.h.Id)
+ c.mu.Unlock()
c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete)
if !req.keepConn {
// connection will close upon return
@@ -247,6 +257,9 @@ func (c *child) serveRequest(req *request, body io.ReadCloser) {
c.handler.ServeHTTP(r, httpReq)
}
r.Close()
+ c.mu.Lock()
+ delete(c.requests, req.reqId)
+ c.mu.Unlock()
c.conn.writeEndRequest(req.reqId, 0, statusRequestComplete)
// Consume the entire body, so the host isn't still writing to
diff --git a/src/pkg/net/http/fs.go b/src/pkg/net/http/fs.go
index 8b32ca1d0..8576cf844 100644
--- a/src/pkg/net/http/fs.go
+++ b/src/pkg/net/http/fs.go
@@ -13,6 +13,7 @@ import (
"mime"
"mime/multipart"
"net/textproto"
+ "net/url"
"os"
"path"
"path/filepath"
@@ -52,12 +53,14 @@ type FileSystem interface {
// A File is returned by a FileSystem's Open method and can be
// served by the FileServer implementation.
+//
+// The methods should behave the same as those on an *os.File.
type File interface {
- Close() error
- Stat() (os.FileInfo, error)
+ io.Closer
+ io.Reader
Readdir(count int) ([]os.FileInfo, error)
- Read([]byte) (int, error)
Seek(offset int64, whence int) (int64, error)
+ Stat() (os.FileInfo, error)
}
func dirList(w ResponseWriter, f File) {
@@ -73,8 +76,11 @@ func dirList(w ResponseWriter, f File) {
if d.IsDir() {
name += "/"
}
- // TODO htmlescape
- fmt.Fprintf(w, "<a href=\"%s\">%s</a>\n", name, name)
+ // name may contain '?' or '#', which must be escaped to remain
+ // part of the URL path, and not indicate the start of a query
+ // string or fragment.
+ url := url.URL{Path: name}
+ fmt.Fprintf(w, "<a href=\"%s\">%s</a>\n", url.String(), htmlReplacer.Replace(name))
}
}
fmt.Fprintf(w, "</pre>\n")
@@ -521,7 +527,7 @@ func (w *countingWriter) Write(p []byte) (n int, err error) {
return len(p), nil
}
-// rangesMIMESize returns the nunber of bytes it takes to encode the
+// rangesMIMESize returns the number of bytes it takes to encode the
// provided ranges as a multipart response.
func rangesMIMESize(ranges []httpRange, contentType string, contentSize int64) (encSize int64) {
var w countingWriter
diff --git a/src/pkg/net/http/fs_test.go b/src/pkg/net/http/fs_test.go
index ae54edf0c..f968565f9 100644
--- a/src/pkg/net/http/fs_test.go
+++ b/src/pkg/net/http/fs_test.go
@@ -227,6 +227,54 @@ func TestFileServerCleans(t *testing.T) {
}
}
+func TestFileServerEscapesNames(t *testing.T) {
+ defer afterTest(t)
+ const dirListPrefix = "<pre>\n"
+ const dirListSuffix = "\n</pre>\n"
+ tests := []struct {
+ name, escaped string
+ }{
+ {`simple_name`, `<a href="simple_name">simple_name</a>`},
+ {`"'<>&`, `<a href="%22%27%3C%3E&">&#34;&#39;&lt;&gt;&amp;</a>`},
+ {`?foo=bar#baz`, `<a href="%3Ffoo=bar%23baz">?foo=bar#baz</a>`},
+ {`<combo>?foo`, `<a href="%3Ccombo%3E%3Ffoo">&lt;combo&gt;?foo</a>`},
+ }
+
+ // We put each test file in its own directory in the fakeFS so we can look at it in isolation.
+ fs := make(fakeFS)
+ for i, test := range tests {
+ testFile := &fakeFileInfo{basename: test.name}
+ fs[fmt.Sprintf("/%d", i)] = &fakeFileInfo{
+ dir: true,
+ modtime: time.Unix(1000000000, 0).UTC(),
+ ents: []*fakeFileInfo{testFile},
+ }
+ fs[fmt.Sprintf("/%d/%s", i, test.name)] = testFile
+ }
+
+ ts := httptest.NewServer(FileServer(&fs))
+ defer ts.Close()
+ for i, test := range tests {
+ url := fmt.Sprintf("%s/%d", ts.URL, i)
+ res, err := Get(url)
+ if err != nil {
+ t.Fatalf("test %q: Get: %v", test.name, err)
+ }
+ b, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("test %q: read Body: %v", test.name, err)
+ }
+ s := string(b)
+ if !strings.HasPrefix(s, dirListPrefix) || !strings.HasSuffix(s, dirListSuffix) {
+ t.Errorf("test %q: listing dir, full output is %q, want prefix %q and suffix %q", test.name, s, dirListPrefix, dirListSuffix)
+ }
+ if trimmed := strings.TrimSuffix(strings.TrimPrefix(s, dirListPrefix), dirListSuffix); trimmed != test.escaped {
+ t.Errorf("test %q: listing dir, filename escaped to %q, want %q", test.name, trimmed, test.escaped)
+ }
+ res.Body.Close()
+ }
+}
+
func mustRemoveAll(dir string) {
err := os.RemoveAll(dir)
if err != nil {
@@ -457,8 +505,9 @@ func (f *fakeFileInfo) Mode() os.FileMode {
type fakeFile struct {
io.ReadSeeker
- fi *fakeFileInfo
- path string // as opened
+ fi *fakeFileInfo
+ path string // as opened
+ entpos int
}
func (f *fakeFile) Close() error { return nil }
@@ -468,10 +517,20 @@ func (f *fakeFile) Readdir(count int) ([]os.FileInfo, error) {
return nil, os.ErrInvalid
}
var fis []os.FileInfo
- for _, fi := range f.fi.ents {
- fis = append(fis, fi)
+
+ limit := f.entpos + count
+ if count <= 0 || limit > len(f.fi.ents) {
+ limit = len(f.fi.ents)
+ }
+ for ; f.entpos < limit; f.entpos++ {
+ fis = append(fis, f.fi.ents[f.entpos])
+ }
+
+ if len(fis) == 0 && count > 0 {
+ return fis, io.EOF
+ } else {
+ return fis, nil
}
- return fis, nil
}
type fakeFS map[string]*fakeFileInfo
@@ -480,7 +539,6 @@ func (fs fakeFS) Open(name string) (File, error) {
name = path.Clean(name)
f, ok := fs[name]
if !ok {
- println("fake filesystem didn't find file", name)
return nil, os.ErrNotExist
}
return &fakeFile{ReadSeeker: strings.NewReader(f.contents), fi: f, path: name}, nil
diff --git a/src/pkg/net/http/header.go b/src/pkg/net/http/header.go
index ca1ae07c2..153b94370 100644
--- a/src/pkg/net/http/header.go
+++ b/src/pkg/net/http/header.go
@@ -9,9 +9,12 @@ import (
"net/textproto"
"sort"
"strings"
+ "sync"
"time"
)
+var raceEnabled = false // set by race.go
+
// A Header represents the key-value pairs in an HTTP header.
type Header map[string][]string
@@ -114,18 +117,15 @@ func (s *headerSorter) Len() int { return len(s.kvs) }
func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] }
func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key }
-// TODO: convert this to a sync.Cache (issue 4720)
-var headerSorterCache = make(chan *headerSorter, 8)
+var headerSorterPool = sync.Pool{
+ New: func() interface{} { return new(headerSorter) },
+}
// sortedKeyValues returns h's keys sorted in the returned kvs
// slice. The headerSorter used to sort is also returned, for possible
// return to headerSorterCache.
func (h Header) sortedKeyValues(exclude map[string]bool) (kvs []keyValues, hs *headerSorter) {
- select {
- case hs = <-headerSorterCache:
- default:
- hs = new(headerSorter)
- }
+ hs = headerSorterPool.Get().(*headerSorter)
if cap(hs.kvs) < len(h) {
hs.kvs = make([]keyValues, 0, len(h))
}
@@ -159,10 +159,7 @@ func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error {
}
}
}
- select {
- case headerSorterCache <- sorter:
- default:
- }
+ headerSorterPool.Put(sorter)
return nil
}
diff --git a/src/pkg/net/http/header_test.go b/src/pkg/net/http/header_test.go
index 9fd9837a5..9dcd591fa 100644
--- a/src/pkg/net/http/header_test.go
+++ b/src/pkg/net/http/header_test.go
@@ -192,9 +192,12 @@ func BenchmarkHeaderWriteSubset(b *testing.B) {
}
}
-func TestHeaderWriteSubsetMallocs(t *testing.T) {
+func TestHeaderWriteSubsetAllocs(t *testing.T) {
if testing.Short() {
- t.Skip("skipping malloc count in short mode")
+ t.Skip("skipping alloc test in short mode")
+ }
+ if raceEnabled {
+ t.Skip("skipping test under race detector")
}
if runtime.GOMAXPROCS(0) > 1 {
t.Skip("skipping; GOMAXPROCS>1")
@@ -204,6 +207,6 @@ func TestHeaderWriteSubsetMallocs(t *testing.T) {
testHeader.WriteSubset(&buf, nil)
})
if n > 0 {
- t.Errorf("mallocs = %g; want 0", n)
+ t.Errorf("allocs = %g; want 0", n)
}
}
diff --git a/src/pkg/net/http/httptest/server_test.go b/src/pkg/net/http/httptest/server_test.go
index 500a9f0b8..501cc8a99 100644
--- a/src/pkg/net/http/httptest/server_test.go
+++ b/src/pkg/net/http/httptest/server_test.go
@@ -8,6 +8,7 @@ import (
"io/ioutil"
"net/http"
"testing"
+ "time"
)
func TestServer(t *testing.T) {
@@ -27,3 +28,25 @@ func TestServer(t *testing.T) {
t.Errorf("got %q, want hello", string(got))
}
}
+
+func TestIssue7264(t *testing.T) {
+ for i := 0; i < 1000; i++ {
+ func() {
+ inHandler := make(chan bool, 1)
+ ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ inHandler <- true
+ }))
+ defer ts.Close()
+ tr := &http.Transport{
+ ResponseHeaderTimeout: time.Nanosecond,
+ }
+ defer tr.CloseIdleConnections()
+ c := &http.Client{Transport: tr}
+ res, err := c.Get(ts.URL)
+ <-inHandler
+ if err == nil {
+ res.Body.Close()
+ }
+ }()
+ }
+}
diff --git a/src/pkg/net/http/httputil/chunked.go b/src/pkg/net/http/httputil/chunked.go
index b66d40951..9632bfd19 100644
--- a/src/pkg/net/http/httputil/chunked.go
+++ b/src/pkg/net/http/httputil/chunked.go
@@ -4,15 +4,14 @@
// The wire protocol for HTTP's "chunked" Transfer-Encoding.
-// This code is a duplicate of ../chunked.go with these edits:
-// s/newChunked/NewChunked/g
-// s/package http/package httputil/
+// This code is duplicated in net/http and net/http/httputil.
// Please make any changes in both files.
package httputil
import (
"bufio"
+ "bytes"
"errors"
"fmt"
"io"
@@ -22,13 +21,13 @@ const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
var ErrLineTooLong = errors.New("header line too long")
-// NewChunkedReader returns a new chunkedReader that translates the data read from r
+// newChunkedReader returns a new chunkedReader that translates the data read from r
// out of HTTP "chunked" format before returning it.
// The chunkedReader returns io.EOF when the final 0-length chunk is read.
//
-// NewChunkedReader is not needed by normal applications. The http package
+// newChunkedReader is not needed by normal applications. The http package
// automatically decodes chunking when reading response bodies.
-func NewChunkedReader(r io.Reader) io.Reader {
+func newChunkedReader(r io.Reader) io.Reader {
br, ok := r.(*bufio.Reader)
if !ok {
br = bufio.NewReader(r)
@@ -59,26 +58,45 @@ func (cr *chunkedReader) beginChunk() {
}
}
-func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
- if cr.err != nil {
- return 0, cr.err
+func (cr *chunkedReader) chunkHeaderAvailable() bool {
+ n := cr.r.Buffered()
+ if n > 0 {
+ peek, _ := cr.r.Peek(n)
+ return bytes.IndexByte(peek, '\n') >= 0
}
- if cr.n == 0 {
- cr.beginChunk()
- if cr.err != nil {
- return 0, cr.err
+ return false
+}
+
+func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
+ for cr.err == nil {
+ if cr.n == 0 {
+ if n > 0 && !cr.chunkHeaderAvailable() {
+ // We've read enough. Don't potentially block
+ // reading a new chunk header.
+ break
+ }
+ cr.beginChunk()
+ continue
}
- }
- if uint64(len(b)) > cr.n {
- b = b[0:cr.n]
- }
- n, cr.err = cr.r.Read(b)
- cr.n -= uint64(n)
- if cr.n == 0 && cr.err == nil {
- // end of chunk (CRLF)
- if _, cr.err = io.ReadFull(cr.r, cr.buf[:]); cr.err == nil {
- if cr.buf[0] != '\r' || cr.buf[1] != '\n' {
- cr.err = errors.New("malformed chunked encoding")
+ if len(b) == 0 {
+ break
+ }
+ rbuf := b
+ if uint64(len(rbuf)) > cr.n {
+ rbuf = rbuf[:cr.n]
+ }
+ var n0 int
+ n0, cr.err = cr.r.Read(rbuf)
+ n += n0
+ b = b[n0:]
+ cr.n -= uint64(n0)
+ // If we're at the end of a chunk, read the next two
+ // bytes to verify they are "\r\n".
+ if cr.n == 0 && cr.err == nil {
+ if _, cr.err = io.ReadFull(cr.r, cr.buf[:2]); cr.err == nil {
+ if cr.buf[0] != '\r' || cr.buf[1] != '\n' {
+ cr.err = errors.New("malformed chunked encoding")
+ }
}
}
}
@@ -117,16 +135,16 @@ func isASCIISpace(b byte) bool {
return b == ' ' || b == '\t' || b == '\n' || b == '\r'
}
-// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP
+// newChunkedWriter returns a new chunkedWriter that translates writes into HTTP
// "chunked" format before writing them to w. Closing the returned chunkedWriter
// sends the final 0-length chunk that marks the end of the stream.
//
-// NewChunkedWriter is not needed by normal applications. The http
+// newChunkedWriter is not needed by normal applications. The http
// package adds chunking automatically if handlers don't set a
-// Content-Length header. Using NewChunkedWriter inside a handler
+// Content-Length header. Using newChunkedWriter inside a handler
// would result in double chunking or chunking with a Content-Length
// length, both of which are wrong.
-func NewChunkedWriter(w io.Writer) io.WriteCloser {
+func newChunkedWriter(w io.Writer) io.WriteCloser {
return &chunkedWriter{w}
}
diff --git a/src/pkg/net/http/httputil/chunked_test.go b/src/pkg/net/http/httputil/chunked_test.go
index a06bffad5..a7a577468 100644
--- a/src/pkg/net/http/httputil/chunked_test.go
+++ b/src/pkg/net/http/httputil/chunked_test.go
@@ -2,26 +2,25 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// This code is a duplicate of ../chunked_test.go with these edits:
-// s/newChunked/NewChunked/g
-// s/package http/package httputil/
+// This code is duplicated in net/http and net/http/httputil.
// Please make any changes in both files.
package httputil
import (
+ "bufio"
"bytes"
"fmt"
"io"
"io/ioutil"
- "runtime"
+ "strings"
"testing"
)
func TestChunk(t *testing.T) {
var b bytes.Buffer
- w := NewChunkedWriter(&b)
+ w := newChunkedWriter(&b)
const chunk1 = "hello, "
const chunk2 = "world! 0123456789abcdef"
w.Write([]byte(chunk1))
@@ -32,7 +31,7 @@ func TestChunk(t *testing.T) {
t.Fatalf("chunk writer wrote %q; want %q", g, e)
}
- r := NewChunkedReader(&b)
+ r := newChunkedReader(&b)
data, err := ioutil.ReadAll(r)
if err != nil {
t.Logf(`data: "%s"`, data)
@@ -43,37 +42,102 @@ func TestChunk(t *testing.T) {
}
}
+func TestChunkReadMultiple(t *testing.T) {
+ // Bunch of small chunks, all read together.
+ {
+ var b bytes.Buffer
+ w := newChunkedWriter(&b)
+ w.Write([]byte("foo"))
+ w.Write([]byte("bar"))
+ w.Close()
+
+ r := newChunkedReader(&b)
+ buf := make([]byte, 10)
+ n, err := r.Read(buf)
+ if n != 6 || err != io.EOF {
+ t.Errorf("Read = %d, %v; want 6, EOF", n, err)
+ }
+ buf = buf[:n]
+ if string(buf) != "foobar" {
+ t.Errorf("Read = %q; want %q", buf, "foobar")
+ }
+ }
+
+ // One big chunk followed by a little chunk, but the small bufio.Reader size
+ // should prevent the second chunk header from being read.
+ {
+ var b bytes.Buffer
+ w := newChunkedWriter(&b)
+ // fillBufChunk is 11 bytes + 3 bytes header + 2 bytes footer = 16 bytes,
+ // the same as the bufio ReaderSize below (the minimum), so even
+ // though we're going to try to Read with a buffer larger enough to also
+ // receive "foo", the second chunk header won't be read yet.
+ const fillBufChunk = "0123456789a"
+ const shortChunk = "foo"
+ w.Write([]byte(fillBufChunk))
+ w.Write([]byte(shortChunk))
+ w.Close()
+
+ r := newChunkedReader(bufio.NewReaderSize(&b, 16))
+ buf := make([]byte, len(fillBufChunk)+len(shortChunk))
+ n, err := r.Read(buf)
+ if n != len(fillBufChunk) || err != nil {
+ t.Errorf("Read = %d, %v; want %d, nil", n, err, len(fillBufChunk))
+ }
+ buf = buf[:n]
+ if string(buf) != fillBufChunk {
+ t.Errorf("Read = %q; want %q", buf, fillBufChunk)
+ }
+
+ n, err = r.Read(buf)
+ if n != len(shortChunk) || err != io.EOF {
+ t.Errorf("Read = %d, %v; want %d, EOF", n, err, len(shortChunk))
+ }
+ }
+
+ // And test that we see an EOF chunk, even though our buffer is already full:
+ {
+ r := newChunkedReader(bufio.NewReader(strings.NewReader("3\r\nfoo\r\n0\r\n")))
+ buf := make([]byte, 3)
+ n, err := r.Read(buf)
+ if n != 3 || err != io.EOF {
+ t.Errorf("Read = %d, %v; want 3, EOF", n, err)
+ }
+ if string(buf) != "foo" {
+ t.Errorf("buf = %q; want foo", buf)
+ }
+ }
+}
+
func TestChunkReaderAllocs(t *testing.T) {
- // temporarily set GOMAXPROCS to 1 as we are testing memory allocations
- defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1))
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
var buf bytes.Buffer
- w := NewChunkedWriter(&buf)
+ w := newChunkedWriter(&buf)
a, b, c := []byte("aaaaaa"), []byte("bbbbbbbbbbbb"), []byte("cccccccccccccccccccccccc")
w.Write(a)
w.Write(b)
w.Write(c)
w.Close()
- r := NewChunkedReader(&buf)
readBuf := make([]byte, len(a)+len(b)+len(c)+1)
-
- var ms runtime.MemStats
- runtime.ReadMemStats(&ms)
- m0 := ms.Mallocs
-
- n, err := io.ReadFull(r, readBuf)
-
- runtime.ReadMemStats(&ms)
- mallocs := ms.Mallocs - m0
- if mallocs > 1 {
- t.Errorf("%d mallocs; want <= 1", mallocs)
- }
-
- if n != len(readBuf)-1 {
- t.Errorf("read %d bytes; want %d", n, len(readBuf)-1)
- }
- if err != io.ErrUnexpectedEOF {
- t.Errorf("read error = %v; want ErrUnexpectedEOF", err)
+ byter := bytes.NewReader(buf.Bytes())
+ bufr := bufio.NewReader(byter)
+ mallocs := testing.AllocsPerRun(100, func() {
+ byter.Seek(0, 0)
+ bufr.Reset(byter)
+ r := newChunkedReader(bufr)
+ n, err := io.ReadFull(r, readBuf)
+ if n != len(readBuf)-1 {
+ t.Fatalf("read %d bytes; want %d", n, len(readBuf)-1)
+ }
+ if err != io.ErrUnexpectedEOF {
+ t.Fatalf("read error = %v; want ErrUnexpectedEOF", err)
+ }
+ })
+ if mallocs > 1.5 {
+ t.Errorf("mallocs = %v; want 1", mallocs)
}
}
diff --git a/src/pkg/net/http/httputil/dump.go b/src/pkg/net/http/httputil/dump.go
index 265499fb0..2a7a413d0 100644
--- a/src/pkg/net/http/httputil/dump.go
+++ b/src/pkg/net/http/httputil/dump.go
@@ -7,6 +7,7 @@ package httputil
import (
"bufio"
"bytes"
+ "errors"
"fmt"
"io"
"io/ioutil"
@@ -29,7 +30,7 @@ func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) {
if err = b.Close(); err != nil {
return nil, nil, err
}
- return ioutil.NopCloser(&buf), ioutil.NopCloser(bytes.NewBuffer(buf.Bytes())), nil
+ return ioutil.NopCloser(&buf), ioutil.NopCloser(bytes.NewReader(buf.Bytes())), nil
}
// dumpConn is a net.Conn which writes to Writer and reads from Reader
@@ -106,6 +107,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) {
return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil
},
}
+ defer t.CloseIdleConnections()
_, err := t.RoundTrip(reqSend)
@@ -230,14 +232,31 @@ func DumpRequest(req *http.Request, body bool) (dump []byte, err error) {
return
}
+// errNoBody is a sentinel error value used by failureToReadBody so we can detect
+// that the lack of body was intentional.
+var errNoBody = errors.New("sentinel error value")
+
+// failureToReadBody is a io.ReadCloser that just returns errNoBody on
+// Read. It's swapped in when we don't actually want to consume the
+// body, but need a non-nil one, and want to distinguish the error
+// from reading the dummy body.
+type failureToReadBody struct{}
+
+func (failureToReadBody) Read([]byte) (int, error) { return 0, errNoBody }
+func (failureToReadBody) Close() error { return nil }
+
+var emptyBody = ioutil.NopCloser(strings.NewReader(""))
+
// DumpResponse is like DumpRequest but dumps a response.
func DumpResponse(resp *http.Response, body bool) (dump []byte, err error) {
var b bytes.Buffer
save := resp.Body
savecl := resp.ContentLength
- if !body || resp.Body == nil {
- resp.Body = nil
- resp.ContentLength = 0
+
+ if !body {
+ resp.Body = failureToReadBody{}
+ } else if resp.Body == nil {
+ resp.Body = emptyBody
} else {
save, resp.Body, err = drainBody(resp.Body)
if err != nil {
@@ -245,11 +264,13 @@ func DumpResponse(resp *http.Response, body bool) (dump []byte, err error) {
}
}
err = resp.Write(&b)
+ if err == errNoBody {
+ err = nil
+ }
resp.Body = save
resp.ContentLength = savecl
if err != nil {
- return
+ return nil, err
}
- dump = b.Bytes()
- return
+ return b.Bytes(), nil
}
diff --git a/src/pkg/net/http/httputil/dump_test.go b/src/pkg/net/http/httputil/dump_test.go
index 987a82048..e1ffb3935 100644
--- a/src/pkg/net/http/httputil/dump_test.go
+++ b/src/pkg/net/http/httputil/dump_test.go
@@ -11,6 +11,8 @@ import (
"io/ioutil"
"net/http"
"net/url"
+ "runtime"
+ "strings"
"testing"
)
@@ -112,6 +114,7 @@ var dumpTests = []dumpTest{
}
func TestDumpRequest(t *testing.T) {
+ numg0 := runtime.NumGoroutine()
for i, tt := range dumpTests {
setBody := func() {
if tt.Body == nil {
@@ -119,7 +122,7 @@ func TestDumpRequest(t *testing.T) {
}
switch b := tt.Body.(type) {
case []byte:
- tt.Req.Body = ioutil.NopCloser(bytes.NewBuffer(b))
+ tt.Req.Body = ioutil.NopCloser(bytes.NewReader(b))
case func() io.ReadCloser:
tt.Req.Body = b()
}
@@ -155,6 +158,9 @@ func TestDumpRequest(t *testing.T) {
}
}
}
+ if dg := runtime.NumGoroutine() - numg0; dg > 4 {
+ t.Errorf("Unexpectedly large number of new goroutines: %d new", dg)
+ }
}
func chunk(s string) string {
@@ -176,3 +182,82 @@ func mustNewRequest(method, url string, body io.Reader) *http.Request {
}
return req
}
+
+var dumpResTests = []struct {
+ res *http.Response
+ body bool
+ want string
+}{
+ {
+ res: &http.Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 50,
+ Header: http.Header{
+ "Foo": []string{"Bar"},
+ },
+ Body: ioutil.NopCloser(strings.NewReader("foo")), // shouldn't be used
+ },
+ body: false, // to verify we see 50, not empty or 3.
+ want: `HTTP/1.1 200 OK
+Content-Length: 50
+Foo: Bar`,
+ },
+
+ {
+ res: &http.Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 3,
+ Body: ioutil.NopCloser(strings.NewReader("foo")),
+ },
+ body: true,
+ want: `HTTP/1.1 200 OK
+Content-Length: 3
+
+foo`,
+ },
+
+ {
+ res: &http.Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: -1,
+ Body: ioutil.NopCloser(strings.NewReader("foo")),
+ TransferEncoding: []string{"chunked"},
+ },
+ body: true,
+ want: `HTTP/1.1 200 OK
+Transfer-Encoding: chunked
+
+3
+foo
+0`,
+ },
+}
+
+func TestDumpResponse(t *testing.T) {
+ for i, tt := range dumpResTests {
+ gotb, err := DumpResponse(tt.res, tt.body)
+ if err != nil {
+ t.Errorf("%d. DumpResponse = %v", i, err)
+ continue
+ }
+ got := string(gotb)
+ got = strings.TrimSpace(got)
+ got = strings.Replace(got, "\r", "", -1)
+
+ if got != tt.want {
+ t.Errorf("%d.\nDumpResponse got:\n%s\n\nWant:\n%s\n", i, got, tt.want)
+ }
+ }
+}
diff --git a/src/pkg/net/http/httputil/httputil.go b/src/pkg/net/http/httputil/httputil.go
new file mode 100644
index 000000000..74fb6c655
--- /dev/null
+++ b/src/pkg/net/http/httputil/httputil.go
@@ -0,0 +1,32 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package httputil provides HTTP utility functions, complementing the
+// more common ones in the net/http package.
+package httputil
+
+import "io"
+
+// NewChunkedReader returns a new chunkedReader that translates the data read from r
+// out of HTTP "chunked" format before returning it.
+// The chunkedReader returns io.EOF when the final 0-length chunk is read.
+//
+// NewChunkedReader is not needed by normal applications. The http package
+// automatically decodes chunking when reading response bodies.
+func NewChunkedReader(r io.Reader) io.Reader {
+ return newChunkedReader(r)
+}
+
+// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP
+// "chunked" format before writing them to w. Closing the returned chunkedWriter
+// sends the final 0-length chunk that marks the end of the stream.
+//
+// NewChunkedWriter is not needed by normal applications. The http
+// package adds chunking automatically if handlers don't set a
+// Content-Length header. Using NewChunkedWriter inside a handler
+// would result in double chunking or chunking with a Content-Length
+// length, both of which are wrong.
+func NewChunkedWriter(w io.Writer) io.WriteCloser {
+ return newChunkedWriter(w)
+}
diff --git a/src/pkg/net/http/httputil/persist.go b/src/pkg/net/http/httputil/persist.go
index 507938aca..987bcc96b 100644
--- a/src/pkg/net/http/httputil/persist.go
+++ b/src/pkg/net/http/httputil/persist.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Package httputil provides HTTP utility functions, complementing the
-// more common ones in the net/http package.
package httputil
import (
@@ -33,8 +31,8 @@ var errClosed = errors.New("i/o operation on closed connection")
// i.e. requests can be read out of sync (but in the same order) while the
// respective responses are sent.
//
-// ServerConn is low-level and should not be needed by most applications.
-// See Server.
+// ServerConn is low-level and old. Applications should instead use Server
+// in the net/http package.
type ServerConn struct {
lk sync.Mutex // read-write protects the following fields
c net.Conn
@@ -47,8 +45,11 @@ type ServerConn struct {
pipe textproto.Pipeline
}
-// NewServerConn returns a new ServerConn reading and writing c. If r is not
+// NewServerConn returns a new ServerConn reading and writing c. If r is not
// nil, it is the buffer to use when reading c.
+//
+// ServerConn is low-level and old. Applications should instead use Server
+// in the net/http package.
func NewServerConn(c net.Conn, r *bufio.Reader) *ServerConn {
if r == nil {
r = bufio.NewReader(c)
@@ -223,8 +224,8 @@ func (sc *ServerConn) Write(req *http.Request, resp *http.Response) error {
// supports hijacking the connection calling Hijack to
// regain control of the underlying net.Conn and deal with it as desired.
//
-// ClientConn is low-level and should not be needed by most applications.
-// See Client.
+// ClientConn is low-level and old. Applications should instead use
+// Client or Transport in the net/http package.
type ClientConn struct {
lk sync.Mutex // read-write protects the following fields
c net.Conn
@@ -240,6 +241,9 @@ type ClientConn struct {
// NewClientConn returns a new ClientConn reading and writing c. If r is not
// nil, it is the buffer to use when reading c.
+//
+// ClientConn is low-level and old. Applications should use Client or
+// Transport in the net/http package.
func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn {
if r == nil {
r = bufio.NewReader(c)
@@ -254,6 +258,9 @@ func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn {
// NewProxyClientConn works like NewClientConn but writes Requests
// using Request's WriteProxy method.
+//
+// New code should not use NewProxyClientConn. See Client or
+// Transport in the net/http package instead.
func NewProxyClientConn(c net.Conn, r *bufio.Reader) *ClientConn {
cc := NewClientConn(c, r)
cc.writeReq = (*http.Request).WriteProxy
diff --git a/src/pkg/net/http/httputil/reverseproxy.go b/src/pkg/net/http/httputil/reverseproxy.go
index 1990f64db..48ada5f5f 100644
--- a/src/pkg/net/http/httputil/reverseproxy.go
+++ b/src/pkg/net/http/httputil/reverseproxy.go
@@ -144,6 +144,10 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
defer res.Body.Close()
+ for _, h := range hopHeaders {
+ res.Header.Del(h)
+ }
+
copyHeader(rw.Header(), res.Header)
rw.WriteHeader(res.StatusCode)
diff --git a/src/pkg/net/http/httputil/reverseproxy_test.go b/src/pkg/net/http/httputil/reverseproxy_test.go
index 1c0444ec4..e9539b44b 100644
--- a/src/pkg/net/http/httputil/reverseproxy_test.go
+++ b/src/pkg/net/http/httputil/reverseproxy_test.go
@@ -16,6 +16,12 @@ import (
"time"
)
+const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
+
+func init() {
+ hopHeaders = append(hopHeaders, fakeHopHeader)
+}
+
func TestReverseProxy(t *testing.T) {
const backendResponse = "I am the backend"
const backendStatus = 404
@@ -36,6 +42,10 @@ func TestReverseProxy(t *testing.T) {
t.Errorf("backend got Host header %q, want %q", g, e)
}
w.Header().Set("X-Foo", "bar")
+ w.Header().Set("Upgrade", "foo")
+ w.Header().Set(fakeHopHeader, "foo")
+ w.Header().Add("X-Multi-Value", "foo")
+ w.Header().Add("X-Multi-Value", "bar")
http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
w.WriteHeader(backendStatus)
w.Write([]byte(backendResponse))
@@ -64,6 +74,12 @@ func TestReverseProxy(t *testing.T) {
if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
t.Errorf("got X-Foo %q; expected %q", g, e)
}
+ if c := res.Header.Get(fakeHopHeader); c != "" {
+ t.Errorf("got %s header value %q", fakeHopHeader, c)
+ }
+ if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
+ t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
+ }
if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
t.Fatalf("got %d SetCookies, want %d", g, e)
}
diff --git a/src/pkg/net/http/proxy_test.go b/src/pkg/net/http/proxy_test.go
index 449ccaeea..b6aed3792 100644
--- a/src/pkg/net/http/proxy_test.go
+++ b/src/pkg/net/http/proxy_test.go
@@ -35,12 +35,8 @@ var UseProxyTests = []struct {
}
func TestUseProxy(t *testing.T) {
- oldenv := os.Getenv("NO_PROXY")
- defer os.Setenv("NO_PROXY", oldenv)
-
- no_proxy := "foobar.com, .barbaz.net"
- os.Setenv("NO_PROXY", no_proxy)
-
+ ResetProxyEnv()
+ os.Setenv("NO_PROXY", "foobar.com, .barbaz.net")
for _, test := range UseProxyTests {
if useProxy(test.host+":80") != test.match {
t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match)
@@ -71,8 +67,15 @@ func TestCacheKeys(t *testing.T) {
proxy = u
}
cm := connectMethod{proxy, tt.scheme, tt.addr}
- if cm.String() != tt.key {
- t.Fatalf("{%q, %q, %q} cache key %q; want %q", tt.proxy, tt.scheme, tt.addr, cm.String(), tt.key)
+ if got := cm.key().String(); got != tt.key {
+ t.Fatalf("{%q, %q, %q} cache key = %q; want %q", tt.proxy, tt.scheme, tt.addr, got, tt.key)
}
}
}
+
+func ResetProxyEnv() {
+ for _, v := range []string{"HTTP_PROXY", "http_proxy", "NO_PROXY", "no_proxy"} {
+ os.Setenv(v, "")
+ }
+ ResetCachedEnvironment()
+}
diff --git a/src/pkg/net/http/race.go b/src/pkg/net/http/race.go
new file mode 100644
index 000000000..766503967
--- /dev/null
+++ b/src/pkg/net/http/race.go
@@ -0,0 +1,11 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build race
+
+package http
+
+func init() {
+ raceEnabled = true
+}
diff --git a/src/pkg/net/http/request.go b/src/pkg/net/http/request.go
index 57b5d0948..a67092066 100644
--- a/src/pkg/net/http/request.go
+++ b/src/pkg/net/http/request.go
@@ -20,6 +20,7 @@ import (
"net/url"
"strconv"
"strings"
+ "sync"
)
const (
@@ -68,18 +69,31 @@ var reqWriteExcludeHeader = map[string]bool{
// A Request represents an HTTP request received by a server
// or to be sent by a client.
+//
+// The field semantics differ slightly between client and server
+// usage. In addition to the notes on the fields below, see the
+// documentation for Request.Write and RoundTripper.
type Request struct {
- Method string // GET, POST, PUT, etc.
+ // Method specifies the HTTP method (GET, POST, PUT, etc.).
+ // For client requests an empty string means GET.
+ Method string
- // URL is created from the URI supplied on the Request-Line
- // as stored in RequestURI.
+ // URL specifies either the URI being requested (for server
+ // requests) or the URL to access (for client requests).
+ //
+ // For server requests the URL is parsed from the URI
+ // supplied on the Request-Line as stored in RequestURI. For
+ // most requests, fields other than Path and RawQuery will be
+ // empty. (See RFC 2616, Section 5.1.2)
//
- // For most requests, fields other than Path and RawQuery
- // will be empty. (See RFC 2616, Section 5.1.2)
+ // For client requests, the URL's Host specifies the server to
+ // connect to, while the Request's Host field optionally
+ // specifies the Host header value to send in the HTTP
+ // request.
URL *url.URL
// The protocol version for incoming requests.
- // Outgoing requests always use HTTP/1.1.
+ // Client requests always use HTTP/1.1.
Proto string // "HTTP/1.0"
ProtoMajor int // 1
ProtoMinor int // 0
@@ -103,15 +117,20 @@ type Request struct {
// The request parser implements this by canonicalizing the
// name, making the first character and any characters
// following a hyphen uppercase and the rest lowercase.
+ //
+ // For client requests certain headers are automatically
+ // added and may override values in Header.
+ //
+ // See the documentation for the Request.Write method.
Header Header
// Body is the request's body.
//
- // For client requests, a nil body means the request has no
+ // For client requests a nil body means the request has no
// body, such as a GET request. The HTTP Client's Transport
// is responsible for calling the Close method.
//
- // For server requests, the Request Body is always non-nil
+ // For server requests the Request Body is always non-nil
// but will return EOF immediately when no body is present.
// The Server will close the request body. The ServeHTTP
// Handler does not need to.
@@ -121,7 +140,7 @@ type Request struct {
// The value -1 indicates that the length is unknown.
// Values >= 0 indicate that the given number of bytes may
// be read from Body.
- // For outgoing requests, a value of 0 means unknown if Body is not nil.
+ // For client requests, a value of 0 means unknown if Body is not nil.
ContentLength int64
// TransferEncoding lists the transfer encodings from outermost to
@@ -132,13 +151,18 @@ type Request struct {
TransferEncoding []string
// Close indicates whether to close the connection after
- // replying to this request.
+ // replying to this request (for servers) or after sending
+ // the request (for clients).
Close bool
- // The host on which the URL is sought.
- // Per RFC 2616, this is either the value of the Host: header
- // or the host name given in the URL itself.
+ // For server requests Host specifies the host on which the
+ // URL is sought. Per RFC 2616, this is either the value of
+ // the "Host" header or the host name given in the URL itself.
// It may be of the form "host:port".
+ //
+ // For client requests Host optionally overrides the Host
+ // header to send. If empty, the Request.Write method uses
+ // the value of URL.Host.
Host string
// Form contains the parsed form data, including both the URL
@@ -158,12 +182,24 @@ type Request struct {
// The HTTP client ignores MultipartForm and uses Body instead.
MultipartForm *multipart.Form
- // Trailer maps trailer keys to values. Like for Header, if the
- // response has multiple trailer lines with the same key, they will be
- // concatenated, delimited by commas.
- // For server requests, Trailer is only populated after Body has been
- // closed or fully consumed.
- // Trailer support is only partially complete.
+ // Trailer specifies additional headers that are sent after the request
+ // body.
+ //
+ // For server requests the Trailer map initially contains only the
+ // trailer keys, with nil values. (The client declares which trailers it
+ // will later send.) While the handler is reading from Body, it must
+ // not reference Trailer. After reading from Body returns EOF, Trailer
+ // can be read again and will contain non-nil values, if they were sent
+ // by the client.
+ //
+ // For client requests Trailer must be initialized to a map containing
+ // the trailer keys to later send. The values may be nil or their final
+ // values. The ContentLength must be 0 or -1, to send a chunked request.
+ // After the HTTP request is sent the map values can be updated while
+ // the request body is read. Once the body returns EOF, the caller must
+ // not mutate Trailer.
+ //
+ // Few HTTP clients, servers, or proxies support HTTP trailers.
Trailer Header
// RemoteAddr allows HTTP servers and other software to record
@@ -381,7 +417,6 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
return err
}
- // TODO: split long values? (If so, should share code with Conn.Write)
err = req.Header.WriteSubset(w, reqWriteExcludeHeader)
if err != nil {
return err
@@ -494,25 +529,20 @@ func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
return line[:s1], line[s1+1 : s2], line[s2+1:], true
}
-// TODO(bradfitz): use a sync.Cache when available
-var textprotoReaderCache = make(chan *textproto.Reader, 4)
+var textprotoReaderPool sync.Pool
func newTextprotoReader(br *bufio.Reader) *textproto.Reader {
- select {
- case r := <-textprotoReaderCache:
- r.R = br
- return r
- default:
- return textproto.NewReader(br)
+ if v := textprotoReaderPool.Get(); v != nil {
+ tr := v.(*textproto.Reader)
+ tr.R = br
+ return tr
}
+ return textproto.NewReader(br)
}
func putTextprotoReader(r *textproto.Reader) {
r.R = nil
- select {
- case textprotoReaderCache <- r:
- default:
- }
+ textprotoReaderPool.Put(r)
}
// ReadRequest reads and parses a request from b.
@@ -588,32 +618,6 @@ func ReadRequest(b *bufio.Reader) (req *Request, err error) {
fixPragmaCacheControl(req.Header)
- // TODO: Parse specific header values:
- // Accept
- // Accept-Encoding
- // Accept-Language
- // Authorization
- // Cache-Control
- // Connection
- // Date
- // Expect
- // From
- // If-Match
- // If-Modified-Since
- // If-None-Match
- // If-Range
- // If-Unmodified-Since
- // Max-Forwards
- // Proxy-Authorization
- // Referer [sic]
- // TE (transfer-codings)
- // Trailer
- // Transfer-Encoding
- // Upgrade
- // User-Agent
- // Via
- // Warning
-
err = readTransfer(req, b)
if err != nil {
return nil, err
@@ -677,6 +681,11 @@ func parsePostForm(r *Request) (vs url.Values, err error) {
return
}
ct := r.Header.Get("Content-Type")
+ // RFC 2616, section 7.2.1 - empty type
+ // SHOULD be treated as application/octet-stream
+ if ct == "" {
+ ct = "application/octet-stream"
+ }
ct, _, err = mime.ParseMediaType(ct)
switch {
case ct == "application/x-www-form-urlencoded":
@@ -707,7 +716,7 @@ func parsePostForm(r *Request) (vs url.Values, err error) {
// orders to call too many functions here.
// Clean this up and write more tests.
// request_test.go contains the start of this,
- // in TestRequestMultipartCallOrder.
+ // in TestParseMultipartFormOrder and others.
}
return
}
@@ -727,7 +736,7 @@ func parsePostForm(r *Request) (vs url.Values, err error) {
func (r *Request) ParseForm() error {
var err error
if r.PostForm == nil {
- if r.Method == "POST" || r.Method == "PUT" {
+ if r.Method == "POST" || r.Method == "PUT" || r.Method == "PATCH" {
r.PostForm, err = parsePostForm(r)
}
if r.PostForm == nil {
@@ -780,9 +789,7 @@ func (r *Request) ParseMultipartForm(maxMemory int64) error {
}
mr, err := r.multipartReader()
- if err == ErrNotMultipart {
- return nil
- } else if err != nil {
+ if err != nil {
return err
}
@@ -860,3 +867,9 @@ func (r *Request) wantsHttp10KeepAlive() bool {
func (r *Request) wantsClose() bool {
return hasToken(r.Header.get("Connection"), "close")
}
+
+func (r *Request) closeBody() {
+ if r.Body != nil {
+ r.Body.Close()
+ }
+}
diff --git a/src/pkg/net/http/request_test.go b/src/pkg/net/http/request_test.go
index 89303c336..b9fa3c2bf 100644
--- a/src/pkg/net/http/request_test.go
+++ b/src/pkg/net/http/request_test.go
@@ -60,6 +60,37 @@ func TestPostQuery(t *testing.T) {
}
}
+func TestPatchQuery(t *testing.T) {
+ req, _ := NewRequest("PATCH", "http://www.google.com/search?q=foo&q=bar&both=x&prio=1&empty=not",
+ strings.NewReader("z=post&both=y&prio=2&empty="))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
+
+ if q := req.FormValue("q"); q != "foo" {
+ t.Errorf(`req.FormValue("q") = %q, want "foo"`, q)
+ }
+ if z := req.FormValue("z"); z != "post" {
+ t.Errorf(`req.FormValue("z") = %q, want "post"`, z)
+ }
+ if bq, found := req.PostForm["q"]; found {
+ t.Errorf(`req.PostForm["q"] = %q, want no entry in map`, bq)
+ }
+ if bz := req.PostFormValue("z"); bz != "post" {
+ t.Errorf(`req.PostFormValue("z") = %q, want "post"`, bz)
+ }
+ if qs := req.Form["q"]; !reflect.DeepEqual(qs, []string{"foo", "bar"}) {
+ t.Errorf(`req.Form["q"] = %q, want ["foo", "bar"]`, qs)
+ }
+ if both := req.Form["both"]; !reflect.DeepEqual(both, []string{"y", "x"}) {
+ t.Errorf(`req.Form["both"] = %q, want ["y", "x"]`, both)
+ }
+ if prio := req.FormValue("prio"); prio != "2" {
+ t.Errorf(`req.FormValue("prio") = %q, want "2" (from body)`, prio)
+ }
+ if empty := req.FormValue("empty"); empty != "" {
+ t.Errorf(`req.FormValue("empty") = %q, want "" (from body)`, empty)
+ }
+}
+
type stringMap map[string][]string
type parseContentTypeTest struct {
shouldError bool
@@ -68,8 +99,9 @@ type parseContentTypeTest struct {
var parseContentTypeTests = []parseContentTypeTest{
{false, stringMap{"Content-Type": {"text/plain"}}},
- // Non-existent keys are not placed. The value nil is illegal.
- {true, stringMap{}},
+ // Empty content type is legal - shoult be treated as
+ // application/octet-stream (RFC 2616, section 7.2.1)
+ {false, stringMap{}},
{true, stringMap{"Content-Type": {"text/plain; boundary="}}},
{false, stringMap{"Content-Type": {"application/unknown"}}},
}
@@ -79,7 +111,7 @@ func TestParseFormUnknownContentType(t *testing.T) {
req := &Request{
Method: "POST",
Header: Header(test.contentType),
- Body: ioutil.NopCloser(bytes.NewBufferString("body")),
+ Body: ioutil.NopCloser(strings.NewReader("body")),
}
err := req.ParseForm()
switch {
@@ -122,7 +154,25 @@ func TestMultipartReader(t *testing.T) {
req.Header = Header{"Content-Type": {"text/plain"}}
multipart, err = req.MultipartReader()
if multipart != nil {
- t.Errorf("unexpected multipart for text/plain")
+ t.Error("unexpected multipart for text/plain")
+ }
+}
+
+func TestParseMultipartForm(t *testing.T) {
+ req := &Request{
+ Method: "POST",
+ Header: Header{"Content-Type": {`multipart/form-data; boundary="foo123"`}},
+ Body: ioutil.NopCloser(new(bytes.Buffer)),
+ }
+ err := req.ParseMultipartForm(25)
+ if err == nil {
+ t.Error("expected multipart EOF, got nil")
+ }
+
+ req.Header = Header{"Content-Type": {"text/plain"}}
+ err = req.ParseMultipartForm(25)
+ if err != ErrNotMultipart {
+ t.Error("expected ErrNotMultipart for text/plain")
}
}
@@ -188,25 +238,72 @@ func TestMultipartRequestAuto(t *testing.T) {
validateTestMultipartContents(t, req, true)
}
-func TestEmptyMultipartRequest(t *testing.T) {
- // Test that FormValue and FormFile automatically invoke
- // ParseMultipartForm and return the right values.
- req, err := NewRequest("GET", "/", nil)
- if err != nil {
- t.Errorf("NewRequest err = %q", err)
- }
+func TestMissingFileMultipartRequest(t *testing.T) {
+ // Test that FormFile returns an error if
+ // the named file is missing.
+ req := newTestMultipartRequest(t)
testMissingFile(t, req)
}
-func TestRequestMultipartCallOrder(t *testing.T) {
+// Test that FormValue invokes ParseMultipartForm.
+func TestFormValueCallsParseMultipartForm(t *testing.T) {
+ req, _ := NewRequest("POST", "http://www.google.com/", strings.NewReader("z=post"))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
+ if req.Form != nil {
+ t.Fatal("Unexpected request Form, want nil")
+ }
+ req.FormValue("z")
+ if req.Form == nil {
+ t.Fatal("ParseMultipartForm not called by FormValue")
+ }
+}
+
+// Test that FormFile invokes ParseMultipartForm.
+func TestFormFileCallsParseMultipartForm(t *testing.T) {
req := newTestMultipartRequest(t)
- _, err := req.MultipartReader()
- if err != nil {
+ if req.Form != nil {
+ t.Fatal("Unexpected request Form, want nil")
+ }
+ req.FormFile("")
+ if req.Form == nil {
+ t.Fatal("ParseMultipartForm not called by FormFile")
+ }
+}
+
+// Test that ParseMultipartForm errors if called
+// after MultipartReader on the same request.
+func TestParseMultipartFormOrder(t *testing.T) {
+ req := newTestMultipartRequest(t)
+ if _, err := req.MultipartReader(); err != nil {
t.Fatalf("MultipartReader: %v", err)
}
- err = req.ParseMultipartForm(1024)
- if err == nil {
- t.Errorf("expected an error from ParseMultipartForm after call to MultipartReader")
+ if err := req.ParseMultipartForm(1024); err == nil {
+ t.Fatal("expected an error from ParseMultipartForm after call to MultipartReader")
+ }
+}
+
+// Test that MultipartReader errors if called
+// after ParseMultipartForm on the same request.
+func TestMultipartReaderOrder(t *testing.T) {
+ req := newTestMultipartRequest(t)
+ if err := req.ParseMultipartForm(25); err != nil {
+ t.Fatalf("ParseMultipartForm: %v", err)
+ }
+ defer req.MultipartForm.RemoveAll()
+ if _, err := req.MultipartReader(); err == nil {
+ t.Fatal("expected an error from MultipartReader after call to ParseMultipartForm")
+ }
+}
+
+// Test that FormFile errors if called after
+// MultipartReader on the same request.
+func TestFormFileOrder(t *testing.T) {
+ req := newTestMultipartRequest(t)
+ if _, err := req.MultipartReader(); err != nil {
+ t.Fatalf("MultipartReader: %v", err)
+ }
+ if _, _, err := req.FormFile(""); err == nil {
+ t.Fatal("expected an error from FormFile after call to MultipartReader")
}
}
@@ -343,7 +440,7 @@ func testMissingFile(t *testing.T, req *Request) {
}
func newTestMultipartRequest(t *testing.T) *Request {
- b := bytes.NewBufferString(strings.Replace(message, "\n", "\r\n", -1))
+ b := strings.NewReader(strings.Replace(message, "\n", "\r\n", -1))
req, err := NewRequest("POST", "/", b)
if err != nil {
t.Fatal("NewRequest:", err)
diff --git a/src/pkg/net/http/requestwrite_test.go b/src/pkg/net/http/requestwrite_test.go
index b27b1f7ce..dc0e204ca 100644
--- a/src/pkg/net/http/requestwrite_test.go
+++ b/src/pkg/net/http/requestwrite_test.go
@@ -310,6 +310,46 @@ var reqWriteTests = []reqWriteTest{
WantError: errors.New("http: Request.ContentLength=5 with nil Body"),
},
+ // Request with a 0 ContentLength and a body with 1 byte content and an error.
+ {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("/"),
+ Host: "example.com",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 0, // as if unset by user
+ },
+
+ Body: func() io.ReadCloser {
+ err := errors.New("Custom reader error")
+ errReader := &errorReader{err}
+ return ioutil.NopCloser(io.MultiReader(strings.NewReader("x"), errReader))
+ },
+
+ WantError: errors.New("Custom reader error"),
+ },
+
+ // Request with a 0 ContentLength and a body without content and an error.
+ {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("/"),
+ Host: "example.com",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 0, // as if unset by user
+ },
+
+ Body: func() io.ReadCloser {
+ err := errors.New("Custom reader error")
+ errReader := &errorReader{err}
+ return ioutil.NopCloser(errReader)
+ },
+
+ WantError: errors.New("Custom reader error"),
+ },
+
// Verify that DumpRequest preserves the HTTP version number, doesn't add a Host,
// and doesn't add a User-Agent.
{
@@ -427,7 +467,7 @@ func TestRequestWrite(t *testing.T) {
}
switch b := tt.Body.(type) {
case []byte:
- tt.Req.Body = ioutil.NopCloser(bytes.NewBuffer(b))
+ tt.Req.Body = ioutil.NopCloser(bytes.NewReader(b))
case func() io.ReadCloser:
tt.Req.Body = b()
}
diff --git a/src/pkg/net/http/response.go b/src/pkg/net/http/response.go
index 35d0ba3bb..5d2c39080 100644
--- a/src/pkg/net/http/response.go
+++ b/src/pkg/net/http/response.go
@@ -8,6 +8,8 @@ package http
import (
"bufio"
+ "bytes"
+ "crypto/tls"
"errors"
"io"
"net/textproto"
@@ -45,7 +47,8 @@ type Response struct {
//
// The http Client and Transport guarantee that Body is always
// non-nil, even on responses without a body or responses with
- // a zero-lengthed body.
+ // a zero-length body. It is the caller's responsibility to
+ // close Body.
//
// The Body is automatically dechunked if the server replied
// with a "chunked" Transfer-Encoding.
@@ -74,6 +77,12 @@ type Response struct {
// Request's Body is nil (having already been consumed).
// This is only populated for Client requests.
Request *Request
+
+ // TLS contains information about the TLS connection on which the
+ // response was received. It is nil for unencrypted responses.
+ // The pointer is shared between responses and should not be
+ // modified.
+ TLS *tls.ConnectionState
}
// Cookies parses and returns the cookies set in the Set-Cookie headers.
@@ -141,6 +150,9 @@ func ReadResponse(r *bufio.Reader, req *Request) (*Response, error) {
// Parse the response headers.
mimeHeader, err := tp.ReadMIMEHeader()
if err != nil {
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
return nil, err
}
resp.Header = Header(mimeHeader)
@@ -187,8 +199,8 @@ func (r *Response) ProtoAtLeast(major, minor int) bool {
// ContentLength
// Header, values for non-canonical keys will have unpredictable behavior
//
+// Body is closed after it is sent.
func (r *Response) Write(w io.Writer) error {
-
// Status line
text := r.Status
if text == "" {
@@ -201,10 +213,45 @@ func (r *Response) Write(w io.Writer) error {
protoMajor, protoMinor := strconv.Itoa(r.ProtoMajor), strconv.Itoa(r.ProtoMinor)
statusCode := strconv.Itoa(r.StatusCode) + " "
text = strings.TrimPrefix(text, statusCode)
- io.WriteString(w, "HTTP/"+protoMajor+"."+protoMinor+" "+statusCode+text+"\r\n")
+ if _, err := io.WriteString(w, "HTTP/"+protoMajor+"."+protoMinor+" "+statusCode+text+"\r\n"); err != nil {
+ return err
+ }
+
+ // Clone it, so we can modify r1 as needed.
+ r1 := new(Response)
+ *r1 = *r
+ if r1.ContentLength == 0 && r1.Body != nil {
+ // Is it actually 0 length? Or just unknown?
+ var buf [1]byte
+ n, err := r1.Body.Read(buf[:])
+ if err != nil && err != io.EOF {
+ return err
+ }
+ if n == 0 {
+ // Reset it to a known zero reader, in case underlying one
+ // is unhappy being read repeatedly.
+ r1.Body = eofReader
+ } else {
+ r1.ContentLength = -1
+ r1.Body = struct {
+ io.Reader
+ io.Closer
+ }{
+ io.MultiReader(bytes.NewReader(buf[:1]), r.Body),
+ r.Body,
+ }
+ }
+ }
+ // If we're sending a non-chunked HTTP/1.1 response without a
+ // content-length, the only way to do that is the old HTTP/1.0
+ // way, by noting the EOF with a connection close, so we need
+ // to set Close.
+ if r1.ContentLength == -1 && !r1.Close && r1.ProtoAtLeast(1, 1) && !chunked(r1.TransferEncoding) {
+ r1.Close = true
+ }
// Process Body,ContentLength,Close,Trailer
- tw, err := newTransferWriter(r)
+ tw, err := newTransferWriter(r1)
if err != nil {
return err
}
@@ -219,8 +266,19 @@ func (r *Response) Write(w io.Writer) error {
return err
}
+ // contentLengthAlreadySent may have been already sent for
+ // POST/PUT requests, even if zero length. See Issue 8180.
+ contentLengthAlreadySent := tw.shouldSendContentLength()
+ if r1.ContentLength == 0 && !chunked(r1.TransferEncoding) && !contentLengthAlreadySent {
+ if _, err := io.WriteString(w, "Content-Length: 0\r\n"); err != nil {
+ return err
+ }
+ }
+
// End-of-header
- io.WriteString(w, "\r\n")
+ if _, err := io.WriteString(w, "\r\n"); err != nil {
+ return err
+ }
// Write body and trailer
err = tw.WriteBody(w)
diff --git a/src/pkg/net/http/response_test.go b/src/pkg/net/http/response_test.go
index 5044306a8..4b8946f7a 100644
--- a/src/pkg/net/http/response_test.go
+++ b/src/pkg/net/http/response_test.go
@@ -14,6 +14,7 @@ import (
"io/ioutil"
"net/url"
"reflect"
+ "regexp"
"strings"
"testing"
)
@@ -28,6 +29,10 @@ func dummyReq(method string) *Request {
return &Request{Method: method}
}
+func dummyReq11(method string) *Request {
+ return &Request{Method: method, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1}
+}
+
var respTests = []respTest{
// Unchunked response without Content-Length.
{
@@ -406,8 +411,7 @@ func TestWriteResponse(t *testing.T) {
t.Errorf("#%d: %v", i, err)
continue
}
- bout := bytes.NewBuffer(nil)
- err = resp.Write(bout)
+ err = resp.Write(ioutil.Discard)
if err != nil {
t.Errorf("#%d: %v", i, err)
continue
@@ -506,6 +510,9 @@ func TestReadResponseCloseInMiddle(t *testing.T) {
rest, err := ioutil.ReadAll(bufr)
checkErr(err, "ReadAll on remainder")
if e, g := "Next Request Here", string(rest); e != g {
+ g = regexp.MustCompile(`(xx+)`).ReplaceAllStringFunc(g, func(match string) string {
+ return fmt.Sprintf("x(repeated x%d)", len(match))
+ })
fatalf("remainder = %q, expected %q", g, e)
}
}
@@ -615,6 +622,15 @@ func TestResponseContentLengthShortBody(t *testing.T) {
}
}
+func TestReadResponseUnexpectedEOF(t *testing.T) {
+ br := bufio.NewReader(strings.NewReader("HTTP/1.1 301 Moved Permanently\r\n" +
+ "Location: http://example.com"))
+ _, err := ReadResponse(br, nil)
+ if err != io.ErrUnexpectedEOF {
+ t.Errorf("ReadResponse = %v; want io.ErrUnexpectedEOF", err)
+ }
+}
+
func TestNeedsSniff(t *testing.T) {
// needsSniff returns true with an empty response.
r := &response{}
diff --git a/src/pkg/net/http/responsewrite_test.go b/src/pkg/net/http/responsewrite_test.go
index 5c10e2161..585b13b85 100644
--- a/src/pkg/net/http/responsewrite_test.go
+++ b/src/pkg/net/http/responsewrite_test.go
@@ -7,6 +7,7 @@ package http
import (
"bytes"
"io/ioutil"
+ "strings"
"testing"
)
@@ -25,7 +26,7 @@ func TestResponseWrite(t *testing.T) {
ProtoMinor: 0,
Request: dummyReq("GET"),
Header: Header{},
- Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")),
+ Body: ioutil.NopCloser(strings.NewReader("abcdef")),
ContentLength: 6,
},
@@ -41,13 +42,113 @@ func TestResponseWrite(t *testing.T) {
ProtoMinor: 0,
Request: dummyReq("GET"),
Header: Header{},
- Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")),
+ Body: ioutil.NopCloser(strings.NewReader("abcdef")),
ContentLength: -1,
},
"HTTP/1.0 200 OK\r\n" +
"\r\n" +
"abcdef",
},
+ // HTTP/1.1 response with unknown length and Connection: close
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Body: ioutil.NopCloser(strings.NewReader("abcdef")),
+ ContentLength: -1,
+ Close: true,
+ },
+ "HTTP/1.1 200 OK\r\n" +
+ "Connection: close\r\n" +
+ "\r\n" +
+ "abcdef",
+ },
+ // HTTP/1.1 response with unknown length and not setting connection: close
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq11("GET"),
+ Header: Header{},
+ Body: ioutil.NopCloser(strings.NewReader("abcdef")),
+ ContentLength: -1,
+ Close: false,
+ },
+ "HTTP/1.1 200 OK\r\n" +
+ "Connection: close\r\n" +
+ "\r\n" +
+ "abcdef",
+ },
+ // HTTP/1.1 response with unknown length and not setting connection: close, but
+ // setting chunked.
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq11("GET"),
+ Header: Header{},
+ Body: ioutil.NopCloser(strings.NewReader("abcdef")),
+ ContentLength: -1,
+ TransferEncoding: []string{"chunked"},
+ Close: false,
+ },
+ "HTTP/1.1 200 OK\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ "6\r\nabcdef\r\n0\r\n\r\n",
+ },
+ // HTTP/1.1 response 0 content-length, and nil body
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq11("GET"),
+ Header: Header{},
+ Body: nil,
+ ContentLength: 0,
+ Close: false,
+ },
+ "HTTP/1.1 200 OK\r\n" +
+ "Content-Length: 0\r\n" +
+ "\r\n",
+ },
+ // HTTP/1.1 response 0 content-length, and non-nil empty body
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq11("GET"),
+ Header: Header{},
+ Body: ioutil.NopCloser(strings.NewReader("")),
+ ContentLength: 0,
+ Close: false,
+ },
+ "HTTP/1.1 200 OK\r\n" +
+ "Content-Length: 0\r\n" +
+ "\r\n",
+ },
+ // HTTP/1.1 response 0 content-length, and non-nil non-empty body
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq11("GET"),
+ Header: Header{},
+ Body: ioutil.NopCloser(strings.NewReader("foo")),
+ ContentLength: 0,
+ Close: false,
+ },
+ "HTTP/1.1 200 OK\r\n" +
+ "Connection: close\r\n" +
+ "\r\nfoo",
+ },
// HTTP/1.1, chunked coding; empty trailer; close
{
Response{
@@ -56,7 +157,7 @@ func TestResponseWrite(t *testing.T) {
ProtoMinor: 1,
Request: dummyReq("GET"),
Header: Header{},
- Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")),
+ Body: ioutil.NopCloser(strings.NewReader("abcdef")),
ContentLength: 6,
TransferEncoding: []string{"chunked"},
Close: true,
@@ -90,6 +191,22 @@ func TestResponseWrite(t *testing.T) {
"Foo: Bar Baz\r\n" +
"\r\n",
},
+
+ // Want a single Content-Length header. Fixing issue 8180 where
+ // there were two.
+ {
+ Response{
+ StatusCode: StatusOK,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: &Request{Method: "POST"},
+ Header: Header{},
+ ContentLength: 0,
+ TransferEncoding: nil,
+ Body: nil,
+ },
+ "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n",
+ },
}
for i := range respWriteTests {
diff --git a/src/pkg/net/http/serve_test.go b/src/pkg/net/http/serve_test.go
index 955112bc2..9e4d226bf 100644
--- a/src/pkg/net/http/serve_test.go
+++ b/src/pkg/net/http/serve_test.go
@@ -419,7 +419,7 @@ func TestServeMuxHandlerRedirects(t *testing.T) {
func TestMuxRedirectLeadingSlashes(t *testing.T) {
paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"}
for _, path := range paths {
- req, err := ReadRequest(bufio.NewReader(bytes.NewBufferString("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n")))
+ req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n")))
if err != nil {
t.Errorf("%s", err)
}
@@ -441,6 +441,9 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) {
}
func TestServerTimeouts(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("skipping test; see http://golang.org/issue/7237")
+ }
defer afterTest(t)
reqNum := 0
ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {
@@ -517,6 +520,9 @@ func TestServerTimeouts(t *testing.T) {
// shouldn't cause a handler to block forever on reads (next HTTP
// request) that will never happen.
func TestOnlyWriteTimeout(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("skipping test; see http://golang.org/issue/7237")
+ }
defer afterTest(t)
var conn net.Conn
var afterTimeoutErrc = make(chan error, 1)
@@ -840,9 +846,14 @@ func TestHeadResponses(t *testing.T) {
}
func TestTLSHandshakeTimeout(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("skipping test; see http://golang.org/issue/7237")
+ }
defer afterTest(t)
ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
+ errc := make(chanWriter, 10) // but only expecting 1
ts.Config.ReadTimeout = 250 * time.Millisecond
+ ts.Config.ErrorLog = log.New(errc, "", 0)
ts.StartTLS()
defer ts.Close()
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
@@ -857,6 +868,14 @@ func TestTLSHandshakeTimeout(t *testing.T) {
t.Errorf("Read = %d, %v; want an error and no bytes", n, err)
}
})
+ select {
+ case v := <-errc:
+ if !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") {
+ t.Errorf("expected a TLS handshake timeout error; got %q", v)
+ }
+ case <-time.After(5 * time.Second):
+ t.Errorf("timeout waiting for logged error")
+ }
}
func TestTLSServer(t *testing.T) {
@@ -869,6 +888,7 @@ func TestTLSServer(t *testing.T) {
}
}
}))
+ ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
defer ts.Close()
// Connect an idle TCP connection to this server before we run
@@ -913,31 +933,50 @@ func TestTLSServer(t *testing.T) {
}
type serverExpectTest struct {
- contentLength int // of request body
+ contentLength int // of request body
+ chunked bool
expectation string // e.g. "100-continue"
readBody bool // whether handler should read the body (if false, sends StatusUnauthorized)
expectedResponse string // expected substring in first line of http response
}
+func expectTest(contentLength int, expectation string, readBody bool, expectedResponse string) serverExpectTest {
+ return serverExpectTest{
+ contentLength: contentLength,
+ expectation: expectation,
+ readBody: readBody,
+ expectedResponse: expectedResponse,
+ }
+}
+
var serverExpectTests = []serverExpectTest{
// Normal 100-continues, case-insensitive.
- {100, "100-continue", true, "100 Continue"},
- {100, "100-cOntInUE", true, "100 Continue"},
+ expectTest(100, "100-continue", true, "100 Continue"),
+ expectTest(100, "100-cOntInUE", true, "100 Continue"),
// No 100-continue.
- {100, "", true, "200 OK"},
+ expectTest(100, "", true, "200 OK"),
// 100-continue but requesting client to deny us,
// so it never reads the body.
- {100, "100-continue", false, "401 Unauthorized"},
+ expectTest(100, "100-continue", false, "401 Unauthorized"),
// Likewise without 100-continue:
- {100, "", false, "401 Unauthorized"},
+ expectTest(100, "", false, "401 Unauthorized"),
// Non-standard expectations are failures
- {0, "a-pony", false, "417 Expectation Failed"},
+ expectTest(0, "a-pony", false, "417 Expectation Failed"),
- // Expect-100 requested but no body
- {0, "100-continue", true, "400 Bad Request"},
+ // Expect-100 requested but no body (is apparently okay: Issue 7625)
+ expectTest(0, "100-continue", true, "200 OK"),
+ // Expect-100 requested but handler doesn't read the body
+ expectTest(0, "100-continue", false, "401 Unauthorized"),
+ // Expect-100 continue with no body, but a chunked body.
+ {
+ expectation: "100-continue",
+ readBody: true,
+ chunked: true,
+ expectedResponse: "100 Continue",
+ },
}
// Tests that the server responds to the "Expect" request header
@@ -966,21 +1005,38 @@ func TestServerExpect(t *testing.T) {
// Only send the body immediately if we're acting like an HTTP client
// that doesn't send 100-continue expectations.
- writeBody := test.contentLength > 0 && strings.ToLower(test.expectation) != "100-continue"
+ writeBody := test.contentLength != 0 && strings.ToLower(test.expectation) != "100-continue"
go func() {
+ contentLen := fmt.Sprintf("Content-Length: %d", test.contentLength)
+ if test.chunked {
+ contentLen = "Transfer-Encoding: chunked"
+ }
_, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+
"Connection: close\r\n"+
- "Content-Length: %d\r\n"+
+ "%s\r\n"+
"Expect: %s\r\nHost: foo\r\n\r\n",
- test.readBody, test.contentLength, test.expectation)
+ test.readBody, contentLen, test.expectation)
if err != nil {
t.Errorf("On test %#v, error writing request headers: %v", test, err)
return
}
if writeBody {
+ var targ io.WriteCloser = struct {
+ io.Writer
+ io.Closer
+ }{
+ conn,
+ ioutil.NopCloser(nil),
+ }
+ if test.chunked {
+ targ = httputil.NewChunkedWriter(conn)
+ }
body := strings.Repeat("A", test.contentLength)
- _, err = fmt.Fprint(conn, body)
+ _, err = fmt.Fprint(targ, body)
+ if err == nil {
+ err = targ.Close()
+ }
if err != nil {
if !test.readBody {
// Server likely already hung up on us.
@@ -1414,6 +1470,9 @@ func TestRequestBodyLimit(t *testing.T) {
// TestClientWriteShutdown tests that if the client shuts down the write
// side of their TCP connection, the server doesn't send a 400 Bad Request.
func TestClientWriteShutdown(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("skipping test; see http://golang.org/issue/7237")
+ }
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
defer ts.Close()
@@ -1934,6 +1993,31 @@ func TestWriteAfterHijack(t *testing.T) {
}
}
+func TestDoubleHijack(t *testing.T) {
+ req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
+ var buf bytes.Buffer
+ conn := &rwTestConn{
+ Reader: bytes.NewReader(req),
+ Writer: &buf,
+ closec: make(chan bool, 1),
+ }
+ handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
+ conn, _, err := rw.(Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ _, _, err = rw.(Hijacker).Hijack()
+ if err == nil {
+ t.Errorf("got err = nil; want err != nil")
+ }
+ conn.Close()
+ })
+ ln := &oneConnListener{conn: conn}
+ go Serve(ln, handler)
+ <-conn.closec
+}
+
// http://code.google.com/p/go/issues/detail?id=5955
// Note that this does not test the "request too large"
// exit path from the http server. This is intentional;
@@ -2037,31 +2121,160 @@ func TestServerReaderFromOrder(t *testing.T) {
}
}
-// Issue 6157
-func TestNoContentTypeOnNotModified(t *testing.T) {
+// Issue 6157, Issue 6685
+func TestCodesPreventingContentTypeAndBody(t *testing.T) {
+ for _, code := range []int{StatusNotModified, StatusNoContent, StatusContinue} {
+ ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.URL.Path == "/header" {
+ w.Header().Set("Content-Length", "123")
+ }
+ w.WriteHeader(code)
+ if r.URL.Path == "/more" {
+ w.Write([]byte("stuff"))
+ }
+ }))
+ for _, req := range []string{
+ "GET / HTTP/1.0",
+ "GET /header HTTP/1.0",
+ "GET /more HTTP/1.0",
+ "GET / HTTP/1.1",
+ "GET /header HTTP/1.1",
+ "GET /more HTTP/1.1",
+ } {
+ got := ht.rawResponse(req)
+ wantStatus := fmt.Sprintf("%d %s", code, StatusText(code))
+ if !strings.Contains(got, wantStatus) {
+ t.Errorf("Code %d: Wanted %q Modified for %q: %s", code, wantStatus, req, got)
+ } else if strings.Contains(got, "Content-Length") {
+ t.Errorf("Code %d: Got a Content-Length from %q: %s", code, req, got)
+ } else if strings.Contains(got, "stuff") {
+ t.Errorf("Code %d: Response contains a body from %q: %s", code, req, got)
+ }
+ }
+ }
+}
+
+func TestContentTypeOkayOn204(t *testing.T) {
ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
- if r.URL.Path == "/header" {
- w.Header().Set("Content-Length", "123")
+ w.Header().Set("Content-Length", "123") // suppressed
+ w.Header().Set("Content-Type", "foo/bar")
+ w.WriteHeader(204)
+ }))
+ got := ht.rawResponse("GET / HTTP/1.1")
+ if !strings.Contains(got, "Content-Type: foo/bar") {
+ t.Errorf("Response = %q; want Content-Type: foo/bar", got)
+ }
+ if strings.Contains(got, "Content-Length: 123") {
+ t.Errorf("Response = %q; don't want a Content-Length", got)
+ }
+}
+
+// Issue 6995
+// A server Handler can receive a Request, and then turn around and
+// give a copy of that Request.Body out to the Transport (e.g. any
+// proxy). So then two people own that Request.Body (both the server
+// and the http client), and both think they can close it on failure.
+// Therefore, all incoming server requests Bodies need to be thread-safe.
+func TestTransportAndServerSharedBodyRace(t *testing.T) {
+ defer afterTest(t)
+
+ const bodySize = 1 << 20
+
+ unblockBackend := make(chan bool)
+ backend := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ io.CopyN(rw, req.Body, bodySize/2)
+ <-unblockBackend
+ }))
+ defer backend.Close()
+
+ backendRespc := make(chan *Response, 1)
+ proxy := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ if req.RequestURI == "/foo" {
+ rw.Write([]byte("bar"))
+ return
}
- w.WriteHeader(StatusNotModified)
- if r.URL.Path == "/more" {
- w.Write([]byte("stuff"))
+ req2, _ := NewRequest("POST", backend.URL, req.Body)
+ req2.ContentLength = bodySize
+
+ bresp, err := DefaultClient.Do(req2)
+ if err != nil {
+ t.Errorf("Proxy outbound request: %v", err)
+ return
+ }
+ _, err = io.CopyN(ioutil.Discard, bresp.Body, bodySize/4)
+ if err != nil {
+ t.Errorf("Proxy copy error: %v", err)
+ return
}
+ backendRespc <- bresp // to close later
+
+ // Try to cause a race: Both the DefaultTransport and the proxy handler's Server
+ // will try to read/close req.Body (aka req2.Body)
+ DefaultTransport.(*Transport).CancelRequest(req2)
+ rw.Write([]byte("OK"))
+ }))
+ defer proxy.Close()
+
+ req, _ := NewRequest("POST", proxy.URL, io.LimitReader(neverEnding('a'), bodySize))
+ res, err := DefaultClient.Do(req)
+ if err != nil {
+ t.Fatalf("Original request: %v", err)
+ }
+
+ // Cleanup, so we don't leak goroutines.
+ res.Body.Close()
+ close(unblockBackend)
+ (<-backendRespc).Body.Close()
+}
+
+// Test that a hanging Request.Body.Read from another goroutine can't
+// cause the Handler goroutine's Request.Body.Close to block.
+func TestRequestBodyCloseDoesntBlock(t *testing.T) {
+ t.Skipf("Skipping known issue; see golang.org/issue/7121")
+ if testing.Short() {
+ t.Skip("skipping in -short mode")
+ }
+ defer afterTest(t)
+
+ readErrCh := make(chan error, 1)
+ errCh := make(chan error, 2)
+
+ server := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ go func(body io.Reader) {
+ _, err := body.Read(make([]byte, 100))
+ readErrCh <- err
+ }(req.Body)
+ time.Sleep(500 * time.Millisecond)
}))
- for _, req := range []string{
- "GET / HTTP/1.0",
- "GET /header HTTP/1.0",
- "GET /more HTTP/1.0",
- "GET / HTTP/1.1",
- "GET /header HTTP/1.1",
- "GET /more HTTP/1.1",
- } {
- got := ht.rawResponse(req)
- if !strings.Contains(got, "304 Not Modified") {
- t.Errorf("Non-304 Not Modified for %q: %s", req, got)
- } else if strings.Contains(got, "Content-Length") {
- t.Errorf("Got a Content-Length from %q: %s", req, got)
+ defer server.Close()
+
+ closeConn := make(chan bool)
+ defer close(closeConn)
+ go func() {
+ conn, err := net.Dial("tcp", server.Listener.Addr().String())
+ if err != nil {
+ errCh <- err
+ return
+ }
+ defer conn.Close()
+ _, err = conn.Write([]byte("POST / HTTP/1.1\r\nConnection: close\r\nHost: foo\r\nContent-Length: 100000\r\n\r\n"))
+ if err != nil {
+ errCh <- err
+ return
}
+ // And now just block, making the server block on our
+ // 100000 bytes of body that will never arrive.
+ <-closeConn
+ }()
+ select {
+ case err := <-readErrCh:
+ if err == nil {
+ t.Error("Read was nil. Expected error.")
+ }
+ case err := <-errCh:
+ t.Error(err)
+ case <-time.After(5 * time.Second):
+ t.Error("timeout")
}
}
@@ -2073,8 +2286,8 @@ func TestResponseWriterWriteStringAllocs(t *testing.T) {
w.Write([]byte("Hello world"))
}
}))
- before := testing.AllocsPerRun(25, func() { ht.rawResponse("GET / HTTP/1.0") })
- after := testing.AllocsPerRun(25, func() { ht.rawResponse("GET /s HTTP/1.0") })
+ before := testing.AllocsPerRun(50, func() { ht.rawResponse("GET / HTTP/1.0") })
+ after := testing.AllocsPerRun(50, func() { ht.rawResponse("GET /s HTTP/1.0") })
if int(after) >= int(before) {
t.Errorf("WriteString allocs of %v >= Write allocs of %v", after, before)
}
@@ -2093,6 +2306,230 @@ func TestAppendTime(t *testing.T) {
}
}
+func TestServerConnState(t *testing.T) {
+ defer afterTest(t)
+ handler := map[string]func(w ResponseWriter, r *Request){
+ "/": func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "Hello.")
+ },
+ "/close": func(w ResponseWriter, r *Request) {
+ w.Header().Set("Connection", "close")
+ fmt.Fprintf(w, "Hello.")
+ },
+ "/hijack": func(w ResponseWriter, r *Request) {
+ c, _, _ := w.(Hijacker).Hijack()
+ c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
+ c.Close()
+ },
+ "/hijack-panic": func(w ResponseWriter, r *Request) {
+ c, _, _ := w.(Hijacker).Hijack()
+ c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
+ c.Close()
+ panic("intentional panic")
+ },
+ }
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ handler[r.URL.Path](w, r)
+ }))
+ defer ts.Close()
+
+ var mu sync.Mutex // guard stateLog and connID
+ var stateLog = map[int][]ConnState{}
+ var connID = map[net.Conn]int{}
+
+ ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
+ ts.Config.ConnState = func(c net.Conn, state ConnState) {
+ if c == nil {
+ t.Errorf("nil conn seen in state %s", state)
+ return
+ }
+ mu.Lock()
+ defer mu.Unlock()
+ id, ok := connID[c]
+ if !ok {
+ id = len(connID) + 1
+ connID[c] = id
+ }
+ stateLog[id] = append(stateLog[id], state)
+ }
+ ts.Start()
+
+ mustGet(t, ts.URL+"/")
+ mustGet(t, ts.URL+"/close")
+
+ mustGet(t, ts.URL+"/")
+ mustGet(t, ts.URL+"/", "Connection", "close")
+
+ mustGet(t, ts.URL+"/hijack")
+ mustGet(t, ts.URL+"/hijack-panic")
+
+ // New->Closed
+ {
+ c, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.Close()
+ }
+
+ // New->Active->Closed
+ {
+ c, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil {
+ t.Fatal(err)
+ }
+ c.Close()
+ }
+
+ // New->Idle->Closed
+ {
+ c, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
+ t.Fatal(err)
+ }
+ res, err := ReadResponse(bufio.NewReader(c), nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := io.Copy(ioutil.Discard, res.Body); err != nil {
+ t.Fatal(err)
+ }
+ c.Close()
+ }
+
+ want := map[int][]ConnState{
+ 1: []ConnState{StateNew, StateActive, StateIdle, StateActive, StateClosed},
+ 2: []ConnState{StateNew, StateActive, StateIdle, StateActive, StateClosed},
+ 3: []ConnState{StateNew, StateActive, StateHijacked},
+ 4: []ConnState{StateNew, StateActive, StateHijacked},
+ 5: []ConnState{StateNew, StateClosed},
+ 6: []ConnState{StateNew, StateActive, StateClosed},
+ 7: []ConnState{StateNew, StateActive, StateIdle, StateClosed},
+ }
+ logString := func(m map[int][]ConnState) string {
+ var b bytes.Buffer
+ for id, l := range m {
+ fmt.Fprintf(&b, "Conn %d: ", id)
+ for _, s := range l {
+ fmt.Fprintf(&b, "%s ", s)
+ }
+ b.WriteString("\n")
+ }
+ return b.String()
+ }
+
+ for i := 0; i < 5; i++ {
+ time.Sleep(time.Duration(i) * 50 * time.Millisecond)
+ mu.Lock()
+ match := reflect.DeepEqual(stateLog, want)
+ mu.Unlock()
+ if match {
+ return
+ }
+ }
+
+ mu.Lock()
+ t.Errorf("Unexpected events.\nGot log: %s\n Want: %s\n", logString(stateLog), logString(want))
+ mu.Unlock()
+}
+
+func mustGet(t *testing.T, url string, headers ...string) {
+ req, err := NewRequest("GET", url, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for len(headers) > 0 {
+ req.Header.Add(headers[0], headers[1])
+ headers = headers[2:]
+ }
+ res, err := DefaultClient.Do(req)
+ if err != nil {
+ t.Errorf("Error fetching %s: %v", url, err)
+ return
+ }
+ _, err = ioutil.ReadAll(res.Body)
+ defer res.Body.Close()
+ if err != nil {
+ t.Errorf("Error reading %s: %v", url, err)
+ }
+}
+
+func TestServerKeepAlivesEnabled(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
+ ts.Config.SetKeepAlivesEnabled(false)
+ ts.Start()
+ defer ts.Close()
+ res, err := Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if !res.Close {
+ t.Errorf("Body.Close == false; want true")
+ }
+}
+
+// golang.org/issue/7856
+func TestServerEmptyBodyRace(t *testing.T) {
+ defer afterTest(t)
+ var n int32
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ atomic.AddInt32(&n, 1)
+ }))
+ defer ts.Close()
+ var wg sync.WaitGroup
+ const reqs = 20
+ for i := 0; i < reqs; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ res, err := Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer res.Body.Close()
+ _, err = io.Copy(ioutil.Discard, res.Body)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ }()
+ }
+ wg.Wait()
+ if got := atomic.LoadInt32(&n); got != reqs {
+ t.Errorf("handler ran %d times; want %d", got, reqs)
+ }
+}
+
+func TestServerConnStateNew(t *testing.T) {
+ sawNew := false // if the test is buggy, we'll race on this variable.
+ srv := &Server{
+ ConnState: func(c net.Conn, state ConnState) {
+ if state == StateNew {
+ sawNew = true // testing that this write isn't racy
+ }
+ },
+ Handler: HandlerFunc(func(w ResponseWriter, r *Request) {}), // irrelevant
+ }
+ srv.Serve(&oneConnListener{
+ conn: &rwTestConn{
+ Reader: strings.NewReader("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"),
+ Writer: ioutil.Discard,
+ },
+ })
+ if !sawNew { // testing that this read isn't racy
+ t.Error("StateNew not seen")
+ }
+}
+
func BenchmarkClientServer(b *testing.B) {
b.ReportAllocs()
b.StopTimer()
@@ -2108,6 +2545,7 @@ func BenchmarkClientServer(b *testing.B) {
b.Fatal("Get:", err)
}
all, err := ioutil.ReadAll(res.Body)
+ res.Body.Close()
if err != nil {
b.Fatal("ReadAll:", err)
}
@@ -2128,41 +2566,33 @@ func BenchmarkClientServerParallel64(b *testing.B) {
benchmarkClientServerParallel(b, 64)
}
-func benchmarkClientServerParallel(b *testing.B, conc int) {
+func benchmarkClientServerParallel(b *testing.B, parallelism int) {
b.ReportAllocs()
- b.StopTimer()
ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
fmt.Fprintf(rw, "Hello world.\n")
}))
defer ts.Close()
- b.StartTimer()
-
- numProcs := runtime.GOMAXPROCS(-1) * conc
- var wg sync.WaitGroup
- wg.Add(numProcs)
- n := int32(b.N)
- for p := 0; p < numProcs; p++ {
- go func() {
- for atomic.AddInt32(&n, -1) >= 0 {
- res, err := Get(ts.URL)
- if err != nil {
- b.Logf("Get: %v", err)
- continue
- }
- all, err := ioutil.ReadAll(res.Body)
- if err != nil {
- b.Logf("ReadAll: %v", err)
- continue
- }
- body := string(all)
- if body != "Hello world.\n" {
- panic("Got body: " + body)
- }
+ b.ResetTimer()
+ b.SetParallelism(parallelism)
+ b.RunParallel(func(pb *testing.PB) {
+ for pb.Next() {
+ res, err := Get(ts.URL)
+ if err != nil {
+ b.Logf("Get: %v", err)
+ continue
}
- wg.Done()
- }()
- }
- wg.Wait()
+ all, err := ioutil.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ b.Logf("ReadAll: %v", err)
+ continue
+ }
+ body := string(all)
+ if body != "Hello world.\n" {
+ panic("Got body: " + body)
+ }
+ }
+ })
}
// A benchmark for profiling the server without the HTTP client code.
@@ -2187,6 +2617,7 @@ func BenchmarkServer(b *testing.B) {
log.Panicf("Get: %v", err)
}
all, err := ioutil.ReadAll(res.Body)
+ res.Body.Close()
if err != nil {
log.Panicf("ReadAll: %v", err)
}
@@ -2390,3 +2821,28 @@ Host: golang.org
b.Errorf("b.N=%d but handled %d", b.N, handled)
}
}
+
+func BenchmarkServerHijack(b *testing.B) {
+ b.ReportAllocs()
+ req := reqBytes(`GET / HTTP/1.1
+Host: golang.org
+`)
+ h := HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ panic(err)
+ }
+ conn.Close()
+ })
+ conn := &rwTestConn{
+ Writer: ioutil.Discard,
+ closec: make(chan bool, 1),
+ }
+ ln := &oneConnListener{conn: conn}
+ for i := 0; i < b.N; i++ {
+ conn.Reader = bytes.NewReader(req)
+ ln.conn = conn
+ Serve(ln, h)
+ <-conn.closec
+ }
+}
diff --git a/src/pkg/net/http/server.go b/src/pkg/net/http/server.go
index 0e46863d5..eae097eb8 100644
--- a/src/pkg/net/http/server.go
+++ b/src/pkg/net/http/server.go
@@ -22,6 +22,7 @@ import (
"strconv"
"strings"
"sync"
+ "sync/atomic"
"time"
)
@@ -138,6 +139,7 @@ func (c *conn) hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
buf = c.buf
c.rwc = nil
c.buf = nil
+ c.setState(rwc, StateHijacked)
return
}
@@ -435,56 +437,52 @@ func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) {
return c, nil
}
-// TODO: use a sync.Cache instead
var (
- bufioReaderCache = make(chan *bufio.Reader, 4)
- bufioWriterCache2k = make(chan *bufio.Writer, 4)
- bufioWriterCache4k = make(chan *bufio.Writer, 4)
+ bufioReaderPool sync.Pool
+ bufioWriter2kPool sync.Pool
+ bufioWriter4kPool sync.Pool
)
-func bufioWriterCache(size int) chan *bufio.Writer {
+func bufioWriterPool(size int) *sync.Pool {
switch size {
case 2 << 10:
- return bufioWriterCache2k
+ return &bufioWriter2kPool
case 4 << 10:
- return bufioWriterCache4k
+ return &bufioWriter4kPool
}
return nil
}
func newBufioReader(r io.Reader) *bufio.Reader {
- select {
- case p := <-bufioReaderCache:
- p.Reset(r)
- return p
- default:
- return bufio.NewReader(r)
+ if v := bufioReaderPool.Get(); v != nil {
+ br := v.(*bufio.Reader)
+ br.Reset(r)
+ return br
}
+ return bufio.NewReader(r)
}
func putBufioReader(br *bufio.Reader) {
br.Reset(nil)
- select {
- case bufioReaderCache <- br:
- default:
- }
+ bufioReaderPool.Put(br)
}
func newBufioWriterSize(w io.Writer, size int) *bufio.Writer {
- select {
- case p := <-bufioWriterCache(size):
- p.Reset(w)
- return p
- default:
- return bufio.NewWriterSize(w, size)
+ pool := bufioWriterPool(size)
+ if pool != nil {
+ if v := pool.Get(); v != nil {
+ bw := v.(*bufio.Writer)
+ bw.Reset(w)
+ return bw
+ }
}
+ return bufio.NewWriterSize(w, size)
}
func putBufioWriter(bw *bufio.Writer) {
bw.Reset(nil)
- select {
- case bufioWriterCache(bw.Available()) <- bw:
- default:
+ if pool := bufioWriterPool(bw.Available()); pool != nil {
+ pool.Put(bw)
}
}
@@ -500,6 +498,10 @@ func (srv *Server) maxHeaderBytes() int {
return DefaultMaxHeaderBytes
}
+func (srv *Server) initialLimitedReaderSize() int64 {
+ return int64(srv.maxHeaderBytes()) + 4096 // bufio slop
+}
+
// wrapper around io.ReaderCloser which on first read, sends an
// HTTP/1.1 100 Continue header
type expectContinueReader struct {
@@ -570,7 +572,7 @@ func (c *conn) readRequest() (w *response, err error) {
}()
}
- c.lr.N = int64(c.server.maxHeaderBytes()) + 4096 /* bufio slop */
+ c.lr.N = c.server.initialLimitedReaderSize()
var req *Request
if req, err = ReadRequest(c.buf.Reader); err != nil {
if c.lr.N == 0 {
@@ -618,11 +620,11 @@ const maxPostHandlerReadBytes = 256 << 10
func (w *response) WriteHeader(code int) {
if w.conn.hijacked() {
- log.Print("http: response.WriteHeader on hijacked connection")
+ w.conn.server.logf("http: response.WriteHeader on hijacked connection")
return
}
if w.wroteHeader {
- log.Print("http: multiple response.WriteHeader calls")
+ w.conn.server.logf("http: multiple response.WriteHeader calls")
return
}
w.wroteHeader = true
@@ -637,7 +639,7 @@ func (w *response) WriteHeader(code int) {
if err == nil && v >= 0 {
w.contentLength = v
} else {
- log.Printf("http: invalid Content-Length of %q", cl)
+ w.conn.server.logf("http: invalid Content-Length of %q", cl)
w.handlerHeader.Del("Content-Length")
}
}
@@ -707,6 +709,7 @@ func (cw *chunkWriter) writeHeader(p []byte) {
cw.wroteHeader = true
w := cw.res
+ keepAlivesEnabled := w.conn.server.doKeepAlives()
isHEAD := w.req.Method == "HEAD"
// header is written out to w.conn.buf below. Depending on the
@@ -739,7 +742,7 @@ func (cw *chunkWriter) writeHeader(p []byte) {
// response header and this is our first (and last) write, set
// it, even to zero. This helps HTTP/1.0 clients keep their
// "keep-alive" connections alive.
- // Exceptions: 304 responses never get Content-Length, and if
+ // Exceptions: 304/204/1xx responses never get Content-Length, and if
// it was a HEAD request, we don't know the difference between
// 0 actual bytes and 0 bytes because the handler noticed it
// was a HEAD request and chose not to write anything. So for
@@ -747,14 +750,14 @@ func (cw *chunkWriter) writeHeader(p []byte) {
// write non-zero bytes. If it's actually 0 bytes and the
// handler never looked at the Request.Method, we just don't
// send a Content-Length header.
- if w.handlerDone && w.status != StatusNotModified && header.get("Content-Length") == "" && (!isHEAD || len(p) > 0) {
+ if w.handlerDone && bodyAllowedForStatus(w.status) && header.get("Content-Length") == "" && (!isHEAD || len(p) > 0) {
w.contentLength = int64(len(p))
setHeader.contentLength = strconv.AppendInt(cw.res.clenBuf[:0], int64(len(p)), 10)
}
// If this was an HTTP/1.0 request with keep-alive and we sent a
// Content-Length back, we can make this a keep-alive response ...
- if w.req.wantsHttp10KeepAlive() {
+ if w.req.wantsHttp10KeepAlive() && keepAlivesEnabled {
sentLength := header.get("Content-Length") != ""
if sentLength && header.get("Connection") == "keep-alive" {
w.closeAfterReply = false
@@ -773,7 +776,7 @@ func (cw *chunkWriter) writeHeader(p []byte) {
w.closeAfterReply = true
}
- if header.get("Connection") == "close" {
+ if header.get("Connection") == "close" || !keepAlivesEnabled {
w.closeAfterReply = true
}
@@ -796,18 +799,16 @@ func (cw *chunkWriter) writeHeader(p []byte) {
}
code := w.status
- if code == StatusNotModified {
- // Must not have body.
- // RFC 2616 section 10.3.5: "the response MUST NOT include other entity-headers"
- for _, k := range []string{"Content-Type", "Content-Length", "Transfer-Encoding"} {
- delHeader(k)
- }
- } else {
+ if bodyAllowedForStatus(code) {
// If no content type, apply sniffing algorithm to body.
_, haveType := header["Content-Type"]
if !haveType {
setHeader.contentType = DetectContentType(p)
}
+ } else {
+ for _, k := range suppressedHeaders(code) {
+ delHeader(k)
+ }
}
if _, ok := header["Date"]; !ok {
@@ -819,13 +820,13 @@ func (cw *chunkWriter) writeHeader(p []byte) {
if hasCL && hasTE && te != "identity" {
// TODO: return an error if WriteHeader gets a return parameter
// For now just ignore the Content-Length.
- log.Printf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d",
+ w.conn.server.logf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d",
te, w.contentLength)
delHeader("Content-Length")
hasCL = false
}
- if w.req.Method == "HEAD" || code == StatusNotModified {
+ if w.req.Method == "HEAD" || !bodyAllowedForStatus(code) {
// do nothing
} else if code == StatusNoContent {
delHeader("Transfer-Encoding")
@@ -855,7 +856,7 @@ func (cw *chunkWriter) writeHeader(p []byte) {
return
}
- if w.closeAfterReply && !hasToken(cw.header.get("Connection"), "close") {
+ if w.closeAfterReply && (!keepAlivesEnabled || !hasToken(cw.header.get("Connection"), "close")) {
delHeader("Connection")
if w.req.ProtoAtLeast(1, 1) {
setHeader.connection = "close"
@@ -919,7 +920,7 @@ func (w *response) bodyAllowed() bool {
if !w.wroteHeader {
panic("")
}
- return w.status != StatusNotModified
+ return bodyAllowedForStatus(w.status)
}
// The Life Of A Write is like this:
@@ -965,7 +966,7 @@ func (w *response) WriteString(data string) (n int, err error) {
// either dataB or dataS is non-zero.
func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err error) {
if w.conn.hijacked() {
- log.Print("http: response.Write on hijacked connection")
+ w.conn.server.logf("http: response.Write on hijacked connection")
return 0, ErrHijacked
}
if !w.wroteHeader {
@@ -1001,11 +1002,10 @@ func (w *response) finishRequest() {
w.cw.close()
w.conn.buf.Flush()
- // Close the body, unless we're about to close the whole TCP connection
- // anyway.
- if !w.closeAfterReply {
- w.req.Body.Close()
- }
+ // Close the body (regardless of w.closeAfterReply) so we can
+ // re-use its bufio.Reader later safely.
+ w.req.Body.Close()
+
if w.req.MultipartForm != nil {
w.req.MultipartForm.RemoveAll()
}
@@ -1084,17 +1084,25 @@ func validNPN(proto string) bool {
return true
}
+func (c *conn) setState(nc net.Conn, state ConnState) {
+ if hook := c.server.ConnState; hook != nil {
+ hook(nc, state)
+ }
+}
+
// Serve a new connection.
func (c *conn) serve() {
+ origConn := c.rwc // copy it before it's set nil on Close or Hijack
defer func() {
if err := recover(); err != nil {
- const size = 4096
+ const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
- log.Printf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf)
+ c.server.logf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf)
}
if !c.hijacked() {
c.close()
+ c.setState(origConn, StateClosed)
}
}()
@@ -1106,6 +1114,7 @@ func (c *conn) serve() {
c.rwc.SetWriteDeadline(time.Now().Add(d))
}
if err := tlsConn.Handshake(); err != nil {
+ c.server.logf("http: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err)
return
}
c.tlsState = new(tls.ConnectionState)
@@ -1121,6 +1130,10 @@ func (c *conn) serve() {
for {
w, err := c.readRequest()
+ if c.lr.N != c.server.initialLimitedReaderSize() {
+ // If we read any bytes off the wire, we're active.
+ c.setState(c.rwc, StateActive)
+ }
if err != nil {
if err == errTooLarge {
// Their HTTP client may or may not be
@@ -1143,16 +1156,10 @@ func (c *conn) serve() {
// Expect 100 Continue support
req := w.req
if req.expectsContinue() {
- if req.ProtoAtLeast(1, 1) {
+ if req.ProtoAtLeast(1, 1) && req.ContentLength != 0 {
// Wrap the Body reader with one that replies on the connection
req.Body = &expectContinueReader{readCloser: req.Body, resp: w}
}
- if req.ContentLength == 0 {
- w.Header().Set("Connection", "close")
- w.WriteHeader(StatusBadRequest)
- w.finishRequest()
- break
- }
req.Header.Del("Expect")
} else if req.Header.get("Expect") != "" {
w.sendExpectationFailed()
@@ -1175,6 +1182,7 @@ func (c *conn) serve() {
}
break
}
+ c.setState(c.rwc, StateIdle)
}
}
@@ -1202,7 +1210,14 @@ func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
if w.wroteHeader {
w.cw.flush()
}
- return w.conn.hijack()
+ // Release the bufioWriter that writes to the chunk writer, it is not
+ // used after a connection has been hijacked.
+ rwc, buf, err = w.conn.hijack()
+ if err == nil {
+ putBufioWriter(w.w)
+ w.w = nil
+ }
+ return rwc, buf, err
}
func (w *response) CloseNotify() <-chan bool {
@@ -1562,6 +1577,7 @@ func Serve(l net.Listener, handler Handler) error {
}
// A Server defines parameters for running an HTTP server.
+// The zero value for Server is a valid configuration.
type Server struct {
Addr string // TCP address to listen on, ":http" if empty
Handler Handler // handler to invoke, http.DefaultServeMux if nil
@@ -1578,6 +1594,66 @@ type Server struct {
// and RemoteAddr if not already set. The connection is
// automatically closed when the function returns.
TLSNextProto map[string]func(*Server, *tls.Conn, Handler)
+
+ // ConnState specifies an optional callback function that is
+ // called when a client connection changes state. See the
+ // ConnState type and associated constants for details.
+ ConnState func(net.Conn, ConnState)
+
+ // ErrorLog specifies an optional logger for errors accepting
+ // connections and unexpected behavior from handlers.
+ // If nil, logging goes to os.Stderr via the log package's
+ // standard logger.
+ ErrorLog *log.Logger
+
+ disableKeepAlives int32 // accessed atomically.
+}
+
+// A ConnState represents the state of a client connection to a server.
+// It's used by the optional Server.ConnState hook.
+type ConnState int
+
+const (
+ // StateNew represents a new connection that is expected to
+ // send a request immediately. Connections begin at this
+ // state and then transition to either StateActive or
+ // StateClosed.
+ StateNew ConnState = iota
+
+ // StateActive represents a connection that has read 1 or more
+ // bytes of a request. The Server.ConnState hook for
+ // StateActive fires before the request has entered a handler
+ // and doesn't fire again until the request has been
+ // handled. After the request is handled, the state
+ // transitions to StateClosed, StateHijacked, or StateIdle.
+ StateActive
+
+ // StateIdle represents a connection that has finished
+ // handling a request and is in the keep-alive state, waiting
+ // for a new request. Connections transition from StateIdle
+ // to either StateActive or StateClosed.
+ StateIdle
+
+ // StateHijacked represents a hijacked connection.
+ // This is a terminal state. It does not transition to StateClosed.
+ StateHijacked
+
+ // StateClosed represents a closed connection.
+ // This is a terminal state. Hijacked connections do not
+ // transition to StateClosed.
+ StateClosed
+)
+
+var stateName = map[ConnState]string{
+ StateNew: "new",
+ StateActive: "active",
+ StateIdle: "idle",
+ StateHijacked: "hijacked",
+ StateClosed: "closed",
+}
+
+func (c ConnState) String() string {
+ return stateName[c]
}
// serverHandler delegates to either the server's Handler or
@@ -1605,11 +1681,11 @@ func (srv *Server) ListenAndServe() error {
if addr == "" {
addr = ":http"
}
- l, e := net.Listen("tcp", addr)
- if e != nil {
- return e
+ ln, err := net.Listen("tcp", addr)
+ if err != nil {
+ return err
}
- return srv.Serve(l)
+ return srv.Serve(tcpKeepAliveListener{ln.(*net.TCPListener)})
}
// Serve accepts incoming connections on the Listener l, creating a
@@ -1630,7 +1706,7 @@ func (srv *Server) Serve(l net.Listener) error {
if max := 1 * time.Second; tempDelay > max {
tempDelay = max
}
- log.Printf("http: Accept error: %v; retrying in %v", e, tempDelay)
+ srv.logf("http: Accept error: %v; retrying in %v", e, tempDelay)
time.Sleep(tempDelay)
continue
}
@@ -1641,10 +1717,35 @@ func (srv *Server) Serve(l net.Listener) error {
if err != nil {
continue
}
+ c.setState(c.rwc, StateNew) // before Serve can return
go c.serve()
}
}
+func (s *Server) doKeepAlives() bool {
+ return atomic.LoadInt32(&s.disableKeepAlives) == 0
+}
+
+// SetKeepAlivesEnabled controls whether HTTP keep-alives are enabled.
+// By default, keep-alives are always enabled. Only very
+// resource-constrained environments or servers in the process of
+// shutting down should disable them.
+func (s *Server) SetKeepAlivesEnabled(v bool) {
+ if v {
+ atomic.StoreInt32(&s.disableKeepAlives, 0)
+ } else {
+ atomic.StoreInt32(&s.disableKeepAlives, 1)
+ }
+}
+
+func (s *Server) logf(format string, args ...interface{}) {
+ if s.ErrorLog != nil {
+ s.ErrorLog.Printf(format, args...)
+ } else {
+ log.Printf(format, args...)
+ }
+}
+
// ListenAndServe listens on the TCP network address addr
// and then calls Serve with handler to handle requests
// on incoming connections. Handler is typically nil,
@@ -1739,12 +1840,12 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error {
return err
}
- conn, err := net.Listen("tcp", addr)
+ ln, err := net.Listen("tcp", addr)
if err != nil {
return err
}
- tlsListener := tls.NewListener(conn, config)
+ tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config)
return srv.Serve(tlsListener)
}
@@ -1834,6 +1935,24 @@ func (tw *timeoutWriter) WriteHeader(code int) {
tw.w.WriteHeader(code)
}
+// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
+// connections. It's used by ListenAndServe and ListenAndServeTLS so
+// dead TCP connections (e.g. closing laptop mid-download) eventually
+// go away.
+type tcpKeepAliveListener struct {
+ *net.TCPListener
+}
+
+func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
+ tc, err := ln.AcceptTCP()
+ if err != nil {
+ return
+ }
+ tc.SetKeepAlive(true)
+ tc.SetKeepAlivePeriod(3 * time.Minute)
+ return tc, nil
+}
+
// globalOptionsHandler responds to "OPTIONS *" requests.
type globalOptionsHandler struct{}
@@ -1850,17 +1969,24 @@ func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) {
}
}
+type eofReaderWithWriteTo struct{}
+
+func (eofReaderWithWriteTo) WriteTo(io.Writer) (int64, error) { return 0, nil }
+func (eofReaderWithWriteTo) Read([]byte) (int, error) { return 0, io.EOF }
+
// eofReader is a non-nil io.ReadCloser that always returns EOF.
-// It embeds a *strings.Reader so it still has a WriteTo method
-// and io.Copy won't need a buffer.
+// It has a WriteTo method so io.Copy won't need a buffer.
var eofReader = &struct {
- *strings.Reader
+ eofReaderWithWriteTo
io.Closer
}{
- strings.NewReader(""),
+ eofReaderWithWriteTo{},
ioutil.NopCloser(nil),
}
+// Verify that an io.Copy from an eofReader won't require a buffer.
+var _ io.WriterTo = eofReader
+
// initNPNRequest is an HTTP handler that initializes certain
// uninitialized fields in its *Request. Such partially-initialized
// Requests come from NPN protocol handlers.
diff --git a/src/pkg/net/http/transfer.go b/src/pkg/net/http/transfer.go
index bacd83732..7f6368652 100644
--- a/src/pkg/net/http/transfer.go
+++ b/src/pkg/net/http/transfer.go
@@ -12,10 +12,20 @@ import (
"io"
"io/ioutil"
"net/textproto"
+ "sort"
"strconv"
"strings"
+ "sync"
)
+type errorReader struct {
+ err error
+}
+
+func (r *errorReader) Read(p []byte) (n int, err error) {
+ return 0, r.err
+}
+
// transferWriter inspects the fields of a user-supplied Request or Response,
// sanitizes them without changing the user object and provides methods for
// writing the respective header, body and trailer in wire format.
@@ -52,14 +62,17 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) {
if t.ContentLength == 0 {
// Test to see if it's actually zero or just unset.
var buf [1]byte
- n, _ := io.ReadFull(t.Body, buf[:])
- if n == 1 {
+ n, rerr := io.ReadFull(t.Body, buf[:])
+ if rerr != nil && rerr != io.EOF {
+ t.ContentLength = -1
+ t.Body = &errorReader{rerr}
+ } else if n == 1 {
// Oh, guess there is data in this Body Reader after all.
// The ContentLength field just wasn't set.
// Stich the Body back together again, re-attaching our
// consumed byte.
t.ContentLength = -1
- t.Body = io.MultiReader(bytes.NewBuffer(buf[:]), t.Body)
+ t.Body = io.MultiReader(bytes.NewReader(buf[:]), t.Body)
} else {
// Body is actually empty.
t.Body = nil
@@ -131,11 +144,10 @@ func (t *transferWriter) shouldSendContentLength() bool {
return false
}
-func (t *transferWriter) WriteHeader(w io.Writer) (err error) {
+func (t *transferWriter) WriteHeader(w io.Writer) error {
if t.Close {
- _, err = io.WriteString(w, "Connection: close\r\n")
- if err != nil {
- return
+ if _, err := io.WriteString(w, "Connection: close\r\n"); err != nil {
+ return err
}
}
@@ -143,43 +155,44 @@ func (t *transferWriter) WriteHeader(w io.Writer) (err error) {
// function of the sanitized field triple (Body, ContentLength,
// TransferEncoding)
if t.shouldSendContentLength() {
- io.WriteString(w, "Content-Length: ")
- _, err = io.WriteString(w, strconv.FormatInt(t.ContentLength, 10)+"\r\n")
- if err != nil {
- return
+ if _, err := io.WriteString(w, "Content-Length: "); err != nil {
+ return err
+ }
+ if _, err := io.WriteString(w, strconv.FormatInt(t.ContentLength, 10)+"\r\n"); err != nil {
+ return err
}
} else if chunked(t.TransferEncoding) {
- _, err = io.WriteString(w, "Transfer-Encoding: chunked\r\n")
- if err != nil {
- return
+ if _, err := io.WriteString(w, "Transfer-Encoding: chunked\r\n"); err != nil {
+ return err
}
}
// Write Trailer header
if t.Trailer != nil {
- // TODO: At some point, there should be a generic mechanism for
- // writing long headers, using HTTP line splitting
- io.WriteString(w, "Trailer: ")
- needComma := false
+ keys := make([]string, 0, len(t.Trailer))
for k := range t.Trailer {
k = CanonicalHeaderKey(k)
switch k {
case "Transfer-Encoding", "Trailer", "Content-Length":
return &badStringError{"invalid Trailer key", k}
}
- if needComma {
- io.WriteString(w, ",")
+ keys = append(keys, k)
+ }
+ if len(keys) > 0 {
+ sort.Strings(keys)
+ // TODO: could do better allocation-wise here, but trailers are rare,
+ // so being lazy for now.
+ if _, err := io.WriteString(w, "Trailer: "+strings.Join(keys, ",")+"\r\n"); err != nil {
+ return err
}
- io.WriteString(w, k)
- needComma = true
}
- _, err = io.WriteString(w, "\r\n")
}
- return
+ return nil
}
-func (t *transferWriter) WriteBody(w io.Writer) (err error) {
+func (t *transferWriter) WriteBody(w io.Writer) error {
+ var err error
var ncopy int64
// Write body
@@ -216,11 +229,16 @@ func (t *transferWriter) WriteBody(w io.Writer) (err error) {
// TODO(petar): Place trailer writer code here.
if chunked(t.TransferEncoding) {
+ // Write Trailer header
+ if t.Trailer != nil {
+ if err := t.Trailer.Write(w); err != nil {
+ return err
+ }
+ }
// Last chunk, empty trailer
_, err = io.WriteString(w, "\r\n")
}
-
- return
+ return err
}
type transferReader struct {
@@ -252,6 +270,22 @@ func bodyAllowedForStatus(status int) bool {
return true
}
+var (
+ suppressedHeaders304 = []string{"Content-Type", "Content-Length", "Transfer-Encoding"}
+ suppressedHeadersNoBody = []string{"Content-Length", "Transfer-Encoding"}
+)
+
+func suppressedHeaders(status int) []string {
+ switch {
+ case status == 304:
+ // RFC 2616 section 10.3.5: "the response MUST NOT include other entity-headers"
+ return suppressedHeaders304
+ case !bodyAllowedForStatus(status):
+ return suppressedHeadersNoBody
+ }
+ return nil
+}
+
// msg is *Request or *Response.
func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
t := &transferReader{RequestMethod: "GET"}
@@ -331,17 +365,17 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
if noBodyExpected(t.RequestMethod) {
t.Body = eofReader
} else {
- t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close}
+ t.Body = &body{src: newChunkedReader(r), hdr: msg, r: r, closing: t.Close}
}
case realLength == 0:
t.Body = eofReader
case realLength > 0:
- t.Body = &body{Reader: io.LimitReader(r, realLength), closing: t.Close}
+ t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close}
default:
// realLength < 0, i.e. "Content-Length" not mentioned in header
if t.Close {
// Close semantics (i.e. HTTP/1.0)
- t.Body = &body{Reader: r, closing: t.Close}
+ t.Body = &body{src: r, closing: t.Close}
} else {
// Persistent connection (i.e. HTTP/1.1)
t.Body = eofReader
@@ -498,7 +532,7 @@ func fixTrailer(header Header, te []string) (Header, error) {
case "Transfer-Encoding", "Trailer", "Content-Length":
return nil, &badStringError{"bad trailer key", key}
}
- trailer.Del(key)
+ trailer[key] = nil
}
if len(trailer) == 0 {
return nil, nil
@@ -514,11 +548,13 @@ func fixTrailer(header Header, te []string) (Header, error) {
// Close ensures that the body has been fully read
// and then reads the trailer if necessary.
type body struct {
- io.Reader
+ src io.Reader
hdr interface{} // non-nil (Response or Request) value means read trailer
r *bufio.Reader // underlying wire-format reader for the trailer
closing bool // is the connection to be closed after reading body?
- closed bool
+
+ mu sync.Mutex // guards closed, and calls to Read and Close
+ closed bool
}
// ErrBodyReadAfterClose is returned when reading a Request or Response
@@ -528,10 +564,17 @@ type body struct {
var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed Body")
func (b *body) Read(p []byte) (n int, err error) {
+ b.mu.Lock()
+ defer b.mu.Unlock()
if b.closed {
return 0, ErrBodyReadAfterClose
}
- n, err = b.Reader.Read(p)
+ return b.readLocked(p)
+}
+
+// Must hold b.mu.
+func (b *body) readLocked(p []byte) (n int, err error) {
+ n, err = b.src.Read(p)
if err == io.EOF {
// Chunked case. Read the trailer.
@@ -543,12 +586,23 @@ func (b *body) Read(p []byte) (n int, err error) {
} else {
// If the server declared the Content-Length, our body is a LimitedReader
// and we need to check whether this EOF arrived early.
- if lr, ok := b.Reader.(*io.LimitedReader); ok && lr.N > 0 {
+ if lr, ok := b.src.(*io.LimitedReader); ok && lr.N > 0 {
err = io.ErrUnexpectedEOF
}
}
}
+ // If we can return an EOF here along with the read data, do
+ // so. This is optional per the io.Reader contract, but doing
+ // so helps the HTTP transport code recycle its connection
+ // earlier (since it will see this EOF itself), even if the
+ // client doesn't do future reads or Close.
+ if err == nil && n > 0 {
+ if lr, ok := b.src.(*io.LimitedReader); ok && lr.N == 0 {
+ err = io.EOF
+ }
+ }
+
return n, err
}
@@ -610,14 +664,26 @@ func (b *body) readTrailer() error {
}
switch rr := b.hdr.(type) {
case *Request:
- rr.Trailer = Header(hdr)
+ mergeSetHeader(&rr.Trailer, Header(hdr))
case *Response:
- rr.Trailer = Header(hdr)
+ mergeSetHeader(&rr.Trailer, Header(hdr))
}
return nil
}
+func mergeSetHeader(dst *Header, src Header) {
+ if *dst == nil {
+ *dst = src
+ return
+ }
+ for k, vv := range src {
+ (*dst)[k] = vv
+ }
+}
+
func (b *body) Close() error {
+ b.mu.Lock()
+ defer b.mu.Unlock()
if b.closed {
return nil
}
@@ -629,12 +695,25 @@ func (b *body) Close() error {
default:
// Fully consume the body, which will also lead to us reading
// the trailer headers after the body, if present.
- _, err = io.Copy(ioutil.Discard, b)
+ _, err = io.Copy(ioutil.Discard, bodyLocked{b})
}
b.closed = true
return err
}
+// bodyLocked is a io.Reader reading from a *body when its mutex is
+// already held.
+type bodyLocked struct {
+ b *body
+}
+
+func (bl bodyLocked) Read(p []byte) (n int, err error) {
+ if bl.b.closed {
+ return 0, ErrBodyReadAfterClose
+ }
+ return bl.b.readLocked(p)
+}
+
// parseContentLength trims whitespace from s and returns -1 if no value
// is set, or the value if it's >= 0.
func parseContentLength(cl string) (int64, error) {
diff --git a/src/pkg/net/http/transfer_test.go b/src/pkg/net/http/transfer_test.go
index 8627a374c..48cd540b9 100644
--- a/src/pkg/net/http/transfer_test.go
+++ b/src/pkg/net/http/transfer_test.go
@@ -6,15 +6,16 @@ package http
import (
"bufio"
+ "io"
"strings"
"testing"
)
func TestBodyReadBadTrailer(t *testing.T) {
b := &body{
- Reader: strings.NewReader("foobar"),
- hdr: true, // force reading the trailer
- r: bufio.NewReader(strings.NewReader("")),
+ src: strings.NewReader("foobar"),
+ hdr: true, // force reading the trailer
+ r: bufio.NewReader(strings.NewReader("")),
}
buf := make([]byte, 7)
n, err := b.Read(buf[:3])
@@ -35,3 +36,29 @@ func TestBodyReadBadTrailer(t *testing.T) {
t.Errorf("final Read was successful (%q), expected error from trailer read", got)
}
}
+
+func TestFinalChunkedBodyReadEOF(t *testing.T) {
+ res, err := ReadResponse(bufio.NewReader(strings.NewReader(
+ "HTTP/1.1 200 OK\r\n"+
+ "Transfer-Encoding: chunked\r\n"+
+ "\r\n"+
+ "0a\r\n"+
+ "Body here\n\r\n"+
+ "09\r\n"+
+ "continued\r\n"+
+ "0\r\n"+
+ "\r\n")), nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ want := "Body here\ncontinued"
+ buf := make([]byte, len(want))
+ n, err := res.Body.Read(buf)
+ if n != len(want) || err != io.EOF {
+ t.Logf("body = %#v", res.Body)
+ t.Errorf("Read = %v, %v; want %d, EOF", n, err, len(want))
+ }
+ if string(buf) != want {
+ t.Errorf("buf = %q; want %q", buf, want)
+ }
+}
diff --git a/src/pkg/net/http/transport.go b/src/pkg/net/http/transport.go
index f6871afac..b1cc632a7 100644
--- a/src/pkg/net/http/transport.go
+++ b/src/pkg/net/http/transport.go
@@ -30,7 +30,14 @@ import (
// and caches them for reuse by subsequent calls. It uses HTTP proxies
// as directed by the $HTTP_PROXY and $NO_PROXY (or $http_proxy and
// $no_proxy) environment variables.
-var DefaultTransport RoundTripper = &Transport{Proxy: ProxyFromEnvironment}
+var DefaultTransport RoundTripper = &Transport{
+ Proxy: ProxyFromEnvironment,
+ Dial: (&net.Dialer{
+ Timeout: 30 * time.Second,
+ KeepAlive: 30 * time.Second,
+ }).Dial,
+ TLSHandshakeTimeout: 10 * time.Second,
+}
// DefaultMaxIdleConnsPerHost is the default value of Transport's
// MaxIdleConnsPerHost.
@@ -40,13 +47,13 @@ const DefaultMaxIdleConnsPerHost = 2
// https, and http proxies (for either http or https with CONNECT).
// Transport can also cache connections for future re-use.
type Transport struct {
- idleMu sync.Mutex
- idleConn map[string][]*persistConn
- idleConnCh map[string]chan *persistConn
- reqMu sync.Mutex
- reqConn map[*Request]*persistConn
- altMu sync.RWMutex
- altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper
+ idleMu sync.Mutex
+ idleConn map[connectMethodKey][]*persistConn
+ idleConnCh map[connectMethodKey]chan *persistConn
+ reqMu sync.Mutex
+ reqCanceler map[*Request]func()
+ altMu sync.RWMutex
+ altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper
// Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the
@@ -63,6 +70,10 @@ type Transport struct {
// tls.Client. If nil, the default configuration is used.
TLSClientConfig *tls.Config
+ // TLSHandshakeTimeout specifies the maximum amount of time waiting to
+ // wait for a TLS handshake. Zero means no timeout.
+ TLSHandshakeTimeout time.Duration
+
// DisableKeepAlives, if true, prevents re-use of TCP connections
// between different HTTP requests.
DisableKeepAlives bool
@@ -98,8 +109,11 @@ type Transport struct {
// An error is returned if the proxy environment is invalid.
// A nil URL and nil error are returned if no proxy is defined in the
// environment, or a proxy should not be used for the given request.
+//
+// As a special case, if req.URL.Host is "localhost" (with or without
+// a port number), then a nil URL and nil error will be returned.
func ProxyFromEnvironment(req *Request) (*url.URL, error) {
- proxy := getenvEitherCase("HTTP_PROXY")
+ proxy := httpProxyEnv.Get()
if proxy == "" {
return nil, nil
}
@@ -149,9 +163,11 @@ func (tr *transportRequest) extraHeaders() Header {
// and redirects), see Get, Post, and the Client type.
func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) {
if req.URL == nil {
+ req.closeBody()
return nil, errors.New("http: nil Request.URL")
}
if req.Header == nil {
+ req.closeBody()
return nil, errors.New("http: nil Request.Header")
}
if req.URL.Scheme != "http" && req.URL.Scheme != "https" {
@@ -162,16 +178,19 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) {
}
t.altMu.RUnlock()
if rt == nil {
+ req.closeBody()
return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme}
}
return rt.RoundTrip(req)
}
if req.URL.Host == "" {
+ req.closeBody()
return nil, errors.New("http: no Host in request URL")
}
treq := &transportRequest{Request: req}
cm, err := t.connectMethodForRequest(treq)
if err != nil {
+ req.closeBody()
return nil, err
}
@@ -179,8 +198,10 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) {
// host (for http or https), the http proxy, or the http proxy
// pre-CONNECTed to https server. In any case, we'll be ready
// to send it requests.
- pconn, err := t.getConn(cm)
+ pconn, err := t.getConn(req, cm)
if err != nil {
+ t.setReqCanceler(req, nil)
+ req.closeBody()
return nil, err
}
@@ -218,9 +239,6 @@ func (t *Transport) CloseIdleConnections() {
t.idleConn = nil
t.idleConnCh = nil
t.idleMu.Unlock()
- if m == nil {
- return
- }
for _, conns := range m {
for _, pconn := range conns {
pconn.close()
@@ -232,10 +250,10 @@ func (t *Transport) CloseIdleConnections() {
// connection.
func (t *Transport) CancelRequest(req *Request) {
t.reqMu.Lock()
- pc := t.reqConn[req]
+ cancel := t.reqCanceler[req]
t.reqMu.Unlock()
- if pc != nil {
- pc.conn.Close()
+ if cancel != nil {
+ cancel()
}
}
@@ -243,24 +261,49 @@ func (t *Transport) CancelRequest(req *Request) {
// Private implementation past this point.
//
-func getenvEitherCase(k string) string {
- if v := os.Getenv(strings.ToUpper(k)); v != "" {
- return v
+var (
+ httpProxyEnv = &envOnce{
+ names: []string{"HTTP_PROXY", "http_proxy"},
}
- return os.Getenv(strings.ToLower(k))
+ noProxyEnv = &envOnce{
+ names: []string{"NO_PROXY", "no_proxy"},
+ }
+)
+
+// envOnce looks up an environment variable (optionally by multiple
+// names) once. It mitigates expensive lookups on some platforms
+// (e.g. Windows).
+type envOnce struct {
+ names []string
+ once sync.Once
+ val string
+}
+
+func (e *envOnce) Get() string {
+ e.once.Do(e.init)
+ return e.val
}
-func (t *Transport) connectMethodForRequest(treq *transportRequest) (*connectMethod, error) {
- cm := &connectMethod{
- targetScheme: treq.URL.Scheme,
- targetAddr: canonicalAddr(treq.URL),
+func (e *envOnce) init() {
+ for _, n := range e.names {
+ e.val = os.Getenv(n)
+ if e.val != "" {
+ return
+ }
}
+}
+
+// reset is used by tests
+func (e *envOnce) reset() {
+ e.once = sync.Once{}
+ e.val = ""
+}
+
+func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) {
+ cm.targetScheme = treq.URL.Scheme
+ cm.targetAddr = canonicalAddr(treq.URL)
if t.Proxy != nil {
- var err error
cm.proxyURL, err = t.Proxy(treq.Request)
- if err != nil {
- return nil, err
- }
}
return cm, nil
}
@@ -316,7 +359,7 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool {
}
}
if t.idleConn == nil {
- t.idleConn = make(map[string][]*persistConn)
+ t.idleConn = make(map[connectMethodKey][]*persistConn)
}
if len(t.idleConn[key]) >= max {
t.idleMu.Unlock()
@@ -336,7 +379,7 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool {
// getIdleConnCh returns a channel to receive and return idle
// persistent connection for the given connectMethod.
// It may return nil, if persistent connections are not being used.
-func (t *Transport) getIdleConnCh(cm *connectMethod) chan *persistConn {
+func (t *Transport) getIdleConnCh(cm connectMethod) chan *persistConn {
if t.DisableKeepAlives {
return nil
}
@@ -344,7 +387,7 @@ func (t *Transport) getIdleConnCh(cm *connectMethod) chan *persistConn {
t.idleMu.Lock()
defer t.idleMu.Unlock()
if t.idleConnCh == nil {
- t.idleConnCh = make(map[string]chan *persistConn)
+ t.idleConnCh = make(map[connectMethodKey]chan *persistConn)
}
ch, ok := t.idleConnCh[key]
if !ok {
@@ -354,7 +397,7 @@ func (t *Transport) getIdleConnCh(cm *connectMethod) chan *persistConn {
return ch
}
-func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) {
+func (t *Transport) getIdleConn(cm connectMethod) (pconn *persistConn) {
key := cm.key()
t.idleMu.Lock()
defer t.idleMu.Unlock()
@@ -373,7 +416,7 @@ func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) {
// 2 or more cached connections; pop last
// TODO: queue?
pconn = pconns[len(pconns)-1]
- t.idleConn[key] = pconns[0 : len(pconns)-1]
+ t.idleConn[key] = pconns[:len(pconns)-1]
}
if !pconn.isBroken() {
return
@@ -381,16 +424,16 @@ func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) {
}
}
-func (t *Transport) setReqConn(r *Request, pc *persistConn) {
+func (t *Transport) setReqCanceler(r *Request, fn func()) {
t.reqMu.Lock()
defer t.reqMu.Unlock()
- if t.reqConn == nil {
- t.reqConn = make(map[*Request]*persistConn)
+ if t.reqCanceler == nil {
+ t.reqCanceler = make(map[*Request]func())
}
- if pc != nil {
- t.reqConn[r] = pc
+ if fn != nil {
+ t.reqCanceler[r] = fn
} else {
- delete(t.reqConn, r)
+ delete(t.reqCanceler, r)
}
}
@@ -405,7 +448,7 @@ func (t *Transport) dial(network, addr string) (c net.Conn, err error) {
// specified in the connectMethod. This includes doing a proxy CONNECT
// and/or setting up TLS. If this doesn't return an error, the persistConn
// is ready to write requests to.
-func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) {
+func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error) {
if pc := t.getIdleConn(cm); pc != nil {
return pc, nil
}
@@ -415,6 +458,16 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) {
err error
}
dialc := make(chan dialRes)
+
+ handlePendingDial := func() {
+ if v := <-dialc; v.err == nil {
+ t.putIdleConn(v.pc)
+ }
+ }
+
+ cancelc := make(chan struct{})
+ t.setReqCanceler(req, func() { close(cancelc) })
+
go func() {
pc, err := t.dialConn(cm)
dialc <- dialRes{pc, err}
@@ -431,16 +484,15 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) {
// else's dial that they didn't use.
// But our dial is still going, so give it away
// when it finishes:
- go func() {
- if v := <-dialc; v.err == nil {
- t.putIdleConn(v.pc)
- }
- }()
+ go handlePendingDial()
return pc, nil
+ case <-cancelc:
+ go handlePendingDial()
+ return nil, errors.New("net/http: request canceled while waiting for connection")
}
}
-func (t *Transport) dialConn(cm *connectMethod) (*persistConn, error) {
+func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
conn, err := t.dial("tcp", cm.addr())
if err != nil {
if cm.proxyURL != nil {
@@ -452,12 +504,13 @@ func (t *Transport) dialConn(cm *connectMethod) (*persistConn, error) {
pa := cm.proxyAuth()
pconn := &persistConn{
- t: t,
- cacheKey: cm.key(),
- conn: conn,
- reqch: make(chan requestAndChan, 50),
- writech: make(chan writeRequest, 50),
- closech: make(chan struct{}),
+ t: t,
+ cacheKey: cm.key(),
+ conn: conn,
+ reqch: make(chan requestAndChan, 1),
+ writech: make(chan writeRequest, 1),
+ closech: make(chan struct{}),
+ writeErrCh: make(chan error, 1),
}
switch {
@@ -511,19 +564,38 @@ func (t *Transport) dialConn(cm *connectMethod) (*persistConn, error) {
cfg = &clone
}
}
- conn = tls.Client(conn, cfg)
- if err = conn.(*tls.Conn).Handshake(); err != nil {
+ plainConn := conn
+ tlsConn := tls.Client(plainConn, cfg)
+ errc := make(chan error, 2)
+ var timer *time.Timer // for canceling TLS handshake
+ if d := t.TLSHandshakeTimeout; d != 0 {
+ timer = time.AfterFunc(d, func() {
+ errc <- tlsHandshakeTimeoutError{}
+ })
+ }
+ go func() {
+ err := tlsConn.Handshake()
+ if timer != nil {
+ timer.Stop()
+ }
+ errc <- err
+ }()
+ if err := <-errc; err != nil {
+ plainConn.Close()
return nil, err
}
if !cfg.InsecureSkipVerify {
- if err = conn.(*tls.Conn).VerifyHostname(cfg.ServerName); err != nil {
+ if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
+ plainConn.Close()
return nil, err
}
}
- pconn.conn = conn
+ cs := tlsConn.ConnectionState()
+ pconn.tlsState = &cs
+ pconn.conn = tlsConn
}
- pconn.br = bufio.NewReader(pconn.conn)
+ pconn.br = bufio.NewReader(noteEOFReader{pconn.conn, &pconn.sawEOF})
pconn.bw = bufio.NewWriter(pconn.conn)
go pconn.readLoop()
go pconn.writeLoop()
@@ -550,7 +622,7 @@ func useProxy(addr string) bool {
}
}
- no_proxy := getenvEitherCase("NO_PROXY")
+ no_proxy := noProxyEnv.Get()
if no_proxy == "*" {
return false
}
@@ -590,8 +662,8 @@ func useProxy(addr string) bool {
//
// Cache key form Description
// ----------------- -------------------------
-// ||http|foo.com http directly to server, no proxy
-// ||https|foo.com https directly to server, no proxy
+// |http|foo.com http directly to server, no proxy
+// |https|foo.com https directly to server, no proxy
// http://proxy.com|https|foo.com http to proxy, then CONNECT to foo.com
// http://proxy.com|http http to proxy, http to anywhere after that
//
@@ -603,20 +675,20 @@ type connectMethod struct {
targetAddr string // Not used if proxy + http targetScheme (4th example in table)
}
-func (ck *connectMethod) key() string {
- return ck.String() // TODO: use a struct type instead
-}
-
-func (ck *connectMethod) String() string {
+func (cm *connectMethod) key() connectMethodKey {
proxyStr := ""
- targetAddr := ck.targetAddr
- if ck.proxyURL != nil {
- proxyStr = ck.proxyURL.String()
- if ck.targetScheme == "http" {
+ targetAddr := cm.targetAddr
+ if cm.proxyURL != nil {
+ proxyStr = cm.proxyURL.String()
+ if cm.targetScheme == "http" {
targetAddr = ""
}
}
- return strings.Join([]string{proxyStr, ck.targetScheme, targetAddr}, "|")
+ return connectMethodKey{
+ proxy: proxyStr,
+ scheme: cm.targetScheme,
+ addr: targetAddr,
+ }
}
// addr returns the first hop "host:port" to which we need to TCP connect.
@@ -637,22 +709,41 @@ func (cm *connectMethod) tlsHost() string {
return h
}
+// connectMethodKey is the map key version of connectMethod, with a
+// stringified proxy URL (or the empty string) instead of a pointer to
+// a URL.
+type connectMethodKey struct {
+ proxy, scheme, addr string
+}
+
+func (k connectMethodKey) String() string {
+ // Only used by tests.
+ return fmt.Sprintf("%s|%s|%s", k.proxy, k.scheme, k.addr)
+}
+
// persistConn wraps a connection, usually a persistent one
// (but may be used for non-keep-alive requests as well)
type persistConn struct {
t *Transport
- cacheKey string // its connectMethod.String()
+ cacheKey connectMethodKey
conn net.Conn
- closed bool // whether conn has been closed
+ tlsState *tls.ConnectionState
br *bufio.Reader // from conn
+ sawEOF bool // whether we've seen EOF from conn; owned by readLoop
bw *bufio.Writer // to conn
reqch chan requestAndChan // written by roundTrip; read by readLoop
writech chan writeRequest // written by roundTrip; read by writeLoop
- closech chan struct{} // broadcast close when readLoop (TCP connection) closes
+ closech chan struct{} // closed when conn closed
isProxy bool
+ // writeErrCh passes the request write error (usually nil)
+ // from the writeLoop goroutine to the readLoop which passes
+ // it off to the res.Body reader, which then uses it to decide
+ // whether or not a connection can be reused. Issue 7569.
+ writeErrCh chan error
- lk sync.Mutex // guards following 3 fields
+ lk sync.Mutex // guards following fields
numExpectedResponses int
+ closed bool // whether conn has been closed
broken bool // an error has happened on this connection; marked broken so it's not reused.
// mutateHeaderFunc is an optional func to modify extra
// headers on each outbound request before it's written. (the
@@ -660,6 +751,7 @@ type persistConn struct {
mutateHeaderFunc func(Header)
}
+// isBroken reports whether this connection is in a known broken state.
func (pc *persistConn) isBroken() bool {
pc.lk.Lock()
b := pc.broken
@@ -667,6 +759,10 @@ func (pc *persistConn) isBroken() bool {
return b
}
+func (pc *persistConn) cancelRequest() {
+ pc.conn.Close()
+}
+
var remoteSideClosedFunc func(error) bool // or nil to use default
func remoteSideClosed(err error) bool {
@@ -680,7 +776,6 @@ func remoteSideClosed(err error) bool {
}
func (pc *persistConn) readLoop() {
- defer close(pc.closech)
alive := true
for alive {
@@ -688,12 +783,14 @@ func (pc *persistConn) readLoop() {
pc.lk.Lock()
if pc.numExpectedResponses == 0 {
- pc.closeLocked()
- pc.lk.Unlock()
- if len(pb) > 0 {
- log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v",
- string(pb), err)
+ if !pc.closed {
+ pc.closeLocked()
+ if len(pb) > 0 {
+ log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v",
+ string(pb), err)
+ }
}
+ pc.lk.Unlock()
return
}
pc.lk.Unlock()
@@ -712,6 +809,11 @@ func (pc *persistConn) readLoop() {
resp, err = ReadResponse(pc.br, rc.req)
}
}
+
+ if resp != nil {
+ resp.TLS = pc.tlsState
+ }
+
hasBody := resp != nil && rc.req.Method != "HEAD" && resp.ContentLength != 0
if err != nil {
@@ -721,13 +823,7 @@ func (pc *persistConn) readLoop() {
resp.Header.Del("Content-Encoding")
resp.Header.Del("Content-Length")
resp.ContentLength = -1
- gzReader, zerr := gzip.NewReader(resp.Body)
- if zerr != nil {
- pc.close()
- err = zerr
- } else {
- resp.Body = &readerAndCloser{gzReader, resp.Body}
- }
+ resp.Body = &gzipReader{body: resp.Body}
}
resp.Body = &bodyEOFSignal{body: resp.Body}
}
@@ -750,24 +846,18 @@ func (pc *persistConn) readLoop() {
return nil
}
resp.Body.(*bodyEOFSignal).fn = func(err error) {
- alive1 := alive
- if err != nil {
- alive1 = false
- }
- if alive1 && !pc.t.putIdleConn(pc) {
- alive1 = false
- }
- if !alive1 || pc.isBroken() {
- pc.close()
- }
- waitForBodyRead <- alive1
+ waitForBodyRead <- alive &&
+ err == nil &&
+ !pc.sawEOF &&
+ pc.wroteRequest() &&
+ pc.t.putIdleConn(pc)
}
}
if alive && !hasBody {
- if !pc.t.putIdleConn(pc) {
- alive = false
- }
+ alive = !pc.sawEOF &&
+ pc.wroteRequest() &&
+ pc.t.putIdleConn(pc)
}
rc.ch <- responseAndError{resp, err}
@@ -775,10 +865,14 @@ func (pc *persistConn) readLoop() {
// Wait for the just-returned response body to be fully consumed
// before we race and peek on the underlying bufio reader.
if waitForBodyRead != nil {
- alive = <-waitForBodyRead
+ select {
+ case alive = <-waitForBodyRead:
+ case <-pc.closech:
+ alive = false
+ }
}
- pc.t.setReqConn(rc.req, nil)
+ pc.t.setReqCanceler(rc.req, nil)
if !alive {
pc.close()
@@ -800,14 +894,44 @@ func (pc *persistConn) writeLoop() {
}
if err != nil {
pc.markBroken()
+ wr.req.Request.closeBody()
}
- wr.ch <- err
+ pc.writeErrCh <- err // to the body reader, which might recycle us
+ wr.ch <- err // to the roundTrip function
case <-pc.closech:
return
}
}
}
+// wroteRequest is a check before recycling a connection that the previous write
+// (from writeLoop above) happened and was successful.
+func (pc *persistConn) wroteRequest() bool {
+ select {
+ case err := <-pc.writeErrCh:
+ // Common case: the write happened well before the response, so
+ // avoid creating a timer.
+ return err == nil
+ default:
+ // Rare case: the request was written in writeLoop above but
+ // before it could send to pc.writeErrCh, the reader read it
+ // all, processed it, and called us here. In this case, give the
+ // write goroutine a bit of time to finish its send.
+ //
+ // Less rare case: We also get here in the legitimate case of
+ // Issue 7569, where the writer is still writing (or stalled),
+ // but the server has already replied. In this case, we don't
+ // want to wait too long, and we want to return false so this
+ // connection isn't re-used.
+ select {
+ case err := <-pc.writeErrCh:
+ return err == nil
+ case <-time.After(50 * time.Millisecond):
+ return false
+ }
+ }
+}
+
type responseAndError struct {
res *Response
err error
@@ -832,8 +956,20 @@ type writeRequest struct {
ch chan<- error
}
+type httpError struct {
+ err string
+ timeout bool
+}
+
+func (e *httpError) Error() string { return e.err }
+func (e *httpError) Timeout() bool { return e.timeout }
+func (e *httpError) Temporary() bool { return true }
+
+var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true}
+var errClosed error = &httpError{err: "net/http: transport closed before response was received"}
+
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
- pc.t.setReqConn(req.Request, pc)
+ pc.t.setReqCanceler(req.Request, pc.cancelRequest)
pc.lk.Lock()
pc.numExpectedResponses++
headerFn := pc.mutateHeaderFunc
@@ -902,11 +1038,11 @@ WaitResponse:
pconnDeadCh = nil // avoid spinning
failTicker = time.After(100 * time.Millisecond) // arbitrary time to wait for resc
case <-failTicker:
- re = responseAndError{err: errors.New("net/http: transport closed before response was received")}
+ re = responseAndError{err: errClosed}
break WaitResponse
case <-respHeaderTimer:
pc.close()
- re = responseAndError{err: errors.New("net/http: timeout awaiting response headers")}
+ re = responseAndError{err: errTimeout}
break WaitResponse
case re = <-resc:
break WaitResponse
@@ -918,7 +1054,7 @@ WaitResponse:
pc.lk.Unlock()
if re.err != nil {
- pc.t.setReqConn(req.Request, nil)
+ pc.t.setReqCanceler(req.Request, nil)
}
return re.res, re.err
}
@@ -943,6 +1079,7 @@ func (pc *persistConn) closeLocked() {
if !pc.closed {
pc.conn.Close()
pc.closed = true
+ close(pc.closech)
}
pc.mutateHeaderFunc = nil
}
@@ -1025,7 +1162,47 @@ func (es *bodyEOFSignal) condfn(err error) {
es.fn = nil
}
+// gzipReader wraps a response body so it can lazily
+// call gzip.NewReader on the first call to Read
+type gzipReader struct {
+ body io.ReadCloser // underlying Response.Body
+ zr io.Reader // lazily-initialized gzip reader
+}
+
+func (gz *gzipReader) Read(p []byte) (n int, err error) {
+ if gz.zr == nil {
+ gz.zr, err = gzip.NewReader(gz.body)
+ if err != nil {
+ return 0, err
+ }
+ }
+ return gz.zr.Read(p)
+}
+
+func (gz *gzipReader) Close() error {
+ return gz.body.Close()
+}
+
type readerAndCloser struct {
io.Reader
io.Closer
}
+
+type tlsHandshakeTimeoutError struct{}
+
+func (tlsHandshakeTimeoutError) Timeout() bool { return true }
+func (tlsHandshakeTimeoutError) Temporary() bool { return true }
+func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
+
+type noteEOFReader struct {
+ r io.Reader
+ sawEOF *bool
+}
+
+func (nr noteEOFReader) Read(p []byte) (n int, err error) {
+ n, err = nr.r.Read(p)
+ if err == io.EOF {
+ *nr.sawEOF = true
+ }
+ return
+}
diff --git a/src/pkg/net/http/transport_test.go b/src/pkg/net/http/transport_test.go
index e4df30a98..964ca0fca 100644
--- a/src/pkg/net/http/transport_test.go
+++ b/src/pkg/net/http/transport_test.go
@@ -11,9 +11,12 @@ import (
"bytes"
"compress/gzip"
"crypto/rand"
+ "crypto/tls"
+ "errors"
"fmt"
"io"
"io/ioutil"
+ "log"
"net"
"net/http"
. "net/http"
@@ -54,21 +57,21 @@ func (c *testCloseConn) Close() error {
// been closed.
type testConnSet struct {
t *testing.T
+ mu sync.Mutex // guards closed and list
closed map[net.Conn]bool
list []net.Conn // in order created
- mutex sync.Mutex
}
func (tcs *testConnSet) insert(c net.Conn) {
- tcs.mutex.Lock()
- defer tcs.mutex.Unlock()
+ tcs.mu.Lock()
+ defer tcs.mu.Unlock()
tcs.closed[c] = false
tcs.list = append(tcs.list, c)
}
func (tcs *testConnSet) remove(c net.Conn) {
- tcs.mutex.Lock()
- defer tcs.mutex.Unlock()
+ tcs.mu.Lock()
+ defer tcs.mu.Unlock()
tcs.closed[c] = true
}
@@ -91,11 +94,19 @@ func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, e
}
func (tcs *testConnSet) check(t *testing.T) {
- tcs.mutex.Lock()
- defer tcs.mutex.Unlock()
-
- for i, c := range tcs.list {
- if !tcs.closed[c] {
+ tcs.mu.Lock()
+ defer tcs.mu.Unlock()
+ for i := 4; i >= 0; i-- {
+ for i, c := range tcs.list {
+ if tcs.closed[c] {
+ continue
+ }
+ if i != 0 {
+ tcs.mu.Unlock()
+ time.Sleep(50 * time.Millisecond)
+ tcs.mu.Lock()
+ continue
+ }
t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
}
}
@@ -271,6 +282,58 @@ func TestTransportIdleCacheKeys(t *testing.T) {
}
}
+// Tests that the HTTP transport re-uses connections when a client
+// reads to the end of a response Body without closing it.
+func TestTransportReadToEndReusesConn(t *testing.T) {
+ defer afterTest(t)
+ const msg = "foobar"
+
+ var addrSeen map[string]int
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ addrSeen[r.RemoteAddr]++
+ if r.URL.Path == "/chunked/" {
+ w.WriteHeader(200)
+ w.(http.Flusher).Flush()
+ } else {
+ w.Header().Set("Content-Type", strconv.Itoa(len(msg)))
+ w.WriteHeader(200)
+ }
+ w.Write([]byte(msg))
+ }))
+ defer ts.Close()
+
+ buf := make([]byte, len(msg))
+
+ for pi, path := range []string{"/content-length/", "/chunked/"} {
+ wantLen := []int{len(msg), -1}[pi]
+ addrSeen = make(map[string]int)
+ for i := 0; i < 3; i++ {
+ res, err := http.Get(ts.URL + path)
+ if err != nil {
+ t.Errorf("Get %s: %v", path, err)
+ continue
+ }
+ // We want to close this body eventually (before the
+ // defer afterTest at top runs), but not before the
+ // len(addrSeen) check at the bottom of this test,
+ // since Closing this early in the loop would risk
+ // making connections be re-used for the wrong reason.
+ defer res.Body.Close()
+
+ if res.ContentLength != int64(wantLen) {
+ t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen)
+ }
+ n, err := res.Body.Read(buf)
+ if n != len(msg) || err != io.EOF {
+ t.Errorf("%s Read = %v, %v; want %d, EOF", path, n, err, len(msg))
+ }
+ }
+ if len(addrSeen) != 1 {
+ t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen))
+ }
+ }
+}
+
func TestTransportMaxPerHostIdleConns(t *testing.T) {
defer afterTest(t)
resch := make(chan string)
@@ -295,10 +358,11 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) {
resp, err := c.Get(ts.URL)
if err != nil {
t.Error(err)
+ return
}
- _, err = ioutil.ReadAll(resp.Body)
- if err != nil {
- t.Fatalf("ReadAll: %v", err)
+ if _, err := ioutil.ReadAll(resp.Body); err != nil {
+ t.Errorf("ReadAll: %v", err)
+ return
}
donech <- true
}
@@ -739,8 +803,38 @@ func TestTransportGzipRecursive(t *testing.T) {
}
}
+// golang.org/issue/7750: request fails when server replies with
+// a short gzip body
+func TestTransportGzipShort(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Encoding", "gzip")
+ w.Write([]byte{0x1f, 0x8b})
+ }))
+ defer ts.Close()
+
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ _, err = ioutil.ReadAll(res.Body)
+ if err == nil {
+ t.Fatal("Expect an error from reading a body.")
+ }
+ if err != io.ErrUnexpectedEOF {
+ t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err)
+ }
+}
+
// tests that persistent goroutine connections shut down when no longer desired.
func TestTransportPersistConnLeak(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("skipping test; see http://golang.org/issue/7237")
+ }
defer afterTest(t)
gotReqCh := make(chan bool)
unblockCh := make(chan bool)
@@ -798,8 +892,8 @@ func TestTransportPersistConnLeak(t *testing.T) {
// We expect 0 or 1 extra goroutine, empirically. Allow up to 5.
// Previously we were leaking one per numReq.
- t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
if int(growth) > 5 {
+ t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
t.Error("too many new goroutines")
}
}
@@ -807,6 +901,9 @@ func TestTransportPersistConnLeak(t *testing.T) {
// golang.org/issue/4531: Transport leaks goroutines when
// request.ContentLength is explicitly short
func TestTransportPersistConnLeakShortBody(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("skipping test; see http://golang.org/issue/7237")
+ }
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
}))
@@ -1014,6 +1111,9 @@ func TestTransportConcurrency(t *testing.T) {
}
func TestIssue4191_InfiniteGetTimeout(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("skipping test; see http://golang.org/issue/7237")
+ }
defer afterTest(t)
const debug = false
mux := NewServeMux()
@@ -1075,6 +1175,9 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) {
}
func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("skipping test; see http://golang.org/issue/7237")
+ }
defer afterTest(t)
const debug = false
mux := NewServeMux()
@@ -1147,9 +1250,13 @@ func TestTransportResponseHeaderTimeout(t *testing.T) {
if testing.Short() {
t.Skip("skipping timeout test in -short mode")
}
+ inHandler := make(chan bool, 1)
mux := NewServeMux()
- mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {})
+ mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {
+ inHandler <- true
+ })
mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
+ inHandler <- true
time.Sleep(2 * time.Second)
})
ts := httptest.NewServer(mux)
@@ -1172,7 +1279,27 @@ func TestTransportResponseHeaderTimeout(t *testing.T) {
}
for i, tt := range tests {
res, err := c.Get(ts.URL + tt.path)
+ select {
+ case <-inHandler:
+ case <-time.After(5 * time.Second):
+ t.Errorf("never entered handler for test index %d, %s", i, tt.path)
+ continue
+ }
if err != nil {
+ uerr, ok := err.(*url.Error)
+ if !ok {
+ t.Errorf("error is not an url.Error; got: %#v", err)
+ continue
+ }
+ nerr, ok := uerr.Err.(net.Error)
+ if !ok {
+ t.Errorf("error does not satisfy net.Error interface; got: %#v", err)
+ continue
+ }
+ if !nerr.Timeout() {
+ t.Errorf("want timeout error; got: %q", nerr)
+ continue
+ }
if strings.Contains(err.Error(), tt.wantErr) {
continue
}
@@ -1243,6 +1370,60 @@ func TestTransportCancelRequest(t *testing.T) {
}
}
+func TestTransportCancelRequestInDial(t *testing.T) {
+ defer afterTest(t)
+ if testing.Short() {
+ t.Skip("skipping test in -short mode")
+ }
+ var logbuf bytes.Buffer
+ eventLog := log.New(&logbuf, "", 0)
+
+ unblockDial := make(chan bool)
+ defer close(unblockDial)
+
+ inDial := make(chan bool)
+ tr := &Transport{
+ Dial: func(network, addr string) (net.Conn, error) {
+ eventLog.Println("dial: blocking")
+ inDial <- true
+ <-unblockDial
+ return nil, errors.New("nope")
+ },
+ }
+ cl := &Client{Transport: tr}
+ gotres := make(chan bool)
+ req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
+ go func() {
+ _, err := cl.Do(req)
+ eventLog.Printf("Get = %v", err)
+ gotres <- true
+ }()
+
+ select {
+ case <-inDial:
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout; never saw blocking dial")
+ }
+
+ eventLog.Printf("canceling")
+ tr.CancelRequest(req)
+
+ select {
+ case <-gotres:
+ case <-time.After(5 * time.Second):
+ panic("hang. events are: " + logbuf.String())
+ }
+
+ got := logbuf.String()
+ want := `dial: blocking
+canceling
+Get = Get http://something.no-network.tld/: net/http: request canceled while waiting for connection
+`
+ if got != want {
+ t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
+ }
+}
+
// golang.org/issue/3672 -- Client can't close HTTP stream
// Calling Close on a Response.Body used to just read until EOF.
// Now it actually closes the TCP connection.
@@ -1283,7 +1464,7 @@ func TestTransportCloseResponseBody(t *testing.T) {
t.Fatal(err)
}
if !bytes.Equal(buf, want) {
- t.Errorf("read %q; want %q", buf, want)
+ t.Fatalf("read %q; want %q", buf, want)
}
didClose := make(chan error, 1)
go func() {
@@ -1372,8 +1553,10 @@ func TestTransportSocketLateBinding(t *testing.T) {
dialGate := make(chan bool, 1)
tr := &Transport{
Dial: func(n, addr string) (net.Conn, error) {
- <-dialGate
- return net.Dial(n, addr)
+ if <-dialGate {
+ return net.Dial(n, addr)
+ }
+ return nil, errors.New("manually closed")
},
DisableKeepAlives: false,
}
@@ -1408,7 +1591,7 @@ func TestTransportSocketLateBinding(t *testing.T) {
t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr)
}
barRes.Body.Close()
- dialGate <- true
+ dialGate <- false
}
// Issue 2184
@@ -1559,13 +1742,11 @@ var proxyFromEnvTests = []proxyFromEnvTest{
}
func TestProxyFromEnvironment(t *testing.T) {
- os.Setenv("HTTP_PROXY", "")
- os.Setenv("http_proxy", "")
- os.Setenv("NO_PROXY", "")
- os.Setenv("no_proxy", "")
+ ResetProxyEnv()
for _, tt := range proxyFromEnvTests {
os.Setenv("HTTP_PROXY", tt.env)
os.Setenv("NO_PROXY", tt.noenv)
+ ResetCachedEnvironment()
reqURL := tt.req
if reqURL == "" {
reqURL = "http://example.com"
@@ -1643,6 +1824,308 @@ func TestTransportClosesRequestBody(t *testing.T) {
}
}
+func TestTransportTLSHandshakeTimeout(t *testing.T) {
+ defer afterTest(t)
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ ln := newLocalListener(t)
+ defer ln.Close()
+ testdonec := make(chan struct{})
+ defer close(testdonec)
+
+ go func() {
+ c, err := ln.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ <-testdonec
+ c.Close()
+ }()
+
+ getdonec := make(chan struct{})
+ go func() {
+ defer close(getdonec)
+ tr := &Transport{
+ Dial: func(_, _ string) (net.Conn, error) {
+ return net.Dial("tcp", ln.Addr().String())
+ },
+ TLSHandshakeTimeout: 250 * time.Millisecond,
+ }
+ cl := &Client{Transport: tr}
+ _, err := cl.Get("https://dummy.tld/")
+ if err == nil {
+ t.Error("expected error")
+ return
+ }
+ ue, ok := err.(*url.Error)
+ if !ok {
+ t.Errorf("expected url.Error; got %#v", err)
+ return
+ }
+ ne, ok := ue.Err.(net.Error)
+ if !ok {
+ t.Errorf("expected net.Error; got %#v", err)
+ return
+ }
+ if !ne.Timeout() {
+ t.Errorf("expected timeout error; got %v", err)
+ }
+ if !strings.Contains(err.Error(), "handshake timeout") {
+ t.Errorf("expected 'handshake timeout' in error; got %v", err)
+ }
+ }()
+ select {
+ case <-getdonec:
+ case <-time.After(5 * time.Second):
+ t.Error("test timeout; TLS handshake hung?")
+ }
+}
+
+// Trying to repro golang.org/issue/3514
+func TestTLSServerClosesConnection(t *testing.T) {
+ defer afterTest(t)
+ if runtime.GOOS == "windows" {
+ t.Skip("skipping flaky test on Windows; golang.org/issue/7634")
+ }
+ closedc := make(chan bool, 1)
+ ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if strings.Contains(r.URL.Path, "/keep-alive-then-die") {
+ conn, _, _ := w.(Hijacker).Hijack()
+ conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
+ conn.Close()
+ closedc <- true
+ return
+ }
+ fmt.Fprintf(w, "hello")
+ }))
+ defer ts.Close()
+ tr := &Transport{
+ TLSClientConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ }
+ defer tr.CloseIdleConnections()
+ client := &Client{Transport: tr}
+
+ var nSuccess = 0
+ var errs []error
+ const trials = 20
+ for i := 0; i < trials; i++ {
+ tr.CloseIdleConnections()
+ res, err := client.Get(ts.URL + "/keep-alive-then-die")
+ if err != nil {
+ t.Fatal(err)
+ }
+ <-closedc
+ slurp, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(slurp) != "foo" {
+ t.Errorf("Got %q, want foo", slurp)
+ }
+
+ // Now try again and see if we successfully
+ // pick a new connection.
+ res, err = client.Get(ts.URL + "/")
+ if err != nil {
+ errs = append(errs, err)
+ continue
+ }
+ slurp, err = ioutil.ReadAll(res.Body)
+ if err != nil {
+ errs = append(errs, err)
+ continue
+ }
+ nSuccess++
+ }
+ if nSuccess > 0 {
+ t.Logf("successes = %d of %d", nSuccess, trials)
+ } else {
+ t.Errorf("All runs failed:")
+ }
+ for _, err := range errs {
+ t.Logf(" err: %v", err)
+ }
+}
+
+// byteFromChanReader is an io.Reader that reads a single byte at a
+// time from the channel. When the channel is closed, the reader
+// returns io.EOF.
+type byteFromChanReader chan byte
+
+func (c byteFromChanReader) Read(p []byte) (n int, err error) {
+ if len(p) == 0 {
+ return
+ }
+ b, ok := <-c
+ if !ok {
+ return 0, io.EOF
+ }
+ p[0] = b
+ return 1, nil
+}
+
+// Verifies that the Transport doesn't reuse a connection in the case
+// where the server replies before the request has been fully
+// written. We still honor that reply (see TestIssue3595), but don't
+// send future requests on the connection because it's then in a
+// questionable state.
+// golang.org/issue/7569
+func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
+ defer afterTest(t)
+ var sconn struct {
+ sync.Mutex
+ c net.Conn
+ }
+ var getOkay bool
+ closeConn := func() {
+ sconn.Lock()
+ defer sconn.Unlock()
+ if sconn.c != nil {
+ sconn.c.Close()
+ sconn.c = nil
+ if !getOkay {
+ t.Logf("Closed server connection")
+ }
+ }
+ }
+ defer closeConn()
+
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method == "GET" {
+ io.WriteString(w, "bar")
+ return
+ }
+ conn, _, _ := w.(Hijacker).Hijack()
+ sconn.Lock()
+ sconn.c = conn
+ sconn.Unlock()
+ conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive
+ go io.Copy(ioutil.Discard, conn)
+ }))
+ defer ts.Close()
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+ client := &Client{Transport: tr}
+
+ const bodySize = 256 << 10
+ finalBit := make(byteFromChanReader, 1)
+ req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
+ req.ContentLength = bodySize
+ res, err := client.Do(req)
+ if err := wantBody(res, err, "foo"); err != nil {
+ t.Errorf("POST response: %v", err)
+ }
+ donec := make(chan bool)
+ go func() {
+ defer close(donec)
+ res, err = client.Get(ts.URL)
+ if err := wantBody(res, err, "bar"); err != nil {
+ t.Errorf("GET response: %v", err)
+ return
+ }
+ getOkay = true // suppress test noise
+ }()
+ time.AfterFunc(5*time.Second, closeConn)
+ select {
+ case <-donec:
+ finalBit <- 'x' // unblock the writeloop of the first Post
+ close(finalBit)
+ case <-time.After(7 * time.Second):
+ t.Fatal("timeout waiting for GET request to finish")
+ }
+}
+
+type errorReader struct {
+ err error
+}
+
+func (e errorReader) Read(p []byte) (int, error) { return 0, e.err }
+
+type closerFunc func() error
+
+func (f closerFunc) Close() error { return f() }
+
+// Issue 6981
+func TestTransportClosesBodyOnError(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("skipping test; see http://golang.org/issue/7782")
+ }
+ defer afterTest(t)
+ readBody := make(chan error, 1)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ _, err := ioutil.ReadAll(r.Body)
+ readBody <- err
+ }))
+ defer ts.Close()
+ fakeErr := errors.New("fake error")
+ didClose := make(chan bool, 1)
+ req, _ := NewRequest("POST", ts.URL, struct {
+ io.Reader
+ io.Closer
+ }{
+ io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), errorReader{fakeErr}),
+ closerFunc(func() error {
+ select {
+ case didClose <- true:
+ default:
+ }
+ return nil
+ }),
+ })
+ res, err := DefaultClient.Do(req)
+ if res != nil {
+ defer res.Body.Close()
+ }
+ if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) {
+ t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error())
+ }
+ select {
+ case err := <-readBody:
+ if err == nil {
+ t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'")
+ }
+ case <-time.After(5 * time.Second):
+ t.Error("timeout waiting for server handler to complete")
+ }
+ select {
+ case <-didClose:
+ default:
+ t.Errorf("didn't see Body.Close")
+ }
+}
+
+func wantBody(res *http.Response, err error, want string) error {
+ if err != nil {
+ return err
+ }
+ slurp, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ return fmt.Errorf("error reading body: %v", err)
+ }
+ if string(slurp) != want {
+ return fmt.Errorf("body = %q; want %q", slurp, want)
+ }
+ if err := res.Body.Close(); err != nil {
+ return fmt.Errorf("body Close = %v", err)
+ }
+ return nil
+}
+
+func newLocalListener(t *testing.T) net.Listener {
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ ln, err = net.Listen("tcp6", "[::1]:0")
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ return ln
+}
+
type countCloseReader struct {
n *int
io.Reader