diff options
Diffstat (limited to 'src/pkg/net/http')
54 files changed, 7051 insertions, 1246 deletions
diff --git a/src/pkg/net/http/cgi/child.go b/src/pkg/net/http/cgi/child.go index 1ba7bec5f..100b8b777 100644 --- a/src/pkg/net/http/cgi/child.go +++ b/src/pkg/net/http/cgi/child.go @@ -91,10 +91,19 @@ func RequestFromMap(params map[string]string) (*http.Request, error) { // TODO: cookies. parsing them isn't exported, though. + uriStr := params["REQUEST_URI"] + if uriStr == "" { + // Fallback to SCRIPT_NAME, PATH_INFO and QUERY_STRING. + uriStr = params["SCRIPT_NAME"] + params["PATH_INFO"] + s := params["QUERY_STRING"] + if s != "" { + uriStr += "?" + s + } + } if r.Host != "" { // Hostname is provided, so we can reasonably construct a URL, // even if we have to assume 'http' for the scheme. - rawurl := "http://" + r.Host + params["REQUEST_URI"] + rawurl := "http://" + r.Host + uriStr url, err := url.Parse(rawurl) if err != nil { return nil, errors.New("cgi: failed to parse host and REQUEST_URI into a URL: " + rawurl) @@ -104,7 +113,6 @@ func RequestFromMap(params map[string]string) (*http.Request, error) { // Fallback logic if we don't have a Host header or the URL // failed to parse if r.URL == nil { - uriStr := params["REQUEST_URI"] url, err := url.Parse(uriStr) if err != nil { return nil, errors.New("cgi: failed to parse REQUEST_URI into a URL: " + uriStr) diff --git a/src/pkg/net/http/cgi/child_test.go b/src/pkg/net/http/cgi/child_test.go index ec53ab851..74e068014 100644 --- a/src/pkg/net/http/cgi/child_test.go +++ b/src/pkg/net/http/cgi/child_test.go @@ -82,6 +82,28 @@ func TestRequestWithoutHost(t *testing.T) { t.Fatalf("unexpected nil URL") } if g, e := req.URL.String(), "/path?a=b"; e != g { - t.Errorf("expected URL %q; got %q", e, g) + t.Errorf("URL = %q; want %q", g, e) + } +} + +func TestRequestWithoutRequestURI(t *testing.T) { + env := map[string]string{ + "SERVER_PROTOCOL": "HTTP/1.1", + "HTTP_HOST": "example.com", + "REQUEST_METHOD": "GET", + "SCRIPT_NAME": "/dir/scriptname", + "PATH_INFO": "/p1/p2", + "QUERY_STRING": "a=1&b=2", + "CONTENT_LENGTH": "123", + } + req, err := RequestFromMap(env) + if err != nil { + t.Fatalf("RequestFromMap: %v", err) + } + if req.URL == nil { + t.Fatalf("unexpected nil URL") + } + if g, e := req.URL.String(), "http://example.com/dir/scriptname/p1/p2?a=1&b=2"; e != g { + t.Errorf("URL = %q; want %q", g, e) } } diff --git a/src/pkg/net/http/cgi/host_test.go b/src/pkg/net/http/cgi/host_test.go index 4db3d850c..8c16e6897 100644 --- a/src/pkg/net/http/cgi/host_test.go +++ b/src/pkg/net/http/cgi/host_test.go @@ -19,7 +19,6 @@ import ( "runtime" "strconv" "strings" - "syscall" "testing" "time" ) @@ -63,17 +62,25 @@ readlines: } for key, expected := range expectedMap { - if got := m[key]; got != expected { + got := m[key] + if key == "cwd" { + // For Windows. golang.org/issue/4645. + fi1, _ := os.Stat(got) + fi2, _ := os.Stat(expected) + if os.SameFile(fi1, fi2) { + got = expected + } + } + if got != expected { t.Errorf("for key %q got %q; expected %q", key, got, expected) } } return rw } -var cgiTested = false -var cgiWorks bool +var cgiTested, cgiWorks bool -func skipTest(t *testing.T) bool { +func check(t *testing.T) { if !cgiTested { cgiTested = true cgiWorks = exec.Command("./testdata/test.cgi").Run() == nil @@ -81,16 +88,12 @@ func skipTest(t *testing.T) bool { if !cgiWorks { // No Perl on Windows, needed by test.cgi // TODO: make the child process be Go, not Perl. - t.Logf("Skipping test: test.cgi failed.") - return true + t.Skip("Skipping test: test.cgi failed.") } - return false } func TestCGIBasicGet(t *testing.T) { - if skipTest(t) { - return - } + check(t) h := &Handler{ Path: "testdata/test.cgi", Root: "/test.cgi", @@ -124,9 +127,7 @@ func TestCGIBasicGet(t *testing.T) { } func TestCGIBasicGetAbsPath(t *testing.T) { - if skipTest(t) { - return - } + check(t) pwd, err := os.Getwd() if err != nil { t.Fatalf("getwd error: %v", err) @@ -144,9 +145,7 @@ func TestCGIBasicGetAbsPath(t *testing.T) { } func TestPathInfo(t *testing.T) { - if skipTest(t) { - return - } + check(t) h := &Handler{ Path: "testdata/test.cgi", Root: "/test.cgi", @@ -163,9 +162,7 @@ func TestPathInfo(t *testing.T) { } func TestPathInfoDirRoot(t *testing.T) { - if skipTest(t) { - return - } + check(t) h := &Handler{ Path: "testdata/test.cgi", Root: "/myscript/", @@ -181,9 +178,7 @@ func TestPathInfoDirRoot(t *testing.T) { } func TestDupHeaders(t *testing.T) { - if skipTest(t) { - return - } + check(t) h := &Handler{ Path: "testdata/test.cgi", } @@ -203,9 +198,7 @@ func TestDupHeaders(t *testing.T) { } func TestPathInfoNoRoot(t *testing.T) { - if skipTest(t) { - return - } + check(t) h := &Handler{ Path: "testdata/test.cgi", Root: "", @@ -221,9 +214,7 @@ func TestPathInfoNoRoot(t *testing.T) { } func TestCGIBasicPost(t *testing.T) { - if skipTest(t) { - return - } + check(t) postReq := `POST /test.cgi?a=b HTTP/1.0 Host: example.com Content-Type: application/x-www-form-urlencoded @@ -250,9 +241,7 @@ func chunk(s string) string { // The CGI spec doesn't allow chunked requests. func TestCGIPostChunked(t *testing.T) { - if skipTest(t) { - return - } + check(t) postReq := `POST /test.cgi?a=b HTTP/1.1 Host: example.com Content-Type: application/x-www-form-urlencoded @@ -273,9 +262,7 @@ Transfer-Encoding: chunked } func TestRedirect(t *testing.T) { - if skipTest(t) { - return - } + check(t) h := &Handler{ Path: "testdata/test.cgi", Root: "/test.cgi", @@ -290,9 +277,7 @@ func TestRedirect(t *testing.T) { } func TestInternalRedirect(t *testing.T) { - if skipTest(t) { - return - } + check(t) baseHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { fmt.Fprintf(rw, "basepath=%s\n", req.URL.Path) fmt.Fprintf(rw, "remoteaddr=%s\n", req.RemoteAddr) @@ -312,8 +297,9 @@ func TestInternalRedirect(t *testing.T) { // TestCopyError tests that we kill the process if there's an error copying // its output. (for example, from the client having gone away) func TestCopyError(t *testing.T) { - if skipTest(t) || runtime.GOOS == "windows" { - return + check(t) + if runtime.GOOS == "windows" { + t.Skipf("skipping test on %q", runtime.GOOS) } h := &Handler{ Path: "testdata/test.cgi", @@ -353,11 +339,7 @@ func TestCopyError(t *testing.T) { } childRunning := func() bool { - p, err := os.FindProcess(pid) - if err != nil { - return false - } - return p.Signal(syscall.Signal(0)) == nil + return isProcessRunning(t, pid) } if !childRunning() { @@ -376,10 +358,10 @@ func TestCopyError(t *testing.T) { } func TestDirUnix(t *testing.T) { - if skipTest(t) || runtime.GOOS == "windows" { - return + check(t) + if runtime.GOOS == "windows" { + t.Skipf("skipping test on %q", runtime.GOOS) } - cwd, _ := os.Getwd() h := &Handler{ Path: "testdata/test.cgi", @@ -404,8 +386,8 @@ func TestDirUnix(t *testing.T) { } func TestDirWindows(t *testing.T) { - if skipTest(t) || runtime.GOOS != "windows" { - return + if runtime.GOOS != "windows" { + t.Skip("Skipping windows specific test.") } cgifile, _ := filepath.Abs("testdata/test.cgi") @@ -414,7 +396,7 @@ func TestDirWindows(t *testing.T) { var err error perl, err = exec.LookPath("perl") if err != nil { - return + t.Skip("Skipping test: perl not found.") } perl, _ = filepath.Abs(perl) @@ -456,7 +438,7 @@ func TestEnvOverride(t *testing.T) { var err error perl, err = exec.LookPath("perl") if err != nil { - return + t.Skipf("Skipping test: perl not found.") } perl, _ = filepath.Abs(perl) diff --git a/src/pkg/net/http/cgi/plan9_test.go b/src/pkg/net/http/cgi/plan9_test.go new file mode 100644 index 000000000..c8235831b --- /dev/null +++ b/src/pkg/net/http/cgi/plan9_test.go @@ -0,0 +1,18 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build plan9 + +package cgi + +import ( + "os" + "strconv" + "testing" +) + +func isProcessRunning(t *testing.T, pid int) bool { + _, err := os.Stat("/proc/" + strconv.Itoa(pid)) + return err == nil +} diff --git a/src/pkg/net/http/cgi/posix_test.go b/src/pkg/net/http/cgi/posix_test.go new file mode 100644 index 000000000..5ff9e7d5e --- /dev/null +++ b/src/pkg/net/http/cgi/posix_test.go @@ -0,0 +1,21 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !plan9 + +package cgi + +import ( + "os" + "syscall" + "testing" +) + +func isProcessRunning(t *testing.T, pid int) bool { + p, err := os.FindProcess(pid) + if err != nil { + return false + } + return p.Signal(syscall.Signal(0)) == nil +} diff --git a/src/pkg/net/http/cgi/testdata/test.cgi b/src/pkg/net/http/cgi/testdata/test.cgi index b46b1330f..3214df6f0 100755 --- a/src/pkg/net/http/cgi/testdata/test.cgi +++ b/src/pkg/net/http/cgi/testdata/test.cgi @@ -8,6 +8,8 @@ use strict; use Cwd; +binmode STDOUT; + my $q = MiniCGI->new; my $params = $q->Vars; @@ -16,51 +18,44 @@ if ($params->{"loc"}) { exit(0); } -my $NL = "\r\n"; -$NL = "\n" if $params->{mode} eq "NL"; - -my $p = sub { - print "$_[0]$NL"; -}; - -# With carriage returns -$p->("Content-Type: text/html"); -$p->("X-CGI-Pid: $$"); -$p->("X-Test-Header: X-Test-Value"); -$p->(""); +print "Content-Type: text/html\r\n"; +print "X-CGI-Pid: $$\r\n"; +print "X-Test-Header: X-Test-Value\r\n"; +print "\r\n"; if ($params->{"bigresponse"}) { - for (1..1024) { - print "A" x 1024, "\n"; + # 17 MB, for OS X: golang.org/issue/4958 + for (1..(17 * 1024)) { + print "A" x 1024, "\r\n"; } exit 0; } -print "test=Hello CGI\n"; +print "test=Hello CGI\r\n"; foreach my $k (sort keys %$params) { - print "param-$k=$params->{$k}\n"; + print "param-$k=$params->{$k}\r\n"; } foreach my $k (sort keys %ENV) { - my $clean_env = $ENV{$k}; - $clean_env =~ s/[\n\r]//g; - print "env-$k=$clean_env\n"; + my $clean_env = $ENV{$k}; + $clean_env =~ s/[\n\r]//g; + print "env-$k=$clean_env\r\n"; } -# NOTE: don't call getcwd() for windows. -# msys return /c/go/src/... not C:\go\... -my $dir; +# NOTE: msys perl returns /c/go/src/... not C:\go\.... +my $dir = getcwd(); if ($^O eq 'MSWin32' || $^O eq 'msys') { - my $cmd = $ENV{'COMSPEC'} || 'c:\\windows\\system32\\cmd.exe'; - $cmd =~ s!\\!/!g; - $dir = `$cmd /c cd`; - chomp $dir; -} else { - $dir = getcwd(); + if ($dir =~ /^.:/) { + $dir =~ s!/!\\!g; + } else { + my $cmd = $ENV{'COMSPEC'} || 'c:\\windows\\system32\\cmd.exe'; + $cmd =~ s!\\!/!g; + $dir = `$cmd /c cd`; + chomp $dir; + } } -print "cwd=$dir\n"; - +print "cwd=$dir\r\n"; # A minimal version of CGI.pm, for people without the perl-modules # package installed. (CGI.pm used to be part of the Perl core, but diff --git a/src/pkg/net/http/chunked.go b/src/pkg/net/http/chunked.go index 60a478fd8..91db01724 100644 --- a/src/pkg/net/http/chunked.go +++ b/src/pkg/net/http/chunked.go @@ -11,10 +11,9 @@ package http import ( "bufio" - "bytes" "errors" + "fmt" "io" - "strconv" ) const maxLineLength = 4096 // assumed <= bufio.defaultBufSize @@ -22,7 +21,7 @@ const maxLineLength = 4096 // assumed <= bufio.defaultBufSize var ErrLineTooLong = errors.New("header line too long") // newChunkedReader returns a new chunkedReader that translates the data read from r -// out of HTTP "chunked" format before returning it. +// out of HTTP "chunked" format before returning it. // The chunkedReader returns io.EOF when the final 0-length chunk is read. // // newChunkedReader is not needed by normal applications. The http package @@ -39,16 +38,17 @@ type chunkedReader struct { r *bufio.Reader n uint64 // unread bytes in chunk err error + buf [2]byte } func (cr *chunkedReader) beginChunk() { // chunk-size CRLF - var line string + var line []byte line, cr.err = readLine(cr.r) if cr.err != nil { return } - cr.n, cr.err = strconv.ParseUint(line, 16, 64) + cr.n, cr.err = parseHexUint(line) if cr.err != nil { return } @@ -74,9 +74,8 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) { cr.n -= uint64(n) if cr.n == 0 && cr.err == nil { // end of chunk (CRLF) - b := make([]byte, 2) - if _, cr.err = io.ReadFull(cr.r, b); cr.err == nil { - if b[0] != '\r' || b[1] != '\n' { + if _, cr.err = io.ReadFull(cr.r, cr.buf[:]); cr.err == nil { + if cr.buf[0] != '\r' || cr.buf[1] != '\n' { cr.err = errors.New("malformed chunked encoding") } } @@ -88,7 +87,7 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) { // Give up if the line exceeds maxLineLength. // The returned bytes are a pointer into storage in // the bufio, so they are only valid until the next bufio read. -func readLineBytes(b *bufio.Reader) (p []byte, err error) { +func readLine(b *bufio.Reader) (p []byte, err error) { if p, err = b.ReadSlice('\n'); err != nil { // We always know when EOF is coming. // If the caller asked for a line, there should be a line. @@ -102,20 +101,18 @@ func readLineBytes(b *bufio.Reader) (p []byte, err error) { if len(p) >= maxLineLength { return nil, ErrLineTooLong } - - // Chop off trailing white space. - p = bytes.TrimRight(p, " \r\t\n") - - return p, nil + return trimTrailingWhitespace(p), nil } -// readLineBytes, but convert the bytes into a string. -func readLine(b *bufio.Reader) (s string, err error) { - p, e := readLineBytes(b) - if e != nil { - return "", e +func trimTrailingWhitespace(b []byte) []byte { + for len(b) > 0 && isASCIISpace(b[len(b)-1]) { + b = b[:len(b)-1] } - return string(p), nil + return b +} + +func isASCIISpace(b byte) bool { + return b == ' ' || b == '\t' || b == '\n' || b == '\r' } // newChunkedWriter returns a new chunkedWriter that translates writes into HTTP @@ -147,9 +144,7 @@ func (cw *chunkedWriter) Write(data []byte) (n int, err error) { return 0, nil } - head := strconv.FormatInt(int64(len(data)), 16) + "\r\n" - - if _, err = io.WriteString(cw.Wire, head); err != nil { + if _, err = fmt.Fprintf(cw.Wire, "%x\r\n", len(data)); err != nil { return 0, err } if n, err = cw.Wire.Write(data); err != nil { @@ -168,3 +163,21 @@ func (cw *chunkedWriter) Close() error { _, err := io.WriteString(cw.Wire, "0\r\n") return err } + +func parseHexUint(v []byte) (n uint64, err error) { + for _, b := range v { + n <<= 4 + switch { + case '0' <= b && b <= '9': + b = b - '0' + case 'a' <= b && b <= 'f': + b = b - 'a' + 10 + case 'A' <= b && b <= 'F': + b = b - 'A' + 10 + default: + return 0, errors.New("invalid byte in chunk length") + } + n |= uint64(b) + } + return +} diff --git a/src/pkg/net/http/chunked_test.go b/src/pkg/net/http/chunked_test.go index b77ee2ff2..0b18c7b55 100644 --- a/src/pkg/net/http/chunked_test.go +++ b/src/pkg/net/http/chunked_test.go @@ -9,7 +9,10 @@ package http import ( "bytes" + "fmt" + "io" "io/ioutil" + "runtime" "testing" ) @@ -37,3 +40,54 @@ func TestChunk(t *testing.T) { t.Errorf("chunk reader read %q; want %q", g, e) } } + +func TestChunkReaderAllocs(t *testing.T) { + // temporarily set GOMAXPROCS to 1 as we are testing memory allocations + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) + var buf bytes.Buffer + w := newChunkedWriter(&buf) + a, b, c := []byte("aaaaaa"), []byte("bbbbbbbbbbbb"), []byte("cccccccccccccccccccccccc") + w.Write(a) + w.Write(b) + w.Write(c) + w.Close() + + r := newChunkedReader(&buf) + readBuf := make([]byte, len(a)+len(b)+len(c)+1) + + var ms runtime.MemStats + runtime.ReadMemStats(&ms) + m0 := ms.Mallocs + + n, err := io.ReadFull(r, readBuf) + + runtime.ReadMemStats(&ms) + mallocs := ms.Mallocs - m0 + if mallocs > 1 { + t.Errorf("%d mallocs; want <= 1", mallocs) + } + + if n != len(readBuf)-1 { + t.Errorf("read %d bytes; want %d", n, len(readBuf)-1) + } + if err != io.ErrUnexpectedEOF { + t.Errorf("read error = %v; want ErrUnexpectedEOF", err) + } +} + +func TestParseHexUint(t *testing.T) { + for i := uint64(0); i <= 1234; i++ { + line := []byte(fmt.Sprintf("%x", i)) + got, err := parseHexUint(line) + if err != nil { + t.Fatalf("on %d: %v", i, err) + } + if got != i { + t.Errorf("for input %q = %d; want %d", line, got, i) + } + } + _, err := parseHexUint([]byte("bogus")) + if err == nil { + t.Error("expected error on bogus input") + } +} diff --git a/src/pkg/net/http/client.go b/src/pkg/net/http/client.go index 54564e098..5ee0804c7 100644 --- a/src/pkg/net/http/client.go +++ b/src/pkg/net/http/client.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. // HTTP client. See RFC 2616. -// +// // This is the high-level Client interface. // The low-level implementation is in transport.go. @@ -14,6 +14,7 @@ import ( "errors" "fmt" "io" + "log" "net/url" "strings" ) @@ -32,17 +33,19 @@ type Client struct { // CheckRedirect specifies the policy for handling redirects. // If CheckRedirect is not nil, the client calls it before - // following an HTTP redirect. The arguments req and via - // are the upcoming request and the requests made already, - // oldest first. If CheckRedirect returns an error, the client - // returns that error instead of issue the Request req. + // following an HTTP redirect. The arguments req and via are + // the upcoming request and the requests made already, oldest + // first. If CheckRedirect returns an error, the Client's Get + // method returns both the previous Response and + // CheckRedirect's error (wrapped in a url.Error) instead of + // issuing the Request req. // // If CheckRedirect is nil, the Client uses its default policy, // which is to stop after 10 consecutive requests. CheckRedirect func(req *Request, via []*Request) error - // Jar specifies the cookie jar. - // If Jar is nil, cookies are not sent in requests and ignored + // Jar specifies the cookie jar. + // If Jar is nil, cookies are not sent in requests and ignored // in responses. Jar CookieJar } @@ -84,10 +87,32 @@ type readClose struct { io.Closer } +func (c *Client) send(req *Request) (*Response, error) { + if c.Jar != nil { + for _, cookie := range c.Jar.Cookies(req.URL) { + req.AddCookie(cookie) + } + } + resp, err := send(req, c.Transport) + if err != nil { + return nil, err + } + if c.Jar != nil { + if rc := resp.Cookies(); len(rc) > 0 { + c.Jar.SetCookies(req.URL, rc) + } + } + return resp, err +} + // Do sends an HTTP request and returns an HTTP response, following // policy (e.g. redirects, cookies, auth) as configured on the client. // -// A non-nil response always contains a non-nil resp.Body. +// An error is returned if caused by client policy (such as +// CheckRedirect), or if there was an HTTP protocol error. +// A non-2xx response doesn't cause an error. +// +// When err is nil, resp always contains a non-nil resp.Body. // // Callers should close resp.Body when done reading from it. If // resp.Body is not closed, the Client's underlying RoundTripper @@ -97,12 +122,16 @@ type readClose struct { // Generally Get, Post, or PostForm will be used instead of Do. func (c *Client) Do(req *Request) (resp *Response, err error) { if req.Method == "GET" || req.Method == "HEAD" { - return c.doFollowingRedirects(req) + return c.doFollowingRedirects(req, shouldRedirectGet) } - return send(req, c.Transport) + if req.Method == "POST" || req.Method == "PUT" { + return c.doFollowingRedirects(req, shouldRedirectPost) + } + return c.send(req) } -// send issues an HTTP request. Caller should close resp.Body when done reading from it. +// send issues an HTTP request. +// Caller should close resp.Body when done reading from it. func send(req *Request, t RoundTripper) (resp *Response, err error) { if t == nil { t = DefaultTransport @@ -130,12 +159,19 @@ func send(req *Request, t RoundTripper) (resp *Response, err error) { if u := req.URL.User; u != nil { req.Header.Set("Authorization", "Basic "+base64.URLEncoding.EncodeToString([]byte(u.String()))) } - return t.RoundTrip(req) + resp, err = t.RoundTrip(req) + if err != nil { + if resp != nil { + log.Printf("RoundTripper returned a response & error; ignoring response") + } + return nil, err + } + return resp, nil } // True if the specified HTTP status code is one for which the Get utility should // automatically redirect. -func shouldRedirect(statusCode int) bool { +func shouldRedirectGet(statusCode int) bool { switch statusCode { case StatusMovedPermanently, StatusFound, StatusSeeOther, StatusTemporaryRedirect: return true @@ -143,6 +179,16 @@ func shouldRedirect(statusCode int) bool { return false } +// True if the specified HTTP status code is one for which the Post utility should +// automatically redirect. +func shouldRedirectPost(statusCode int) bool { + switch statusCode { + case StatusFound, StatusSeeOther: + return true + } + return false +} + // Get issues a GET to the specified URL. If the response is one of the following // redirect codes, Get follows the redirect, up to a maximum of 10 redirects: // @@ -151,10 +197,15 @@ func shouldRedirect(statusCode int) bool { // 303 (See Other) // 307 (Temporary Redirect) // -// Caller should close r.Body when done reading from it. +// An error is returned if there were too many redirects or if there +// was an HTTP protocol error. A non-2xx response doesn't cause an +// error. +// +// When err is nil, resp always contains a non-nil resp.Body. +// Caller should close resp.Body when done reading from it. // // Get is a wrapper around DefaultClient.Get. -func Get(url string) (r *Response, err error) { +func Get(url string) (resp *Response, err error) { return DefaultClient.Get(url) } @@ -167,18 +218,21 @@ func Get(url string) (r *Response, err error) { // 303 (See Other) // 307 (Temporary Redirect) // -// Caller should close r.Body when done reading from it. -func (c *Client) Get(url string) (r *Response, err error) { +// An error is returned if the Client's CheckRedirect function fails +// or if there was an HTTP protocol error. A non-2xx response doesn't +// cause an error. +// +// When err is nil, resp always contains a non-nil resp.Body. +// Caller should close resp.Body when done reading from it. +func (c *Client) Get(url string) (resp *Response, err error) { req, err := NewRequest("GET", url, nil) if err != nil { return nil, err } - return c.doFollowingRedirects(req) + return c.doFollowingRedirects(req, shouldRedirectGet) } -func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err error) { - // TODO: if/when we add cookie support, the redirected request shouldn't - // necessarily supply the same cookies as the original. +func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bool) (resp *Response, err error) { var base *url.URL redirectChecker := c.CheckRedirect if redirectChecker == nil { @@ -190,17 +244,16 @@ func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err error) { return nil, errors.New("http: nil Request.URL") } - jar := c.Jar - if jar == nil { - jar = blackHoleJar{} - } - req := ireq urlStr := "" // next relative or absolute URL to fetch (after first request) + redirectFailed := false for redirect := 0; ; redirect++ { if redirect != 0 { req = new(Request) req.Method = ireq.Method + if ireq.Method == "POST" || ireq.Method == "PUT" { + req.Method = "GET" + } req.Header = make(Header) req.URL, err = base.Parse(urlStr) if err != nil { @@ -215,26 +268,21 @@ func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err error) { err = redirectChecker(req, via) if err != nil { + redirectFailed = true break } } } - for _, cookie := range jar.Cookies(req.URL) { - req.AddCookie(cookie) - } urlStr = req.URL.String() - if r, err = send(req, c.Transport); err != nil { + if resp, err = c.send(req); err != nil { break } - if c := r.Cookies(); len(c) > 0 { - jar.SetCookies(req.URL, c) - } - if shouldRedirect(r.StatusCode) { - r.Body.Close() - if urlStr = r.Header.Get("Location"); urlStr == "" { - err = errors.New(fmt.Sprintf("%d response missing Location header", r.StatusCode)) + if shouldRedirect(resp.StatusCode) { + resp.Body.Close() + if urlStr = resp.Header.Get("Location"); urlStr == "" { + err = errors.New(fmt.Sprintf("%d response missing Location header", resp.StatusCode)) break } base = req.URL @@ -245,12 +293,23 @@ func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err error) { } method := ireq.Method - err = &url.Error{ + urlErr := &url.Error{ Op: method[0:1] + strings.ToLower(method[1:]), URL: urlStr, Err: err, } - return + + if redirectFailed { + // Special case for Go 1 compatibility: return both the response + // and an error if the CheckRedirect function failed. + // See http://golang.org/issue/3795 + return resp, urlErr + } + + if resp != nil { + resp.Body.Close() + } + return nil, urlErr } func defaultCheckRedirect(req *Request, via []*Request) error { @@ -262,49 +321,42 @@ func defaultCheckRedirect(req *Request, via []*Request) error { // Post issues a POST to the specified URL. // -// Caller should close r.Body when done reading from it. +// Caller should close resp.Body when done reading from it. // // Post is a wrapper around DefaultClient.Post -func Post(url string, bodyType string, body io.Reader) (r *Response, err error) { +func Post(url string, bodyType string, body io.Reader) (resp *Response, err error) { return DefaultClient.Post(url, bodyType, body) } // Post issues a POST to the specified URL. // -// Caller should close r.Body when done reading from it. -func (c *Client) Post(url string, bodyType string, body io.Reader) (r *Response, err error) { +// Caller should close resp.Body when done reading from it. +func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Response, err error) { req, err := NewRequest("POST", url, body) if err != nil { return nil, err } req.Header.Set("Content-Type", bodyType) - if c.Jar != nil { - for _, cookie := range c.Jar.Cookies(req.URL) { - req.AddCookie(cookie) - } - } - r, err = send(req, c.Transport) - if err == nil && c.Jar != nil { - c.Jar.SetCookies(req.URL, r.Cookies()) - } - return r, err + return c.doFollowingRedirects(req, shouldRedirectPost) } -// PostForm issues a POST to the specified URL, -// with data's keys and values urlencoded as the request body. +// PostForm issues a POST to the specified URL, with data's keys and +// values URL-encoded as the request body. // -// Caller should close r.Body when done reading from it. +// When err is nil, resp always contains a non-nil resp.Body. +// Caller should close resp.Body when done reading from it. // // PostForm is a wrapper around DefaultClient.PostForm -func PostForm(url string, data url.Values) (r *Response, err error) { +func PostForm(url string, data url.Values) (resp *Response, err error) { return DefaultClient.PostForm(url, data) } -// PostForm issues a POST to the specified URL, +// PostForm issues a POST to the specified URL, // with data's keys and values urlencoded as the request body. // -// Caller should close r.Body when done reading from it. -func (c *Client) PostForm(url string, data url.Values) (r *Response, err error) { +// When err is nil, resp always contains a non-nil resp.Body. +// Caller should close resp.Body when done reading from it. +func (c *Client) PostForm(url string, data url.Values) (resp *Response, err error) { return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) } @@ -318,7 +370,7 @@ func (c *Client) PostForm(url string, data url.Values) (r *Response, err error) // 307 (Temporary Redirect) // // Head is a wrapper around DefaultClient.Head -func Head(url string) (r *Response, err error) { +func Head(url string) (resp *Response, err error) { return DefaultClient.Head(url) } @@ -330,10 +382,10 @@ func Head(url string) (r *Response, err error) { // 302 (Found) // 303 (See Other) // 307 (Temporary Redirect) -func (c *Client) Head(url string) (r *Response, err error) { +func (c *Client) Head(url string) (resp *Response, err error) { req, err := NewRequest("HEAD", url, nil) if err != nil { return nil, err } - return c.doFollowingRedirects(req) + return c.doFollowingRedirects(req, shouldRedirectGet) } diff --git a/src/pkg/net/http/client_test.go b/src/pkg/net/http/client_test.go index 9b4261b9f..88649bb16 100644 --- a/src/pkg/net/http/client_test.go +++ b/src/pkg/net/http/client_test.go @@ -7,7 +7,9 @@ package http_test import ( + "bytes" "crypto/tls" + "crypto/x509" "errors" "fmt" "io" @@ -53,6 +55,7 @@ func pedanticReadAll(r io.Reader) (b []byte, err error) { } func TestClient(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(robotsTxtHandler) defer ts.Close() @@ -70,6 +73,7 @@ func TestClient(t *testing.T) { } func TestClientHead(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(robotsTxtHandler) defer ts.Close() @@ -92,6 +96,7 @@ func (t *recordingTransport) RoundTrip(req *Request) (resp *Response, err error) } func TestGetRequestFormat(t *testing.T) { + defer checkLeakedTransports(t) tr := &recordingTransport{} client := &Client{Transport: tr} url := "http://dummy.faketld/" @@ -108,6 +113,7 @@ func TestGetRequestFormat(t *testing.T) { } func TestPostRequestFormat(t *testing.T) { + defer checkLeakedTransports(t) tr := &recordingTransport{} client := &Client{Transport: tr} @@ -134,6 +140,7 @@ func TestPostRequestFormat(t *testing.T) { } func TestPostFormRequestFormat(t *testing.T) { + defer checkLeakedTransports(t) tr := &recordingTransport{} client := &Client{Transport: tr} @@ -175,6 +182,7 @@ func TestPostFormRequestFormat(t *testing.T) { } func TestRedirects(t *testing.T) { + defer checkLeakedTransports(t) var ts *httptest.Server ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { n, _ := strconv.Atoi(r.FormValue("n")) @@ -218,6 +226,10 @@ func TestRedirects(t *testing.T) { return checkErr }} res, err := c.Get(ts.URL) + if err != nil { + t.Fatalf("Get error: %v", err) + } + res.Body.Close() finalUrl := res.Request.URL.String() if e, g := "<nil>", fmt.Sprintf("%v", err); e != g { t.Errorf("with custom client, expected error %q, got %q", e, g) @@ -231,9 +243,63 @@ func TestRedirects(t *testing.T) { checkErr = errors.New("no redirects allowed") res, err = c.Get(ts.URL) - finalUrl = res.Request.URL.String() - if e, g := "Get /?n=1: no redirects allowed", fmt.Sprintf("%v", err); e != g { - t.Errorf("with redirects forbidden, expected error %q, got %q", e, g) + if urlError, ok := err.(*url.Error); !ok || urlError.Err != checkErr { + t.Errorf("with redirects forbidden, expected a *url.Error with our 'no redirects allowed' error inside; got %#v (%q)", err, err) + } + if res == nil { + t.Fatalf("Expected a non-nil Response on CheckRedirect failure (http://golang.org/issue/3795)") + } + res.Body.Close() + if res.Header.Get("Location") == "" { + t.Errorf("no Location header in Response") + } +} + +func TestPostRedirects(t *testing.T) { + defer checkLeakedTransports(t) + var log struct { + sync.Mutex + bytes.Buffer + } + var ts *httptest.Server + ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + log.Lock() + fmt.Fprintf(&log.Buffer, "%s %s ", r.Method, r.RequestURI) + log.Unlock() + if v := r.URL.Query().Get("code"); v != "" { + code, _ := strconv.Atoi(v) + if code/100 == 3 { + w.Header().Set("Location", ts.URL) + } + w.WriteHeader(code) + } + })) + defer ts.Close() + tests := []struct { + suffix string + want int // response code + }{ + {"/", 200}, + {"/?code=301", 301}, + {"/?code=302", 200}, + {"/?code=303", 200}, + {"/?code=404", 404}, + } + for _, tt := range tests { + res, err := Post(ts.URL+tt.suffix, "text/plain", strings.NewReader("Some content")) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != tt.want { + t.Errorf("POST %s: status code = %d; want %d", tt.suffix, res.StatusCode, tt.want) + } + } + log.Lock() + got := log.String() + log.Unlock() + want := "POST / POST /?code=301 POST /?code=302 GET / POST /?code=303 GET / POST /?code=404 " + if got != want { + t.Errorf("Log differs.\n Got: %q\nWant: %q", got, want) } } @@ -279,6 +345,10 @@ func TestClientSendsCookieFromJar(t *testing.T) { req, _ := NewRequest("GET", us, nil) client.Do(req) // Note: doesn't hit network matchReturnedCookies(t, expectedCookies, tr.req.Cookies()) + + req, _ = NewRequest("POST", us, nil) + client.Do(req) // Note: doesn't hit network + matchReturnedCookies(t, expectedCookies, tr.req.Cookies()) } // Just enough correctness for our redirect tests. Uses the URL.Host as the @@ -291,6 +361,9 @@ type TestJar struct { func (j *TestJar) SetCookies(u *url.URL, cookies []*Cookie) { j.m.Lock() defer j.m.Unlock() + if j.perURL == nil { + j.perURL = make(map[string][]*Cookie) + } j.perURL[u.Host] = cookies } @@ -301,6 +374,7 @@ func (j *TestJar) Cookies(u *url.URL) []*Cookie { } func TestRedirectCookiesOnRequest(t *testing.T) { + defer checkLeakedTransports(t) var ts *httptest.Server ts = httptest.NewServer(echoCookiesRedirectHandler) defer ts.Close() @@ -318,14 +392,20 @@ func TestRedirectCookiesOnRequest(t *testing.T) { } func TestRedirectCookiesJar(t *testing.T) { + defer checkLeakedTransports(t) var ts *httptest.Server ts = httptest.NewServer(echoCookiesRedirectHandler) defer ts.Close() - c := &Client{} - c.Jar = &TestJar{perURL: make(map[string][]*Cookie)} + c := &Client{ + Jar: new(TestJar), + } u, _ := url.Parse(ts.URL) c.Jar.SetCookies(u, []*Cookie{expectedCookies[0]}) - resp, _ := c.Get(ts.URL) + resp, err := c.Get(ts.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + resp.Body.Close() matchReturnedCookies(t, expectedCookies, resp.Cookies()) } @@ -348,7 +428,72 @@ func matchReturnedCookies(t *testing.T, expected, given []*Cookie) { } } +func TestJarCalls(t *testing.T) { + defer checkLeakedTransports(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + pathSuffix := r.RequestURI[1:] + if r.RequestURI == "/nosetcookie" { + return // dont set cookies for this path + } + SetCookie(w, &Cookie{Name: "name" + pathSuffix, Value: "val" + pathSuffix}) + if r.RequestURI == "/" { + Redirect(w, r, "http://secondhost.fake/secondpath", 302) + } + })) + defer ts.Close() + jar := new(RecordingJar) + c := &Client{ + Jar: jar, + Transport: &Transport{ + Dial: func(_ string, _ string) (net.Conn, error) { + return net.Dial("tcp", ts.Listener.Addr().String()) + }, + }, + } + _, err := c.Get("http://firsthost.fake/") + if err != nil { + t.Fatal(err) + } + _, err = c.Get("http://firsthost.fake/nosetcookie") + if err != nil { + t.Fatal(err) + } + got := jar.log.String() + want := `Cookies("http://firsthost.fake/") +SetCookie("http://firsthost.fake/", [name=val]) +Cookies("http://secondhost.fake/secondpath") +SetCookie("http://secondhost.fake/secondpath", [namesecondpath=valsecondpath]) +Cookies("http://firsthost.fake/nosetcookie") +` + if got != want { + t.Errorf("Got Jar calls:\n%s\nWant:\n%s", got, want) + } +} + +// RecordingJar keeps a log of calls made to it, without +// tracking any cookies. +type RecordingJar struct { + mu sync.Mutex + log bytes.Buffer +} + +func (j *RecordingJar) SetCookies(u *url.URL, cookies []*Cookie) { + j.logf("SetCookie(%q, %v)\n", u, cookies) +} + +func (j *RecordingJar) Cookies(u *url.URL) []*Cookie { + j.logf("Cookies(%q)\n", u) + return nil +} + +func (j *RecordingJar) logf(format string, args ...interface{}) { + j.mu.Lock() + defer j.mu.Unlock() + fmt.Fprintf(&j.log, format, args...) +} + func TestStreamingGet(t *testing.T) { + defer checkLeakedTransports(t) say := make(chan string) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.(Flusher).Flush() @@ -399,6 +544,7 @@ func (c *writeCountingConn) Write(p []byte) (int, error) { // TestClientWrites verifies that client requests are buffered and we // don't send a TCP packet per line of the http request + body. func TestClientWrites(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { })) defer ts.Close() @@ -432,6 +578,7 @@ func TestClientWrites(t *testing.T) { } func TestClientInsecureTransport(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello")) })) @@ -446,15 +593,20 @@ func TestClientInsecureTransport(t *testing.T) { InsecureSkipVerify: insecure, }, } + defer tr.CloseIdleConnections() c := &Client{Transport: tr} - _, err := c.Get(ts.URL) + res, err := c.Get(ts.URL) if (err == nil) != insecure { t.Errorf("insecure=%v: got unexpected err=%v", insecure, err) } + if res != nil { + res.Body.Close() + } } } func TestClientErrorWithRequestURI(t *testing.T) { + defer checkLeakedTransports(t) req, _ := NewRequest("GET", "http://localhost:1234/", nil) req.RequestURI = "/this/field/is/illegal/and/should/error/" _, err := DefaultClient.Do(req) @@ -465,3 +617,87 @@ func TestClientErrorWithRequestURI(t *testing.T) { t.Errorf("wanted error mentioning RequestURI; got error: %v", err) } } + +func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport { + certs := x509.NewCertPool() + for _, c := range ts.TLS.Certificates { + roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) + if err != nil { + t.Fatalf("error parsing server's root cert: %v", err) + } + for _, root := range roots { + certs.AddCert(root) + } + } + return &Transport{ + TLSClientConfig: &tls.Config{RootCAs: certs}, + } +} + +func TestClientWithCorrectTLSServerName(t *testing.T) { + defer checkLeakedTransports(t) + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.TLS.ServerName != "127.0.0.1" { + t.Errorf("expected client to set ServerName 127.0.0.1, got: %q", r.TLS.ServerName) + } + })) + defer ts.Close() + + c := &Client{Transport: newTLSTransport(t, ts)} + if _, err := c.Get(ts.URL); err != nil { + t.Fatalf("expected successful TLS connection, got error: %v", err) + } +} + +func TestClientWithIncorrectTLSServerName(t *testing.T) { + defer checkLeakedTransports(t) + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + defer ts.Close() + + trans := newTLSTransport(t, ts) + trans.TLSClientConfig.ServerName = "badserver" + c := &Client{Transport: trans} + _, err := c.Get(ts.URL) + if err == nil { + t.Fatalf("expected an error") + } + if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") { + t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err) + } +} + +// Verify Response.ContentLength is populated. http://golang.org/issue/4126 +func TestClientHeadContentLength(t *testing.T) { + defer checkLeakedTransports(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if v := r.FormValue("cl"); v != "" { + w.Header().Set("Content-Length", v) + } + })) + defer ts.Close() + tests := []struct { + suffix string + want int64 + }{ + {"/?cl=1234", 1234}, + {"/?cl=0", 0}, + {"", -1}, + } + for _, tt := range tests { + req, _ := NewRequest("HEAD", ts.URL+tt.suffix, nil) + res, err := DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + if res.ContentLength != tt.want { + t.Errorf("Content-Length = %d; want %d", res.ContentLength, tt.want) + } + bs, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if len(bs) != 0 { + t.Errorf("Unexpected content: %q", bs) + } + } +} diff --git a/src/pkg/net/http/cookie.go b/src/pkg/net/http/cookie.go index 2e30bbff1..155b09223 100644 --- a/src/pkg/net/http/cookie.go +++ b/src/pkg/net/http/cookie.go @@ -26,7 +26,7 @@ type Cookie struct { Expires time.Time RawExpires string - // MaxAge=0 means no 'Max-Age' attribute specified. + // MaxAge=0 means no 'Max-Age' attribute specified. // MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0' // MaxAge>0 means Max-Age attribute present and given in seconds MaxAge int @@ -258,10 +258,5 @@ func parseCookieValueUsing(raw string, validByte func(byte) bool) (string, bool) } func isCookieNameValid(raw string) bool { - for _, c := range raw { - if !isToken(byte(c)) { - return false - } - } - return true + return strings.IndexFunc(raw, isNotToken) < 0 } diff --git a/src/pkg/net/http/cookie_test.go b/src/pkg/net/http/cookie_test.go index 1e9186a05..f84f73936 100644 --- a/src/pkg/net/http/cookie_test.go +++ b/src/pkg/net/http/cookie_test.go @@ -217,7 +217,7 @@ var readCookiesTests = []struct { func TestReadCookies(t *testing.T) { for i, tt := range readCookiesTests { - for n := 0; n < 2; n++ { // to verify readCookies doesn't mutate its input + for n := 0; n < 2; n++ { // to verify readCookies doesn't mutate its input c := readCookies(tt.Header, tt.Filter) if !reflect.DeepEqual(c, tt.Cookies) { t.Errorf("#%d readCookies:\nhave: %s\nwant: %s\n", i, toJSON(c), toJSON(tt.Cookies)) diff --git a/src/pkg/net/http/cookiejar/jar.go b/src/pkg/net/http/cookiejar/jar.go new file mode 100644 index 000000000..5d1aeb87f --- /dev/null +++ b/src/pkg/net/http/cookiejar/jar.go @@ -0,0 +1,494 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package cookiejar implements an in-memory RFC 6265-compliant http.CookieJar. +package cookiejar + +import ( + "errors" + "fmt" + "net" + "net/http" + "net/url" + "sort" + "strings" + "sync" + "time" +) + +// PublicSuffixList provides the public suffix of a domain. For example: +// - the public suffix of "example.com" is "com", +// - the public suffix of "foo1.foo2.foo3.co.uk" is "co.uk", and +// - the public suffix of "bar.pvt.k12.ma.us" is "pvt.k12.ma.us". +// +// Implementations of PublicSuffixList must be safe for concurrent use by +// multiple goroutines. +// +// An implementation that always returns "" is valid and may be useful for +// testing but it is not secure: it means that the HTTP server for foo.com can +// set a cookie for bar.com. +type PublicSuffixList interface { + // PublicSuffix returns the public suffix of domain. + // + // TODO: specify which of the caller and callee is responsible for IP + // addresses, for leading and trailing dots, for case sensitivity, and + // for IDN/Punycode. + PublicSuffix(domain string) string + + // String returns a description of the source of this public suffix + // list. The description will typically contain something like a time + // stamp or version number. + String() string +} + +// Options are the options for creating a new Jar. +type Options struct { + // PublicSuffixList is the public suffix list that determines whether + // an HTTP server can set a cookie for a domain. + // + // A nil value is valid and may be useful for testing but it is not + // secure: it means that the HTTP server for foo.co.uk can set a cookie + // for bar.co.uk. + PublicSuffixList PublicSuffixList +} + +// Jar implements the http.CookieJar interface from the net/http package. +type Jar struct { + psList PublicSuffixList + + // mu locks the remaining fields. + mu sync.Mutex + + // entries is a set of entries, keyed by their eTLD+1 and subkeyed by + // their name/domain/path. + entries map[string]map[string]entry + + // nextSeqNum is the next sequence number assigned to a new cookie + // created SetCookies. + nextSeqNum uint64 +} + +// New returns a new cookie jar. A nil *Options is equivalent to a zero +// Options. +func New(o *Options) (*Jar, error) { + jar := &Jar{ + entries: make(map[string]map[string]entry), + } + if o != nil { + jar.psList = o.PublicSuffixList + } + return jar, nil +} + +// entry is the internal representation of a cookie. +// +// This struct type is not used outside of this package per se, but the exported +// fields are those of RFC 6265. +type entry struct { + Name string + Value string + Domain string + Path string + Secure bool + HttpOnly bool + Persistent bool + HostOnly bool + Expires time.Time + Creation time.Time + LastAccess time.Time + + // seqNum is a sequence number so that Cookies returns cookies in a + // deterministic order, even for cookies that have equal Path length and + // equal Creation time. This simplifies testing. + seqNum uint64 +} + +// Id returns the domain;path;name triple of e as an id. +func (e *entry) id() string { + return fmt.Sprintf("%s;%s;%s", e.Domain, e.Path, e.Name) +} + +// shouldSend determines whether e's cookie qualifies to be included in a +// request to host/path. It is the caller's responsibility to check if the +// cookie is expired. +func (e *entry) shouldSend(https bool, host, path string) bool { + return e.domainMatch(host) && e.pathMatch(path) && (https || !e.Secure) +} + +// domainMatch implements "domain-match" of RFC 6265 section 5.1.3. +func (e *entry) domainMatch(host string) bool { + if e.Domain == host { + return true + } + return !e.HostOnly && hasDotSuffix(host, e.Domain) +} + +// pathMatch implements "path-match" according to RFC 6265 section 5.1.4. +func (e *entry) pathMatch(requestPath string) bool { + if requestPath == e.Path { + return true + } + if strings.HasPrefix(requestPath, e.Path) { + if e.Path[len(e.Path)-1] == '/' { + return true // The "/any/" matches "/any/path" case. + } else if requestPath[len(e.Path)] == '/' { + return true // The "/any" matches "/any/path" case. + } + } + return false +} + +// hasDotSuffix returns whether s ends in "."+suffix. +func hasDotSuffix(s, suffix string) bool { + return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix +} + +// byPathLength is a []entry sort.Interface that sorts according to RFC 6265 +// section 5.4 point 2: by longest path and then by earliest creation time. +type byPathLength []entry + +func (s byPathLength) Len() int { return len(s) } + +func (s byPathLength) Less(i, j int) bool { + if len(s[i].Path) != len(s[j].Path) { + return len(s[i].Path) > len(s[j].Path) + } + if !s[i].Creation.Equal(s[j].Creation) { + return s[i].Creation.Before(s[j].Creation) + } + return s[i].seqNum < s[j].seqNum +} + +func (s byPathLength) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// Cookies implements the Cookies method of the http.CookieJar interface. +// +// It returns an empty slice if the URL's scheme is not HTTP or HTTPS. +func (j *Jar) Cookies(u *url.URL) (cookies []*http.Cookie) { + return j.cookies(u, time.Now()) +} + +// cookies is like Cookies but takes the current time as a parameter. +func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) { + if u.Scheme != "http" && u.Scheme != "https" { + return cookies + } + host, err := canonicalHost(u.Host) + if err != nil { + return cookies + } + key := jarKey(host, j.psList) + + j.mu.Lock() + defer j.mu.Unlock() + + submap := j.entries[key] + if submap == nil { + return cookies + } + + https := u.Scheme == "https" + path := u.Path + if path == "" { + path = "/" + } + + modified := false + var selected []entry + for id, e := range submap { + if e.Persistent && !e.Expires.After(now) { + delete(submap, id) + modified = true + continue + } + if !e.shouldSend(https, host, path) { + continue + } + e.LastAccess = now + submap[id] = e + selected = append(selected, e) + modified = true + } + if modified { + if len(submap) == 0 { + delete(j.entries, key) + } else { + j.entries[key] = submap + } + } + + sort.Sort(byPathLength(selected)) + for _, e := range selected { + cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value}) + } + + return cookies +} + +// SetCookies implements the SetCookies method of the http.CookieJar interface. +// +// It does nothing if the URL's scheme is not HTTP or HTTPS. +func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) { + j.setCookies(u, cookies, time.Now()) +} + +// setCookies is like SetCookies but takes the current time as parameter. +func (j *Jar) setCookies(u *url.URL, cookies []*http.Cookie, now time.Time) { + if len(cookies) == 0 { + return + } + if u.Scheme != "http" && u.Scheme != "https" { + return + } + host, err := canonicalHost(u.Host) + if err != nil { + return + } + key := jarKey(host, j.psList) + defPath := defaultPath(u.Path) + + j.mu.Lock() + defer j.mu.Unlock() + + submap := j.entries[key] + + modified := false + for _, cookie := range cookies { + e, remove, err := j.newEntry(cookie, now, defPath, host) + if err != nil { + continue + } + id := e.id() + if remove { + if submap != nil { + if _, ok := submap[id]; ok { + delete(submap, id) + modified = true + } + } + continue + } + if submap == nil { + submap = make(map[string]entry) + } + + if old, ok := submap[id]; ok { + e.Creation = old.Creation + e.seqNum = old.seqNum + } else { + e.Creation = now + e.seqNum = j.nextSeqNum + j.nextSeqNum++ + } + e.LastAccess = now + submap[id] = e + modified = true + } + + if modified { + if len(submap) == 0 { + delete(j.entries, key) + } else { + j.entries[key] = submap + } + } +} + +// canonicalHost strips port from host if present and returns the canonicalized +// host name. +func canonicalHost(host string) (string, error) { + var err error + host = strings.ToLower(host) + if hasPort(host) { + host, _, err = net.SplitHostPort(host) + if err != nil { + return "", err + } + } + if strings.HasSuffix(host, ".") { + // Strip trailing dot from fully qualified domain names. + host = host[:len(host)-1] + } + return toASCII(host) +} + +// hasPort returns whether host contains a port number. host may be a host +// name, an IPv4 or an IPv6 address. +func hasPort(host string) bool { + colons := strings.Count(host, ":") + if colons == 0 { + return false + } + if colons == 1 { + return true + } + return host[0] == '[' && strings.Contains(host, "]:") +} + +// jarKey returns the key to use for a jar. +func jarKey(host string, psl PublicSuffixList) string { + if isIP(host) { + return host + } + + var i int + if psl == nil { + i = strings.LastIndex(host, ".") + if i == -1 { + return host + } + } else { + suffix := psl.PublicSuffix(host) + if suffix == host { + return host + } + i = len(host) - len(suffix) + if i <= 0 || host[i-1] != '.' { + // The provided public suffix list psl is broken. + // Storing cookies under host is a safe stopgap. + return host + } + } + prevDot := strings.LastIndex(host[:i-1], ".") + return host[prevDot+1:] +} + +// isIP returns whether host is an IP address. +func isIP(host string) bool { + return net.ParseIP(host) != nil +} + +// defaultPath returns the directory part of an URL's path according to +// RFC 6265 section 5.1.4. +func defaultPath(path string) string { + if len(path) == 0 || path[0] != '/' { + return "/" // Path is empty or malformed. + } + + i := strings.LastIndex(path, "/") // Path starts with "/", so i != -1. + if i == 0 { + return "/" // Path has the form "/abc". + } + return path[:i] // Path is either of form "/abc/xyz" or "/abc/xyz/". +} + +// newEntry creates an entry from a http.Cookie c. now is the current time and +// is compared to c.Expires to determine deletion of c. defPath and host are the +// default-path and the canonical host name of the URL c was received from. +// +// remove is whether the jar should delete this cookie, as it has already +// expired with respect to now. In this case, e may be incomplete, but it will +// be valid to call e.id (which depends on e's Name, Domain and Path). +// +// A malformed c.Domain will result in an error. +func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e entry, remove bool, err error) { + e.Name = c.Name + + if c.Path == "" || c.Path[0] != '/' { + e.Path = defPath + } else { + e.Path = c.Path + } + + e.Domain, e.HostOnly, err = j.domainAndType(host, c.Domain) + if err != nil { + return e, false, err + } + + // MaxAge takes precedence over Expires. + if c.MaxAge < 0 { + return e, true, nil + } else if c.MaxAge > 0 { + e.Expires = now.Add(time.Duration(c.MaxAge) * time.Second) + e.Persistent = true + } else { + if c.Expires.IsZero() { + e.Expires = endOfTime + e.Persistent = false + } else { + if !c.Expires.After(now) { + return e, true, nil + } + e.Expires = c.Expires + e.Persistent = true + } + } + + e.Value = c.Value + e.Secure = c.Secure + e.HttpOnly = c.HttpOnly + + return e, false, nil +} + +var ( + errIllegalDomain = errors.New("cookiejar: illegal cookie domain attribute") + errMalformedDomain = errors.New("cookiejar: malformed cookie domain attribute") + errNoHostname = errors.New("cookiejar: no host name available (IP only)") +) + +// endOfTime is the time when session (non-persistent) cookies expire. +// This instant is representable in most date/time formats (not just +// Go's time.Time) and should be far enough in the future. +var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC) + +// domainAndType determines the cookie's domain and hostOnly attribute. +func (j *Jar) domainAndType(host, domain string) (string, bool, error) { + if domain == "" { + // No domain attribute in the SetCookie header indicates a + // host cookie. + return host, true, nil + } + + if isIP(host) { + // According to RFC 6265 domain-matching includes not being + // an IP address. + // TODO: This might be relaxed as in common browsers. + return "", false, errNoHostname + } + + // From here on: If the cookie is valid, it is a domain cookie (with + // the one exception of a public suffix below). + // See RFC 6265 section 5.2.3. + if domain[0] == '.' { + domain = domain[1:] + } + + if len(domain) == 0 || domain[0] == '.' { + // Received either "Domain=." or "Domain=..some.thing", + // both are illegal. + return "", false, errMalformedDomain + } + domain = strings.ToLower(domain) + + if domain[len(domain)-1] == '.' { + // We received stuff like "Domain=www.example.com.". + // Browsers do handle such stuff (actually differently) but + // RFC 6265 seems to be clear here (e.g. section 4.1.2.3) in + // requiring a reject. 4.1.2.3 is not normative, but + // "Domain Matching" (5.1.3) and "Canonicalized Host Names" + // (5.1.2) are. + return "", false, errMalformedDomain + } + + // See RFC 6265 section 5.3 #5. + if j.psList != nil { + if ps := j.psList.PublicSuffix(domain); ps != "" && !hasDotSuffix(domain, ps) { + if host == domain { + // This is the one exception in which a cookie + // with a domain attribute is a host cookie. + return host, true, nil + } + return "", false, errIllegalDomain + } + } + + // The domain must domain-match host: www.mycompany.com cannot + // set cookies for .ourcompetitors.com. + if host != domain && !hasDotSuffix(host, domain) { + return "", false, errIllegalDomain + } + + return domain, false, nil +} diff --git a/src/pkg/net/http/cookiejar/jar_test.go b/src/pkg/net/http/cookiejar/jar_test.go new file mode 100644 index 000000000..3aa601586 --- /dev/null +++ b/src/pkg/net/http/cookiejar/jar_test.go @@ -0,0 +1,1267 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cookiejar + +import ( + "fmt" + "net/http" + "net/url" + "sort" + "strings" + "testing" + "time" +) + +// tNow is the synthetic current time used as now during testing. +var tNow = time.Date(2013, 1, 1, 12, 0, 0, 0, time.UTC) + +// testPSL implements PublicSuffixList with just two rules: "co.uk" +// and the default rule "*". +type testPSL struct{} + +func (testPSL) String() string { + return "testPSL" +} +func (testPSL) PublicSuffix(d string) string { + if d == "co.uk" || strings.HasSuffix(d, ".co.uk") { + return "co.uk" + } + return d[strings.LastIndex(d, ".")+1:] +} + +// newTestJar creates an empty Jar with testPSL as the public suffix list. +func newTestJar() *Jar { + jar, err := New(&Options{PublicSuffixList: testPSL{}}) + if err != nil { + panic(err) + } + return jar +} + +var hasDotSuffixTests = [...]struct { + s, suffix string +}{ + {"", ""}, + {"", "."}, + {"", "x"}, + {".", ""}, + {".", "."}, + {".", ".."}, + {".", "x"}, + {".", "x."}, + {".", ".x"}, + {".", ".x."}, + {"x", ""}, + {"x", "."}, + {"x", ".."}, + {"x", "x"}, + {"x", "x."}, + {"x", ".x"}, + {"x", ".x."}, + {".x", ""}, + {".x", "."}, + {".x", ".."}, + {".x", "x"}, + {".x", "x."}, + {".x", ".x"}, + {".x", ".x."}, + {"x.", ""}, + {"x.", "."}, + {"x.", ".."}, + {"x.", "x"}, + {"x.", "x."}, + {"x.", ".x"}, + {"x.", ".x."}, + {"com", ""}, + {"com", "m"}, + {"com", "om"}, + {"com", "com"}, + {"com", ".com"}, + {"com", "x.com"}, + {"com", "xcom"}, + {"com", "xorg"}, + {"com", "org"}, + {"com", "rg"}, + {"foo.com", ""}, + {"foo.com", "m"}, + {"foo.com", "om"}, + {"foo.com", "com"}, + {"foo.com", ".com"}, + {"foo.com", "o.com"}, + {"foo.com", "oo.com"}, + {"foo.com", "foo.com"}, + {"foo.com", ".foo.com"}, + {"foo.com", "x.foo.com"}, + {"foo.com", "xfoo.com"}, + {"foo.com", "xfoo.org"}, + {"foo.com", "foo.org"}, + {"foo.com", "oo.org"}, + {"foo.com", "o.org"}, + {"foo.com", ".org"}, + {"foo.com", "org"}, + {"foo.com", "rg"}, +} + +func TestHasDotSuffix(t *testing.T) { + for _, tc := range hasDotSuffixTests { + got := hasDotSuffix(tc.s, tc.suffix) + want := strings.HasSuffix(tc.s, "."+tc.suffix) + if got != want { + t.Errorf("s=%q, suffix=%q: got %v, want %v", tc.s, tc.suffix, got, want) + } + } +} + +var canonicalHostTests = map[string]string{ + "www.example.com": "www.example.com", + "WWW.EXAMPLE.COM": "www.example.com", + "wWw.eXAmple.CoM": "www.example.com", + "www.example.com:80": "www.example.com", + "192.168.0.10": "192.168.0.10", + "192.168.0.5:8080": "192.168.0.5", + "2001:4860:0:2001::68": "2001:4860:0:2001::68", + "[2001:4860:0:::68]:8080": "2001:4860:0:::68", + "www.bücher.de": "www.xn--bcher-kva.de", + "www.example.com.": "www.example.com", + "[bad.unmatched.bracket:": "error", +} + +func TestCanonicalHost(t *testing.T) { + for h, want := range canonicalHostTests { + got, err := canonicalHost(h) + if want == "error" { + if err == nil { + t.Errorf("%q: got nil error, want non-nil", h) + } + continue + } + if err != nil { + t.Errorf("%q: %v", h, err) + continue + } + if got != want { + t.Errorf("%q: got %q, want %q", h, got, want) + continue + } + } +} + +var hasPortTests = map[string]bool{ + "www.example.com": false, + "www.example.com:80": true, + "127.0.0.1": false, + "127.0.0.1:8080": true, + "2001:4860:0:2001::68": false, + "[2001::0:::68]:80": true, +} + +func TestHasPort(t *testing.T) { + for host, want := range hasPortTests { + if got := hasPort(host); got != want { + t.Errorf("%q: got %t, want %t", host, got, want) + } + } +} + +var jarKeyTests = map[string]string{ + "foo.www.example.com": "example.com", + "www.example.com": "example.com", + "example.com": "example.com", + "com": "com", + "foo.www.bbc.co.uk": "bbc.co.uk", + "www.bbc.co.uk": "bbc.co.uk", + "bbc.co.uk": "bbc.co.uk", + "co.uk": "co.uk", + "uk": "uk", + "192.168.0.5": "192.168.0.5", +} + +func TestJarKey(t *testing.T) { + for host, want := range jarKeyTests { + if got := jarKey(host, testPSL{}); got != want { + t.Errorf("%q: got %q, want %q", host, got, want) + } + } +} + +var jarKeyNilPSLTests = map[string]string{ + "foo.www.example.com": "example.com", + "www.example.com": "example.com", + "example.com": "example.com", + "com": "com", + "foo.www.bbc.co.uk": "co.uk", + "www.bbc.co.uk": "co.uk", + "bbc.co.uk": "co.uk", + "co.uk": "co.uk", + "uk": "uk", + "192.168.0.5": "192.168.0.5", +} + +func TestJarKeyNilPSL(t *testing.T) { + for host, want := range jarKeyNilPSLTests { + if got := jarKey(host, nil); got != want { + t.Errorf("%q: got %q, want %q", host, got, want) + } + } +} + +var isIPTests = map[string]bool{ + "127.0.0.1": true, + "1.2.3.4": true, + "2001:4860:0:2001::68": true, + "example.com": false, + "1.1.1.300": false, + "www.foo.bar.net": false, + "123.foo.bar.net": false, +} + +func TestIsIP(t *testing.T) { + for host, want := range isIPTests { + if got := isIP(host); got != want { + t.Errorf("%q: got %t, want %t", host, got, want) + } + } +} + +var defaultPathTests = map[string]string{ + "/": "/", + "/abc": "/", + "/abc/": "/abc", + "/abc/xyz": "/abc", + "/abc/xyz/": "/abc/xyz", + "/a/b/c.html": "/a/b", + "": "/", + "strange": "/", + "//": "/", + "/a//b": "/a/", + "/a/./b": "/a/.", + "/a/../b": "/a/..", +} + +func TestDefaultPath(t *testing.T) { + for path, want := range defaultPathTests { + if got := defaultPath(path); got != want { + t.Errorf("%q: got %q, want %q", path, got, want) + } + } +} + +var domainAndTypeTests = [...]struct { + host string // host Set-Cookie header was received from + domain string // domain attribute in Set-Cookie header + wantDomain string // expected domain of cookie + wantHostOnly bool // expected host-cookie flag + wantErr error // expected error +}{ + {"www.example.com", "", "www.example.com", true, nil}, + {"127.0.0.1", "", "127.0.0.1", true, nil}, + {"2001:4860:0:2001::68", "", "2001:4860:0:2001::68", true, nil}, + {"www.example.com", "example.com", "example.com", false, nil}, + {"www.example.com", ".example.com", "example.com", false, nil}, + {"www.example.com", "www.example.com", "www.example.com", false, nil}, + {"www.example.com", ".www.example.com", "www.example.com", false, nil}, + {"foo.sso.example.com", "sso.example.com", "sso.example.com", false, nil}, + {"bar.co.uk", "bar.co.uk", "bar.co.uk", false, nil}, + {"foo.bar.co.uk", ".bar.co.uk", "bar.co.uk", false, nil}, + {"127.0.0.1", "127.0.0.1", "", false, errNoHostname}, + {"2001:4860:0:2001::68", "2001:4860:0:2001::68", "2001:4860:0:2001::68", false, errNoHostname}, + {"www.example.com", ".", "", false, errMalformedDomain}, + {"www.example.com", "..", "", false, errMalformedDomain}, + {"www.example.com", "other.com", "", false, errIllegalDomain}, + {"www.example.com", "com", "", false, errIllegalDomain}, + {"www.example.com", ".com", "", false, errIllegalDomain}, + {"foo.bar.co.uk", ".co.uk", "", false, errIllegalDomain}, + {"127.www.0.0.1", "127.0.0.1", "", false, errIllegalDomain}, + {"com", "", "com", true, nil}, + {"com", "com", "com", true, nil}, + {"com", ".com", "com", true, nil}, + {"co.uk", "", "co.uk", true, nil}, + {"co.uk", "co.uk", "co.uk", true, nil}, + {"co.uk", ".co.uk", "co.uk", true, nil}, +} + +func TestDomainAndType(t *testing.T) { + jar := newTestJar() + for _, tc := range domainAndTypeTests { + domain, hostOnly, err := jar.domainAndType(tc.host, tc.domain) + if err != tc.wantErr { + t.Errorf("%q/%q: got %q error, want %q", + tc.host, tc.domain, err, tc.wantErr) + continue + } + if err != nil { + continue + } + if domain != tc.wantDomain || hostOnly != tc.wantHostOnly { + t.Errorf("%q/%q: got %q/%t want %q/%t", + tc.host, tc.domain, domain, hostOnly, + tc.wantDomain, tc.wantHostOnly) + } + } +} + +// expiresIn creates an expires attribute delta seconds from tNow. +func expiresIn(delta int) string { + t := tNow.Add(time.Duration(delta) * time.Second) + return "expires=" + t.Format(time.RFC1123) +} + +// mustParseURL parses s to an URL and panics on error. +func mustParseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil || u.Scheme == "" || u.Host == "" { + panic(fmt.Sprintf("Unable to parse URL %s.", s)) + } + return u +} + +// jarTest encapsulates the following actions on a jar: +// 1. Perform SetCookies with fromURL and the cookies from setCookies. +// (Done at time tNow + 0 ms.) +// 2. Check that the entries in the jar matches content. +// (Done at time tNow + 1001 ms.) +// 3. For each query in tests: Check that Cookies with toURL yields the +// cookies in want. +// (Query n done at tNow + (n+2)*1001 ms.) +type jarTest struct { + description string // The description of what this test is supposed to test + fromURL string // The full URL of the request from which Set-Cookie headers where received + setCookies []string // All the cookies received from fromURL + content string // The whole (non-expired) content of the jar + queries []query // Queries to test the Jar.Cookies method +} + +// query contains one test of the cookies returned from Jar.Cookies. +type query struct { + toURL string // the URL in the Cookies call + want string // the expected list of cookies (order matters) +} + +// run runs the jarTest. +func (test jarTest) run(t *testing.T, jar *Jar) { + now := tNow + + // Populate jar with cookies. + setCookies := make([]*http.Cookie, len(test.setCookies)) + for i, cs := range test.setCookies { + cookies := (&http.Response{Header: http.Header{"Set-Cookie": {cs}}}).Cookies() + if len(cookies) != 1 { + panic(fmt.Sprintf("Wrong cookie line %q: %#v", cs, cookies)) + } + setCookies[i] = cookies[0] + } + jar.setCookies(mustParseURL(test.fromURL), setCookies, now) + now = now.Add(1001 * time.Millisecond) + + // Serialize non-expired entries in the form "name1=val1 name2=val2". + var cs []string + for _, submap := range jar.entries { + for _, cookie := range submap { + if !cookie.Expires.After(now) { + continue + } + cs = append(cs, cookie.Name+"="+cookie.Value) + } + } + sort.Strings(cs) + got := strings.Join(cs, " ") + + // Make sure jar content matches our expectations. + if got != test.content { + t.Errorf("Test %q Content\ngot %q\nwant %q", + test.description, got, test.content) + } + + // Test different calls to Cookies. + for i, query := range test.queries { + now = now.Add(1001 * time.Millisecond) + var s []string + for _, c := range jar.cookies(mustParseURL(query.toURL), now) { + s = append(s, c.Name+"="+c.Value) + } + if got := strings.Join(s, " "); got != query.want { + t.Errorf("Test %q #%d\ngot %q\nwant %q", test.description, i, got, query.want) + } + } +} + +// basicsTests contains fundamental tests. Each jarTest has to be performed on +// a fresh, empty Jar. +var basicsTests = [...]jarTest{ + { + "Retrieval of a plain host cookie.", + "http://www.host.test/", + []string{"A=a"}, + "A=a", + []query{ + {"http://www.host.test", "A=a"}, + {"http://www.host.test/", "A=a"}, + {"http://www.host.test/some/path", "A=a"}, + {"https://www.host.test", "A=a"}, + {"https://www.host.test/", "A=a"}, + {"https://www.host.test/some/path", "A=a"}, + {"ftp://www.host.test", ""}, + {"ftp://www.host.test/", ""}, + {"ftp://www.host.test/some/path", ""}, + {"http://www.other.org", ""}, + {"http://sibling.host.test", ""}, + {"http://deep.www.host.test", ""}, + }, + }, + { + "Secure cookies are not returned to http.", + "http://www.host.test/", + []string{"A=a; secure"}, + "A=a", + []query{ + {"http://www.host.test", ""}, + {"http://www.host.test/", ""}, + {"http://www.host.test/some/path", ""}, + {"https://www.host.test", "A=a"}, + {"https://www.host.test/", "A=a"}, + {"https://www.host.test/some/path", "A=a"}, + }, + }, + { + "Explicit path.", + "http://www.host.test/", + []string{"A=a; path=/some/path"}, + "A=a", + []query{ + {"http://www.host.test", ""}, + {"http://www.host.test/", ""}, + {"http://www.host.test/some", ""}, + {"http://www.host.test/some/", ""}, + {"http://www.host.test/some/path", "A=a"}, + {"http://www.host.test/some/paths", ""}, + {"http://www.host.test/some/path/foo", "A=a"}, + {"http://www.host.test/some/path/foo/", "A=a"}, + }, + }, + { + "Implicit path #1: path is a directory.", + "http://www.host.test/some/path/", + []string{"A=a"}, + "A=a", + []query{ + {"http://www.host.test", ""}, + {"http://www.host.test/", ""}, + {"http://www.host.test/some", ""}, + {"http://www.host.test/some/", ""}, + {"http://www.host.test/some/path", "A=a"}, + {"http://www.host.test/some/paths", ""}, + {"http://www.host.test/some/path/foo", "A=a"}, + {"http://www.host.test/some/path/foo/", "A=a"}, + }, + }, + { + "Implicit path #2: path is not a directory.", + "http://www.host.test/some/path/index.html", + []string{"A=a"}, + "A=a", + []query{ + {"http://www.host.test", ""}, + {"http://www.host.test/", ""}, + {"http://www.host.test/some", ""}, + {"http://www.host.test/some/", ""}, + {"http://www.host.test/some/path", "A=a"}, + {"http://www.host.test/some/paths", ""}, + {"http://www.host.test/some/path/foo", "A=a"}, + {"http://www.host.test/some/path/foo/", "A=a"}, + }, + }, + { + "Implicit path #3: no path in URL at all.", + "http://www.host.test", + []string{"A=a"}, + "A=a", + []query{ + {"http://www.host.test", "A=a"}, + {"http://www.host.test/", "A=a"}, + {"http://www.host.test/some/path", "A=a"}, + }, + }, + { + "Cookies are sorted by path length.", + "http://www.host.test/", + []string{ + "A=a; path=/foo/bar", + "B=b; path=/foo/bar/baz/qux", + "C=c; path=/foo/bar/baz", + "D=d; path=/foo"}, + "A=a B=b C=c D=d", + []query{ + {"http://www.host.test/foo/bar/baz/qux", "B=b C=c A=a D=d"}, + {"http://www.host.test/foo/bar/baz/", "C=c A=a D=d"}, + {"http://www.host.test/foo/bar", "A=a D=d"}, + }, + }, + { + "Creation time determines sorting on same length paths.", + "http://www.host.test/", + []string{ + "A=a; path=/foo/bar", + "X=x; path=/foo/bar", + "Y=y; path=/foo/bar/baz/qux", + "B=b; path=/foo/bar/baz/qux", + "C=c; path=/foo/bar/baz", + "W=w; path=/foo/bar/baz", + "Z=z; path=/foo", + "D=d; path=/foo"}, + "A=a B=b C=c D=d W=w X=x Y=y Z=z", + []query{ + {"http://www.host.test/foo/bar/baz/qux", "Y=y B=b C=c W=w A=a X=x Z=z D=d"}, + {"http://www.host.test/foo/bar/baz/", "C=c W=w A=a X=x Z=z D=d"}, + {"http://www.host.test/foo/bar", "A=a X=x Z=z D=d"}, + }, + }, + { + "Sorting of same-name cookies.", + "http://www.host.test/", + []string{ + "A=1; path=/", + "A=2; path=/path", + "A=3; path=/quux", + "A=4; path=/path/foo", + "A=5; domain=.host.test; path=/path", + "A=6; domain=.host.test; path=/quux", + "A=7; domain=.host.test; path=/path/foo", + }, + "A=1 A=2 A=3 A=4 A=5 A=6 A=7", + []query{ + {"http://www.host.test/path", "A=2 A=5 A=1"}, + {"http://www.host.test/path/foo", "A=4 A=7 A=2 A=5 A=1"}, + }, + }, + { + "Disallow domain cookie on public suffix.", + "http://www.bbc.co.uk", + []string{ + "a=1", + "b=2; domain=co.uk", + }, + "a=1", + []query{{"http://www.bbc.co.uk", "a=1"}}, + }, + { + "Host cookie on IP.", + "http://192.168.0.10", + []string{"a=1"}, + "a=1", + []query{{"http://192.168.0.10", "a=1"}}, + }, + { + "Port is ignored #1.", + "http://www.host.test/", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://www.host.test:8080/", "a=1"}, + }, + }, + { + "Port is ignored #2.", + "http://www.host.test:8080/", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://www.host.test:8080/", "a=1"}, + {"http://www.host.test:1234/", "a=1"}, + }, + }, +} + +func TestBasics(t *testing.T) { + for _, test := range basicsTests { + jar := newTestJar() + test.run(t, jar) + } +} + +// updateAndDeleteTests contains jarTests which must be performed on the same +// Jar. +var updateAndDeleteTests = [...]jarTest{ + { + "Set initial cookies.", + "http://www.host.test", + []string{ + "a=1", + "b=2; secure", + "c=3; httponly", + "d=4; secure; httponly"}, + "a=1 b=2 c=3 d=4", + []query{ + {"http://www.host.test", "a=1 c=3"}, + {"https://www.host.test", "a=1 b=2 c=3 d=4"}, + }, + }, + { + "Update value via http.", + "http://www.host.test", + []string{ + "a=w", + "b=x; secure", + "c=y; httponly", + "d=z; secure; httponly"}, + "a=w b=x c=y d=z", + []query{ + {"http://www.host.test", "a=w c=y"}, + {"https://www.host.test", "a=w b=x c=y d=z"}, + }, + }, + { + "Clear Secure flag from a http.", + "http://www.host.test/", + []string{ + "b=xx", + "d=zz; httponly"}, + "a=w b=xx c=y d=zz", + []query{{"http://www.host.test", "a=w b=xx c=y d=zz"}}, + }, + { + "Delete all.", + "http://www.host.test/", + []string{ + "a=1; max-Age=-1", // delete via MaxAge + "b=2; " + expiresIn(-10), // delete via Expires + "c=2; max-age=-1; " + expiresIn(-10), // delete via both + "d=4; max-age=-1; " + expiresIn(10)}, // MaxAge takes precedence + "", + []query{{"http://www.host.test", ""}}, + }, + { + "Refill #1.", + "http://www.host.test", + []string{ + "A=1", + "A=2; path=/foo", + "A=3; domain=.host.test", + "A=4; path=/foo; domain=.host.test"}, + "A=1 A=2 A=3 A=4", + []query{{"http://www.host.test/foo", "A=2 A=4 A=1 A=3"}}, + }, + { + "Refill #2.", + "http://www.google.com", + []string{ + "A=6", + "A=7; path=/foo", + "A=8; domain=.google.com", + "A=9; path=/foo; domain=.google.com"}, + "A=1 A=2 A=3 A=4 A=6 A=7 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=4 A=1 A=3"}, + {"http://www.google.com/foo", "A=7 A=9 A=6 A=8"}, + }, + }, + { + "Delete A7.", + "http://www.google.com", + []string{"A=; path=/foo; max-age=-1"}, + "A=1 A=2 A=3 A=4 A=6 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=4 A=1 A=3"}, + {"http://www.google.com/foo", "A=9 A=6 A=8"}, + }, + }, + { + "Delete A4.", + "http://www.host.test", + []string{"A=; path=/foo; domain=host.test; max-age=-1"}, + "A=1 A=2 A=3 A=6 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=1 A=3"}, + {"http://www.google.com/foo", "A=9 A=6 A=8"}, + }, + }, + { + "Delete A6.", + "http://www.google.com", + []string{"A=; max-age=-1"}, + "A=1 A=2 A=3 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=1 A=3"}, + {"http://www.google.com/foo", "A=9 A=8"}, + }, + }, + { + "Delete A3.", + "http://www.host.test", + []string{"A=; domain=host.test; max-age=-1"}, + "A=1 A=2 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=1"}, + {"http://www.google.com/foo", "A=9 A=8"}, + }, + }, + { + "No cross-domain delete.", + "http://www.host.test", + []string{ + "A=; domain=google.com; max-age=-1", + "A=; path=/foo; domain=google.com; max-age=-1"}, + "A=1 A=2 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=1"}, + {"http://www.google.com/foo", "A=9 A=8"}, + }, + }, + { + "Delete A8 and A9.", + "http://www.google.com", + []string{ + "A=; domain=google.com; max-age=-1", + "A=; path=/foo; domain=google.com; max-age=-1"}, + "A=1 A=2", + []query{ + {"http://www.host.test/foo", "A=2 A=1"}, + {"http://www.google.com/foo", ""}, + }, + }, +} + +func TestUpdateAndDelete(t *testing.T) { + jar := newTestJar() + for _, test := range updateAndDeleteTests { + test.run(t, jar) + } +} + +func TestExpiration(t *testing.T) { + jar := newTestJar() + jarTest{ + "Expiration.", + "http://www.host.test", + []string{ + "a=1", + "b=2; max-age=3", + "c=3; " + expiresIn(3), + "d=4; max-age=5", + "e=5; " + expiresIn(5), + "f=6; max-age=100", + }, + "a=1 b=2 c=3 d=4 e=5 f=6", // executed at t0 + 1001 ms + []query{ + {"http://www.host.test", "a=1 b=2 c=3 d=4 e=5 f=6"}, // t0 + 2002 ms + {"http://www.host.test", "a=1 d=4 e=5 f=6"}, // t0 + 3003 ms + {"http://www.host.test", "a=1 d=4 e=5 f=6"}, // t0 + 4004 ms + {"http://www.host.test", "a=1 f=6"}, // t0 + 5005 ms + {"http://www.host.test", "a=1 f=6"}, // t0 + 6006 ms + }, + }.run(t, jar) +} + +// +// Tests derived from Chromium's cookie_store_unittest.h. +// + +// See http://src.chromium.org/viewvc/chrome/trunk/src/net/cookies/cookie_store_unittest.h?revision=159685&content-type=text/plain +// Some of the original tests are in a bad condition (e.g. +// DomainWithTrailingDotTest) or are not RFC 6265 conforming (e.g. +// TestNonDottedAndTLD #1 and #6) and have not been ported. + +// chromiumBasicsTests contains fundamental tests. Each jarTest has to be +// performed on a fresh, empty Jar. +var chromiumBasicsTests = [...]jarTest{ + { + "DomainWithTrailingDotTest.", + "http://www.google.com/", + []string{ + "a=1; domain=.www.google.com.", + "b=2; domain=.www.google.com.."}, + "", + []query{ + {"http://www.google.com", ""}, + }, + }, + { + "ValidSubdomainTest #1.", + "http://a.b.c.d.com", + []string{ + "a=1; domain=.a.b.c.d.com", + "b=2; domain=.b.c.d.com", + "c=3; domain=.c.d.com", + "d=4; domain=.d.com"}, + "a=1 b=2 c=3 d=4", + []query{ + {"http://a.b.c.d.com", "a=1 b=2 c=3 d=4"}, + {"http://b.c.d.com", "b=2 c=3 d=4"}, + {"http://c.d.com", "c=3 d=4"}, + {"http://d.com", "d=4"}, + }, + }, + { + "ValidSubdomainTest #2.", + "http://a.b.c.d.com", + []string{ + "a=1; domain=.a.b.c.d.com", + "b=2; domain=.b.c.d.com", + "c=3; domain=.c.d.com", + "d=4; domain=.d.com", + "X=bcd; domain=.b.c.d.com", + "X=cd; domain=.c.d.com"}, + "X=bcd X=cd a=1 b=2 c=3 d=4", + []query{ + {"http://b.c.d.com", "b=2 c=3 d=4 X=bcd X=cd"}, + {"http://c.d.com", "c=3 d=4 X=cd"}, + }, + }, + { + "InvalidDomainTest #1.", + "http://foo.bar.com", + []string{ + "a=1; domain=.yo.foo.bar.com", + "b=2; domain=.foo.com", + "c=3; domain=.bar.foo.com", + "d=4; domain=.foo.bar.com.net", + "e=5; domain=ar.com", + "f=6; domain=.", + "g=7; domain=/", + "h=8; domain=http://foo.bar.com", + "i=9; domain=..foo.bar.com", + "j=10; domain=..bar.com", + "k=11; domain=.foo.bar.com?blah", + "l=12; domain=.foo.bar.com/blah", + "m=12; domain=.foo.bar.com:80", + "n=14; domain=.foo.bar.com:", + "o=15; domain=.foo.bar.com#sup", + }, + "", // Jar is empty. + []query{{"http://foo.bar.com", ""}}, + }, + { + "InvalidDomainTest #2.", + "http://foo.com.com", + []string{"a=1; domain=.foo.com.com.com"}, + "", + []query{{"http://foo.bar.com", ""}}, + }, + { + "DomainWithoutLeadingDotTest #1.", + "http://manage.hosted.filefront.com", + []string{"a=1; domain=filefront.com"}, + "a=1", + []query{{"http://www.filefront.com", "a=1"}}, + }, + { + "DomainWithoutLeadingDotTest #2.", + "http://www.google.com", + []string{"a=1; domain=www.google.com"}, + "a=1", + []query{ + {"http://www.google.com", "a=1"}, + {"http://sub.www.google.com", "a=1"}, + {"http://something-else.com", ""}, + }, + }, + { + "CaseInsensitiveDomainTest.", + "http://www.google.com", + []string{ + "a=1; domain=.GOOGLE.COM", + "b=2; domain=.www.gOOgLE.coM"}, + "a=1 b=2", + []query{{"http://www.google.com", "a=1 b=2"}}, + }, + { + "TestIpAddress #1.", + "http://1.2.3.4/foo", + []string{"a=1; path=/"}, + "a=1", + []query{{"http://1.2.3.4/foo", "a=1"}}, + }, + { + "TestIpAddress #2.", + "http://1.2.3.4/foo", + []string{ + "a=1; domain=.1.2.3.4", + "b=2; domain=.3.4"}, + "", + []query{{"http://1.2.3.4/foo", ""}}, + }, + { + "TestIpAddress #3.", + "http://1.2.3.4/foo", + []string{"a=1; domain=1.2.3.4"}, + "", + []query{{"http://1.2.3.4/foo", ""}}, + }, + { + "TestNonDottedAndTLD #2.", + "http://com./index.html", + []string{"a=1"}, + "a=1", + []query{ + {"http://com./index.html", "a=1"}, + {"http://no-cookies.com./index.html", ""}, + }, + }, + { + "TestNonDottedAndTLD #3.", + "http://a.b", + []string{ + "a=1; domain=.b", + "b=2; domain=b"}, + "", + []query{{"http://bar.foo", ""}}, + }, + { + "TestNonDottedAndTLD #4.", + "http://google.com", + []string{ + "a=1; domain=.com", + "b=2; domain=com"}, + "", + []query{{"http://google.com", ""}}, + }, + { + "TestNonDottedAndTLD #5.", + "http://google.co.uk", + []string{ + "a=1; domain=.co.uk", + "b=2; domain=.uk"}, + "", + []query{ + {"http://google.co.uk", ""}, + {"http://else.co.com", ""}, + {"http://else.uk", ""}, + }, + }, + { + "TestHostEndsWithDot.", + "http://www.google.com", + []string{ + "a=1", + "b=2; domain=.www.google.com."}, + "a=1", + []query{{"http://www.google.com", "a=1"}}, + }, + { + "PathTest", + "http://www.google.izzle", + []string{"a=1; path=/wee"}, + "a=1", + []query{ + {"http://www.google.izzle/wee", "a=1"}, + {"http://www.google.izzle/wee/", "a=1"}, + {"http://www.google.izzle/wee/war", "a=1"}, + {"http://www.google.izzle/wee/war/more/more", "a=1"}, + {"http://www.google.izzle/weehee", ""}, + {"http://www.google.izzle/", ""}, + }, + }, +} + +func TestChromiumBasics(t *testing.T) { + for _, test := range chromiumBasicsTests { + jar := newTestJar() + test.run(t, jar) + } +} + +// chromiumDomainTests contains jarTests which must be executed all on the +// same Jar. +var chromiumDomainTests = [...]jarTest{ + { + "Fill #1.", + "http://www.google.izzle", + []string{"A=B"}, + "A=B", + []query{{"http://www.google.izzle", "A=B"}}, + }, + { + "Fill #2.", + "http://www.google.izzle", + []string{"C=D; domain=.google.izzle"}, + "A=B C=D", + []query{{"http://www.google.izzle", "A=B C=D"}}, + }, + { + "Verify A is a host cookie and not accessible from subdomain.", + "http://unused.nil", + []string{}, + "A=B C=D", + []query{{"http://foo.www.google.izzle", "C=D"}}, + }, + { + "Verify domain cookies are found on proper domain.", + "http://www.google.izzle", + []string{"E=F; domain=.www.google.izzle"}, + "A=B C=D E=F", + []query{{"http://www.google.izzle", "A=B C=D E=F"}}, + }, + { + "Leading dots in domain attributes are optional.", + "http://www.google.izzle", + []string{"G=H; domain=www.google.izzle"}, + "A=B C=D E=F G=H", + []query{{"http://www.google.izzle", "A=B C=D E=F G=H"}}, + }, + { + "Verify domain enforcement works #1.", + "http://www.google.izzle", + []string{"K=L; domain=.bar.www.google.izzle"}, + "A=B C=D E=F G=H", + []query{{"http://bar.www.google.izzle", "C=D E=F G=H"}}, + }, + { + "Verify domain enforcement works #2.", + "http://unused.nil", + []string{}, + "A=B C=D E=F G=H", + []query{{"http://www.google.izzle", "A=B C=D E=F G=H"}}, + }, +} + +func TestChromiumDomain(t *testing.T) { + jar := newTestJar() + for _, test := range chromiumDomainTests { + test.run(t, jar) + } + +} + +// chromiumDeletionTests must be performed all on the same Jar. +var chromiumDeletionTests = [...]jarTest{ + { + "Create session cookie a1.", + "http://www.google.com", + []string{"a=1"}, + "a=1", + []query{{"http://www.google.com", "a=1"}}, + }, + { + "Delete sc a1 via MaxAge.", + "http://www.google.com", + []string{"a=1; max-age=-1"}, + "", + []query{{"http://www.google.com", ""}}, + }, + { + "Create session cookie b2.", + "http://www.google.com", + []string{"b=2"}, + "b=2", + []query{{"http://www.google.com", "b=2"}}, + }, + { + "Delete sc b2 via Expires.", + "http://www.google.com", + []string{"b=2; " + expiresIn(-10)}, + "", + []query{{"http://www.google.com", ""}}, + }, + { + "Create persistent cookie c3.", + "http://www.google.com", + []string{"c=3; max-age=3600"}, + "c=3", + []query{{"http://www.google.com", "c=3"}}, + }, + { + "Delete pc c3 via MaxAge.", + "http://www.google.com", + []string{"c=3; max-age=-1"}, + "", + []query{{"http://www.google.com", ""}}, + }, + { + "Create persistent cookie d4.", + "http://www.google.com", + []string{"d=4; max-age=3600"}, + "d=4", + []query{{"http://www.google.com", "d=4"}}, + }, + { + "Delete pc d4 via Expires.", + "http://www.google.com", + []string{"d=4; " + expiresIn(-10)}, + "", + []query{{"http://www.google.com", ""}}, + }, +} + +func TestChromiumDeletion(t *testing.T) { + jar := newTestJar() + for _, test := range chromiumDeletionTests { + test.run(t, jar) + } +} + +// domainHandlingTests tests and documents the rules for domain handling. +// Each test must be performed on an empty new Jar. +var domainHandlingTests = [...]jarTest{ + { + "Host cookie", + "http://www.host.test", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://host.test", ""}, + {"http://bar.host.test", ""}, + {"http://foo.www.host.test", ""}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Domain cookie #1", + "http://www.host.test", + []string{"a=1; domain=host.test"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://host.test", "a=1"}, + {"http://bar.host.test", "a=1"}, + {"http://foo.www.host.test", "a=1"}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Domain cookie #2", + "http://www.host.test", + []string{"a=1; domain=.host.test"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://host.test", "a=1"}, + {"http://bar.host.test", "a=1"}, + {"http://foo.www.host.test", "a=1"}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Host cookie on IDNA domain #1", + "http://www.bücher.test", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.bücher.test", "a=1"}, + {"http://www.xn--bcher-kva.test", "a=1"}, + {"http://bücher.test", ""}, + {"http://xn--bcher-kva.test", ""}, + {"http://bar.bücher.test", ""}, + {"http://bar.xn--bcher-kva.test", ""}, + {"http://foo.www.bücher.test", ""}, + {"http://foo.www.xn--bcher-kva.test", ""}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Host cookie on IDNA domain #2", + "http://www.xn--bcher-kva.test", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.bücher.test", "a=1"}, + {"http://www.xn--bcher-kva.test", "a=1"}, + {"http://bücher.test", ""}, + {"http://xn--bcher-kva.test", ""}, + {"http://bar.bücher.test", ""}, + {"http://bar.xn--bcher-kva.test", ""}, + {"http://foo.www.bücher.test", ""}, + {"http://foo.www.xn--bcher-kva.test", ""}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Domain cookie on IDNA domain #1", + "http://www.bücher.test", + []string{"a=1; domain=xn--bcher-kva.test"}, + "a=1", + []query{ + {"http://www.bücher.test", "a=1"}, + {"http://www.xn--bcher-kva.test", "a=1"}, + {"http://bücher.test", "a=1"}, + {"http://xn--bcher-kva.test", "a=1"}, + {"http://bar.bücher.test", "a=1"}, + {"http://bar.xn--bcher-kva.test", "a=1"}, + {"http://foo.www.bücher.test", "a=1"}, + {"http://foo.www.xn--bcher-kva.test", "a=1"}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Domain cookie on IDNA domain #2", + "http://www.xn--bcher-kva.test", + []string{"a=1; domain=xn--bcher-kva.test"}, + "a=1", + []query{ + {"http://www.bücher.test", "a=1"}, + {"http://www.xn--bcher-kva.test", "a=1"}, + {"http://bücher.test", "a=1"}, + {"http://xn--bcher-kva.test", "a=1"}, + {"http://bar.bücher.test", "a=1"}, + {"http://bar.xn--bcher-kva.test", "a=1"}, + {"http://foo.www.bücher.test", "a=1"}, + {"http://foo.www.xn--bcher-kva.test", "a=1"}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Host cookie on TLD.", + "http://com", + []string{"a=1"}, + "a=1", + []query{ + {"http://com", "a=1"}, + {"http://any.com", ""}, + {"http://any.test", ""}, + }, + }, + { + "Domain cookie on TLD becomes a host cookie.", + "http://com", + []string{"a=1; domain=com"}, + "a=1", + []query{ + {"http://com", "a=1"}, + {"http://any.com", ""}, + {"http://any.test", ""}, + }, + }, + { + "Host cookie on public suffix.", + "http://co.uk", + []string{"a=1"}, + "a=1", + []query{ + {"http://co.uk", "a=1"}, + {"http://uk", ""}, + {"http://some.co.uk", ""}, + {"http://foo.some.co.uk", ""}, + {"http://any.uk", ""}, + }, + }, + { + "Domain cookie on public suffix is ignored.", + "http://some.co.uk", + []string{"a=1; domain=co.uk"}, + "", + []query{ + {"http://co.uk", ""}, + {"http://uk", ""}, + {"http://some.co.uk", ""}, + {"http://foo.some.co.uk", ""}, + {"http://any.uk", ""}, + }, + }, +} + +func TestDomainHandling(t *testing.T) { + for _, test := range domainHandlingTests { + jar := newTestJar() + test.run(t, jar) + } +} diff --git a/src/pkg/net/http/cookiejar/punycode.go b/src/pkg/net/http/cookiejar/punycode.go new file mode 100644 index 000000000..ea7ceb5ef --- /dev/null +++ b/src/pkg/net/http/cookiejar/punycode.go @@ -0,0 +1,159 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cookiejar + +// This file implements the Punycode algorithm from RFC 3492. + +import ( + "fmt" + "strings" + "unicode/utf8" +) + +// These parameter values are specified in section 5. +// +// All computation is done with int32s, so that overflow behavior is identical +// regardless of whether int is 32-bit or 64-bit. +const ( + base int32 = 36 + damp int32 = 700 + initialBias int32 = 72 + initialN int32 = 128 + skew int32 = 38 + tmax int32 = 26 + tmin int32 = 1 +) + +// encode encodes a string as specified in section 6.3 and prepends prefix to +// the result. +// +// The "while h < length(input)" line in the specification becomes "for +// remaining != 0" in the Go code, because len(s) in Go is in bytes, not runes. +func encode(prefix, s string) (string, error) { + output := make([]byte, len(prefix), len(prefix)+1+2*len(s)) + copy(output, prefix) + delta, n, bias := int32(0), initialN, initialBias + b, remaining := int32(0), int32(0) + for _, r := range s { + if r < 0x80 { + b++ + output = append(output, byte(r)) + } else { + remaining++ + } + } + h := b + if b > 0 { + output = append(output, '-') + } + for remaining != 0 { + m := int32(0x7fffffff) + for _, r := range s { + if m > r && r >= n { + m = r + } + } + delta += (m - n) * (h + 1) + if delta < 0 { + return "", fmt.Errorf("cookiejar: invalid label %q", s) + } + n = m + for _, r := range s { + if r < n { + delta++ + if delta < 0 { + return "", fmt.Errorf("cookiejar: invalid label %q", s) + } + continue + } + if r > n { + continue + } + q := delta + for k := base; ; k += base { + t := k - bias + if t < tmin { + t = tmin + } else if t > tmax { + t = tmax + } + if q < t { + break + } + output = append(output, encodeDigit(t+(q-t)%(base-t))) + q = (q - t) / (base - t) + } + output = append(output, encodeDigit(q)) + bias = adapt(delta, h+1, h == b) + delta = 0 + h++ + remaining-- + } + delta++ + n++ + } + return string(output), nil +} + +func encodeDigit(digit int32) byte { + switch { + case 0 <= digit && digit < 26: + return byte(digit + 'a') + case 26 <= digit && digit < 36: + return byte(digit + ('0' - 26)) + } + panic("cookiejar: internal error in punycode encoding") +} + +// adapt is the bias adaptation function specified in section 6.1. +func adapt(delta, numPoints int32, firstTime bool) int32 { + if firstTime { + delta /= damp + } else { + delta /= 2 + } + delta += delta / numPoints + k := int32(0) + for delta > ((base-tmin)*tmax)/2 { + delta /= base - tmin + k += base + } + return k + (base-tmin+1)*delta/(delta+skew) +} + +// Strictly speaking, the remaining code below deals with IDNA (RFC 5890 and +// friends) and not Punycode (RFC 3492) per se. + +// acePrefix is the ASCII Compatible Encoding prefix. +const acePrefix = "xn--" + +// toASCII converts a domain or domain label to its ASCII form. For example, +// toASCII("bücher.example.com") is "xn--bcher-kva.example.com", and +// toASCII("golang") is "golang". +func toASCII(s string) (string, error) { + if ascii(s) { + return s, nil + } + labels := strings.Split(s, ".") + for i, label := range labels { + if !ascii(label) { + a, err := encode(acePrefix, label) + if err != nil { + return "", err + } + labels[i] = a + } + } + return strings.Join(labels, "."), nil +} + +func ascii(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] >= utf8.RuneSelf { + return false + } + } + return true +} diff --git a/src/pkg/net/http/cookiejar/punycode_test.go b/src/pkg/net/http/cookiejar/punycode_test.go new file mode 100644 index 000000000..0301de14e --- /dev/null +++ b/src/pkg/net/http/cookiejar/punycode_test.go @@ -0,0 +1,161 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cookiejar + +import ( + "testing" +) + +var punycodeTestCases = [...]struct { + s, encoded string +}{ + {"", ""}, + {"-", "--"}, + {"-a", "-a-"}, + {"-a-", "-a--"}, + {"a", "a-"}, + {"a-", "a--"}, + {"a-b", "a-b-"}, + {"books", "books-"}, + {"bücher", "bcher-kva"}, + {"Hello世界", "Hello-ck1hg65u"}, + {"ü", "tda"}, + {"üý", "tdac"}, + + // The test cases below come from RFC 3492 section 7.1 with Errata 3026. + { + // (A) Arabic (Egyptian). + "\u0644\u064A\u0647\u0645\u0627\u0628\u062A\u0643\u0644" + + "\u0645\u0648\u0634\u0639\u0631\u0628\u064A\u061F", + "egbpdaj6bu4bxfgehfvwxn", + }, + { + // (B) Chinese (simplified). + "\u4ED6\u4EEC\u4E3A\u4EC0\u4E48\u4E0D\u8BF4\u4E2D\u6587", + "ihqwcrb4cv8a8dqg056pqjye", + }, + { + // (C) Chinese (traditional). + "\u4ED6\u5011\u7232\u4EC0\u9EBD\u4E0D\u8AAA\u4E2D\u6587", + "ihqwctvzc91f659drss3x8bo0yb", + }, + { + // (D) Czech. + "\u0050\u0072\u006F\u010D\u0070\u0072\u006F\u0073\u0074" + + "\u011B\u006E\u0065\u006D\u006C\u0075\u0076\u00ED\u010D" + + "\u0065\u0073\u006B\u0079", + "Proprostnemluvesky-uyb24dma41a", + }, + { + // (E) Hebrew. + "\u05DC\u05DE\u05D4\u05D4\u05DD\u05E4\u05E9\u05D5\u05D8" + + "\u05DC\u05D0\u05DE\u05D3\u05D1\u05E8\u05D9\u05DD\u05E2" + + "\u05D1\u05E8\u05D9\u05EA", + "4dbcagdahymbxekheh6e0a7fei0b", + }, + { + // (F) Hindi (Devanagari). + "\u092F\u0939\u0932\u094B\u0917\u0939\u093F\u0928\u094D" + + "\u0926\u0940\u0915\u094D\u092F\u094B\u0902\u0928\u0939" + + "\u0940\u0902\u092C\u094B\u0932\u0938\u0915\u0924\u0947" + + "\u0939\u0948\u0902", + "i1baa7eci9glrd9b2ae1bj0hfcgg6iyaf8o0a1dig0cd", + }, + { + // (G) Japanese (kanji and hiragana). + "\u306A\u305C\u307F\u3093\u306A\u65E5\u672C\u8A9E\u3092" + + "\u8A71\u3057\u3066\u304F\u308C\u306A\u3044\u306E\u304B", + "n8jok5ay5dzabd5bym9f0cm5685rrjetr6pdxa", + }, + { + // (H) Korean (Hangul syllables). + "\uC138\uACC4\uC758\uBAA8\uB4E0\uC0AC\uB78C\uB4E4\uC774" + + "\uD55C\uAD6D\uC5B4\uB97C\uC774\uD574\uD55C\uB2E4\uBA74" + + "\uC5BC\uB9C8\uB098\uC88B\uC744\uAE4C", + "989aomsvi5e83db1d2a355cv1e0vak1dwrv93d5xbh15a0dt30a5j" + + "psd879ccm6fea98c", + }, + { + // (I) Russian (Cyrillic). + "\u043F\u043E\u0447\u0435\u043C\u0443\u0436\u0435\u043E" + + "\u043D\u0438\u043D\u0435\u0433\u043E\u0432\u043E\u0440" + + "\u044F\u0442\u043F\u043E\u0440\u0443\u0441\u0441\u043A" + + "\u0438", + "b1abfaaepdrnnbgefbadotcwatmq2g4l", + }, + { + // (J) Spanish. + "\u0050\u006F\u0072\u0071\u0075\u00E9\u006E\u006F\u0070" + + "\u0075\u0065\u0064\u0065\u006E\u0073\u0069\u006D\u0070" + + "\u006C\u0065\u006D\u0065\u006E\u0074\u0065\u0068\u0061" + + "\u0062\u006C\u0061\u0072\u0065\u006E\u0045\u0073\u0070" + + "\u0061\u00F1\u006F\u006C", + "PorqunopuedensimplementehablarenEspaol-fmd56a", + }, + { + // (K) Vietnamese. + "\u0054\u1EA1\u0069\u0073\u0061\u006F\u0068\u1ECD\u006B" + + "\u0068\u00F4\u006E\u0067\u0074\u0068\u1EC3\u0063\u0068" + + "\u1EC9\u006E\u00F3\u0069\u0074\u0069\u1EBF\u006E\u0067" + + "\u0056\u0069\u1EC7\u0074", + "TisaohkhngthchnitingVit-kjcr8268qyxafd2f1b9g", + }, + { + // (L) 3<nen>B<gumi><kinpachi><sensei>. + "\u0033\u5E74\u0042\u7D44\u91D1\u516B\u5148\u751F", + "3B-ww4c5e180e575a65lsy2b", + }, + { + // (M) <amuro><namie>-with-SUPER-MONKEYS. + "\u5B89\u5BA4\u5948\u7F8E\u6075\u002D\u0077\u0069\u0074" + + "\u0068\u002D\u0053\u0055\u0050\u0045\u0052\u002D\u004D" + + "\u004F\u004E\u004B\u0045\u0059\u0053", + "-with-SUPER-MONKEYS-pc58ag80a8qai00g7n9n", + }, + { + // (N) Hello-Another-Way-<sorezore><no><basho>. + "\u0048\u0065\u006C\u006C\u006F\u002D\u0041\u006E\u006F" + + "\u0074\u0068\u0065\u0072\u002D\u0057\u0061\u0079\u002D" + + "\u305D\u308C\u305E\u308C\u306E\u5834\u6240", + "Hello-Another-Way--fc4qua05auwb3674vfr0b", + }, + { + // (O) <hitotsu><yane><no><shita>2. + "\u3072\u3068\u3064\u5C4B\u6839\u306E\u4E0B\u0032", + "2-u9tlzr9756bt3uc0v", + }, + { + // (P) Maji<de>Koi<suru>5<byou><mae> + "\u004D\u0061\u006A\u0069\u3067\u004B\u006F\u0069\u3059" + + "\u308B\u0035\u79D2\u524D", + "MajiKoi5-783gue6qz075azm5e", + }, + { + // (Q) <pafii>de<runba> + "\u30D1\u30D5\u30A3\u30FC\u0064\u0065\u30EB\u30F3\u30D0", + "de-jg4avhby1noc0d", + }, + { + // (R) <sono><supiido><de> + "\u305D\u306E\u30B9\u30D4\u30FC\u30C9\u3067", + "d9juau41awczczp", + }, + { + // (S) -> $1.00 <- + "\u002D\u003E\u0020\u0024\u0031\u002E\u0030\u0030\u0020" + + "\u003C\u002D", + "-> $1.00 <--", + }, +} + +func TestPunycode(t *testing.T) { + for _, tc := range punycodeTestCases { + if got, err := encode("", tc.s); err != nil { + t.Errorf(`encode("", %q): %v`, tc.s, err) + } else if got != tc.encoded { + t.Errorf(`encode("", %q): got %q, want %q`, tc.s, got, tc.encoded) + } + } +} diff --git a/src/pkg/net/http/example_test.go b/src/pkg/net/http/example_test.go index ec814407d..22073eaf7 100644 --- a/src/pkg/net/http/example_test.go +++ b/src/pkg/net/http/example_test.go @@ -43,10 +43,10 @@ func ExampleGet() { log.Fatal(err) } robots, err := ioutil.ReadAll(res.Body) + res.Body.Close() if err != nil { log.Fatal(err) } - res.Body.Close() fmt.Printf("%s", robots) } diff --git a/src/pkg/net/http/export_test.go b/src/pkg/net/http/export_test.go index 13640ca85..a7bca20a0 100644 --- a/src/pkg/net/http/export_test.go +++ b/src/pkg/net/http/export_test.go @@ -7,12 +7,25 @@ package http -import "time" +import ( + "net" + "time" +) + +func NewLoggingConn(baseName string, c net.Conn) net.Conn { + return newLoggingConn(baseName, c) +} + +func (t *Transport) NumPendingRequestsForTesting() int { + t.reqMu.Lock() + defer t.reqMu.Unlock() + return len(t.reqConn) +} func (t *Transport) IdleConnKeysForTesting() (keys []string) { keys = make([]string, 0) - t.lk.Lock() - defer t.lk.Unlock() + t.idleMu.Lock() + defer t.idleMu.Unlock() if t.idleConn == nil { return } @@ -23,8 +36,8 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) { } func (t *Transport) IdleConnCountForTesting(cacheKey string) int { - t.lk.Lock() - defer t.lk.Unlock() + t.idleMu.Lock() + defer t.idleMu.Unlock() if t.idleConn == nil { return 0 } diff --git a/src/pkg/net/http/filetransport_test.go b/src/pkg/net/http/filetransport_test.go index 039926b53..6f1a537e2 100644 --- a/src/pkg/net/http/filetransport_test.go +++ b/src/pkg/net/http/filetransport_test.go @@ -2,11 +2,10 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package http_test +package http import ( "io/ioutil" - "net/http" "os" "path/filepath" "testing" @@ -32,9 +31,9 @@ func TestFileTransport(t *testing.T) { defer os.Remove(dname) defer os.Remove(fname) - tr := &http.Transport{} - tr.RegisterProtocol("file", http.NewFileTransport(http.Dir(dname))) - c := &http.Client{Transport: tr} + tr := &Transport{} + tr.RegisterProtocol("file", NewFileTransport(Dir(dname))) + c := &Client{Transport: tr} fooURLs := []string{"file:///foo.txt", "file://../foo.txt"} for _, urlstr := range fooURLs { @@ -62,4 +61,5 @@ func TestFileTransport(t *testing.T) { if res.StatusCode != 404 { t.Errorf("for %s, StatusCode = %d, want 404", badURL, res.StatusCode) } + res.Body.Close() } diff --git a/src/pkg/net/http/fs.go b/src/pkg/net/http/fs.go index f35dd32c3..b6bea0dfa 100644 --- a/src/pkg/net/http/fs.go +++ b/src/pkg/net/http/fs.go @@ -11,6 +11,8 @@ import ( "fmt" "io" "mime" + "mime/multipart" + "net/textproto" "os" "path" "path/filepath" @@ -26,7 +28,8 @@ import ( type Dir string func (d Dir) Open(name string) (File, error) { - if filepath.Separator != '/' && strings.IndexRune(name, filepath.Separator) >= 0 { + if filepath.Separator != '/' && strings.IndexRune(name, filepath.Separator) >= 0 || + strings.Contains(name, "\x00") { return nil, errors.New("http: invalid character in file path") } dir := string(d) @@ -97,6 +100,9 @@ func dirList(w ResponseWriter, f File) { // The content's Seek method must work: ServeContent uses // a seek to the end of the content to determine its size. // +// If the caller has set w's ETag header, ServeContent uses it to +// handle requests using If-Range and If-None-Match. +// // Note that *os.File implements the io.ReadSeeker interface. func ServeContent(w ResponseWriter, req *Request, name string, modtime time.Time, content io.ReadSeeker) { size, err := content.Seek(0, os.SEEK_END) @@ -119,12 +125,17 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, if checkLastModified(w, r, modtime) { return } + rangeReq, done := checkETag(w, r) + if done { + return + } code := StatusOK // If Content-Type isn't set, use the file's extension to find it. - if w.Header().Get("Content-Type") == "" { - ctype := mime.TypeByExtension(filepath.Ext(name)) + ctype := w.Header().Get("Content-Type") + if ctype == "" { + ctype = mime.TypeByExtension(filepath.Ext(name)) if ctype == "" { // read a chunk to decide between utf-8 text and binary var buf [1024]byte @@ -141,18 +152,34 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, } // handle Content-Range header. - // TODO(adg): handle multiple ranges sendSize := size + var sendContent io.Reader = content if size >= 0 { - ranges, err := parseRange(r.Header.Get("Range"), size) - if err == nil && len(ranges) > 1 { - err = errors.New("multiple ranges not supported") - } + ranges, err := parseRange(rangeReq, size) if err != nil { Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) return } - if len(ranges) == 1 { + if sumRangesSize(ranges) >= size { + // The total number of bytes in all the ranges + // is larger than the size of the file by + // itself, so this is probably an attack, or a + // dumb client. Ignore the range request. + ranges = nil + } + switch { + case len(ranges) == 1: + // RFC 2616, Section 14.16: + // "When an HTTP message includes the content of a single + // range (for example, a response to a request for a + // single range, or to a request for a set of ranges + // that overlap without any holes), this content is + // transmitted with a Content-Range header, and a + // Content-Length header showing the number of bytes + // actually transferred. + // ... + // A response to a request for a single range MUST NOT + // be sent using the multipart/byteranges media type." ra := ranges[0] if _, err := content.Seek(ra.start, os.SEEK_SET); err != nil { Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) @@ -160,7 +187,41 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, } sendSize = ra.length code = StatusPartialContent - w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", ra.start, ra.start+ra.length-1, size)) + w.Header().Set("Content-Range", ra.contentRange(size)) + case len(ranges) > 1: + for _, ra := range ranges { + if ra.start > size { + Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) + return + } + } + sendSize = rangesMIMESize(ranges, ctype, size) + code = StatusPartialContent + + pr, pw := io.Pipe() + mw := multipart.NewWriter(pw) + w.Header().Set("Content-Type", "multipart/byteranges; boundary="+mw.Boundary()) + sendContent = pr + defer pr.Close() // cause writing goroutine to fail and exit if CopyN doesn't finish. + go func() { + for _, ra := range ranges { + part, err := mw.CreatePart(ra.mimeHeader(ctype, size)) + if err != nil { + pw.CloseWithError(err) + return + } + if _, err := content.Seek(ra.start, os.SEEK_SET); err != nil { + pw.CloseWithError(err) + return + } + if _, err := io.CopyN(part, content, ra.length); err != nil { + pw.CloseWithError(err) + return + } + } + mw.Close() + pw.Close() + }() } w.Header().Set("Accept-Ranges", "bytes") @@ -172,11 +233,7 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, w.WriteHeader(code) if r.Method != "HEAD" { - if sendSize == -1 { - io.Copy(w, content) - } else { - io.CopyN(w, content, sendSize) - } + io.CopyN(w, sendContent, sendSize) } } @@ -190,6 +247,9 @@ func checkLastModified(w ResponseWriter, r *Request, modtime time.Time) bool { // The Date-Modified header truncates sub-second precision, so // use mtime < t+1s instead of mtime <= t to check for unmodified. if t, err := time.Parse(TimeFormat, r.Header.Get("If-Modified-Since")); err == nil && modtime.Before(t.Add(1*time.Second)) { + h := w.Header() + delete(h, "Content-Type") + delete(h, "Content-Length") w.WriteHeader(StatusNotModified) return true } @@ -197,6 +257,58 @@ func checkLastModified(w ResponseWriter, r *Request, modtime time.Time) bool { return false } +// checkETag implements If-None-Match and If-Range checks. +// The ETag must have been previously set in the ResponseWriter's headers. +// +// The return value is the effective request "Range" header to use and +// whether this request is now considered done. +func checkETag(w ResponseWriter, r *Request) (rangeReq string, done bool) { + etag := w.Header().get("Etag") + rangeReq = r.Header.get("Range") + + // Invalidate the range request if the entity doesn't match the one + // the client was expecting. + // "If-Range: version" means "ignore the Range: header unless version matches the + // current file." + // We only support ETag versions. + // The caller must have set the ETag on the response already. + if ir := r.Header.get("If-Range"); ir != "" && ir != etag { + // TODO(bradfitz): handle If-Range requests with Last-Modified + // times instead of ETags? I'd rather not, at least for + // now. That seems like a bug/compromise in the RFC 2616, and + // I've never heard of anybody caring about that (yet). + rangeReq = "" + } + + if inm := r.Header.get("If-None-Match"); inm != "" { + // Must know ETag. + if etag == "" { + return rangeReq, false + } + + // TODO(bradfitz): non-GET/HEAD requests require more work: + // sending a different status code on matches, and + // also can't use weak cache validators (those with a "W/ + // prefix). But most users of ServeContent will be using + // it on GET or HEAD, so only support those for now. + if r.Method != "GET" && r.Method != "HEAD" { + return rangeReq, false + } + + // TODO(bradfitz): deal with comma-separated or multiple-valued + // list of If-None-match values. For now just handle the common + // case of a single item. + if inm == etag || inm == "*" { + h := w.Header() + delete(h, "Content-Type") + delete(h, "Content-Length") + w.WriteHeader(StatusNotModified) + return "", true + } + } + return rangeReq, false +} + // name is '/'-separated, not filepath.Separator. func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirect bool) { const indexPage = "/index.html" @@ -243,9 +355,6 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec // use contents of index.html for directory, if present if d.IsDir() { - if checkLastModified(w, r, d.ModTime()) { - return - } index := name + indexPage ff, err := fs.Open(index) if err == nil { @@ -259,11 +368,16 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec } } + // Still a directory? (we didn't find an index.html file) if d.IsDir() { + if checkLastModified(w, r, d.ModTime()) { + return + } dirList(w, f) return } + // serverContent will check modification time serveContent(w, r, d.Name(), d.ModTime(), d.Size(), f) } @@ -312,6 +426,17 @@ type httpRange struct { start, length int64 } +func (r httpRange) contentRange(size int64) string { + return fmt.Sprintf("bytes %d-%d/%d", r.start, r.start+r.length-1, size) +} + +func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHeader { + return textproto.MIMEHeader{ + "Content-Range": {r.contentRange(size)}, + "Content-Type": {contentType}, + } +} + // parseRange parses a Range header string as per RFC 2616. func parseRange(s string, size int64) ([]httpRange, error) { if s == "" { @@ -323,11 +448,15 @@ func parseRange(s string, size int64) ([]httpRange, error) { } var ranges []httpRange for _, ra := range strings.Split(s[len(b):], ",") { + ra = strings.TrimSpace(ra) + if ra == "" { + continue + } i := strings.Index(ra, "-") if i < 0 { return nil, errors.New("invalid range") } - start, end := ra[:i], ra[i+1:] + start, end := strings.TrimSpace(ra[:i]), strings.TrimSpace(ra[i+1:]) var r httpRange if start == "" { // If no start is specified, end specifies the @@ -365,3 +494,32 @@ func parseRange(s string, size int64) ([]httpRange, error) { } return ranges, nil } + +// countingWriter counts how many bytes have been written to it. +type countingWriter int64 + +func (w *countingWriter) Write(p []byte) (n int, err error) { + *w += countingWriter(len(p)) + return len(p), nil +} + +// rangesMIMESize returns the nunber of bytes it takes to encode the +// provided ranges as a multipart response. +func rangesMIMESize(ranges []httpRange, contentType string, contentSize int64) (encSize int64) { + var w countingWriter + mw := multipart.NewWriter(&w) + for _, ra := range ranges { + mw.CreatePart(ra.mimeHeader(contentType, contentSize)) + encSize += ra.length + } + mw.Close() + encSize += int64(w) + return +} + +func sumRangesSize(ranges []httpRange) (size int64) { + for _, ra := range ranges { + size += ra.length + } + return +} diff --git a/src/pkg/net/http/fs_test.go b/src/pkg/net/http/fs_test.go index 5aa93ce58..0dd6d0df9 100644 --- a/src/pkg/net/http/fs_test.go +++ b/src/pkg/net/http/fs_test.go @@ -10,12 +10,15 @@ import ( "fmt" "io" "io/ioutil" + "mime" + "mime/multipart" "net" . "net/http" "net/http/httptest" "net/url" "os" "os/exec" + "path" "path/filepath" "regexp" "runtime" @@ -25,24 +28,33 @@ import ( ) const ( - testFile = "testdata/file" - testFileLength = 11 + testFile = "testdata/file" + testFileLen = 11 ) +type wantRange struct { + start, end int64 // range [start,end) +} + var ServeFileRangeTests = []struct { - start, end int - r string - code int + r string + code int + ranges []wantRange }{ - {0, testFileLength, "", StatusOK}, - {0, 5, "0-4", StatusPartialContent}, - {2, testFileLength, "2-", StatusPartialContent}, - {testFileLength - 5, testFileLength, "-5", StatusPartialContent}, - {3, 8, "3-7", StatusPartialContent}, - {0, 0, "20-", StatusRequestedRangeNotSatisfiable}, + {r: "", code: StatusOK}, + {r: "bytes=0-4", code: StatusPartialContent, ranges: []wantRange{{0, 5}}}, + {r: "bytes=2-", code: StatusPartialContent, ranges: []wantRange{{2, testFileLen}}}, + {r: "bytes=-5", code: StatusPartialContent, ranges: []wantRange{{testFileLen - 5, testFileLen}}}, + {r: "bytes=3-7", code: StatusPartialContent, ranges: []wantRange{{3, 8}}}, + {r: "bytes=20-", code: StatusRequestedRangeNotSatisfiable}, + {r: "bytes=0-0,-2", code: StatusPartialContent, ranges: []wantRange{{0, 1}, {testFileLen - 2, testFileLen}}}, + {r: "bytes=0-1,5-8", code: StatusPartialContent, ranges: []wantRange{{0, 2}, {5, 9}}}, + {r: "bytes=0-1,5-", code: StatusPartialContent, ranges: []wantRange{{0, 2}, {5, testFileLen}}}, + {r: "bytes=0-,1-,2-,3-,4-", code: StatusOK}, // ignore wasteful range request } func TestServeFile(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "testdata/file") })) @@ -65,33 +77,86 @@ func TestServeFile(t *testing.T) { // straight GET _, body := getBody(t, "straight get", req) - if !equal(body, file) { + if !bytes.Equal(body, file) { t.Fatalf("body mismatch: got %q, want %q", body, file) } // Range tests - for i, rt := range ServeFileRangeTests { - req.Header.Set("Range", "bytes="+rt.r) - if rt.r == "" { - req.Header["Range"] = nil +Cases: + for _, rt := range ServeFileRangeTests { + if rt.r != "" { + req.Header.Set("Range", rt.r) } - r, body := getBody(t, fmt.Sprintf("test %d", i), req) - if r.StatusCode != rt.code { - t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, r.StatusCode, rt.code) + resp, body := getBody(t, fmt.Sprintf("range test %q", rt.r), req) + if resp.StatusCode != rt.code { + t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, resp.StatusCode, rt.code) } if rt.code == StatusRequestedRangeNotSatisfiable { continue } - h := fmt.Sprintf("bytes %d-%d/%d", rt.start, rt.end-1, testFileLength) - if rt.r == "" { - h = "" + wantContentRange := "" + if len(rt.ranges) == 1 { + rng := rt.ranges[0] + wantContentRange = fmt.Sprintf("bytes %d-%d/%d", rng.start, rng.end-1, testFileLen) + } + cr := resp.Header.Get("Content-Range") + if cr != wantContentRange { + t.Errorf("range=%q: Content-Range = %q, want %q", rt.r, cr, wantContentRange) } - cr := r.Header.Get("Content-Range") - if cr != h { - t.Errorf("header mismatch: range=%q: got %q, want %q", rt.r, cr, h) + ct := resp.Header.Get("Content-Type") + if len(rt.ranges) == 1 { + rng := rt.ranges[0] + wantBody := file[rng.start:rng.end] + if !bytes.Equal(body, wantBody) { + t.Errorf("range=%q: body = %q, want %q", rt.r, body, wantBody) + } + if strings.HasPrefix(ct, "multipart/byteranges") { + t.Errorf("range=%q content-type = %q; unexpected multipart/byteranges", rt.r, ct) + } } - if !equal(body, file[rt.start:rt.end]) { - t.Errorf("body mismatch: range=%q: got %q, want %q", rt.r, body, file[rt.start:rt.end]) + if len(rt.ranges) > 1 { + typ, params, err := mime.ParseMediaType(ct) + if err != nil { + t.Errorf("range=%q content-type = %q; %v", rt.r, ct, err) + continue + } + if typ != "multipart/byteranges" { + t.Errorf("range=%q content-type = %q; want multipart/byteranges", rt.r, typ) + continue + } + if params["boundary"] == "" { + t.Errorf("range=%q content-type = %q; lacks boundary", rt.r, ct) + continue + } + if g, w := resp.ContentLength, int64(len(body)); g != w { + t.Errorf("range=%q Content-Length = %d; want %d", rt.r, g, w) + continue + } + mr := multipart.NewReader(bytes.NewReader(body), params["boundary"]) + for ri, rng := range rt.ranges { + part, err := mr.NextPart() + if err != nil { + t.Errorf("range=%q, reading part index %d: %v", rt.r, ri, err) + continue Cases + } + wantContentRange = fmt.Sprintf("bytes %d-%d/%d", rng.start, rng.end-1, testFileLen) + if g, w := part.Header.Get("Content-Range"), wantContentRange; g != w { + t.Errorf("range=%q: part Content-Range = %q; want %q", rt.r, g, w) + } + body, err := ioutil.ReadAll(part) + if err != nil { + t.Errorf("range=%q, reading part index %d body: %v", rt.r, ri, err) + continue Cases + } + wantBody := file[rng.start:rng.end] + if !bytes.Equal(body, wantBody) { + t.Errorf("range=%q: body = %q, want %q", rt.r, body, wantBody) + } + } + _, err = mr.NextPart() + if err != io.EOF { + t.Errorf("range=%q; expected final error io.EOF; got %v", rt.r, err) + } } } } @@ -105,6 +170,7 @@ var fsRedirectTestData = []struct { } func TestFSRedirect(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(StripPrefix("/test", FileServer(Dir(".")))) defer ts.Close() @@ -129,6 +195,7 @@ func (fs *testFileSystem) Open(name string) (File, error) { } func TestFileServerCleans(t *testing.T) { + defer checkLeakedTransports(t) ch := make(chan string, 1) fs := FileServer(&testFileSystem{func(name string) (File, error) { ch <- name @@ -160,6 +227,7 @@ func mustRemoveAll(dir string) { } func TestFileServerImplicitLeadingSlash(t *testing.T) { + defer checkLeakedTransports(t) tempDir, err := ioutil.TempDir("", "") if err != nil { t.Fatalf("TempDir: %v", err) @@ -193,8 +261,7 @@ func TestFileServerImplicitLeadingSlash(t *testing.T) { func TestDirJoin(t *testing.T) { wfi, err := os.Stat("/etc/hosts") if err != nil { - t.Logf("skipping test; no /etc/hosts file") - return + t.Skip("skipping test; no /etc/hosts file") } test := func(d Dir, name string) { f, err := d.Open(name) @@ -239,6 +306,7 @@ func TestEmptyDirOpenCWD(t *testing.T) { } func TestServeFileContentType(t *testing.T) { + defer checkLeakedTransports(t) const ctype = "icecream/chocolate" ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.FormValue("override") == "1" { @@ -255,12 +323,14 @@ func TestServeFileContentType(t *testing.T) { if h := resp.Header.Get("Content-Type"); h != want { t.Errorf("Content-Type mismatch: got %q, want %q", h, want) } + resp.Body.Close() } get("0", "text/plain; charset=utf-8") get("1", ctype) } func TestServeFileMimeType(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "testdata/style.css") })) @@ -269,6 +339,7 @@ func TestServeFileMimeType(t *testing.T) { if err != nil { t.Fatal(err) } + resp.Body.Close() want := "text/css; charset=utf-8" if h := resp.Header.Get("Content-Type"); h != want { t.Errorf("Content-Type mismatch: got %q, want %q", h, want) @@ -276,6 +347,7 @@ func TestServeFileMimeType(t *testing.T) { } func TestServeFileFromCWD(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "fs_test.go") })) @@ -284,12 +356,14 @@ func TestServeFileFromCWD(t *testing.T) { if err != nil { t.Fatal(err) } + r.Body.Close() if r.StatusCode != 200 { t.Fatalf("expected 200 OK, got %s", r.Status) } } func TestServeFileWithContentEncoding(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", "foo") ServeFile(w, r, "testdata/file") @@ -299,12 +373,14 @@ func TestServeFileWithContentEncoding(t *testing.T) { if err != nil { t.Fatal(err) } + resp.Body.Close() if g, e := resp.ContentLength, int64(-1); g != e { t.Errorf("Content-Length mismatch: got %d, want %d", g, e) } } func TestServeIndexHtml(t *testing.T) { + defer checkLeakedTransports(t) const want = "index.html says hello\n" ts := httptest.NewServer(FileServer(Dir("."))) defer ts.Close() @@ -325,64 +401,289 @@ func TestServeIndexHtml(t *testing.T) { } } -func TestServeContent(t *testing.T) { - type req struct { - name string - modtime time.Time - content io.ReadSeeker - } - ch := make(chan req, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - p := <-ch - ServeContent(w, r, p.name, p.modtime, p.content) - })) +func TestFileServerZeroByte(t *testing.T) { + defer checkLeakedTransports(t) + ts := httptest.NewServer(FileServer(Dir("."))) defer ts.Close() - css, err := os.Open("testdata/style.css") + res, err := Get(ts.URL + "/..\x00") if err != nil { t.Fatal(err) } - defer css.Close() + b, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal("reading Body:", err) + } + if res.StatusCode == 200 { + t.Errorf("got status 200; want an error. Body is:\n%s", string(b)) + } +} + +type fakeFileInfo struct { + dir bool + basename string + modtime time.Time + ents []*fakeFileInfo + contents string +} + +func (f *fakeFileInfo) Name() string { return f.basename } +func (f *fakeFileInfo) Sys() interface{} { return nil } +func (f *fakeFileInfo) ModTime() time.Time { return f.modtime } +func (f *fakeFileInfo) IsDir() bool { return f.dir } +func (f *fakeFileInfo) Size() int64 { return int64(len(f.contents)) } +func (f *fakeFileInfo) Mode() os.FileMode { + if f.dir { + return 0755 | os.ModeDir + } + return 0644 +} + +type fakeFile struct { + io.ReadSeeker + fi *fakeFileInfo + path string // as opened +} + +func (f *fakeFile) Close() error { return nil } +func (f *fakeFile) Stat() (os.FileInfo, error) { return f.fi, nil } +func (f *fakeFile) Readdir(count int) ([]os.FileInfo, error) { + if !f.fi.dir { + return nil, os.ErrInvalid + } + var fis []os.FileInfo + for _, fi := range f.fi.ents { + fis = append(fis, fi) + } + return fis, nil +} + +type fakeFS map[string]*fakeFileInfo + +func (fs fakeFS) Open(name string) (File, error) { + name = path.Clean(name) + f, ok := fs[name] + if !ok { + println("fake filesystem didn't find file", name) + return nil, os.ErrNotExist + } + return &fakeFile{ReadSeeker: strings.NewReader(f.contents), fi: f, path: name}, nil +} + +func TestDirectoryIfNotModified(t *testing.T) { + defer checkLeakedTransports(t) + const indexContents = "I am a fake index.html file" + fileMod := time.Unix(1000000000, 0).UTC() + fileModStr := fileMod.Format(TimeFormat) + dirMod := time.Unix(123, 0).UTC() + indexFile := &fakeFileInfo{ + basename: "index.html", + modtime: fileMod, + contents: indexContents, + } + fs := fakeFS{ + "/": &fakeFileInfo{ + dir: true, + modtime: dirMod, + ents: []*fakeFileInfo{indexFile}, + }, + "/index.html": indexFile, + } + + ts := httptest.NewServer(FileServer(fs)) + defer ts.Close() - ch <- req{"style.css", time.Time{}, css} res, err := Get(ts.URL) if err != nil { t.Fatal(err) } - if g, e := res.Header.Get("Content-Type"), "text/css; charset=utf-8"; g != e { - t.Errorf("style.css: content type = %q, want %q", g, e) + b, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) } - if g := res.Header.Get("Last-Modified"); g != "" { - t.Errorf("want empty Last-Modified; got %q", g) + if string(b) != indexContents { + t.Fatalf("Got body %q; want %q", b, indexContents) } + res.Body.Close() + + lastMod := res.Header.Get("Last-Modified") + if lastMod != fileModStr { + t.Fatalf("initial Last-Modified = %q; want %q", lastMod, fileModStr) + } + + req, _ := NewRequest("GET", ts.URL, nil) + req.Header.Set("If-Modified-Since", lastMod) - fi, err := css.Stat() + res, err = DefaultClient.Do(req) if err != nil { t.Fatal(err) } - ch <- req{"style.html", fi.ModTime(), css} - res, err = Get(ts.URL) + if res.StatusCode != 304 { + t.Fatalf("Code after If-Modified-Since request = %v; want 304", res.StatusCode) + } + res.Body.Close() + + // Advance the index.html file's modtime, but not the directory's. + indexFile.modtime = indexFile.modtime.Add(1 * time.Hour) + + res, err = DefaultClient.Do(req) if err != nil { t.Fatal(err) } - if g, e := res.Header.Get("Content-Type"), "text/html; charset=utf-8"; g != e { - t.Errorf("style.html: content type = %q, want %q", g, e) + if res.StatusCode != 200 { + t.Fatalf("Code after second If-Modified-Since request = %v; want 200; res is %#v", res.StatusCode, res) } - if g := res.Header.Get("Last-Modified"); g == "" { - t.Errorf("want non-empty last-modified") + res.Body.Close() +} + +func mustStat(t *testing.T, fileName string) os.FileInfo { + fi, err := os.Stat(fileName) + if err != nil { + t.Fatal(err) + } + return fi +} + +func TestServeContent(t *testing.T) { + defer checkLeakedTransports(t) + type serveParam struct { + name string + modtime time.Time + content io.ReadSeeker + contentType string + etag string + } + servec := make(chan serveParam, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + p := <-servec + if p.etag != "" { + w.Header().Set("ETag", p.etag) + } + if p.contentType != "" { + w.Header().Set("Content-Type", p.contentType) + } + ServeContent(w, r, p.name, p.modtime, p.content) + })) + defer ts.Close() + + type testCase struct { + file string + modtime time.Time + serveETag string // optional + serveContentType string // optional + reqHeader map[string]string + wantLastMod string + wantContentType string + wantStatus int + } + htmlModTime := mustStat(t, "testdata/index.html").ModTime() + tests := map[string]testCase{ + "no_last_modified": { + file: "testdata/style.css", + wantContentType: "text/css; charset=utf-8", + wantStatus: 200, + }, + "with_last_modified": { + file: "testdata/index.html", + wantContentType: "text/html; charset=utf-8", + modtime: htmlModTime, + wantLastMod: htmlModTime.UTC().Format(TimeFormat), + wantStatus: 200, + }, + "not_modified_modtime": { + file: "testdata/style.css", + modtime: htmlModTime, + reqHeader: map[string]string{ + "If-Modified-Since": htmlModTime.UTC().Format(TimeFormat), + }, + wantStatus: 304, + }, + "not_modified_modtime_with_contenttype": { + file: "testdata/style.css", + serveContentType: "text/css", // explicit content type + modtime: htmlModTime, + reqHeader: map[string]string{ + "If-Modified-Since": htmlModTime.UTC().Format(TimeFormat), + }, + wantStatus: 304, + }, + "not_modified_etag": { + file: "testdata/style.css", + serveETag: `"foo"`, + reqHeader: map[string]string{ + "If-None-Match": `"foo"`, + }, + wantStatus: 304, + }, + "range_good": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "Range": "bytes=0-4", + }, + wantStatus: StatusPartialContent, + wantContentType: "text/css; charset=utf-8", + }, + // An If-Range resource for entity "A", but entity "B" is now current. + // The Range request should be ignored. + "range_no_match": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "Range": "bytes=0-4", + "If-Range": `"B"`, + }, + wantStatus: 200, + wantContentType: "text/css; charset=utf-8", + }, + } + for testName, tt := range tests { + f, err := os.Open(tt.file) + if err != nil { + t.Fatalf("test %q: %v", testName, err) + } + defer f.Close() + + servec <- serveParam{ + name: filepath.Base(tt.file), + content: f, + modtime: tt.modtime, + etag: tt.serveETag, + contentType: tt.serveContentType, + } + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + for k, v := range tt.reqHeader { + req.Header.Set(k, v) + } + res, err := DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + io.Copy(ioutil.Discard, res.Body) + res.Body.Close() + if res.StatusCode != tt.wantStatus { + t.Errorf("test %q: status = %d; want %d", testName, res.StatusCode, tt.wantStatus) + } + if g, e := res.Header.Get("Content-Type"), tt.wantContentType; g != e { + t.Errorf("test %q: content-type = %q, want %q", testName, g, e) + } + if g, e := res.Header.Get("Last-Modified"), tt.wantLastMod; g != e { + t.Errorf("test %q: last-modified = %q, want %q", testName, g, e) + } } } // verifies that sendfile is being used on Linux func TestLinuxSendfile(t *testing.T) { + defer checkLeakedTransports(t) if runtime.GOOS != "linux" { - t.Logf("skipping; linux-only test") - return + t.Skip("skipping; linux-only test") } - _, err := exec.LookPath("strace") - if err != nil { - t.Logf("skipping; strace not found in path") - return + if _, err := exec.LookPath("strace"); err != nil { + t.Skip("skipping; strace not found in path") } ln, err := net.Listen("tcp", "127.0.0.1:0") @@ -401,10 +702,8 @@ func TestLinuxSendfile(t *testing.T) { child.Env = append([]string{"GO_WANT_HELPER_PROCESS=1"}, os.Environ()...) child.Stdout = &buf child.Stderr = &buf - err = child.Start() - if err != nil { - t.Logf("skipping; failed to start straced child: %v", err) - return + if err := child.Start(); err != nil { + t.Skipf("skipping; failed to start straced child: %v", err) } res, err := Get(fmt.Sprintf("http://%s/", ln.Addr())) @@ -464,15 +763,3 @@ func TestLinuxSendfileChild(*testing.T) { panic(err) } } - -func equal(a, b []byte) bool { - if len(a) != len(b) { - return false - } - for i := range a { - if a[i] != b[i] { - return false - } - } - return true -} diff --git a/src/pkg/net/http/header.go b/src/pkg/net/http/header.go index b107c312d..f479b7b4e 100644 --- a/src/pkg/net/http/header.go +++ b/src/pkg/net/http/header.go @@ -5,11 +5,11 @@ package http import ( - "fmt" "io" "net/textproto" "sort" "strings" + "time" ) // A Header represents the key-value pairs in an HTTP header. @@ -36,6 +36,14 @@ func (h Header) Get(key string) string { return textproto.MIMEHeader(h).Get(key) } +// get is like Get, but key must already be in CanonicalHeaderKey form. +func (h Header) get(key string) string { + if v := h[key]; len(v) > 0 { + return v[0] + } + return "" +} + // Del deletes the values associated with key. func (h Header) Del(key string) { textproto.MIMEHeader(h).Del(key) @@ -46,24 +54,87 @@ func (h Header) Write(w io.Writer) error { return h.WriteSubset(w, nil) } +func (h Header) clone() Header { + h2 := make(Header, len(h)) + for k, vv := range h { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + h2[k] = vv2 + } + return h2 +} + +var timeFormats = []string{ + TimeFormat, + time.RFC850, + time.ANSIC, +} + +// ParseTime parses a time header (such as the Date: header), +// trying each of the three formats allowed by HTTP/1.1: +// TimeFormat, time.RFC850, and time.ANSIC. +func ParseTime(text string) (t time.Time, err error) { + for _, layout := range timeFormats { + t, err = time.Parse(layout, text) + if err == nil { + return + } + } + return +} + var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ") +type writeStringer interface { + WriteString(string) (int, error) +} + +// stringWriter implements WriteString on a Writer. +type stringWriter struct { + w io.Writer +} + +func (w stringWriter) WriteString(s string) (n int, err error) { + return w.w.Write([]byte(s)) +} + +type keyValues struct { + key string + values []string +} + +type byKey []keyValues + +func (s byKey) Len() int { return len(s) } +func (s byKey) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s byKey) Less(i, j int) bool { return s[i].key < s[j].key } + +func (h Header) sortedKeyValues(exclude map[string]bool) []keyValues { + kvs := make([]keyValues, 0, len(h)) + for k, vv := range h { + if !exclude[k] { + kvs = append(kvs, keyValues{k, vv}) + } + } + sort.Sort(byKey(kvs)) + return kvs +} + // WriteSubset writes a header in wire format. // If exclude is not nil, keys where exclude[key] == true are not written. func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error { - keys := make([]string, 0, len(h)) - for k := range h { - if exclude == nil || !exclude[k] { - keys = append(keys, k) - } + ws, ok := w.(writeStringer) + if !ok { + ws = stringWriter{w} } - sort.Strings(keys) - for _, k := range keys { - for _, v := range h[k] { + for _, kv := range h.sortedKeyValues(exclude) { + for _, v := range kv.values { v = headerNewlineToSpace.Replace(v) - v = strings.TrimSpace(v) - if _, err := fmt.Fprintf(w, "%s: %s\r\n", k, v); err != nil { - return err + v = textproto.TrimString(v) + for _, s := range []string{kv.key, ": ", v, "\r\n"} { + if _, err := ws.WriteString(s); err != nil { + return err + } } } } @@ -76,3 +147,43 @@ func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error { // the rest are converted to lowercase. For example, the // canonical key for "accept-encoding" is "Accept-Encoding". func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } + +// hasToken returns whether token appears with v, ASCII +// case-insensitive, with space or comma boundaries. +// token must be all lowercase. +// v may contain mixed cased. +func hasToken(v, token string) bool { + if len(token) > len(v) || token == "" { + return false + } + if v == token { + return true + } + for sp := 0; sp <= len(v)-len(token); sp++ { + // Check that first character is good. + // The token is ASCII, so checking only a single byte + // is sufficient. We skip this potential starting + // position if both the first byte and its potential + // ASCII uppercase equivalent (b|0x20) don't match. + // False positives ('^' => '~') are caught by EqualFold. + if b := v[sp]; b != token[0] && b|0x20 != token[0] { + continue + } + // Check that start pos is on a valid token boundary. + if sp > 0 && !isTokenBoundary(v[sp-1]) { + continue + } + // Check that end pos is on a valid token boundary. + if endPos := sp + len(token); endPos != len(v) && !isTokenBoundary(v[endPos]) { + continue + } + if strings.EqualFold(v[sp:sp+len(token)], token) { + return true + } + } + return false +} + +func isTokenBoundary(b byte) bool { + return b == ' ' || b == ',' || b == '\t' +} diff --git a/src/pkg/net/http/header_test.go b/src/pkg/net/http/header_test.go index ccdee8a97..2313b5549 100644 --- a/src/pkg/net/http/header_test.go +++ b/src/pkg/net/http/header_test.go @@ -7,6 +7,7 @@ package http import ( "bytes" "testing" + "time" ) var headerWriteTests = []struct { @@ -67,6 +68,24 @@ var headerWriteTests = []struct { nil, "Blank: \r\nDouble-Blank: \r\nDouble-Blank: \r\n", }, + // Tests header sorting when over the insertion sort threshold side: + { + Header{ + "k1": {"1a", "1b"}, + "k2": {"2a", "2b"}, + "k3": {"3a", "3b"}, + "k4": {"4a", "4b"}, + "k5": {"5a", "5b"}, + "k6": {"6a", "6b"}, + "k7": {"7a", "7b"}, + "k8": {"8a", "8b"}, + "k9": {"9a", "9b"}, + }, + map[string]bool{"k5": true}, + "k1: 1a\r\nk1: 1b\r\nk2: 2a\r\nk2: 2b\r\nk3: 3a\r\nk3: 3b\r\n" + + "k4: 4a\r\nk4: 4b\r\nk6: 6a\r\nk6: 6b\r\n" + + "k7: 7a\r\nk7: 7b\r\nk8: 8a\r\nk8: 8b\r\nk9: 9a\r\nk9: 9b\r\n", + }, } func TestHeaderWrite(t *testing.T) { @@ -79,3 +98,107 @@ func TestHeaderWrite(t *testing.T) { buf.Reset() } } + +var parseTimeTests = []struct { + h Header + err bool +}{ + {Header{"Date": {""}}, true}, + {Header{"Date": {"invalid"}}, true}, + {Header{"Date": {"1994-11-06T08:49:37Z00:00"}}, true}, + {Header{"Date": {"Sun, 06 Nov 1994 08:49:37 GMT"}}, false}, + {Header{"Date": {"Sunday, 06-Nov-94 08:49:37 GMT"}}, false}, + {Header{"Date": {"Sun Nov 6 08:49:37 1994"}}, false}, +} + +func TestParseTime(t *testing.T) { + expect := time.Date(1994, 11, 6, 8, 49, 37, 0, time.UTC) + for i, test := range parseTimeTests { + d, err := ParseTime(test.h.Get("Date")) + if err != nil { + if !test.err { + t.Errorf("#%d:\n got err: %v", i, err) + } + continue + } + if test.err { + t.Errorf("#%d:\n should err", i) + continue + } + if !expect.Equal(d) { + t.Errorf("#%d:\n got: %v\nwant: %v", i, d, expect) + } + } +} + +type hasTokenTest struct { + header string + token string + want bool +} + +var hasTokenTests = []hasTokenTest{ + {"", "", false}, + {"", "foo", false}, + {"foo", "foo", true}, + {"foo ", "foo", true}, + {" foo", "foo", true}, + {" foo ", "foo", true}, + {"foo,bar", "foo", true}, + {"bar,foo", "foo", true}, + {"bar, foo", "foo", true}, + {"bar,foo, baz", "foo", true}, + {"bar, foo,baz", "foo", true}, + {"bar,foo, baz", "foo", true}, + {"bar, foo, baz", "foo", true}, + {"FOO", "foo", true}, + {"FOO ", "foo", true}, + {" FOO", "foo", true}, + {" FOO ", "foo", true}, + {"FOO,BAR", "foo", true}, + {"BAR,FOO", "foo", true}, + {"BAR, FOO", "foo", true}, + {"BAR,FOO, baz", "foo", true}, + {"BAR, FOO,BAZ", "foo", true}, + {"BAR,FOO, BAZ", "foo", true}, + {"BAR, FOO, BAZ", "foo", true}, + {"foobar", "foo", false}, + {"barfoo ", "foo", false}, +} + +func TestHasToken(t *testing.T) { + for _, tt := range hasTokenTests { + if hasToken(tt.header, tt.token) != tt.want { + t.Errorf("hasToken(%q, %q) = %v; want %v", tt.header, tt.token, !tt.want, tt.want) + } + } +} + +var testHeader = Header{ + "Content-Length": {"123"}, + "Content-Type": {"text/plain"}, + "Date": {"some date at some time Z"}, + "Server": {"Go http package"}, +} + +var buf bytes.Buffer + +func BenchmarkHeaderWriteSubset(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + buf.Reset() + testHeader.WriteSubset(&buf, nil) + } +} + +func TestHeaderWriteSubsetMallocs(t *testing.T) { + n := testing.AllocsPerRun(100, func() { + buf.Reset() + testHeader.WriteSubset(&buf, nil) + }) + if n > 1 { + // TODO(bradfitz,rsc): once we can sort without allocating, + // make this an error. See http://golang.org/issue/3761 + // t.Errorf("got %v allocs, want <= %v", n, 1) + } +} diff --git a/src/pkg/net/http/httptest/example_test.go b/src/pkg/net/http/httptest/example_test.go new file mode 100644 index 000000000..239470d97 --- /dev/null +++ b/src/pkg/net/http/httptest/example_test.go @@ -0,0 +1,50 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptest_test + +import ( + "fmt" + "io/ioutil" + "log" + "net/http" + "net/http/httptest" +) + +func ExampleRecorder() { + handler := func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "something failed", http.StatusInternalServerError) + } + + req, err := http.NewRequest("GET", "http://example.com/foo", nil) + if err != nil { + log.Fatal(err) + } + + w := httptest.NewRecorder() + handler(w, req) + + fmt.Printf("%d - %s", w.Code, w.Body.String()) + // Output: 500 - something failed +} + +func ExampleServer() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, client") + })) + defer ts.Close() + + res, err := http.Get(ts.URL) + if err != nil { + log.Fatal(err) + } + greeting, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%s", greeting) + // Output: Hello, client +} diff --git a/src/pkg/net/http/httptest/recorder.go b/src/pkg/net/http/httptest/recorder.go index 9aa0d510b..5451f5423 100644 --- a/src/pkg/net/http/httptest/recorder.go +++ b/src/pkg/net/http/httptest/recorder.go @@ -17,6 +17,8 @@ type ResponseRecorder struct { HeaderMap http.Header // the HTTP response headers Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to Flushed bool + + wroteHeader bool } // NewRecorder returns an initialized ResponseRecorder. @@ -24,6 +26,7 @@ func NewRecorder() *ResponseRecorder { return &ResponseRecorder{ HeaderMap: make(http.Header), Body: new(bytes.Buffer), + Code: 200, } } @@ -33,26 +36,37 @@ const DefaultRemoteAddr = "1.2.3.4" // Header returns the response headers. func (rw *ResponseRecorder) Header() http.Header { - return rw.HeaderMap + m := rw.HeaderMap + if m == nil { + m = make(http.Header) + rw.HeaderMap = m + } + return m } // Write always succeeds and writes to rw.Body, if not nil. func (rw *ResponseRecorder) Write(buf []byte) (int, error) { + if !rw.wroteHeader { + rw.WriteHeader(200) + } if rw.Body != nil { rw.Body.Write(buf) } - if rw.Code == 0 { - rw.Code = http.StatusOK - } return len(buf), nil } // WriteHeader sets rw.Code. func (rw *ResponseRecorder) WriteHeader(code int) { - rw.Code = code + if !rw.wroteHeader { + rw.Code = code + } + rw.wroteHeader = true } // Flush sets rw.Flushed to true. func (rw *ResponseRecorder) Flush() { + if !rw.wroteHeader { + rw.WriteHeader(200) + } rw.Flushed = true } diff --git a/src/pkg/net/http/httptest/recorder_test.go b/src/pkg/net/http/httptest/recorder_test.go new file mode 100644 index 000000000..2b563260c --- /dev/null +++ b/src/pkg/net/http/httptest/recorder_test.go @@ -0,0 +1,90 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptest + +import ( + "fmt" + "net/http" + "testing" +) + +func TestRecorder(t *testing.T) { + type checkFunc func(*ResponseRecorder) error + check := func(fns ...checkFunc) []checkFunc { return fns } + + hasStatus := func(wantCode int) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Code != wantCode { + return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode) + } + return nil + } + } + hasContents := func(want string) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Body.String() != want { + return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want) + } + return nil + } + } + hasFlush := func(want bool) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Flushed != want { + return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want) + } + return nil + } + } + + tests := []struct { + name string + h func(w http.ResponseWriter, r *http.Request) + checks []checkFunc + }{ + { + "200 default", + func(w http.ResponseWriter, r *http.Request) {}, + check(hasStatus(200), hasContents("")), + }, + { + "first code only", + func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(201) + w.WriteHeader(202) + w.Write([]byte("hi")) + }, + check(hasStatus(201), hasContents("hi")), + }, + { + "write sends 200", + func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hi first")) + w.WriteHeader(201) + w.WriteHeader(202) + }, + check(hasStatus(200), hasContents("hi first"), hasFlush(false)), + }, + { + "flush", + func(w http.ResponseWriter, r *http.Request) { + w.(http.Flusher).Flush() // also sends a 200 + w.WriteHeader(201) + }, + check(hasStatus(200), hasFlush(true)), + }, + } + r, _ := http.NewRequest("GET", "http://foo.com/", nil) + for _, tt := range tests { + h := http.HandlerFunc(tt.h) + rec := NewRecorder() + h.ServeHTTP(rec, r) + for _, check := range tt.checks { + if err := check(rec); err != nil { + t.Errorf("%s: %v", tt.name, err) + } + } + } +} diff --git a/src/pkg/net/http/httptest/server.go b/src/pkg/net/http/httptest/server.go index 57cf0c941..7f265552f 100644 --- a/src/pkg/net/http/httptest/server.go +++ b/src/pkg/net/http/httptest/server.go @@ -21,7 +21,11 @@ import ( type Server struct { URL string // base URL of form http://ipaddr:port with no trailing slash Listener net.Listener - TLS *tls.Config // nil if not using using TLS + + // TLS is the optional TLS configuration, populated with a new config + // after TLS is started. If set on an unstarted server before StartTLS + // is called, existing fields are copied into the new config. + TLS *tls.Config // Config may be changed after calling NewUnstartedServer and // before Start or StartTLS. @@ -36,13 +40,16 @@ type Server struct { // accepted. type historyListener struct { net.Listener - history []net.Conn + sync.Mutex // protects history + history []net.Conn } func (hs *historyListener) Accept() (c net.Conn, err error) { c, err = hs.Listener.Accept() if err == nil { + hs.Lock() hs.history = append(hs.history, c) + hs.Unlock() } return } @@ -96,7 +103,7 @@ func (s *Server) Start() { if s.URL != "" { panic("Server already started") } - s.Listener = &historyListener{s.Listener, make([]net.Conn, 0)} + s.Listener = &historyListener{Listener: s.Listener} s.URL = "http://" + s.Listener.Addr().String() s.wrapHandler() go s.Config.Serve(s.Listener) @@ -116,13 +123,20 @@ func (s *Server) StartTLS() { panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) } - s.TLS = &tls.Config{ - NextProtos: []string{"http/1.1"}, - Certificates: []tls.Certificate{cert}, + existingConfig := s.TLS + s.TLS = new(tls.Config) + if existingConfig != nil { + *s.TLS = *existingConfig + } + if s.TLS.NextProtos == nil { + s.TLS.NextProtos = []string{"http/1.1"} + } + if len(s.TLS.Certificates) == 0 { + s.TLS.Certificates = []tls.Certificate{cert} } tlsListener := tls.NewListener(s.Listener, s.TLS) - s.Listener = &historyListener{tlsListener, make([]net.Conn, 0)} + s.Listener = &historyListener{Listener: tlsListener} s.URL = "https://" + s.Listener.Addr().String() s.wrapHandler() go s.Config.Serve(s.Listener) @@ -152,6 +166,10 @@ func NewTLSServer(handler http.Handler) *Server { func (s *Server) Close() { s.Listener.Close() s.wg.Wait() + s.CloseClientConnections() + if t, ok := http.DefaultTransport.(*http.Transport); ok { + t.CloseIdleConnections() + } } // CloseClientConnections closes any currently open HTTP connections @@ -161,9 +179,11 @@ func (s *Server) CloseClientConnections() { if !ok { return } + hl.Lock() for _, conn := range hl.history { conn.Close() } + hl.Unlock() } // waitGroupHandler wraps a handler, incrementing and decrementing a @@ -180,28 +200,29 @@ func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.h.ServeHTTP(w, r) } -// localhostCert is a PEM-encoded TLS cert with SAN DNS names +// localhostCert is a PEM-encoded TLS cert with SAN IPs // "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end // of ASN.1 time). +// generated from src/pkg/crypto/tls: +// go run generate_cert.go --rsa-bits 512 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h var localhostCert = []byte(`-----BEGIN CERTIFICATE----- -MIIBOTCB5qADAgECAgEAMAsGCSqGSIb3DQEBBTAAMB4XDTcwMDEwMTAwMDAwMFoX -DTQ5MTIzMTIzNTk1OVowADBaMAsGCSqGSIb3DQEBAQNLADBIAkEAsuA5mAFMj6Q7 -qoBzcvKzIq4kzuT5epSp2AkcQfyBHm7K13Ws7u+0b5Vb9gqTf5cAiIKcrtrXVqkL -8i1UQF6AzwIDAQABo08wTTAOBgNVHQ8BAf8EBAMCACQwDQYDVR0OBAYEBAECAwQw -DwYDVR0jBAgwBoAEAQIDBDAbBgNVHREEFDASggkxMjcuMC4wLjGCBVs6OjFdMAsG -CSqGSIb3DQEBBQNBAJH30zjLWRztrWpOCgJL8RQWLaKzhK79pVhAx6q/3NrF16C7 -+l1BRZstTwIGdoGId8BRpErK1TXkniFb95ZMynM= ------END CERTIFICATE----- -`) +MIIBdzCCASOgAwIBAgIBADALBgkqhkiG9w0BAQUwEjEQMA4GA1UEChMHQWNtZSBD +bzAeFw03MDAxMDEwMDAwMDBaFw00OTEyMzEyMzU5NTlaMBIxEDAOBgNVBAoTB0Fj +bWUgQ28wWjALBgkqhkiG9w0BAQEDSwAwSAJBAN55NcYKZeInyTuhcCwFMhDHCmwa +IUSdtXdcbItRB/yfXGBhiex00IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEA +AaNoMGYwDgYDVR0PAQH/BAQDAgCkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1Ud +EwEB/wQFMAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAA +AAAAAAAAAAAAAAEwCwYJKoZIhvcNAQEFA0EAAoQn/ytgqpiLcZu9XKbCJsJcvkgk +Se6AbGXgSlq+ZCEVo0qIwSgeBqmsJxUu7NCSOwVJLYNEBO2DtIxoYVk+MA== +-----END CERTIFICATE-----`) // localhostKey is the private key for localhostCert. var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY----- -MIIBPQIBAAJBALLgOZgBTI+kO6qAc3LysyKuJM7k+XqUqdgJHEH8gR5uytd1rO7v -tG+VW/YKk3+XAIiCnK7a11apC/ItVEBegM8CAwEAAQJBAI5sxq7naeR9ahyqRkJi -SIv2iMxLuPEHaezf5CYOPWjSjBPyVhyRevkhtqEjF/WkgL7C2nWpYHsUcBDBQVF0 -3KECIQDtEGB2ulnkZAahl3WuJziXGLB+p8Wgx7wzSM6bHu1c6QIhAMEp++CaS+SJ -/TrU0zwY/fW4SvQeb49BPZUF3oqR8Xz3AiEA1rAJHBzBgdOQKdE3ksMUPcnvNJSN -poCcELmz2clVXtkCIQCLytuLV38XHToTipR4yMl6O+6arzAjZ56uq7m7ZRV0TwIh -AM65XAOw8Dsg9Kq78aYXiOEDc5DL0sbFUu/SlmRcCg93 ------END RSA PRIVATE KEY----- -`) +MIIBPAIBAAJBAN55NcYKZeInyTuhcCwFMhDHCmwaIUSdtXdcbItRB/yfXGBhiex0 +0IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEAAQJBAQdUx66rfh8sYsgfdcvV +NoafYpnEcB5s4m/vSVe6SU7dCK6eYec9f9wpT353ljhDUHq3EbmE4foNzJngh35d +AekCIQDhRQG5Li0Wj8TM4obOnnXUXf1jRv0UkzE9AHWLG5q3AwIhAPzSjpYUDjVW +MCUXgckTpKCuGwbJk7424Nb8bLzf3kllAiA5mUBgjfr/WtFSJdWcPQ4Zt9KTMNKD +EUO0ukpTwEIl6wIhAMbGqZK3zAAFdq8DD2jPx+UJXnh0rnOkZBzDtJ6/iN69AiEA +1Aq8MJgTaYsDQWyU/hDq5YkDJc9e9DSCvUIzqxQWMQE= +-----END RSA PRIVATE KEY-----`) diff --git a/src/pkg/net/http/httputil/chunked.go b/src/pkg/net/http/httputil/chunked.go index 29eaf3475..b66d40951 100644 --- a/src/pkg/net/http/httputil/chunked.go +++ b/src/pkg/net/http/httputil/chunked.go @@ -13,10 +13,9 @@ package httputil import ( "bufio" - "bytes" "errors" + "fmt" "io" - "strconv" ) const maxLineLength = 4096 // assumed <= bufio.defaultBufSize @@ -24,7 +23,7 @@ const maxLineLength = 4096 // assumed <= bufio.defaultBufSize var ErrLineTooLong = errors.New("header line too long") // NewChunkedReader returns a new chunkedReader that translates the data read from r -// out of HTTP "chunked" format before returning it. +// out of HTTP "chunked" format before returning it. // The chunkedReader returns io.EOF when the final 0-length chunk is read. // // NewChunkedReader is not needed by normal applications. The http package @@ -41,16 +40,17 @@ type chunkedReader struct { r *bufio.Reader n uint64 // unread bytes in chunk err error + buf [2]byte } func (cr *chunkedReader) beginChunk() { // chunk-size CRLF - var line string + var line []byte line, cr.err = readLine(cr.r) if cr.err != nil { return } - cr.n, cr.err = strconv.ParseUint(line, 16, 64) + cr.n, cr.err = parseHexUint(line) if cr.err != nil { return } @@ -76,9 +76,8 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) { cr.n -= uint64(n) if cr.n == 0 && cr.err == nil { // end of chunk (CRLF) - b := make([]byte, 2) - if _, cr.err = io.ReadFull(cr.r, b); cr.err == nil { - if b[0] != '\r' || b[1] != '\n' { + if _, cr.err = io.ReadFull(cr.r, cr.buf[:]); cr.err == nil { + if cr.buf[0] != '\r' || cr.buf[1] != '\n' { cr.err = errors.New("malformed chunked encoding") } } @@ -90,7 +89,7 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) { // Give up if the line exceeds maxLineLength. // The returned bytes are a pointer into storage in // the bufio, so they are only valid until the next bufio read. -func readLineBytes(b *bufio.Reader) (p []byte, err error) { +func readLine(b *bufio.Reader) (p []byte, err error) { if p, err = b.ReadSlice('\n'); err != nil { // We always know when EOF is coming. // If the caller asked for a line, there should be a line. @@ -104,20 +103,18 @@ func readLineBytes(b *bufio.Reader) (p []byte, err error) { if len(p) >= maxLineLength { return nil, ErrLineTooLong } - - // Chop off trailing white space. - p = bytes.TrimRight(p, " \r\t\n") - - return p, nil + return trimTrailingWhitespace(p), nil } -// readLineBytes, but convert the bytes into a string. -func readLine(b *bufio.Reader) (s string, err error) { - p, e := readLineBytes(b) - if e != nil { - return "", e +func trimTrailingWhitespace(b []byte) []byte { + for len(b) > 0 && isASCIISpace(b[len(b)-1]) { + b = b[:len(b)-1] } - return string(p), nil + return b +} + +func isASCIISpace(b byte) bool { + return b == ' ' || b == '\t' || b == '\n' || b == '\r' } // NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP @@ -149,9 +146,7 @@ func (cw *chunkedWriter) Write(data []byte) (n int, err error) { return 0, nil } - head := strconv.FormatInt(int64(len(data)), 16) + "\r\n" - - if _, err = io.WriteString(cw.Wire, head); err != nil { + if _, err = fmt.Fprintf(cw.Wire, "%x\r\n", len(data)); err != nil { return 0, err } if n, err = cw.Wire.Write(data); err != nil { @@ -170,3 +165,21 @@ func (cw *chunkedWriter) Close() error { _, err := io.WriteString(cw.Wire, "0\r\n") return err } + +func parseHexUint(v []byte) (n uint64, err error) { + for _, b := range v { + n <<= 4 + switch { + case '0' <= b && b <= '9': + b = b - '0' + case 'a' <= b && b <= 'f': + b = b - 'a' + 10 + case 'A' <= b && b <= 'F': + b = b - 'A' + 10 + default: + return 0, errors.New("invalid byte in chunk length") + } + n |= uint64(b) + } + return +} diff --git a/src/pkg/net/http/httputil/chunked_test.go b/src/pkg/net/http/httputil/chunked_test.go index 155a32bdf..a06bffad5 100644 --- a/src/pkg/net/http/httputil/chunked_test.go +++ b/src/pkg/net/http/httputil/chunked_test.go @@ -11,7 +11,10 @@ package httputil import ( "bytes" + "fmt" + "io" "io/ioutil" + "runtime" "testing" ) @@ -39,3 +42,54 @@ func TestChunk(t *testing.T) { t.Errorf("chunk reader read %q; want %q", g, e) } } + +func TestChunkReaderAllocs(t *testing.T) { + // temporarily set GOMAXPROCS to 1 as we are testing memory allocations + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) + var buf bytes.Buffer + w := NewChunkedWriter(&buf) + a, b, c := []byte("aaaaaa"), []byte("bbbbbbbbbbbb"), []byte("cccccccccccccccccccccccc") + w.Write(a) + w.Write(b) + w.Write(c) + w.Close() + + r := NewChunkedReader(&buf) + readBuf := make([]byte, len(a)+len(b)+len(c)+1) + + var ms runtime.MemStats + runtime.ReadMemStats(&ms) + m0 := ms.Mallocs + + n, err := io.ReadFull(r, readBuf) + + runtime.ReadMemStats(&ms) + mallocs := ms.Mallocs - m0 + if mallocs > 1 { + t.Errorf("%d mallocs; want <= 1", mallocs) + } + + if n != len(readBuf)-1 { + t.Errorf("read %d bytes; want %d", n, len(readBuf)-1) + } + if err != io.ErrUnexpectedEOF { + t.Errorf("read error = %v; want ErrUnexpectedEOF", err) + } +} + +func TestParseHexUint(t *testing.T) { + for i := uint64(0); i <= 1234; i++ { + line := []byte(fmt.Sprintf("%x", i)) + got, err := parseHexUint(line) + if err != nil { + t.Fatalf("on %d: %v", i, err) + } + if got != i { + t.Errorf("for input %q = %d; want %d", line, got, i) + } + } + _, err := parseHexUint([]byte("bogus")) + if err == nil { + t.Error("expected error on bogus input") + } +} diff --git a/src/pkg/net/http/httputil/dump.go b/src/pkg/net/http/httputil/dump.go index 892ef4ede..0b0035661 100644 --- a/src/pkg/net/http/httputil/dump.go +++ b/src/pkg/net/http/httputil/dump.go @@ -75,7 +75,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) { // Use the actual Transport code to record what we would send // on the wire, but not using TCP. Use a Transport with a - // customer dialer that returns a fake net.Conn that waits + // custom dialer that returns a fake net.Conn that waits // for the full input (and recording it), and then responds // with a dummy response. var buf bytes.Buffer // records the output @@ -89,7 +89,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) { t := &http.Transport{ Dial: func(net, addr string) (net.Conn, error) { - return &dumpConn{io.MultiWriter(pw, &buf), dr}, nil + return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil }, } diff --git a/src/pkg/net/http/httputil/reverseproxy.go b/src/pkg/net/http/httputil/reverseproxy.go index 9c4bd6e09..134c45299 100644 --- a/src/pkg/net/http/httputil/reverseproxy.go +++ b/src/pkg/net/http/httputil/reverseproxy.go @@ -17,6 +17,10 @@ import ( "time" ) +// onExitFlushLoop is a callback set by tests to detect the state of the +// flushLoop() goroutine. +var onExitFlushLoop func() + // ReverseProxy is an HTTP Handler that takes an incoming request and // sends it to another server, proxying the response back to the // client. @@ -102,8 +106,14 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { outreq.Header.Del("Connection") } - if clientIp, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { - outreq.Header.Set("X-Forwarded-For", clientIp) + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + if prior, ok := outreq.Header["X-Forwarded-For"]; ok { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + outreq.Header.Set("X-Forwarded-For", clientIP) } res, err := transport.RoundTrip(outreq) @@ -112,20 +122,29 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusInternalServerError) return } + defer res.Body.Close() copyHeader(rw.Header(), res.Header) rw.WriteHeader(res.StatusCode) + p.copyResponse(rw, res.Body) +} - if res.Body != nil { - var dst io.Writer = rw - if p.FlushInterval != 0 { - if wf, ok := rw.(writeFlusher); ok { - dst = &maxLatencyWriter{dst: wf, latency: p.FlushInterval} +func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { + if p.FlushInterval != 0 { + if wf, ok := dst.(writeFlusher); ok { + mlw := &maxLatencyWriter{ + dst: wf, + latency: p.FlushInterval, + done: make(chan bool), } + go mlw.flushLoop() + defer mlw.stop() + dst = mlw } - io.Copy(dst, res.Body) } + + io.Copy(dst, src) } type writeFlusher interface { @@ -137,22 +156,14 @@ type maxLatencyWriter struct { dst writeFlusher latency time.Duration - lk sync.Mutex // protects init of done, as well Write + Flush + lk sync.Mutex // protects Write + Flush done chan bool } -func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { +func (m *maxLatencyWriter) Write(p []byte) (int, error) { m.lk.Lock() defer m.lk.Unlock() - if m.done == nil { - m.done = make(chan bool) - go m.flushLoop() - } - n, err = m.dst.Write(p) - if err != nil { - m.done <- true - } - return + return m.dst.Write(p) } func (m *maxLatencyWriter) flushLoop() { @@ -160,13 +171,18 @@ func (m *maxLatencyWriter) flushLoop() { defer t.Stop() for { select { + case <-m.done: + if onExitFlushLoop != nil { + onExitFlushLoop() + } + return case <-t.C: m.lk.Lock() m.dst.Flush() m.lk.Unlock() - case <-m.done: - return } } panic("unreached") } + +func (m *maxLatencyWriter) stop() { m.done <- true } diff --git a/src/pkg/net/http/httputil/reverseproxy_test.go b/src/pkg/net/http/httputil/reverseproxy_test.go index 28e9c90ad..863927162 100644 --- a/src/pkg/net/http/httputil/reverseproxy_test.go +++ b/src/pkg/net/http/httputil/reverseproxy_test.go @@ -11,7 +11,9 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" + "time" ) func TestReverseProxy(t *testing.T) { @@ -70,6 +72,47 @@ func TestReverseProxy(t *testing.T) { } } +func TestXForwardedFor(t *testing.T) { + const prevForwardedFor = "client ip" + const backendResponse = "I am the backend" + const backendStatus = 404 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Forwarded-For") == "" { + t.Errorf("didn't get X-Forwarded-For header") + } + if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) { + t.Errorf("X-Forwarded-For didn't contain prior data") + } + w.WriteHeader(backendStatus) + w.Write([]byte(backendResponse)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Host = "some-name" + getReq.Header.Set("Connection", "close") + getReq.Header.Set("X-Forwarded-For", prevForwardedFor) + getReq.Close = true + res, err := http.DefaultClient.Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + bodyBytes, _ := ioutil.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} + var proxyQueryTests = []struct { baseSuffix string // suffix to add to backend URL reqSuffix string // suffix to add to frontend's request URL @@ -107,3 +150,44 @@ func TestReverseProxyQuery(t *testing.T) { frontend.Close() } } + +func TestReverseProxyFlushInterval(t *testing.T) { + const expected = "hi" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(expected)) + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.FlushInterval = time.Microsecond + + done := make(chan bool) + onExitFlushLoop = func() { done <- true } + defer func() { onExitFlushLoop = nil }() + + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected { + t.Errorf("got body %q; expected %q", bodyBytes, expected) + } + + select { + case <-done: + // OK + case <-time.After(5 * time.Second): + t.Error("maxLatencyWriter flushLoop() never exited") + } +} diff --git a/src/pkg/net/http/jar.go b/src/pkg/net/http/jar.go index 2c2caa251..5c3de0dad 100644 --- a/src/pkg/net/http/jar.go +++ b/src/pkg/net/http/jar.go @@ -8,23 +8,20 @@ import ( "net/url" ) -// A CookieJar manages storage and use of cookies in HTTP requests. +// A CookieJar manages storage and use of cookies in HTTP requests. // // Implementations of CookieJar must be safe for concurrent use by multiple // goroutines. +// +// The net/http/cookiejar package provides a CookieJar implementation. type CookieJar interface { - // SetCookies handles the receipt of the cookies in a reply for the - // given URL. It may or may not choose to save the cookies, depending - // on the jar's policy and implementation. + // SetCookies handles the receipt of the cookies in a reply for the + // given URL. It may or may not choose to save the cookies, depending + // on the jar's policy and implementation. SetCookies(u *url.URL, cookies []*Cookie) // Cookies returns the cookies to send in a request for the given URL. - // It is up to the implementation to honor the standard cookie use - // restrictions such as in RFC 6265. + // It is up to the implementation to honor the standard cookie use + // restrictions such as in RFC 6265. Cookies(u *url.URL) []*Cookie } - -type blackHoleJar struct{} - -func (blackHoleJar) SetCookies(u *url.URL, cookies []*Cookie) {} -func (blackHoleJar) Cookies(u *url.URL) []*Cookie { return nil } diff --git a/src/pkg/net/http/lex.go b/src/pkg/net/http/lex.go index ffb393ccf..cb33318f4 100644 --- a/src/pkg/net/http/lex.go +++ b/src/pkg/net/http/lex.go @@ -6,131 +6,91 @@ package http // This file deals with lexical matters of HTTP -func isSeparator(c byte) bool { - switch c { - case '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}', ' ', '\t': - return true - } - return false +var isTokenTable = [127]bool{ + '!': true, + '#': true, + '$': true, + '%': true, + '&': true, + '\'': true, + '*': true, + '+': true, + '-': true, + '.': true, + '0': true, + '1': true, + '2': true, + '3': true, + '4': true, + '5': true, + '6': true, + '7': true, + '8': true, + '9': true, + 'A': true, + 'B': true, + 'C': true, + 'D': true, + 'E': true, + 'F': true, + 'G': true, + 'H': true, + 'I': true, + 'J': true, + 'K': true, + 'L': true, + 'M': true, + 'N': true, + 'O': true, + 'P': true, + 'Q': true, + 'R': true, + 'S': true, + 'T': true, + 'U': true, + 'W': true, + 'V': true, + 'X': true, + 'Y': true, + 'Z': true, + '^': true, + '_': true, + '`': true, + 'a': true, + 'b': true, + 'c': true, + 'd': true, + 'e': true, + 'f': true, + 'g': true, + 'h': true, + 'i': true, + 'j': true, + 'k': true, + 'l': true, + 'm': true, + 'n': true, + 'o': true, + 'p': true, + 'q': true, + 'r': true, + 's': true, + 't': true, + 'u': true, + 'v': true, + 'w': true, + 'x': true, + 'y': true, + 'z': true, + '|': true, + '~': true, } -func isCtl(c byte) bool { return (0 <= c && c <= 31) || c == 127 } - -func isChar(c byte) bool { return 0 <= c && c <= 127 } - -func isAnyText(c byte) bool { return !isCtl(c) } - -func isQdText(c byte) bool { return isAnyText(c) && c != '"' } - -func isToken(c byte) bool { return isChar(c) && !isCtl(c) && !isSeparator(c) } - -// Valid escaped sequences are not specified in RFC 2616, so for now, we assume -// that they coincide with the common sense ones used by GO. Malformed -// characters should probably not be treated as errors by a robust (forgiving) -// parser, so we replace them with the '?' character. -func httpUnquotePair(b byte) byte { - // skip the first byte, which should always be '\' - switch b { - case 'a': - return '\a' - case 'b': - return '\b' - case 'f': - return '\f' - case 'n': - return '\n' - case 'r': - return '\r' - case 't': - return '\t' - case 'v': - return '\v' - case '\\': - return '\\' - case '\'': - return '\'' - case '"': - return '"' - } - return '?' -} - -// raw must begin with a valid quoted string. Only the first quoted string is -// parsed and is unquoted in result. eaten is the number of bytes parsed, or -1 -// upon failure. -func httpUnquote(raw []byte) (eaten int, result string) { - buf := make([]byte, len(raw)) - if raw[0] != '"' { - return -1, "" - } - eaten = 1 - j := 0 // # of bytes written in buf - for i := 1; i < len(raw); i++ { - switch b := raw[i]; b { - case '"': - eaten++ - buf = buf[0:j] - return i + 1, string(buf) - case '\\': - if len(raw) < i+2 { - return -1, "" - } - buf[j] = httpUnquotePair(raw[i+1]) - eaten += 2 - j++ - i++ - default: - if isQdText(b) { - buf[j] = b - } else { - buf[j] = '?' - } - eaten++ - j++ - } - } - return -1, "" +func isToken(r rune) bool { + i := int(r) + return i < len(isTokenTable) && isTokenTable[i] } -// This is a best effort parse, so errors are not returned, instead not all of -// the input string might be parsed. result is always non-nil. -func httpSplitFieldValue(fv string) (eaten int, result []string) { - result = make([]string, 0, len(fv)) - raw := []byte(fv) - i := 0 - chunk := "" - for i < len(raw) { - b := raw[i] - switch { - case b == '"': - eaten, unq := httpUnquote(raw[i:len(raw)]) - if eaten < 0 { - return i, result - } else { - i += eaten - chunk += unq - } - case isSeparator(b): - if chunk != "" { - result = result[0 : len(result)+1] - result[len(result)-1] = chunk - chunk = "" - } - i++ - case isToken(b): - chunk += string(b) - i++ - case b == '\n' || b == '\r': - i++ - default: - chunk += "?" - i++ - } - } - if chunk != "" { - result = result[0 : len(result)+1] - result[len(result)-1] = chunk - chunk = "" - } - return i, result +func isNotToken(r rune) bool { + return !isToken(r) } diff --git a/src/pkg/net/http/lex_test.go b/src/pkg/net/http/lex_test.go index 5386f7534..6d9d294f7 100644 --- a/src/pkg/net/http/lex_test.go +++ b/src/pkg/net/http/lex_test.go @@ -8,63 +8,24 @@ import ( "testing" ) -type lexTest struct { - Raw string - Parsed int // # of parsed characters - Result []string -} +func isChar(c rune) bool { return c <= 127 } -var lexTests = []lexTest{ - { - Raw: `"abc"def,:ghi`, - Parsed: 13, - Result: []string{"abcdef", "ghi"}, - }, - // My understanding of the RFC is that escape sequences outside of - // quotes are not interpreted? - { - Raw: `"\t"\t"\t"`, - Parsed: 10, - Result: []string{"\t", "t\t"}, - }, - { - Raw: `"\yab"\r\n`, - Parsed: 10, - Result: []string{"?ab", "r", "n"}, - }, - { - Raw: "ab\f", - Parsed: 3, - Result: []string{"ab?"}, - }, - { - Raw: "\"ab \" c,de f, gh, ij\n\t\r", - Parsed: 23, - Result: []string{"ab ", "c", "de", "f", "gh", "ij"}, - }, -} +func isCtl(c rune) bool { return c <= 31 || c == 127 } -func min(x, y int) int { - if x <= y { - return x +func isSeparator(c rune) bool { + switch c { + case '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}', ' ', '\t': + return true } - return y + return false } -func TestSplitFieldValue(t *testing.T) { - for k, l := range lexTests { - parsed, result := httpSplitFieldValue(l.Raw) - if parsed != l.Parsed { - t.Errorf("#%d: Parsed %d, expected %d", k, parsed, l.Parsed) - } - if len(result) != len(l.Result) { - t.Errorf("#%d: Result len %d, expected %d", k, len(result), len(l.Result)) - } - for i := 0; i < min(len(result), len(l.Result)); i++ { - if result[i] != l.Result[i] { - t.Errorf("#%d: %d-th entry mismatch. Have {%s}, expect {%s}", - k, i, result[i], l.Result[i]) - } +func TestIsToken(t *testing.T) { + for i := 0; i <= 130; i++ { + r := rune(i) + expected := isChar(r) && !isCtl(r) && !isSeparator(r) + if isToken(r) != expected { + t.Errorf("isToken(0x%x) = %v", r, !expected) } } } diff --git a/src/pkg/net/http/npn_test.go b/src/pkg/net/http/npn_test.go new file mode 100644 index 000000000..98b8930d0 --- /dev/null +++ b/src/pkg/net/http/npn_test.go @@ -0,0 +1,118 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "bufio" + "crypto/tls" + "fmt" + "io" + "io/ioutil" + . "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestNextProtoUpgrade(t *testing.T) { + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "path=%s,proto=", r.URL.Path) + if r.TLS != nil { + w.Write([]byte(r.TLS.NegotiatedProtocol)) + } + if r.RemoteAddr == "" { + t.Error("request with no RemoteAddr") + } + if r.Body == nil { + t.Errorf("request with nil Body") + } + })) + ts.TLS = &tls.Config{ + NextProtos: []string{"unhandled-proto", "tls-0.9"}, + } + ts.Config.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){ + "tls-0.9": handleTLSProtocol09, + } + ts.StartTLS() + defer ts.Close() + + tr := newTLSTransport(t, ts) + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + // Normal request, without NPN. + { + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if want := "path=/,proto="; string(body) != want { + t.Errorf("plain request = %q; want %q", body, want) + } + } + + // Request to an advertised but unhandled NPN protocol. + // Server will hang up. + { + tr.CloseIdleConnections() + tr.TLSClientConfig.NextProtos = []string{"unhandled-proto"} + _, err := c.Get(ts.URL) + if err == nil { + t.Errorf("expected error on unhandled-proto request") + } + } + + // Request using the "tls-0.9" protocol, which we register here. + // It is HTTP/0.9 over TLS. + { + tlsConfig := newTLSTransport(t, ts).TLSClientConfig + tlsConfig.NextProtos = []string{"tls-0.9"} + conn, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig) + if err != nil { + t.Fatal(err) + } + conn.Write([]byte("GET /foo\n")) + body, err := ioutil.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + if want := "path=/foo,proto=tls-0.9"; string(body) != want { + t.Errorf("plain request = %q; want %q", body, want) + } + } +} + +// handleTLSProtocol09 implements the HTTP/0.9 protocol over TLS, for the +// TestNextProtoUpgrade test. +func handleTLSProtocol09(srv *Server, conn *tls.Conn, h Handler) { + br := bufio.NewReader(conn) + line, err := br.ReadString('\n') + if err != nil { + return + } + line = strings.TrimSpace(line) + path := strings.TrimPrefix(line, "GET ") + if path == line { + return + } + req, _ := NewRequest("GET", path, nil) + req.Proto = "HTTP/0.9" + req.ProtoMajor = 0 + req.ProtoMinor = 9 + rw := &http09Writer{conn, make(Header)} + h.ServeHTTP(rw, req) +} + +type http09Writer struct { + io.Writer + h Header +} + +func (w http09Writer) Header() Header { return w.h } +func (w http09Writer) WriteHeader(int) {} // no headers diff --git a/src/pkg/net/http/pprof/pprof.go b/src/pkg/net/http/pprof/pprof.go index 06fcde144..0c7548e3e 100644 --- a/src/pkg/net/http/pprof/pprof.go +++ b/src/pkg/net/http/pprof/pprof.go @@ -14,6 +14,14 @@ // To use pprof, link this package into your program: // import _ "net/http/pprof" // +// If your application is not already running an http server, you +// need to start one. Add "net/http" and "log" to your imports and +// the following code to your main function: +// +// go func() { +// log.Println(http.ListenAndServe("localhost:6060", nil)) +// }() +// // Then use the pprof tool to look at the heap profile: // // go tool pprof http://localhost:6060/debug/pprof/heap @@ -22,9 +30,12 @@ // // go tool pprof http://localhost:6060/debug/pprof/profile // -// Or to view all available profiles: +// Or to look at the goroutine blocking profile: +// +// go tool pprof http://localhost:6060/debug/pprof/block // -// go tool pprof http://localhost:6060/debug/pprof/ +// To view all available profiles, open http://localhost:6060/debug/pprof/ +// in your browser. // // For a study of the facility in action, visit // @@ -161,7 +172,7 @@ func (name handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // listing the available profiles. func Index(w http.ResponseWriter, r *http.Request) { if strings.HasPrefix(r.URL.Path, "/debug/pprof/") { - name := r.URL.Path[len("/debug/pprof/"):] + name := strings.TrimPrefix(r.URL.Path, "/debug/pprof/") if name != "" { handler(name).ServeHTTP(w, r) return diff --git a/src/pkg/net/http/proxy_test.go b/src/pkg/net/http/proxy_test.go index 5ecffafac..449ccaeea 100644 --- a/src/pkg/net/http/proxy_test.go +++ b/src/pkg/net/http/proxy_test.go @@ -25,13 +25,13 @@ var UseProxyTests = []struct { {"[::2]", true}, // not a loopback address {"barbaz.net", false}, // match as .barbaz.net - {"foobar.com", false}, // have a port but match + {"foobar.com", false}, // have a port but match {"foofoobar.com", true}, // not match as a part of foobar.com {"baz.com", true}, // not match as a part of barbaz.com {"localhost.net", true}, // not match as suffix of address {"local.localhost", true}, // not match as prefix as address {"barbarbaz.net", true}, // not match because NO_PROXY have a '.' - {"www.foobar.com", true}, // not match because NO_PROXY is not .foobar.com + {"www.foobar.com", false}, // match because NO_PROXY includes "foobar.com" } func TestUseProxy(t *testing.T) { diff --git a/src/pkg/net/http/range_test.go b/src/pkg/net/http/range_test.go index 5274a81fa..ef911af7b 100644 --- a/src/pkg/net/http/range_test.go +++ b/src/pkg/net/http/range_test.go @@ -14,15 +14,34 @@ var ParseRangeTests = []struct { r []httpRange }{ {"", 0, nil}, + {"", 1000, nil}, {"foo", 0, nil}, {"bytes=", 0, nil}, + {"bytes=7", 10, nil}, + {"bytes= 7 ", 10, nil}, + {"bytes=1-", 0, nil}, {"bytes=5-4", 10, nil}, {"bytes=0-2,5-4", 10, nil}, + {"bytes=2-5,4-3", 10, nil}, + {"bytes=--5,4--3", 10, nil}, + {"bytes=A-", 10, nil}, + {"bytes=A- ", 10, nil}, + {"bytes=A-Z", 10, nil}, + {"bytes= -Z", 10, nil}, + {"bytes=5-Z", 10, nil}, + {"bytes=Ran-dom, garbage", 10, nil}, + {"bytes=0x01-0x02", 10, nil}, + {"bytes= ", 10, nil}, + {"bytes= , , , ", 10, nil}, + {"bytes=0-9", 10, []httpRange{{0, 10}}}, {"bytes=0-", 10, []httpRange{{0, 10}}}, {"bytes=5-", 10, []httpRange{{5, 5}}}, {"bytes=0-20", 10, []httpRange{{0, 10}}}, {"bytes=15-,0-5", 10, nil}, + {"bytes=1-2,5-", 10, []httpRange{{1, 2}, {5, 5}}}, + {"bytes=-2 , 7-", 11, []httpRange{{9, 2}, {7, 4}}}, + {"bytes=0-0 ,2-2, 7-", 11, []httpRange{{0, 1}, {2, 1}, {7, 4}}}, {"bytes=-5", 10, []httpRange{{5, 5}}}, {"bytes=-15", 10, []httpRange{{0, 10}}}, {"bytes=0-499", 10000, []httpRange{{0, 500}}}, @@ -32,6 +51,9 @@ var ParseRangeTests = []struct { {"bytes=0-0,-1", 10000, []httpRange{{0, 1}, {9999, 1}}}, {"bytes=500-600,601-999", 10000, []httpRange{{500, 101}, {601, 399}}}, {"bytes=500-700,601-999", 10000, []httpRange{{500, 201}, {601, 399}}}, + + // Match Apache laxity: + {"bytes= 1 -2 , 4- 5, 7 - 8 , ,,", 11, []httpRange{{1, 2}, {4, 2}, {7, 2}}}, } func TestParseRange(t *testing.T) { diff --git a/src/pkg/net/http/readrequest_test.go b/src/pkg/net/http/readrequest_test.go index 2e03c658a..ffdd6a892 100644 --- a/src/pkg/net/http/readrequest_test.go +++ b/src/pkg/net/http/readrequest_test.go @@ -247,6 +247,54 @@ var reqTests = []reqTest{ noTrailer, noError, }, + + // SSDP Notify request. golang.org/issue/3692 + { + "NOTIFY * HTTP/1.1\r\nServer: foo\r\n\r\n", + &Request{ + Method: "NOTIFY", + URL: &url.URL{ + Path: "*", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{ + "Server": []string{"foo"}, + }, + Close: false, + ContentLength: 0, + RequestURI: "*", + }, + + noBody, + noTrailer, + noError, + }, + + // OPTIONS request. Similar to golang.org/issue/3692 + { + "OPTIONS * HTTP/1.1\r\nServer: foo\r\n\r\n", + &Request{ + Method: "OPTIONS", + URL: &url.URL{ + Path: "*", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{ + "Server": []string{"foo"}, + }, + Close: false, + ContentLength: 0, + RequestURI: "*", + }, + + noBody, + noTrailer, + noError, + }, } func TestReadRequest(t *testing.T) { diff --git a/src/pkg/net/http/request.go b/src/pkg/net/http/request.go index f5bc6eb91..217f35b48 100644 --- a/src/pkg/net/http/request.go +++ b/src/pkg/net/http/request.go @@ -19,6 +19,7 @@ import ( "mime/multipart" "net/textproto" "net/url" + "strconv" "strings" ) @@ -70,7 +71,13 @@ var reqWriteExcludeHeader = map[string]bool{ // or to be sent by a client. type Request struct { Method string // GET, POST, PUT, etc. - URL *url.URL + + // URL is created from the URI supplied on the Request-Line + // as stored in RequestURI. + // + // For most requests, fields other than Path and RawQuery + // will be empty. (See RFC 2616, Section 5.1.2) + URL *url.URL // The protocol version for incoming requests. // Outgoing requests always use HTTP/1.1. @@ -123,6 +130,7 @@ type Request struct { // The host on which the URL is sought. // Per RFC 2616, this is either the value of the Host: header // or the host name given in the URL itself. + // It may be of the form "host:port". Host string // Form contains the parsed form data, including both the URL @@ -131,6 +139,12 @@ type Request struct { // The HTTP client ignores Form and uses Body instead. Form url.Values + // PostForm contains the parsed form data from POST or PUT + // body parameters. + // This field is only available after ParseForm is called. + // The HTTP client ignores PostForm and uses Body instead. + PostForm url.Values + // MultipartForm is the parsed multipart form, including file uploads. // This field is only available after ParseMultipartForm is called. // The HTTP client ignores MultipartForm and uses Body instead. @@ -317,11 +331,20 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err } // TODO(bradfitz): escape at least newlines in ruri? - bw := bufio.NewWriter(w) - fmt.Fprintf(bw, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri) + // Wrap the writer in a bufio Writer if it's not already buffered. + // Don't always call NewWriter, as that forces a bytes.Buffer + // and other small bufio Writers to have a minimum 4k buffer + // size. + var bw *bufio.Writer + if _, ok := w.(io.ByteWriter); !ok { + bw = bufio.NewWriter(w) + w = bw + } + + fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri) // Header lines - fmt.Fprintf(bw, "Host: %s\r\n", host) + fmt.Fprintf(w, "Host: %s\r\n", host) // Use the defaultUserAgent unless the Header contains one, which // may be blank to not send the header. @@ -332,7 +355,7 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err } } if userAgent != "" { - fmt.Fprintf(bw, "User-Agent: %s\r\n", userAgent) + fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent) } // Process Body,ContentLength,Close,Trailer @@ -340,65 +363,61 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err if err != nil { return err } - err = tw.WriteHeader(bw) + err = tw.WriteHeader(w) if err != nil { return err } // TODO: split long values? (If so, should share code with Conn.Write) - err = req.Header.WriteSubset(bw, reqWriteExcludeHeader) + err = req.Header.WriteSubset(w, reqWriteExcludeHeader) if err != nil { return err } if extraHeaders != nil { - err = extraHeaders.Write(bw) + err = extraHeaders.Write(w) if err != nil { return err } } - io.WriteString(bw, "\r\n") + io.WriteString(w, "\r\n") // Write body and trailer - err = tw.WriteBody(bw) + err = tw.WriteBody(w) if err != nil { return err } - return bw.Flush() -} - -// Convert decimal at s[i:len(s)] to integer, -// returning value, string position where the digits stopped, -// and whether there was a valid number (digits, not too big). -func atoi(s string, i int) (n, i1 int, ok bool) { - const Big = 1000000 - if i >= len(s) || s[i] < '0' || s[i] > '9' { - return 0, 0, false - } - n = 0 - for ; i < len(s) && '0' <= s[i] && s[i] <= '9'; i++ { - n = n*10 + int(s[i]-'0') - if n > Big { - return 0, 0, false - } + if bw != nil { + return bw.Flush() } - return n, i, true + return nil } // ParseHTTPVersion parses a HTTP version string. // "HTTP/1.0" returns (1, 0, true). func ParseHTTPVersion(vers string) (major, minor int, ok bool) { - if len(vers) < 5 || vers[0:5] != "HTTP/" { + const Big = 1000000 // arbitrary upper bound + switch vers { + case "HTTP/1.1": + return 1, 1, true + case "HTTP/1.0": + return 1, 0, true + } + if !strings.HasPrefix(vers, "HTTP/") { return 0, 0, false } - major, i, ok := atoi(vers, 5) - if !ok || i >= len(vers) || vers[i] != '.' { + dot := strings.Index(vers, ".") + if dot < 0 { return 0, 0, false } - minor, i, ok = atoi(vers, i+1) - if !ok || i != len(vers) { + major, err := strconv.Atoi(vers[5:dot]) + if err != nil || major < 0 || major > Big { + return 0, 0, false + } + minor, err = strconv.Atoi(vers[dot+1:]) + if err != nil || minor < 0 || minor > Big { return 0, 0, false } return major, minor, true @@ -426,10 +445,12 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { } if body != nil { switch v := body.(type) { - case *strings.Reader: - req.ContentLength = int64(v.Len()) case *bytes.Buffer: req.ContentLength = int64(v.Len()) + case *bytes.Reader: + req.ContentLength = int64(v.Len()) + case *strings.Reader: + req.ContentLength = int64(v.Len()) } } @@ -513,9 +534,9 @@ func ReadRequest(b *bufio.Reader) (req *Request, err error) { // the same. In the second case, any Host line is ignored. req.Host = req.URL.Host if req.Host == "" { - req.Host = req.Header.Get("Host") + req.Host = req.Header.get("Host") } - req.Header.Del("Host") + delete(req.Header, "Host") fixPragmaCacheControl(req.Header) @@ -594,66 +615,97 @@ func (l *maxBytesReader) Close() error { return l.r.Close() } -// ParseForm parses the raw query from the URL. +func copyValues(dst, src url.Values) { + for k, vs := range src { + for _, value := range vs { + dst.Add(k, value) + } + } +} + +func parsePostForm(r *Request) (vs url.Values, err error) { + if r.Body == nil { + err = errors.New("missing form body") + return + } + ct := r.Header.Get("Content-Type") + ct, _, err = mime.ParseMediaType(ct) + switch { + case ct == "application/x-www-form-urlencoded": + var reader io.Reader = r.Body + maxFormSize := int64(1<<63 - 1) + if _, ok := r.Body.(*maxBytesReader); !ok { + maxFormSize = int64(10 << 20) // 10 MB is a lot of text. + reader = io.LimitReader(r.Body, maxFormSize+1) + } + b, e := ioutil.ReadAll(reader) + if e != nil { + if err == nil { + err = e + } + break + } + if int64(len(b)) > maxFormSize { + err = errors.New("http: POST too large") + return + } + vs, e = url.ParseQuery(string(b)) + if err == nil { + err = e + } + case ct == "multipart/form-data": + // handled by ParseMultipartForm (which is calling us, or should be) + // TODO(bradfitz): there are too many possible + // orders to call too many functions here. + // Clean this up and write more tests. + // request_test.go contains the start of this, + // in TestRequestMultipartCallOrder. + } + return +} + +// ParseForm parses the raw query from the URL and updates r.Form. +// +// For POST or PUT requests, it also parses the request body as a form and +// put the results into both r.PostForm and r.Form. +// POST and PUT body parameters take precedence over URL query string values +// in r.Form. // -// For POST or PUT requests, it also parses the request body as a form. // If the request Body's size has not already been limited by MaxBytesReader, // the size is capped at 10MB. // // ParseMultipartForm calls ParseForm automatically. // It is idempotent. -func (r *Request) ParseForm() (err error) { - if r.Form != nil { - return - } - if r.URL != nil { - r.Form, err = url.ParseQuery(r.URL.RawQuery) +func (r *Request) ParseForm() error { + var err error + if r.PostForm == nil { + if r.Method == "POST" || r.Method == "PUT" { + r.PostForm, err = parsePostForm(r) + } + if r.PostForm == nil { + r.PostForm = make(url.Values) + } } - if r.Method == "POST" || r.Method == "PUT" { - if r.Body == nil { - return errors.New("missing form body") + if r.Form == nil { + if len(r.PostForm) > 0 { + r.Form = make(url.Values) + copyValues(r.Form, r.PostForm) } - ct := r.Header.Get("Content-Type") - ct, _, err = mime.ParseMediaType(ct) - switch { - case ct == "application/x-www-form-urlencoded": - var reader io.Reader = r.Body - maxFormSize := int64(1<<63 - 1) - if _, ok := r.Body.(*maxBytesReader); !ok { - maxFormSize = int64(10 << 20) // 10 MB is a lot of text. - reader = io.LimitReader(r.Body, maxFormSize+1) - } - b, e := ioutil.ReadAll(reader) - if e != nil { - if err == nil { - err = e - } - break - } - if int64(len(b)) > maxFormSize { - return errors.New("http: POST too large") - } - var newValues url.Values - newValues, e = url.ParseQuery(string(b)) + var newValues url.Values + if r.URL != nil { + var e error + newValues, e = url.ParseQuery(r.URL.RawQuery) if err == nil { err = e } - if r.Form == nil { - r.Form = make(url.Values) - } - // Copy values into r.Form. TODO: make this smoother. - for k, vs := range newValues { - for _, value := range vs { - r.Form.Add(k, value) - } - } - case ct == "multipart/form-data": - // handled by ParseMultipartForm (which is calling us, or should be) - // TODO(bradfitz): there are too many possible - // orders to call too many functions here. - // Clean this up and write more tests. - // request_test.go contains the start of this, - // in TestRequestMultipartCallOrder. + } + if newValues == nil { + newValues = make(url.Values) + } + if r.Form == nil { + r.Form = newValues + } else { + copyValues(r.Form, newValues) } } return err @@ -699,7 +751,9 @@ func (r *Request) ParseMultipartForm(maxMemory int64) error { } // FormValue returns the first value for the named component of the query. +// POST and PUT body parameters take precedence over URL query string values. // FormValue calls ParseMultipartForm and ParseForm if necessary. +// To access multiple values of the same key use ParseForm. func (r *Request) FormValue(key string) string { if r.Form == nil { r.ParseMultipartForm(defaultMaxMemory) @@ -710,6 +764,19 @@ func (r *Request) FormValue(key string) string { return "" } +// PostFormValue returns the first value for the named component of the POST +// or PUT request body. URL query parameters are ignored. +// PostFormValue calls ParseMultipartForm and ParseForm if necessary. +func (r *Request) PostFormValue(key string) string { + if r.PostForm == nil { + r.ParseMultipartForm(defaultMaxMemory) + } + if vs := r.PostForm[key]; len(vs) > 0 { + return vs[0] + } + return "" +} + // FormFile returns the first file for the provided form key. // FormFile calls ParseMultipartForm and ParseForm if necessary. func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, error) { @@ -732,12 +799,16 @@ func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, e } func (r *Request) expectsContinue() bool { - return strings.ToLower(r.Header.Get("Expect")) == "100-continue" + return hasToken(r.Header.get("Expect"), "100-continue") } func (r *Request) wantsHttp10KeepAlive() bool { if r.ProtoMajor != 1 || r.ProtoMinor != 0 { return false } - return strings.Contains(strings.ToLower(r.Header.Get("Connection")), "keep-alive") + return hasToken(r.Header.get("Connection"), "keep-alive") +} + +func (r *Request) wantsClose() bool { + return hasToken(r.Header.get("Connection"), "close") } diff --git a/src/pkg/net/http/request_test.go b/src/pkg/net/http/request_test.go index 6e00b9bfd..00ad791de 100644 --- a/src/pkg/net/http/request_test.go +++ b/src/pkg/net/http/request_test.go @@ -30,8 +30,8 @@ func TestQuery(t *testing.T) { } func TestPostQuery(t *testing.T) { - req, _ := NewRequest("POST", "http://www.google.com/search?q=foo&q=bar&both=x", - strings.NewReader("z=post&both=y")) + req, _ := NewRequest("POST", "http://www.google.com/search?q=foo&q=bar&both=x&prio=1&empty=not", + strings.NewReader("z=post&both=y&prio=2&empty=")) req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") if q := req.FormValue("q"); q != "foo" { @@ -40,8 +40,23 @@ func TestPostQuery(t *testing.T) { if z := req.FormValue("z"); z != "post" { t.Errorf(`req.FormValue("z") = %q, want "post"`, z) } - if both := req.Form["both"]; !reflect.DeepEqual(both, []string{"x", "y"}) { - t.Errorf(`req.FormValue("both") = %q, want ["x", "y"]`, both) + if bq, found := req.PostForm["q"]; found { + t.Errorf(`req.PostForm["q"] = %q, want no entry in map`, bq) + } + if bz := req.PostFormValue("z"); bz != "post" { + t.Errorf(`req.PostFormValue("z") = %q, want "post"`, bz) + } + if qs := req.Form["q"]; !reflect.DeepEqual(qs, []string{"foo", "bar"}) { + t.Errorf(`req.Form["q"] = %q, want ["foo", "bar"]`, qs) + } + if both := req.Form["both"]; !reflect.DeepEqual(both, []string{"y", "x"}) { + t.Errorf(`req.Form["both"] = %q, want ["y", "x"]`, both) + } + if prio := req.FormValue("prio"); prio != "2" { + t.Errorf(`req.FormValue("prio") = %q, want "2" (from body)`, prio) + } + if empty := req.FormValue("empty"); empty != "" { + t.Errorf(`req.FormValue("empty") = %q, want "" (from body)`, empty) } } @@ -76,6 +91,23 @@ func TestParseFormUnknownContentType(t *testing.T) { } } +func TestParseFormInitializeOnError(t *testing.T) { + nilBody, _ := NewRequest("POST", "http://www.google.com/search?q=foo", nil) + tests := []*Request{ + nilBody, + {Method: "GET", URL: nil}, + } + for i, req := range tests { + err := req.ParseForm() + if req.Form == nil { + t.Errorf("%d. Form not initialized, error %v", i, err) + } + if req.PostForm == nil { + t.Errorf("%d. PostForm not initialized, error %v", i, err) + } + } +} + func TestMultipartReader(t *testing.T) { req := &Request{ Method: "POST", @@ -129,7 +161,7 @@ func TestSetBasicAuth(t *testing.T) { } func TestMultipartRequest(t *testing.T) { - // Test that we can read the values and files of a + // Test that we can read the values and files of a // multipart request with FormValue and FormFile, // and that ParseMultipartForm can be called multiple times. req := newTestMultipartRequest(t) @@ -196,6 +228,75 @@ func TestReadRequestErrors(t *testing.T) { } } +func TestNewRequestHost(t *testing.T) { + req, err := NewRequest("GET", "http://localhost:1234/", nil) + if err != nil { + t.Fatal(err) + } + if req.Host != "localhost:1234" { + t.Errorf("Host = %q; want localhost:1234", req.Host) + } +} + +func TestNewRequestContentLength(t *testing.T) { + readByte := func(r io.Reader) io.Reader { + var b [1]byte + r.Read(b[:]) + return r + } + tests := []struct { + r io.Reader + want int64 + }{ + {bytes.NewReader([]byte("123")), 3}, + {bytes.NewBuffer([]byte("1234")), 4}, + {strings.NewReader("12345"), 5}, + // Not detected: + {struct{ io.Reader }{strings.NewReader("xyz")}, 0}, + {io.NewSectionReader(strings.NewReader("x"), 0, 6), 0}, + {readByte(io.NewSectionReader(strings.NewReader("xy"), 0, 6)), 0}, + } + for _, tt := range tests { + req, err := NewRequest("POST", "http://localhost/", tt.r) + if err != nil { + t.Fatal(err) + } + if req.ContentLength != tt.want { + t.Errorf("ContentLength(%T) = %d; want %d", tt.r, req.ContentLength, tt.want) + } + } +} + +type logWrites struct { + t *testing.T + dst *[]string +} + +func (l logWrites) WriteByte(c byte) error { + l.t.Fatalf("unexpected WriteByte call") + return nil +} + +func (l logWrites) Write(p []byte) (n int, err error) { + *l.dst = append(*l.dst, string(p)) + return len(p), nil +} + +func TestRequestWriteBufferedWriter(t *testing.T) { + got := []string{} + req, _ := NewRequest("GET", "http://foo.com/", nil) + req.Write(logWrites{t, &got}) + want := []string{ + "GET / HTTP/1.1\r\n", + "Host: foo.com\r\n", + "User-Agent: Go http package\r\n", + "\r\n", + } + if !reflect.DeepEqual(got, want) { + t.Errorf("Writes = %q\n Want = %q", got, want) + } +} + func testMissingFile(t *testing.T, req *Request) { f, fh, err := req.FormFile("missing") if f != nil { @@ -300,3 +401,81 @@ Content-Disposition: form-data; name="textb" ` + textbValue + ` --MyBoundary-- ` + +func benchmarkReadRequest(b *testing.B, request string) { + request = request + "\n" // final \n + request = strings.Replace(request, "\n", "\r\n", -1) // expand \n to \r\n + b.SetBytes(int64(len(request))) + r := bufio.NewReader(&infiniteReader{buf: []byte(request)}) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := ReadRequest(r) + if err != nil { + b.Fatalf("failed to read request: %v", err) + } + } +} + +// infiniteReader satisfies Read requests as if the contents of buf +// loop indefinitely. +type infiniteReader struct { + buf []byte + offset int +} + +func (r *infiniteReader) Read(b []byte) (int, error) { + n := copy(b, r.buf[r.offset:]) + r.offset = (r.offset + n) % len(r.buf) + return n, nil +} + +func BenchmarkReadRequestChrome(b *testing.B) { + // https://github.com/felixge/node-http-perf/blob/master/fixtures/get.http + benchmarkReadRequest(b, `GET / HTTP/1.1 +Host: localhost:8080 +Connection: keep-alive +Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8 +User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17 +Accept-Encoding: gzip,deflate,sdch +Accept-Language: en-US,en;q=0.8 +Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3 +Cookie: __utma=1.1978842379.1323102373.1323102373.1323102373.1; EPi:NumberOfVisits=1,2012-02-28T13:42:18; CrmSession=5b707226b9563e1bc69084d07a107c98; plushContainerWidth=100%25; plushNoTopMenu=0; hudson_auto_refresh=false +`) +} + +func BenchmarkReadRequestCurl(b *testing.B) { + // curl http://localhost:8080/ + benchmarkReadRequest(b, `GET / HTTP/1.1 +User-Agent: curl/7.27.0 +Host: localhost:8080 +Accept: */* +`) +} + +func BenchmarkReadRequestApachebench(b *testing.B) { + // ab -n 1 -c 1 http://localhost:8080/ + benchmarkReadRequest(b, `GET / HTTP/1.0 +Host: localhost:8080 +User-Agent: ApacheBench/2.3 +Accept: */* +`) +} + +func BenchmarkReadRequestSiege(b *testing.B) { + // siege -r 1 -c 1 http://localhost:8080/ + benchmarkReadRequest(b, `GET / HTTP/1.1 +Host: localhost:8080 +Accept: */* +Accept-Encoding: gzip +User-Agent: JoeDog/1.00 [en] (X11; I; Siege 2.70) +Connection: keep-alive +`) +} + +func BenchmarkReadRequestWrk(b *testing.B) { + // wrk -t 1 -r 1 -c 1 http://localhost:8080/ + benchmarkReadRequest(b, `GET / HTTP/1.1 +Host: localhost:8080 +`) +} diff --git a/src/pkg/net/http/requestwrite_test.go b/src/pkg/net/http/requestwrite_test.go index fc3186f0c..bc637f18b 100644 --- a/src/pkg/net/http/requestwrite_test.go +++ b/src/pkg/net/http/requestwrite_test.go @@ -328,6 +328,69 @@ var reqWriteTests = []reqWriteTest{ "User-Agent: Go http package\r\n" + "X-Foo: X-Bar\r\n\r\n", }, + + // If no Request.Host and no Request.URL.Host, we send + // an empty Host header, and don't use + // Request.Header["Host"]. This is just testing that + // we don't change Go 1.0 behavior. + { + Req: Request{ + Method: "GET", + Host: "", + URL: &url.URL{ + Scheme: "http", + Host: "", + Path: "/search", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{ + "Host": []string{"bad.example.com"}, + }, + }, + + WantWrite: "GET /search HTTP/1.1\r\n" + + "Host: \r\n" + + "User-Agent: Go http package\r\n\r\n", + }, + + // Opaque test #1 from golang.org/issue/4860 + { + Req: Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "www.google.com", + Opaque: "/%2F/%2F/", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + }, + + WantWrite: "GET /%2F/%2F/ HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go http package\r\n\r\n", + }, + + // Opaque test #2 from golang.org/issue/4860 + { + Req: Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "x.google.com", + Opaque: "//y.google.com/%2F/%2F/", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + }, + + WantWrite: "GET http://y.google.com/%2F/%2F/ HTTP/1.1\r\n" + + "Host: x.google.com\r\n" + + "User-Agent: Go http package\r\n\r\n", + }, } func TestRequestWrite(t *testing.T) { diff --git a/src/pkg/net/http/response.go b/src/pkg/net/http/response.go index 945ecd8a4..391ebbf6d 100644 --- a/src/pkg/net/http/response.go +++ b/src/pkg/net/http/response.go @@ -49,7 +49,7 @@ type Response struct { Body io.ReadCloser // ContentLength records the length of the associated content. The - // value -1 indicates that the length is unknown. Unless RequestMethod + // value -1 indicates that the length is unknown. Unless Request.Method // is "HEAD", values >= 0 indicate that the given number of bytes may // be read from Body. ContentLength int64 @@ -107,7 +107,6 @@ func ReadResponse(r *bufio.Reader, req *Request) (resp *Response, err error) { resp = new(Response) resp.Request = req - resp.Request.Method = strings.ToUpper(resp.Request.Method) // Parse the first line of the response. line, err := tp.ReadLine() @@ -179,7 +178,7 @@ func (r *Response) ProtoAtLeast(major, minor int) bool { // StatusCode // ProtoMajor // ProtoMinor -// RequestMethod +// Request.Method // TransferEncoding // Trailer // Body @@ -188,11 +187,6 @@ func (r *Response) ProtoAtLeast(major, minor int) bool { // func (r *Response) Write(w io.Writer) error { - // RequestMethod should be upper-case - if r.Request != nil { - r.Request.Method = strings.ToUpper(r.Request.Method) - } - // Status line text := r.Status if text == "" { @@ -204,9 +198,7 @@ func (r *Response) Write(w io.Writer) error { } protoMajor, protoMinor := strconv.Itoa(r.ProtoMajor), strconv.Itoa(r.ProtoMinor) statusCode := strconv.Itoa(r.StatusCode) + " " - if strings.HasPrefix(text, statusCode) { - text = text[len(statusCode):] - } + text = strings.TrimPrefix(text, statusCode) io.WriteString(w, "HTTP/"+protoMajor+"."+protoMinor+" "+statusCode+text+"\r\n") // Process Body,ContentLength,Close,Trailer diff --git a/src/pkg/net/http/response_test.go b/src/pkg/net/http/response_test.go index 6eed4887d..2f5f77369 100644 --- a/src/pkg/net/http/response_test.go +++ b/src/pkg/net/http/response_test.go @@ -124,7 +124,7 @@ var respTests = []respTest{ // Chunked response without Content-Length. { - "HTTP/1.0 200 OK\r\n" + + "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "0a\r\n" + @@ -137,12 +137,12 @@ var respTests = []respTest{ Response{ Status: "200 OK", StatusCode: 200, - Proto: "HTTP/1.0", + Proto: "HTTP/1.1", ProtoMajor: 1, - ProtoMinor: 0, + ProtoMinor: 1, Request: dummyReq("GET"), Header: Header{}, - Close: true, + Close: false, ContentLength: -1, TransferEncoding: []string{"chunked"}, }, @@ -152,24 +152,24 @@ var respTests = []respTest{ // Chunked response with Content-Length. { - "HTTP/1.0 200 OK\r\n" + + "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "Content-Length: 10\r\n" + "\r\n" + "0a\r\n" + - "Body here\n" + + "Body here\n\r\n" + "0\r\n" + "\r\n", Response{ Status: "200 OK", StatusCode: 200, - Proto: "HTTP/1.0", + Proto: "HTTP/1.1", ProtoMajor: 1, - ProtoMinor: 0, + ProtoMinor: 1, Request: dummyReq("GET"), Header: Header{}, - Close: true, + Close: false, ContentLength: -1, // TODO(rsc): Fix? TransferEncoding: []string{"chunked"}, }, @@ -177,23 +177,88 @@ var respTests = []respTest{ "Body here\n", }, - // Chunked response in response to a HEAD request (the "chunked" should - // be ignored, as HEAD responses never have bodies) + // Chunked response in response to a HEAD request { - "HTTP/1.0 200 OK\r\n" + + "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n", Response{ - Status: "200 OK", - StatusCode: 200, - Proto: "HTTP/1.0", - ProtoMajor: 1, - ProtoMinor: 0, - Request: dummyReq("HEAD"), - Header: Header{}, - Close: true, - ContentLength: 0, + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("HEAD"), + Header: Header{}, + TransferEncoding: []string{"chunked"}, + Close: false, + ContentLength: -1, + }, + + "", + }, + + // Content-Length in response to a HEAD request + { + "HTTP/1.0 200 OK\r\n" + + "Content-Length: 256\r\n" + + "\r\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("HEAD"), + Header: Header{"Content-Length": {"256"}}, + TransferEncoding: nil, + Close: true, + ContentLength: 256, + }, + + "", + }, + + // Content-Length in response to a HEAD request with HTTP/1.1 + { + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 256\r\n" + + "\r\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("HEAD"), + Header: Header{"Content-Length": {"256"}}, + TransferEncoding: nil, + Close: false, + ContentLength: 256, + }, + + "", + }, + + // No Content-Length or Chunked in response to a HEAD request + { + "HTTP/1.0 200 OK\r\n" + + "\r\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("HEAD"), + Header: Header{}, + TransferEncoding: nil, + Close: true, + ContentLength: -1, }, "", @@ -259,16 +324,37 @@ var respTests = []respTest{ "", }, + + // golang.org/issue/4767: don't special-case multipart/byteranges responses + { + `HTTP/1.1 206 Partial Content +Connection: close +Content-Type: multipart/byteranges; boundary=18a75608c8f47cef + +some body`, + Response{ + Status: "206 Partial Content", + StatusCode: 206, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("GET"), + Header: Header{ + "Content-Type": []string{"multipart/byteranges; boundary=18a75608c8f47cef"}, + }, + Close: true, + ContentLength: -1, + }, + + "some body", + }, } func TestReadResponse(t *testing.T) { - for i := range respTests { - tt := &respTests[i] - var braw bytes.Buffer - braw.WriteString(tt.Raw) - resp, err := ReadResponse(bufio.NewReader(&braw), tt.Resp.Request) + for i, tt := range respTests { + resp, err := ReadResponse(bufio.NewReader(strings.NewReader(tt.Raw)), tt.Resp.Request) if err != nil { - t.Errorf("#%d: %s", i, err) + t.Errorf("#%d: %v", i, err) continue } rbody := resp.Body @@ -276,7 +362,11 @@ func TestReadResponse(t *testing.T) { diff(t, fmt.Sprintf("#%d Response", i), resp, &tt.Resp) var bout bytes.Buffer if rbody != nil { - io.Copy(&bout, rbody) + _, err = io.Copy(&bout, rbody) + if err != nil { + t.Errorf("#%d: %v", i, err) + continue + } rbody.Close() } body := bout.String() @@ -286,6 +376,22 @@ func TestReadResponse(t *testing.T) { } } +func TestWriteResponse(t *testing.T) { + for i, tt := range respTests { + resp, err := ReadResponse(bufio.NewReader(strings.NewReader(tt.Raw)), tt.Resp.Request) + if err != nil { + t.Errorf("#%d: %v", i, err) + continue + } + bout := bytes.NewBuffer(nil) + err = resp.Write(bout) + if err != nil { + t.Errorf("#%d: %v", i, err) + continue + } + } +} + var readResponseCloseInMiddleTests = []struct { chunked, compressed bool }{ diff --git a/src/pkg/net/http/responsewrite_test.go b/src/pkg/net/http/responsewrite_test.go index f8e63acf4..5c10e2161 100644 --- a/src/pkg/net/http/responsewrite_test.go +++ b/src/pkg/net/http/responsewrite_test.go @@ -15,83 +15,83 @@ type respWriteTest struct { Raw string } -var respWriteTests = []respWriteTest{ - // HTTP/1.0, identity coding; no trailer - { - Response{ - StatusCode: 503, - ProtoMajor: 1, - ProtoMinor: 0, - Request: dummyReq("GET"), - Header: Header{}, - Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), - ContentLength: 6, - }, +func TestResponseWrite(t *testing.T) { + respWriteTests := []respWriteTest{ + // HTTP/1.0, identity coding; no trailer + { + Response{ + StatusCode: 503, + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{}, + Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), + ContentLength: 6, + }, - "HTTP/1.0 503 Service Unavailable\r\n" + - "Content-Length: 6\r\n\r\n" + - "abcdef", - }, - // Unchunked response without Content-Length. - { - Response{ - StatusCode: 200, - ProtoMajor: 1, - ProtoMinor: 0, - Request: dummyReq("GET"), - Header: Header{}, - Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), - ContentLength: -1, + "HTTP/1.0 503 Service Unavailable\r\n" + + "Content-Length: 6\r\n\r\n" + + "abcdef", }, - "HTTP/1.0 200 OK\r\n" + - "\r\n" + - "abcdef", - }, - // HTTP/1.1, chunked coding; empty trailer; close - { - Response{ - StatusCode: 200, - ProtoMajor: 1, - ProtoMinor: 1, - Request: dummyReq("GET"), - Header: Header{}, - Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), - ContentLength: 6, - TransferEncoding: []string{"chunked"}, - Close: true, + // Unchunked response without Content-Length. + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{}, + Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), + ContentLength: -1, + }, + "HTTP/1.0 200 OK\r\n" + + "\r\n" + + "abcdef", }, + // HTTP/1.1, chunked coding; empty trailer; close + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("GET"), + Header: Header{}, + Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), + ContentLength: 6, + TransferEncoding: []string{"chunked"}, + Close: true, + }, - "HTTP/1.1 200 OK\r\n" + - "Connection: close\r\n" + - "Transfer-Encoding: chunked\r\n\r\n" + - "6\r\nabcdef\r\n0\r\n\r\n", - }, + "HTTP/1.1 200 OK\r\n" + + "Connection: close\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "6\r\nabcdef\r\n0\r\n\r\n", + }, - // Header value with a newline character (Issue 914). - // Also tests removal of leading and trailing whitespace. - { - Response{ - StatusCode: 204, - ProtoMajor: 1, - ProtoMinor: 1, - Request: dummyReq("GET"), - Header: Header{ - "Foo": []string{" Bar\nBaz "}, + // Header value with a newline character (Issue 914). + // Also tests removal of leading and trailing whitespace. + { + Response{ + StatusCode: 204, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("GET"), + Header: Header{ + "Foo": []string{" Bar\nBaz "}, + }, + Body: nil, + ContentLength: 0, + TransferEncoding: []string{"chunked"}, + Close: true, }, - Body: nil, - ContentLength: 0, - TransferEncoding: []string{"chunked"}, - Close: true, - }, - "HTTP/1.1 204 No Content\r\n" + - "Connection: close\r\n" + - "Foo: Bar Baz\r\n" + - "\r\n", - }, -} + "HTTP/1.1 204 No Content\r\n" + + "Connection: close\r\n" + + "Foo: Bar Baz\r\n" + + "\r\n", + }, + } -func TestResponseWrite(t *testing.T) { for i := range respWriteTests { tt := &respWriteTests[i] var braw bytes.Buffer diff --git a/src/pkg/net/http/serve_test.go b/src/pkg/net/http/serve_test.go index b6a6b4c77..3300fef59 100644 --- a/src/pkg/net/http/serve_test.go +++ b/src/pkg/net/http/serve_test.go @@ -20,8 +20,13 @@ import ( "net/http/httputil" "net/url" "os" + "os/exec" "reflect" + "runtime" + "strconv" "strings" + "sync" + "sync/atomic" "syscall" "testing" "time" @@ -62,6 +67,7 @@ func (a dummyAddr) String() string { type testConn struct { readBuf bytes.Buffer writeBuf bytes.Buffer + closec chan bool // if non-nil, send value to it on close } func (c *testConn) Read(b []byte) (int, error) { @@ -73,6 +79,10 @@ func (c *testConn) Write(b []byte) (int, error) { } func (c *testConn) Close() error { + select { + case c.closec <- true: + default: + } return nil } @@ -168,13 +178,18 @@ var vtests = []struct { {"http://someHost.com/someDir/apage", "someHost.com/someDir"}, {"http://otherHost.com/someDir/apage", "someDir"}, {"http://otherHost.com/aDir/apage", "Default"}, + // redirections for trees + {"http://localhost/someDir", "/someDir/"}, + {"http://someHost.com/someDir", "/someDir/"}, } func TestHostHandlers(t *testing.T) { + defer checkLeakedTransports(t) + mux := NewServeMux() for _, h := range handlers { - Handle(h.pattern, stringHandler(h.msg)) + mux.Handle(h.pattern, stringHandler(h.msg)) } - ts := httptest.NewServer(nil) + ts := httptest.NewServer(mux) defer ts.Close() conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -199,9 +214,19 @@ func TestHostHandlers(t *testing.T) { t.Errorf("reading response: %v", err) continue } - s := r.Header.Get("Result") - if s != vt.expected { - t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected) + switch r.StatusCode { + case StatusOK: + s := r.Header.Get("Result") + if s != vt.expected { + t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected) + } + case StatusMovedPermanently: + s := r.Header.Get("Location") + if s != vt.expected { + t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected) + } + default: + t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode) } } } @@ -232,28 +257,22 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { } func TestServerTimeouts(t *testing.T) { - // TODO(bradfitz): convert this to use httptest.Server - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listen error: %v", err) - } - addr, _ := l.Addr().(*net.TCPAddr) - + defer checkLeakedTransports(t) reqNum := 0 - handler := HandlerFunc(func(res ResponseWriter, req *Request) { + ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { reqNum++ fmt.Fprintf(res, "req=%d", reqNum) - }) - - server := &Server{Handler: handler, ReadTimeout: 250 * time.Millisecond, WriteTimeout: 250 * time.Millisecond} - go server.Serve(l) - - url := fmt.Sprintf("http://%s/", addr) + })) + ts.Config.ReadTimeout = 250 * time.Millisecond + ts.Config.WriteTimeout = 250 * time.Millisecond + ts.Start() + defer ts.Close() // Hit the HTTP server successfully. tr := &Transport{DisableKeepAlives: true} // they interfere with this test + defer tr.CloseIdleConnections() c := &Client{Transport: tr} - r, err := c.Get(url) + r, err := c.Get(ts.URL) if err != nil { t.Fatalf("http Get #1: %v", err) } @@ -266,13 +285,13 @@ func TestServerTimeouts(t *testing.T) { // Slow client that should timeout. t1 := time.Now() - conn, err := net.Dial("tcp", addr.String()) + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) } buf := make([]byte, 1) n, err := conn.Read(buf) - latency := time.Now().Sub(t1) + latency := time.Since(t1) if n != 0 || err != io.EOF { t.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF) } @@ -283,7 +302,7 @@ func TestServerTimeouts(t *testing.T) { // Hit the HTTP server successfully again, verifying that the // previous slow connection didn't run our handler. (that we // get "req=2", not "req=3") - r, err = Get(url) + r, err = Get(ts.URL) if err != nil { t.Fatalf("http Get #2: %v", err) } @@ -293,11 +312,87 @@ func TestServerTimeouts(t *testing.T) { t.Errorf("Get #2 got %q, want %q", string(got), expected) } - l.Close() + if !testing.Short() { + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + go io.Copy(ioutil.Discard, conn) + for i := 0; i < 5; i++ { + _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n")) + if err != nil { + t.Fatalf("on write %d: %v", i, err) + } + time.Sleep(ts.Config.ReadTimeout / 2) + } + } +} + +// golang.org/issue/4741 -- setting only a write timeout that triggers +// shouldn't cause a handler to block forever on reads (next HTTP +// request) that will never happen. +func TestOnlyWriteTimeout(t *testing.T) { + defer checkLeakedTransports(t) + var conn net.Conn + var afterTimeoutErrc = make(chan error, 1) + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, req *Request) { + buf := make([]byte, 512<<10) + _, err := w.Write(buf) + if err != nil { + t.Errorf("handler Write error: %v", err) + return + } + conn.SetWriteDeadline(time.Now().Add(-30 * time.Second)) + _, err = w.Write(buf) + afterTimeoutErrc <- err + })) + ts.Listener = trackLastConnListener{ts.Listener, &conn} + ts.Start() + defer ts.Close() + + tr := &Transport{DisableKeepAlives: false} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + errc := make(chan error) + go func() { + res, err := c.Get(ts.URL) + if err != nil { + errc <- err + return + } + _, err = io.Copy(ioutil.Discard, res.Body) + errc <- err + }() + select { + case err := <-errc: + if err == nil { + t.Errorf("expected an error from Get request") + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for Get error") + } + if err := <-afterTimeoutErrc; err == nil { + t.Error("expected write error after timeout") + } +} + +// trackLastConnListener tracks the last net.Conn that was accepted. +type trackLastConnListener struct { + net.Listener + last *net.Conn // destination } -// TestIdentityResponse verifies that a handler can unset +func (l trackLastConnListener) Accept() (c net.Conn, err error) { + c, err = l.Listener.Accept() + *l.last = c + return +} + +// TestIdentityResponse verifies that a handler can unset func TestIdentityResponse(t *testing.T) { + defer checkLeakedTransports(t) handler := HandlerFunc(func(rw ResponseWriter, req *Request) { rw.Header().Set("Content-Length", "3") rw.Header().Set("Transfer-Encoding", req.FormValue("te")) @@ -343,10 +438,12 @@ func TestIdentityResponse(t *testing.T) { // Verify that ErrContentLength is returned url := ts.URL + "/?overwrite=1" - _, err := Get(url) + res, err := Get(url) if err != nil { t.Fatalf("error with Get of %s: %v", url, err) } + res.Body.Close() + // Verify that the connection is closed when the declared Content-Length // is larger than what the handler wrote. conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -370,7 +467,8 @@ func TestIdentityResponse(t *testing.T) { }) } -func testTcpConnectionCloses(t *testing.T, req string, h Handler) { +func testTCPConnectionCloses(t *testing.T, req string, h Handler) { + defer checkLeakedTransports(t) s := httptest.NewServer(h) defer s.Close() @@ -386,17 +484,18 @@ func testTcpConnectionCloses(t *testing.T, req string, h Handler) { } r := bufio.NewReader(conn) - _, err = ReadResponse(r, &Request{Method: "GET"}) + res, err := ReadResponse(r, &Request{Method: "GET"}) if err != nil { t.Fatal("ReadResponse error:", err) } - success := make(chan bool) + didReadAll := make(chan bool, 1) go func() { select { case <-time.After(5 * time.Second): - t.Fatal("body not closed after 5s") - case <-success: + t.Error("body not closed after 5s") + return + case <-didReadAll: } }() @@ -404,32 +503,43 @@ func testTcpConnectionCloses(t *testing.T, req string, h Handler) { if err != nil { t.Fatal("read error:", err) } + didReadAll <- true - success <- true + if !res.Close { + t.Errorf("Response.Close = false; want true") + } } // TestServeHTTP10Close verifies that HTTP/1.0 requests won't be kept alive. func TestServeHTTP10Close(t *testing.T) { - testTcpConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "testdata/file") })) } +// TestClientCanClose verifies that clients can also force a connection to close. +func TestClientCanClose(t *testing.T) { + testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + // Nothing. + })) +} + // TestHandlersCanSetConnectionClose verifies that handlers can force a connection to close, // even for HTTP/1.1 requests. func TestHandlersCanSetConnectionClose11(t *testing.T) { - testTcpConnectionCloses(t, "GET / HTTP/1.1\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + testTCPConnectionCloses(t, "GET / HTTP/1.1\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Connection", "close") })) } func TestHandlersCanSetConnectionClose10(t *testing.T) { - testTcpConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + testTCPConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Connection", "close") })) } func TestSetsRemoteAddr(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%s", r.RemoteAddr) })) @@ -450,11 +560,13 @@ func TestSetsRemoteAddr(t *testing.T) { } func TestChunkedResponseHeaders(t *testing.T) { + defer checkLeakedTransports(t) log.SetOutput(ioutil.Discard) // is noisy otherwise defer log.SetOutput(os.Stderr) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted + w.(Flusher).Flush() fmt.Fprintf(w, "I am a chunked response.") })) defer ts.Close() @@ -463,6 +575,7 @@ func TestChunkedResponseHeaders(t *testing.T) { if err != nil { t.Fatalf("Get error: %v", err) } + defer res.Body.Close() if g, e := res.ContentLength, int64(-1); g != e { t.Errorf("expected ContentLength of %d; got %d", e, g) } @@ -478,6 +591,7 @@ func TestChunkedResponseHeaders(t *testing.T) { // chunking in their response headers and aren't allowed to produce // output. func Test304Responses(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusNotModified) _, err := w.Write([]byte("illegal body")) @@ -507,6 +621,7 @@ func Test304Responses(t *testing.T) { // allowed to produce output, and don't set a Content-Type since // the real type of the body data cannot be inferred. func TestHeadResponses(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { _, err := w.Write([]byte("Ignored body")) if err != ErrBodyNotAllowed { @@ -541,6 +656,7 @@ func TestHeadResponses(t *testing.T) { } func TestTLSHandshakeTimeout(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) ts.Config.ReadTimeout = 250 * time.Millisecond ts.StartTLS() @@ -560,6 +676,7 @@ func TestTLSHandshakeTimeout(t *testing.T) { } func TestTLSServer(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.TLS != nil { w.Header().Set("X-TLS-Set", "true") @@ -642,6 +759,7 @@ var serverExpectTests = []serverExpectTest{ // Tests that the server responds to the "Expect" request header // correctly. func TestServerExpect(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { // Note using r.FormValue("readbody") because for POST // requests that would read from r.Body, which we only @@ -661,30 +779,51 @@ func TestServerExpect(t *testing.T) { t.Fatalf("Dial: %v", err) } defer conn.Close() - sendf := func(format string, args ...interface{}) { - _, err := fmt.Fprintf(conn, format, args...) - if err != nil { - t.Fatalf("On test %#v, error writing %q: %v", test, format, err) - } - } + + // Only send the body immediately if we're acting like an HTTP client + // that doesn't send 100-continue expectations. + writeBody := test.contentLength > 0 && strings.ToLower(test.expectation) != "100-continue" + go func() { - sendf("POST /?readbody=%v HTTP/1.1\r\n"+ + _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+ "Connection: close\r\n"+ "Content-Length: %d\r\n"+ "Expect: %s\r\nHost: foo\r\n\r\n", test.readBody, test.contentLength, test.expectation) - if test.contentLength > 0 && strings.ToLower(test.expectation) != "100-continue" { + if err != nil { + t.Errorf("On test %#v, error writing request headers: %v", test, err) + return + } + if writeBody { body := strings.Repeat("A", test.contentLength) - sendf(body) + _, err = fmt.Fprint(conn, body) + if err != nil { + if !test.readBody { + // Server likely already hung up on us. + // See larger comment below. + t.Logf("On test %#v, acceptable error writing request body: %v", test, err) + return + } + t.Errorf("On test %#v, error writing request body: %v", test, err) + } } }() bufr := bufio.NewReader(conn) line, err := bufr.ReadString('\n') if err != nil { - t.Fatalf("ReadString: %v", err) + if writeBody && !test.readBody { + // This is an acceptable failure due to a possible TCP race: + // We were still writing data and the server hung up on us. A TCP + // implementation may send a RST if our request body data was known + // to be lost, which may trigger our reads to fail. + // See RFC 1122 page 88. + t.Logf("On test %#v, acceptable error from ReadString: %v", test, err) + return + } + t.Fatalf("On test %#v, ReadString: %v", test, err) } if !strings.Contains(line, test.expectedResponse) { - t.Errorf("for test %#v got first line=%q", test, line) + t.Errorf("On test %#v, got first line = %q; want %q", test, line, test.expectedResponse) } } @@ -714,6 +853,7 @@ func TestServerUnreadRequestBodyLittle(t *testing.T) { t.Errorf("on request, read buffer length is %d; expected about 100 KB", conn.readBuf.Len()) } rw.WriteHeader(200) + rw.(Flusher).Flush() if g, e := conn.readBuf.Len(), 0; g != e { t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e) } @@ -736,27 +876,28 @@ func TestServerUnreadRequestBodyLarge(t *testing.T) { "Content-Length: %d\r\n"+ "\r\n", len(body)))) conn.readBuf.Write([]byte(body)) - - done := make(chan bool) + conn.closec = make(chan bool, 1) ls := &oneConnListener{conn} go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) { - defer close(done) if conn.readBuf.Len() < len(body)/2 { t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len()) } rw.WriteHeader(200) + rw.(Flusher).Flush() if conn.readBuf.Len() < len(body)/2 { t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len()) } - if c := rw.Header().Get("Connection"); c != "close" { - t.Errorf(`Connection header = %q; want "close"`, c) - } })) - <-done + <-conn.closec + + if res := conn.writeBuf.String(); !strings.Contains(res, "Connection: close") { + t.Errorf("Expected a Connection: close header; got response: %s", res) + } } func TestTimeoutHandler(t *testing.T) { + defer checkLeakedTransports(t) sendHi := make(chan bool, 1) writeErrors := make(chan error, 1) sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -831,6 +972,7 @@ func TestRedirectMunging(t *testing.T) { // the previous request's body, which is not optimal for zero-lengthed bodies, // as the client would then see http.ErrBodyReadAfterClose and not 0, io.EOF. func TestZeroLengthPostAndResponse(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { all, err := ioutil.ReadAll(r.Body) if err != nil { @@ -868,15 +1010,20 @@ func TestZeroLengthPostAndResponse(t *testing.T) { } } +func TestHandlerPanicNil(t *testing.T) { + testHandlerPanic(t, false, nil) +} + func TestHandlerPanic(t *testing.T) { - testHandlerPanic(t, false) + testHandlerPanic(t, false, "intentional death for testing") } func TestHandlerPanicWithHijack(t *testing.T) { - testHandlerPanic(t, true) + testHandlerPanic(t, true, "intentional death for testing") } -func testHandlerPanic(t *testing.T, withHijack bool) { +func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) { + defer checkLeakedTransports(t) // Unlike the other tests that set the log output to ioutil.Discard // to quiet the output, this test uses a pipe. The pipe serves three // purposes: @@ -896,6 +1043,7 @@ func testHandlerPanic(t *testing.T, withHijack bool) { pr, pw := io.Pipe() log.SetOutput(pw) defer log.SetOutput(os.Stderr) + defer pw.Close() ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if withHijack { @@ -905,7 +1053,7 @@ func testHandlerPanic(t *testing.T, withHijack bool) { } defer rwc.Close() } - panic("intentional death for testing") + panic(panicValue) })) defer ts.Close() @@ -917,8 +1065,8 @@ func testHandlerPanic(t *testing.T, withHijack bool) { buf := make([]byte, 4<<10) _, err := pr.Read(buf) pr.Close() - if err != nil { - t.Fatal(err) + if err != nil && err != io.EOF { + t.Error(err) } done <- true }() @@ -928,6 +1076,10 @@ func testHandlerPanic(t *testing.T, withHijack bool) { t.Logf("expected an error") } + if panicValue == nil { + return + } + select { case <-done: return @@ -937,6 +1089,7 @@ func testHandlerPanic(t *testing.T, withHijack bool) { } func TestNoDate(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header()["Date"] = nil })) @@ -952,6 +1105,7 @@ func TestNoDate(t *testing.T) { } func TestStripPrefix(t *testing.T) { + defer checkLeakedTransports(t) h := HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Path", r.URL.Path) }) @@ -965,6 +1119,7 @@ func TestStripPrefix(t *testing.T) { if g, e := res.Header.Get("X-Path"), "/bar"; g != e { t.Errorf("test 1: got %s, want %s", g, e) } + res.Body.Close() res, err = Get(ts.URL + "/bar") if err != nil { @@ -973,9 +1128,11 @@ func TestStripPrefix(t *testing.T) { if g, e := res.StatusCode, 404; g != e { t.Errorf("test 2: got status %v, want %v", g, e) } + res.Body.Close() } func TestRequestLimit(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { t.Fatalf("didn't expect to get request in Handler") })) @@ -992,6 +1149,7 @@ func TestRequestLimit(t *testing.T) { // we do support it (at least currently), so we expect a response below. t.Fatalf("Do: %v", err) } + defer res.Body.Close() if res.StatusCode != 413 { t.Fatalf("expected 413 response status; got: %d %s", res.StatusCode, res.Status) } @@ -1013,11 +1171,12 @@ type countReader struct { func (cr countReader) Read(p []byte) (n int, err error) { n, err = cr.r.Read(p) - *cr.n += int64(n) + atomic.AddInt64(cr.n, int64(n)) return } func TestRequestBodyLimit(t *testing.T) { + defer checkLeakedTransports(t) const limit = 1 << 20 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { r.Body = MaxBytesReader(w, r.Body, limit) @@ -1031,8 +1190,8 @@ func TestRequestBodyLimit(t *testing.T) { })) defer ts.Close() - nWritten := int64(0) - req, _ := NewRequest("POST", ts.URL, io.LimitReader(countReader{neverEnding('a'), &nWritten}, limit*200)) + nWritten := new(int64) + req, _ := NewRequest("POST", ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200)) // Send the POST, but don't care it succeeds or not. The // remote side is going to reply and then close the TCP @@ -1045,7 +1204,7 @@ func TestRequestBodyLimit(t *testing.T) { // the remote side hung up on us before we wrote too much. _, _ = DefaultClient.Do(req) - if nWritten > limit*100 { + if atomic.LoadInt64(nWritten) > limit*100 { t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d", limit, nWritten) } @@ -1054,6 +1213,7 @@ func TestRequestBodyLimit(t *testing.T) { // TestClientWriteShutdown tests that if the client shuts down the write // side of their TCP connection, the server doesn't send a 400 Bad Request. func TestClientWriteShutdown(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) defer ts.Close() conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -1086,28 +1246,207 @@ func TestClientWriteShutdown(t *testing.T) { // Tests that chunked server responses that write 1 byte at a time are // buffered before chunk headers are added, not after chunk headers. func TestServerBufferedChunking(t *testing.T) { - if true { - t.Logf("Skipping known broken test; see Issue 2357") - return - } conn := new(testConn) conn.readBuf.Write([]byte("GET / HTTP/1.1\r\n\r\n")) - done := make(chan bool) + conn.closec = make(chan bool, 1) ls := &oneConnListener{conn} go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) { - defer close(done) - rw.Header().Set("Content-Type", "text/plain") // prevent sniffing, which buffers + rw.(Flusher).Flush() // force the Header to be sent, in chunking mode, not counting the length rw.Write([]byte{'x'}) rw.Write([]byte{'y'}) rw.Write([]byte{'z'}) })) - <-done + <-conn.closec if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) { t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q", conn.writeBuf.Bytes()) } } +// Tests that the server flushes its response headers out when it's +// ignoring the response body and waits a bit before forcefully +// closing the TCP connection, causing the client to get a RST. +// See http://golang.org/issue/3595 +func TestServerGracefulClose(t *testing.T) { + defer checkLeakedTransports(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + Error(w, "bye", StatusUnauthorized) + })) + defer ts.Close() + + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + const bodySize = 5 << 20 + req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize)) + for i := 0; i < bodySize; i++ { + req = append(req, 'x') + } + writeErr := make(chan error) + go func() { + _, err := conn.Write(req) + writeErr <- err + }() + br := bufio.NewReader(conn) + lineNum := 0 + for { + line, err := br.ReadString('\n') + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("ReadLine: %v", err) + } + lineNum++ + if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") { + t.Errorf("Response line = %q; want a 401", line) + } + } + // Wait for write to finish. This is a broken pipe on both + // Darwin and Linux, but checking this isn't the point of + // the test. + <-writeErr +} + +func TestCaseSensitiveMethod(t *testing.T) { + defer checkLeakedTransports(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method != "get" { + t.Errorf(`Got method %q; want "get"`, r.Method) + } + })) + defer ts.Close() + req, _ := NewRequest("get", ts.URL, nil) + res, err := DefaultClient.Do(req) + if err != nil { + t.Error(err) + return + } + res.Body.Close() +} + +// TestContentLengthZero tests that for both an HTTP/1.0 and HTTP/1.1 +// request (both keep-alive), when a Handler never writes any +// response, the net/http package adds a "Content-Length: 0" response +// header. +func TestContentLengthZero(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {})) + defer ts.Close() + + for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} { + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("error dialing: %v", err) + } + _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version) + if err != nil { + t.Fatalf("error writing: %v", err) + } + req, _ := NewRequest("GET", "/", nil) + res, err := ReadResponse(bufio.NewReader(conn), req) + if err != nil { + t.Fatalf("error reading response: %v", err) + } + if te := res.TransferEncoding; len(te) > 0 { + t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te) + } + if cl := res.ContentLength; cl != 0 { + t.Errorf("For version %q, Content-Length = %v; want 0", version, cl) + } + conn.Close() + } +} + +func TestCloseNotifier(t *testing.T) { + gotReq := make(chan bool, 1) + sawClose := make(chan bool, 1) + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + gotReq <- true + cc := rw.(CloseNotifier).CloseNotify() + <-cc + sawClose <- true + })) + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("error dialing: %v", err) + } + diec := make(chan bool) + go func() { + _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n") + if err != nil { + t.Fatal(err) + } + <-diec + conn.Close() + }() +For: + for { + select { + case <-gotReq: + diec <- true + case <-sawClose: + break For + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } + } + ts.Close() +} + +func TestOptions(t *testing.T) { + uric := make(chan string, 2) // only expect 1, but leave space for 2 + mux := NewServeMux() + mux.HandleFunc("/", func(w ResponseWriter, r *Request) { + uric <- r.RequestURI + }) + ts := httptest.NewServer(mux) + defer ts.Close() + + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // An OPTIONS * request should succeed. + _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n")) + if err != nil { + t.Fatal(err) + } + br := bufio.NewReader(conn) + res, err := ReadResponse(br, &Request{Method: "OPTIONS"}) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != 200 { + t.Errorf("Got non-200 response to OPTIONS *: %#v", res) + } + + // A GET * request on a ServeMux should fail. + _, err = conn.Write([]byte("GET * HTTP/1.1\r\nHost: foo.com\r\n\r\n")) + if err != nil { + t.Fatal(err) + } + res, err = ReadResponse(br, &Request{Method: "GET"}) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != 400 { + t.Errorf("Got non-400 response to GET *: %#v", res) + } + + res, err = Get(ts.URL + "/second") + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if got := <-uric; got != "/second" { + t.Errorf("Handler saw request for %q; want /second", got) + } +} + // goTimeout runs f, failing t if f takes more than ns to complete. func goTimeout(t *testing.T, d time.Duration, f func()) { ch := make(chan bool, 2) @@ -1184,3 +1523,100 @@ func BenchmarkClientServer(b *testing.B) { b.StopTimer() } + +func BenchmarkClientServerParallel4(b *testing.B) { + benchmarkClientServerParallel(b, 4) +} + +func BenchmarkClientServerParallel64(b *testing.B) { + benchmarkClientServerParallel(b, 64) +} + +func benchmarkClientServerParallel(b *testing.B, conc int) { + b.StopTimer() + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + fmt.Fprintf(rw, "Hello world.\n") + })) + defer ts.Close() + b.StartTimer() + + numProcs := runtime.GOMAXPROCS(-1) * conc + var wg sync.WaitGroup + wg.Add(numProcs) + n := int32(b.N) + for p := 0; p < numProcs; p++ { + go func() { + for atomic.AddInt32(&n, -1) >= 0 { + res, err := Get(ts.URL) + if err != nil { + b.Logf("Get: %v", err) + continue + } + all, err := ioutil.ReadAll(res.Body) + if err != nil { + b.Logf("ReadAll: %v", err) + continue + } + body := string(all) + if body != "Hello world.\n" { + panic("Got body: " + body) + } + } + wg.Done() + }() + } + wg.Wait() +} + +// A benchmark for profiling the server without the HTTP client code. +// The client code runs in a subprocess. +// +// For use like: +// $ go test -c +// $ ./http.test -test.run=XX -test.bench=BenchmarkServer -test.benchtime=15s -test.cpuprofile=http.prof +// $ go tool pprof http.test http.prof +// (pprof) web +func BenchmarkServer(b *testing.B) { + // Child process mode; + if url := os.Getenv("TEST_BENCH_SERVER_URL"); url != "" { + n, err := strconv.Atoi(os.Getenv("TEST_BENCH_CLIENT_N")) + if err != nil { + panic(err) + } + for i := 0; i < n; i++ { + res, err := Get(url) + if err != nil { + log.Panicf("Get: %v", err) + } + all, err := ioutil.ReadAll(res.Body) + if err != nil { + log.Panicf("ReadAll: %v", err) + } + body := string(all) + if body != "Hello world.\n" { + log.Panicf("Got body: %q", body) + } + } + os.Exit(0) + return + } + + var res = []byte("Hello world.\n") + b.StopTimer() + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + rw.Header().Set("Content-Type", "text/html; charset=utf-8") + rw.Write(res) + })) + defer ts.Close() + b.StartTimer() + + cmd := exec.Command(os.Args[0], "-test.run=XXXX", "-test.bench=BenchmarkServer") + cmd.Env = append([]string{ + fmt.Sprintf("TEST_BENCH_CLIENT_N=%d", b.N), + fmt.Sprintf("TEST_BENCH_SERVER_URL=%s", ts.URL), + }, os.Environ()...) + out, err := cmd.CombinedOutput() + if err != nil { + b.Errorf("Test failure: %v, with output: %s", err, out) + } +} diff --git a/src/pkg/net/http/server.go b/src/pkg/net/http/server.go index 0572b4ae3..b6ab78228 100644 --- a/src/pkg/net/http/server.go +++ b/src/pkg/net/http/server.go @@ -11,7 +11,6 @@ package http import ( "bufio" - "bytes" "crypto/tls" "errors" "fmt" @@ -21,7 +20,7 @@ import ( "net" "net/url" "path" - "runtime/debug" + "runtime" "strconv" "strings" "sync" @@ -94,30 +93,188 @@ type Hijacker interface { Hijack() (net.Conn, *bufio.ReadWriter, error) } +// The CloseNotifier interface is implemented by ResponseWriters which +// allow detecting when the underlying connection has gone away. +// +// This mechanism can be used to cancel long operations on the server +// if the client has disconnected before the response is ready. +type CloseNotifier interface { + // CloseNotify returns a channel that receives a single value + // when the client connection has gone away. + CloseNotify() <-chan bool +} + // A conn represents the server side of an HTTP connection. type conn struct { remoteAddr string // network address of remote side server *Server // the Server on which the connection arrived rwc net.Conn // i/o connection - lr *io.LimitedReader // io.LimitReader(rwc) - buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->rwc - hijacked bool // connection has been hijacked by handler + sr switchReader // where the LimitReader reads from; usually the rwc + lr *io.LimitedReader // io.LimitReader(sr) + buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->sr->rwc tlsState *tls.ConnectionState // or nil when not using TLS - body []byte + + mu sync.Mutex // guards the following + clientGone bool // if client has disconnected mid-request + closeNotifyc chan bool // made lazily + hijackedv bool // connection has been hijacked by handler +} + +func (c *conn) hijacked() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.hijackedv +} + +func (c *conn) hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.hijackedv { + return nil, nil, ErrHijacked + } + if c.closeNotifyc != nil { + return nil, nil, errors.New("http: Hijack is incompatible with use of CloseNotifier") + } + c.hijackedv = true + rwc = c.rwc + buf = c.buf + c.rwc = nil + c.buf = nil + return +} + +func (c *conn) closeNotify() <-chan bool { + c.mu.Lock() + defer c.mu.Unlock() + if c.closeNotifyc == nil { + c.closeNotifyc = make(chan bool) + if c.hijackedv { + // to obey the function signature, even though + // it'll never receive a value. + return c.closeNotifyc + } + pr, pw := io.Pipe() + + readSource := c.sr.r + c.sr.Lock() + c.sr.r = pr + c.sr.Unlock() + go func() { + _, err := io.Copy(pw, readSource) + if err == nil { + err = io.EOF + } + pw.CloseWithError(err) + c.noteClientGone() + }() + } + return c.closeNotifyc +} + +func (c *conn) noteClientGone() { + c.mu.Lock() + defer c.mu.Unlock() + if c.closeNotifyc != nil && !c.clientGone { + c.closeNotifyc <- true + } + c.clientGone = true +} + +type switchReader struct { + sync.Mutex + r io.Reader +} + +func (sr *switchReader) Read(p []byte) (n int, err error) { + sr.Lock() + r := sr.r + sr.Unlock() + return r.Read(p) +} + +// This should be >= 512 bytes for DetectContentType, +// but otherwise it's somewhat arbitrary. +const bufferBeforeChunkingSize = 2048 + +// chunkWriter writes to a response's conn buffer, and is the writer +// wrapped by the response.bufw buffered writer. +// +// chunkWriter also is responsible for finalizing the Header, including +// conditionally setting the Content-Type and setting a Content-Length +// in cases where the handler's final output is smaller than the buffer +// size. It also conditionally adds chunk headers, when in chunking mode. +// +// See the comment above (*response).Write for the entire write flow. +type chunkWriter struct { + res *response + header Header // a deep copy of r.Header, once WriteHeader is called + wroteHeader bool // whether the header's been sent + + // set by the writeHeader method: + chunking bool // using chunked transfer encoding for reply body +} + +var crlf = []byte("\r\n") + +func (cw *chunkWriter) Write(p []byte) (n int, err error) { + if !cw.wroteHeader { + cw.writeHeader(p) + } + if cw.chunking { + _, err = fmt.Fprintf(cw.res.conn.buf, "%x\r\n", len(p)) + if err != nil { + cw.res.conn.rwc.Close() + return + } + } + n, err = cw.res.conn.buf.Write(p) + if cw.chunking && err == nil { + _, err = cw.res.conn.buf.Write(crlf) + } + if err != nil { + cw.res.conn.rwc.Close() + } + return +} + +func (cw *chunkWriter) flush() { + if !cw.wroteHeader { + cw.writeHeader(nil) + } + cw.res.conn.buf.Flush() +} + +func (cw *chunkWriter) close() { + if !cw.wroteHeader { + cw.writeHeader(nil) + } + if cw.chunking { + // zero EOF chunk, trailer key/value pairs (currently + // unsupported in Go's server), followed by a blank + // line. + io.WriteString(cw.res.conn.buf, "0\r\n\r\n") + } } // A response represents the server side of an HTTP response. type response struct { conn *conn req *Request // request for this response - chunking bool // using chunked transfer encoding for reply body - wroteHeader bool // reply header has been written + wroteHeader bool // reply header has been (logically) written wroteContinue bool // 100 Continue response was written - header Header // reply header parameters - written int64 // number of bytes written in body - contentLength int64 // explicitly-declared Content-Length; or -1 - status int // status code passed to WriteHeader - needSniff bool // need to sniff to find Content-Type + + w *bufio.Writer // buffers output in chunks to chunkWriter + cw *chunkWriter + + // handlerHeader is the Header that Handlers get access to, + // which may be retained and mutated even after WriteHeader. + // handlerHeader is copied into cw.header at WriteHeader + // time, and privately mutated thereafter. + handlerHeader Header + + written int64 // number of bytes written in body + contentLength int64 // explicitly-declared Content-Length; or -1 + status int // status code passed to WriteHeader // close connection after this reply. set on request and // updated after response from handler if there's a @@ -127,12 +284,14 @@ type response struct { // requestBodyLimitHit is set by requestTooLarge when // maxBytesReader hits its max size. It is checked in - // WriteHeader, to make sure we don't consume the the + // WriteHeader, to make sure we don't consume the // remaining request body to try to advance to the next HTTP - // request. Instead, when this is set, we stop doing + // request. Instead, when this is set, we stop reading // subsequent requests on this connection and stop reading // input from it. requestBodyLimitHit bool + + handlerDone bool // set true when the handler exits } // requestTooLarge is called by maxBytesReader when too much input has @@ -145,42 +304,68 @@ func (w *response) requestTooLarge() { } } +// needsSniff returns whether a Content-Type still needs to be sniffed. +func (w *response) needsSniff() bool { + return !w.cw.wroteHeader && w.handlerHeader.Get("Content-Type") == "" && w.written < sniffLen +} + type writerOnly struct { io.Writer } func (w *response) ReadFrom(src io.Reader) (n int64, err error) { - // Call WriteHeader before checking w.chunking if it hasn't - // been called yet, since WriteHeader is what sets w.chunking. if !w.wroteHeader { w.WriteHeader(StatusOK) } - if !w.chunking && w.bodyAllowed() && !w.needSniff { - w.Flush() + + if w.needsSniff() { + n0, err := io.Copy(writerOnly{w}, io.LimitReader(src, sniffLen)) + n += n0 + if err != nil { + return n, err + } + } + + w.w.Flush() // get rid of any previous writes + w.cw.flush() // make sure Header is written; flush data to rwc + + // Now that cw has been flushed, its chunking field is guaranteed initialized. + if !w.cw.chunking && w.bodyAllowed() { if rf, ok := w.conn.rwc.(io.ReaderFrom); ok { - n, err = rf.ReadFrom(src) - w.written += n - return + n0, err := rf.ReadFrom(src) + n += n0 + w.written += n0 + return n, err } } + // Fall back to default io.Copy implementation. // Use wrapper to hide w.ReadFrom from io.Copy. - return io.Copy(writerOnly{w}, src) + n0, err := io.Copy(writerOnly{w}, src) + n += n0 + return n, err } // noLimit is an effective infinite upper bound for io.LimitedReader const noLimit int64 = (1 << 63) - 1 +// debugServerConnections controls whether all server connections are wrapped +// with a verbose logging wrapper. +const debugServerConnections = false + // Create new connection from rwc. func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) { c = new(conn) c.remoteAddr = rwc.RemoteAddr().String() c.server = srv c.rwc = rwc - c.body = make([]byte, sniffLen) - c.lr = io.LimitReader(rwc, noLimit).(*io.LimitedReader) + if debugServerConnections { + c.rwc = newLoggingConn("server", c.rwc) + } + c.sr = switchReader{r: c.rwc} + c.lr = io.LimitReader(&c.sr, noLimit).(*io.LimitedReader) br := bufio.NewReader(c.lr) - bw := bufio.NewWriter(rwc) + bw := bufio.NewWriter(c.rwc) c.buf = bufio.NewReadWriter(br, bw) return c, nil } @@ -207,9 +392,9 @@ type expectContinueReader struct { func (ecr *expectContinueReader) Read(p []byte) (n int, err error) { if ecr.closed { - return 0, errors.New("http: Read after Close on request Body") + return 0, ErrBodyReadAfterClose } - if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked { + if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked() { ecr.resp.wroteContinue = true io.WriteString(ecr.resp.conn.buf, "HTTP/1.1 100 Continue\r\n\r\n") ecr.resp.conn.buf.Flush() @@ -232,9 +417,19 @@ var errTooLarge = errors.New("http: request too large") // Read next request from connection. func (c *conn) readRequest() (w *response, err error) { - if c.hijacked { + if c.hijacked() { return nil, ErrHijacked } + + if d := c.server.ReadTimeout; d != 0 { + c.rwc.SetReadDeadline(time.Now().Add(d)) + } + if d := c.server.WriteTimeout; d != 0 { + defer func() { + c.rwc.SetWriteDeadline(time.Now().Add(d)) + }() + } + c.lr.N = int64(c.server.maxHeaderBytes()) + 4096 /* bufio slop */ var req *Request if req, err = ReadRequest(c.buf.Reader); err != nil { @@ -248,17 +443,20 @@ func (c *conn) readRequest() (w *response, err error) { req.RemoteAddr = c.remoteAddr req.TLS = c.tlsState - w = new(response) - w.conn = c - w.req = req - w.header = make(Header) - w.contentLength = -1 - c.body = c.body[:0] + w = &response{ + conn: c, + req: req, + handlerHeader: make(Header), + contentLength: -1, + cw: new(chunkWriter), + } + w.cw.res = w + w.w = bufio.NewWriterSize(w.cw, bufferBeforeChunkingSize) return w, nil } func (w *response) Header() Header { - return w.header + return w.handlerHeader } // maxPostHandlerReadBytes is the max number of Request.Body bytes not @@ -273,7 +471,7 @@ func (w *response) Header() Header { const maxPostHandlerReadBytes = 256 << 10 func (w *response) WriteHeader(code int) { - if w.conn.hijacked { + if w.conn.hijacked() { log.Print("http: response.WriteHeader on hijacked connection") return } @@ -284,31 +482,68 @@ func (w *response) WriteHeader(code int) { w.wroteHeader = true w.status = code - // Check for a explicit (and valid) Content-Length header. - var hasCL bool - var contentLength int64 - if clenStr := w.header.Get("Content-Length"); clenStr != "" { - var err error - contentLength, err = strconv.ParseInt(clenStr, 10, 64) - if err == nil { - hasCL = true + w.cw.header = w.handlerHeader.clone() + + if cl := w.cw.header.get("Content-Length"); cl != "" { + v, err := strconv.ParseInt(cl, 10, 64) + if err == nil && v >= 0 { + w.contentLength = v } else { - log.Printf("http: invalid Content-Length of %q sent", clenStr) - w.header.Del("Content-Length") + log.Printf("http: invalid Content-Length of %q", cl) + w.cw.header.Del("Content-Length") + } + } +} + +// writeHeader finalizes the header sent to the client and writes it +// to cw.res.conn.buf. +// +// p is not written by writeHeader, but is the first chunk of the body +// that will be written. It is sniffed for a Content-Type if none is +// set explicitly. It's also used to set the Content-Length, if the +// total body size was small and the handler has already finished +// running. +func (cw *chunkWriter) writeHeader(p []byte) { + if cw.wroteHeader { + return + } + cw.wroteHeader = true + + w := cw.res + code := w.status + done := w.handlerDone + + // If the handler is done but never sent a Content-Length + // response header and this is our first (and last) write, set + // it, even to zero. This helps HTTP/1.0 clients keep their + // "keep-alive" connections alive. + if done && cw.header.get("Content-Length") == "" && w.req.Method != "HEAD" { + w.contentLength = int64(len(p)) + cw.header.Set("Content-Length", strconv.Itoa(len(p))) + } + + // If this was an HTTP/1.0 request with keep-alive and we sent a + // Content-Length back, we can make this a keep-alive response ... + if w.req.wantsHttp10KeepAlive() { + sentLength := cw.header.get("Content-Length") != "" + if sentLength && cw.header.get("Connection") == "keep-alive" { + w.closeAfterReply = false } } + // Check for a explicit (and valid) Content-Length header. + hasCL := w.contentLength != -1 + if w.req.wantsHttp10KeepAlive() && (w.req.Method == "HEAD" || hasCL) { - _, connectionHeaderSet := w.header["Connection"] + _, connectionHeaderSet := cw.header["Connection"] if !connectionHeaderSet { - w.header.Set("Connection", "keep-alive") + cw.header.Set("Connection", "keep-alive") } - } else if !w.req.ProtoAtLeast(1, 1) { - // Client did not ask to keep connection alive. + } else if !w.req.ProtoAtLeast(1, 1) || w.req.wantsClose() { w.closeAfterReply = true } - if w.header.Get("Connection") == "close" { + if cw.header.get("Connection") == "close" { w.closeAfterReply = true } @@ -322,7 +557,7 @@ func (w *response) WriteHeader(code int) { n, _ := io.CopyN(ioutil.Discard, w.req.Body, maxPostHandlerReadBytes+1) if n >= maxPostHandlerReadBytes { w.requestTooLarge() - w.header.Set("Connection", "close") + cw.header.Set("Connection", "close") } else { w.req.Body.Close() } @@ -332,64 +567,67 @@ func (w *response) WriteHeader(code int) { if code == StatusNotModified { // Must not have body. for _, header := range []string{"Content-Type", "Content-Length", "Transfer-Encoding"} { - if w.header.Get(header) != "" { - // TODO: return an error if WriteHeader gets a return parameter - // or set a flag on w to make future Writes() write an error page? - // for now just log and drop the header. - log.Printf("http: StatusNotModified response with header %q defined", header) - w.header.Del(header) + // RFC 2616 section 10.3.5: "the response MUST NOT include other entity-headers" + if cw.header.get(header) != "" { + cw.header.Del(header) } } } else { // If no content type, apply sniffing algorithm to body. - if w.header.Get("Content-Type") == "" && w.req.Method != "HEAD" { - w.needSniff = true + if cw.header.get("Content-Type") == "" && w.req.Method != "HEAD" { + cw.header.Set("Content-Type", DetectContentType(p)) } } - if _, ok := w.header["Date"]; !ok { - w.Header().Set("Date", time.Now().UTC().Format(TimeFormat)) + if _, ok := cw.header["Date"]; !ok { + cw.header.Set("Date", time.Now().UTC().Format(TimeFormat)) } - te := w.header.Get("Transfer-Encoding") + te := cw.header.get("Transfer-Encoding") hasTE := te != "" if hasCL && hasTE && te != "identity" { // TODO: return an error if WriteHeader gets a return parameter // For now just ignore the Content-Length. log.Printf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d", - te, contentLength) - w.header.Del("Content-Length") + te, w.contentLength) + cw.header.Del("Content-Length") hasCL = false } if w.req.Method == "HEAD" || code == StatusNotModified { // do nothing + } else if code == StatusNoContent { + cw.header.Del("Transfer-Encoding") } else if hasCL { - w.contentLength = contentLength - w.header.Del("Transfer-Encoding") + cw.header.Del("Transfer-Encoding") } else if w.req.ProtoAtLeast(1, 1) { // HTTP/1.1 or greater: use chunked transfer encoding // to avoid closing the connection at EOF. // TODO: this blows away any custom or stacked Transfer-Encoding they // might have set. Deal with that as need arises once we have a valid // use case. - w.chunking = true - w.header.Set("Transfer-Encoding", "chunked") + cw.chunking = true + cw.header.Set("Transfer-Encoding", "chunked") } else { // HTTP version < 1.1: cannot do chunked transfer // encoding and we don't know the Content-Length so // signal EOF by closing connection. w.closeAfterReply = true - w.header.Del("Transfer-Encoding") // in case already set + cw.header.Del("Transfer-Encoding") // in case already set } // Cannot use Content-Length with non-identity Transfer-Encoding. - if w.chunking { - w.header.Del("Content-Length") + if cw.chunking { + cw.header.Del("Content-Length") } if !w.req.ProtoAtLeast(1, 0) { return } + + if w.closeAfterReply && !hasToken(cw.header.get("Connection"), "close") { + cw.header.Set("Connection", "close") + } + proto := "HTTP/1.0" if w.req.ProtoAtLeast(1, 1) { proto = "HTTP/1.1" @@ -400,37 +638,8 @@ func (w *response) WriteHeader(code int) { text = "status code " + codestring } io.WriteString(w.conn.buf, proto+" "+codestring+" "+text+"\r\n") - w.header.Write(w.conn.buf) - - // If we need to sniff the body, leave the header open. - // Otherwise, end it here. - if !w.needSniff { - io.WriteString(w.conn.buf, "\r\n") - } -} - -// sniff uses the first block of written data, -// stored in w.conn.body, to decide the Content-Type -// for the HTTP body. -func (w *response) sniff() { - if !w.needSniff { - return - } - w.needSniff = false - - data := w.conn.body - fmt.Fprintf(w.conn.buf, "Content-Type: %s\r\n\r\n", DetectContentType(data)) - - if len(data) == 0 { - return - } - if w.chunking { - fmt.Fprintf(w.conn.buf, "%x\r\n", len(data)) - } - _, err := w.conn.buf.Write(data) - if w.chunking && err == nil { - io.WriteString(w.conn.buf, "\r\n") - } + cw.header.Write(w.conn.buf) + w.conn.buf.Write(crlf) } // bodyAllowed returns true if a Write is allowed for this response type. @@ -442,8 +651,40 @@ func (w *response) bodyAllowed() bool { return w.status != StatusNotModified && w.req.Method != "HEAD" } +// The Life Of A Write is like this: +// +// Handler starts. No header has been sent. The handler can either +// write a header, or just start writing. Writing before sending a header +// sends an implicity empty 200 OK header. +// +// If the handler didn't declare a Content-Length up front, we either +// go into chunking mode or, if the handler finishes running before +// the chunking buffer size, we compute a Content-Length and send that +// in the header instead. +// +// Likewise, if the handler didn't set a Content-Type, we sniff that +// from the initial chunk of output. +// +// The Writers are wired together like: +// +// 1. *response (the ResponseWriter) -> +// 2. (*response).w, a *bufio.Writer of bufferBeforeChunkingSize bytes +// 3. chunkWriter.Writer (whose writeHeader finalizes Content-Length/Type) +// and which writes the chunk headers, if needed. +// 4. conn.buf, a bufio.Writer of default (4kB) bytes +// 5. the rwc, the net.Conn. +// +// TODO(bradfitz): short-circuit some of the buffering when the +// initial header contains both a Content-Type and Content-Length. +// Also short-circuit in (1) when the header's been sent and not in +// chunking mode, writing directly to (4) instead, if (2) has no +// buffered data. More generally, we could short-circuit from (1) to +// (3) even in chunking mode if the write size from (1) is over some +// threshold and nothing is in (2). The answer might be mostly making +// bufferBeforeChunkingSize smaller and having bufio's fast-paths deal +// with this instead. func (w *response) Write(data []byte) (n int, err error) { - if w.conn.hijacked { + if w.conn.hijacked() { log.Print("http: response.Write on hijacked connection") return 0, ErrHijacked } @@ -461,73 +702,20 @@ func (w *response) Write(data []byte) (n int, err error) { if w.contentLength != -1 && w.written > w.contentLength { return 0, ErrContentLength } - - var m int - if w.needSniff { - // We need to sniff the beginning of the output to - // determine the content type. Accumulate the - // initial writes in w.conn.body. - // Cap m so that append won't allocate. - m = cap(w.conn.body) - len(w.conn.body) - if m > len(data) { - m = len(data) - } - w.conn.body = append(w.conn.body, data[:m]...) - data = data[m:] - if len(data) == 0 { - // Copied everything into the buffer. - // Wait for next write. - return m, nil - } - - // Filled the buffer; more data remains. - // Sniff the content (flushes the buffer) - // and then proceed with the remainder - // of the data as a normal Write. - // Calling sniff clears needSniff. - w.sniff() - } - - // TODO(rsc): if chunking happened after the buffering, - // then there would be fewer chunk headers. - // On the other hand, it would make hijacking more difficult. - if w.chunking { - fmt.Fprintf(w.conn.buf, "%x\r\n", len(data)) // TODO(rsc): use strconv not fmt - } - n, err = w.conn.buf.Write(data) - if err == nil && w.chunking { - if n != len(data) { - err = io.ErrShortWrite - } - if err == nil { - io.WriteString(w.conn.buf, "\r\n") - } - } - - return m + n, err + return w.w.Write(data) } func (w *response) finishRequest() { - // If this was an HTTP/1.0 request with keep-alive and we sent a Content-Length - // back, we can make this a keep-alive response ... - if w.req.wantsHttp10KeepAlive() { - sentLength := w.header.Get("Content-Length") != "" - if sentLength && w.header.Get("Connection") == "keep-alive" { - w.closeAfterReply = false - } - } + w.handlerDone = true + if !w.wroteHeader { w.WriteHeader(StatusOK) } - if w.needSniff { - w.sniff() - } - if w.chunking { - io.WriteString(w.conn.buf, "0\r\n") - // trailer key/value pairs, followed by blank line - io.WriteString(w.conn.buf, "\r\n") - } + + w.w.Flush() + w.cw.close() w.conn.buf.Flush() + // Close the body, unless we're about to close the whole TCP connection // anyway. if !w.closeAfterReply { @@ -537,7 +725,7 @@ func (w *response) finishRequest() { w.req.MultipartForm.RemoveAll() } - if w.contentLength != -1 && w.contentLength != w.written { + if w.contentLength != -1 && w.bodyAllowed() && w.contentLength != w.written { // Did not write enough. Avoid getting out of sync. w.closeAfterReply = true } @@ -547,66 +735,114 @@ func (w *response) Flush() { if !w.wroteHeader { w.WriteHeader(StatusOK) } - w.sniff() - w.conn.buf.Flush() + w.w.Flush() + w.cw.flush() } -// Close the connection. -func (c *conn) close() { +func (c *conn) finalFlush() { if c.buf != nil { c.buf.Flush() c.buf = nil } +} + +// Close the connection. +func (c *conn) close() { + c.finalFlush() if c.rwc != nil { c.rwc.Close() c.rwc = nil } } +// rstAvoidanceDelay is the amount of time we sleep after closing the +// write side of a TCP connection before closing the entire socket. +// By sleeping, we increase the chances that the client sees our FIN +// and processes its final data before they process the subsequent RST +// from closing a connection with known unread data. +// This RST seems to occur mostly on BSD systems. (And Windows?) +// This timeout is somewhat arbitrary (~latency around the planet). +const rstAvoidanceDelay = 500 * time.Millisecond + +// closeWrite flushes any outstanding data and sends a FIN packet (if +// client is connected via TCP), signalling that we're done. We then +// pause for a bit, hoping the client processes it before `any +// subsequent RST. +// +// See http://golang.org/issue/3595 +func (c *conn) closeWriteAndWait() { + c.finalFlush() + if tcp, ok := c.rwc.(*net.TCPConn); ok { + tcp.CloseWrite() + } + time.Sleep(rstAvoidanceDelay) +} + +// validNPN returns whether the proto is not a blacklisted Next +// Protocol Negotiation protocol. Empty and built-in protocol types +// are blacklisted and can't be overridden with alternate +// implementations. +func validNPN(proto string) bool { + switch proto { + case "", "http/1.1", "http/1.0": + return false + } + return true +} + // Serve a new connection. func (c *conn) serve() { defer func() { - err := recover() - if err == nil { - return + if err := recover(); err != nil { + const size = 4096 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + log.Printf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf) } - - var buf bytes.Buffer - fmt.Fprintf(&buf, "http: panic serving %v: %v\n", c.remoteAddr, err) - buf.Write(debug.Stack()) - log.Print(buf.String()) - - if c.rwc != nil { // may be nil if connection hijacked - c.rwc.Close() + if !c.hijacked() { + c.close() } }() if tlsConn, ok := c.rwc.(*tls.Conn); ok { + if d := c.server.ReadTimeout; d != 0 { + c.rwc.SetReadDeadline(time.Now().Add(d)) + } + if d := c.server.WriteTimeout; d != 0 { + c.rwc.SetWriteDeadline(time.Now().Add(d)) + } if err := tlsConn.Handshake(); err != nil { - c.close() return } c.tlsState = new(tls.ConnectionState) *c.tlsState = tlsConn.ConnectionState() + if proto := c.tlsState.NegotiatedProtocol; validNPN(proto) { + if fn := c.server.TLSNextProto[proto]; fn != nil { + h := initNPNRequest{tlsConn, serverHandler{c.server}} + fn(c.server, tlsConn, h) + } + return + } } for { w, err := c.readRequest() if err != nil { - msg := "400 Bad Request" if err == errTooLarge { // Their HTTP client may or may not be // able to read this if we're // responding to them and hanging up // while they're still writing their // request. Undefined behavior. - msg = "413 Request Entity Too Large" + io.WriteString(c.rwc, "HTTP/1.1 413 Request Entity Too Large\r\n\r\n") + c.closeWriteAndWait() + break } else if err == io.EOF { break // Don't reply } else if neterr, ok := err.(net.Error); ok && neterr.Timeout() { break // Don't reply } - fmt.Fprintf(c.rwc, "HTTP/1.1 %s\r\n\r\n", msg) + io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\n\r\n") break } @@ -624,59 +860,59 @@ func (c *conn) serve() { break } req.Header.Del("Expect") - } else if req.Header.Get("Expect") != "" { - // TODO(bradfitz): let ServeHTTP handlers handle - // requests with non-standard expectation[s]? Seems - // theoretical at best, and doesn't fit into the - // current ServeHTTP model anyway. We'd need to - // make the ResponseWriter an optional - // "ExpectReplier" interface or something. - // - // For now we'll just obey RFC 2616 14.20 which says - // "If a server receives a request containing an - // Expect field that includes an expectation- - // extension that it does not support, it MUST - // respond with a 417 (Expectation Failed) status." - w.Header().Set("Connection", "close") - w.WriteHeader(StatusExpectationFailed) - w.finishRequest() + } else if req.Header.get("Expect") != "" { + w.sendExpectationFailed() break } - handler := c.server.Handler - if handler == nil { - handler = DefaultServeMux - } - // HTTP cannot have multiple simultaneous active requests.[*] // Until the server replies to this request, it can't read another, // so we might as well run the handler in this goroutine. // [*] Not strictly true: HTTP pipelining. We could let them all process // in parallel even if their responses need to be serialized. - handler.ServeHTTP(w, w.req) - if c.hijacked { + serverHandler{c.server}.ServeHTTP(w, w.req) + if c.hijacked() { return } w.finishRequest() if w.closeAfterReply { + if w.requestBodyLimitHit { + c.closeWriteAndWait() + } break } } - c.close() +} + +func (w *response) sendExpectationFailed() { + // TODO(bradfitz): let ServeHTTP handlers handle + // requests with non-standard expectation[s]? Seems + // theoretical at best, and doesn't fit into the + // current ServeHTTP model anyway. We'd need to + // make the ResponseWriter an optional + // "ExpectReplier" interface or something. + // + // For now we'll just obey RFC 2616 14.20 which says + // "If a server receives a request containing an + // Expect field that includes an expectation- + // extension that it does not support, it MUST + // respond with a 417 (Expectation Failed) status." + w.Header().Set("Connection", "close") + w.WriteHeader(StatusExpectationFailed) + w.finishRequest() } // Hijack implements the Hijacker.Hijack method. Our response is both a ResponseWriter // and a Hijacker. func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { - if w.conn.hijacked { - return nil, nil, ErrHijacked + if w.wroteHeader { + w.cw.flush() } - w.conn.hijacked = true - rwc = w.conn.rwc - buf = w.conn.buf - w.conn.rwc = nil - w.conn.buf = nil - return + return w.conn.hijack() +} + +func (w *response) CloseNotify() <-chan bool { + return w.conn.closeNotify() } // The HandlerFunc type is an adapter to allow the use of @@ -817,13 +1053,13 @@ func RedirectHandler(url string, code int) Handler { // patterns and calls the handler for the pattern that // most closely matches the URL. // -// Patterns named fixed, rooted paths, like "/favicon.ico", +// Patterns name fixed, rooted paths, like "/favicon.ico", // or rooted subtrees, like "/images/" (note the trailing slash). // Longer patterns take precedence over shorter ones, so that // if there are handlers registered for both "/images/" // and "/images/thumbnails/", the latter handler will be // called for paths beginning "/images/thumbnails/" and the -// former will receiver requests for any other paths in the +// former will receive requests for any other paths in the // "/images/" subtree. // // Patterns may optionally begin with a host name, restricting matches to @@ -836,13 +1072,15 @@ func RedirectHandler(url string, code int) Handler { // redirecting any request containing . or .. elements to an // equivalent .- and ..-free URL. type ServeMux struct { - mu sync.RWMutex - m map[string]muxEntry + mu sync.RWMutex + m map[string]muxEntry + hosts bool // whether any patterns contain hostnames } type muxEntry struct { explicit bool h Handler + pattern string } // NewServeMux allocates and returns a new ServeMux. @@ -883,8 +1121,7 @@ func cleanPath(p string) string { // Find a handler on a handler map given a path string // Most-specific (longest) pattern wins -func (mux *ServeMux) match(path string) Handler { - var h Handler +func (mux *ServeMux) match(path string) (h Handler, pattern string) { var n = 0 for k, v := range mux.m { if !pathMatch(k, path) { @@ -893,37 +1130,64 @@ func (mux *ServeMux) match(path string) Handler { if h == nil || len(k) > n { n = len(k) h = v.h + pattern = v.pattern + } + } + return +} + +// Handler returns the handler to use for the given request, +// consulting r.Method, r.Host, and r.URL.Path. It always returns +// a non-nil handler. If the path is not in its canonical form, the +// handler will be an internally-generated handler that redirects +// to the canonical path. +// +// Handler also returns the registered pattern that matches the +// request or, in the case of internally-generated redirects, +// the pattern that will match after following the redirect. +// +// If there is no registered handler that applies to the request, +// Handler returns a ``page not found'' handler and an empty pattern. +func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) { + if r.Method != "CONNECT" { + if p := cleanPath(r.URL.Path); p != r.URL.Path { + _, pattern = mux.handler(r.Host, p) + return RedirectHandler(p, StatusMovedPermanently), pattern } } - return h + + return mux.handler(r.Host, r.URL.Path) } -// handler returns the handler to use for the request r. -func (mux *ServeMux) handler(r *Request) Handler { +// handler is the main implementation of Handler. +// The path is known to be in canonical form, except for CONNECT methods. +func (mux *ServeMux) handler(host, path string) (h Handler, pattern string) { mux.mu.RLock() defer mux.mu.RUnlock() // Host-specific pattern takes precedence over generic ones - h := mux.match(r.Host + r.URL.Path) + if mux.hosts { + h, pattern = mux.match(host + path) + } if h == nil { - h = mux.match(r.URL.Path) + h, pattern = mux.match(path) } if h == nil { - h = NotFoundHandler() + h, pattern = NotFoundHandler(), "" } - return h + return } // ServeHTTP dispatches the request to the handler whose // pattern most closely matches the request URL. func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) { - // Clean path to canonical form and redirect. - if p := cleanPath(r.URL.Path); p != r.URL.Path { - w.Header().Set("Location", p) - w.WriteHeader(StatusMovedPermanently) + if r.RequestURI == "*" { + w.Header().Set("Connection", "close") + w.WriteHeader(StatusBadRequest) return } - mux.handler(r).ServeHTTP(w, r) + h, _ := mux.Handler(r) + h.ServeHTTP(w, r) } // Handle registers the handler for the given pattern. @@ -942,14 +1206,26 @@ func (mux *ServeMux) Handle(pattern string, handler Handler) { panic("http: multiple registrations for " + pattern) } - mux.m[pattern] = muxEntry{explicit: true, h: handler} + mux.m[pattern] = muxEntry{explicit: true, h: handler, pattern: pattern} + + if pattern[0] != '/' { + mux.hosts = true + } // Helpful behavior: // If pattern is /tree/, insert an implicit permanent redirect for /tree. // It can be overridden by an explicit registration. n := len(pattern) if n > 0 && pattern[n-1] == '/' && !mux.m[pattern[0:n-1]].explicit { - mux.m[pattern[0:n-1]] = muxEntry{h: RedirectHandler(pattern, StatusMovedPermanently)} + // If pattern contains a host name, strip it and use remaining + // path for redirect. + path := pattern + if pattern[0] != '/' { + // In pattern, at least the last character is a '/', so + // strings.Index can't be -1. + path = pattern[strings.Index(pattern, "/"):] + } + mux.m[pattern[0:n-1]] = muxEntry{h: RedirectHandler(path, StatusMovedPermanently), pattern: pattern} } } @@ -971,7 +1247,7 @@ func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { } // Serve accepts incoming HTTP connections on the listener l, -// creating a new service thread for each. The service threads +// creating a new service goroutine for each. The service goroutines // read requests and then call handler to reply to them. // Handler is typically nil, in which case the DefaultServeMux is used. func Serve(l net.Listener, handler Handler) error { @@ -987,6 +1263,32 @@ type Server struct { WriteTimeout time.Duration // maximum duration before timing out write of the response MaxHeaderBytes int // maximum size of request headers, DefaultMaxHeaderBytes if 0 TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS + + // TLSNextProto optionally specifies a function to take over + // ownership of the provided TLS connection when an NPN + // protocol upgrade has occured. The map key is the protocol + // name negotiated. The Handler argument should be used to + // handle HTTP requests and will initialize the Request's TLS + // and RemoteAddr if not already set. The connection is + // automatically closed when the function returns. + TLSNextProto map[string]func(*Server, *tls.Conn, Handler) +} + +// serverHandler delegates to either the server's Handler or +// DefaultServeMux and also handles "OPTIONS *" requests. +type serverHandler struct { + srv *Server +} + +func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) { + handler := sh.srv.Handler + if handler == nil { + handler = DefaultServeMux + } + if req.RequestURI == "*" && req.Method == "OPTIONS" { + handler = globalOptionsHandler{} + } + handler.ServeHTTP(rw, req) } // ListenAndServe listens on the TCP network address srv.Addr and then @@ -1005,7 +1307,7 @@ func (srv *Server) ListenAndServe() error { } // Serve accepts incoming connections on the Listener l, creating a -// new service thread for each. The service threads read requests and +// new service goroutine for each. The service goroutines read requests and // then call srv.Handler to reply to them. func (srv *Server) Serve(l net.Listener) error { defer l.Close() @@ -1029,12 +1331,6 @@ func (srv *Server) Serve(l net.Listener) error { return e } tempDelay = 0 - if srv.ReadTimeout != 0 { - rw.SetReadDeadline(time.Now().Add(srv.ReadTimeout)) - } - if srv.WriteTimeout != 0 { - rw.SetWriteDeadline(time.Now().Add(srv.WriteTimeout)) - } c, err := srv.newConn(rw) if err != nil { continue @@ -1150,7 +1446,7 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { // TimeoutHandler returns a Handler that runs h with the given time limit. // // The new Handler calls h.ServeHTTP to handle each request, but if a -// call runs for more than ns nanoseconds, the handler responds with +// call runs for longer than its time limit, the handler responds with // a 503 Service Unavailable error and the given message in its body. // (If msg is empty, a suitable default message will be sent.) // After such a timeout, writes by h to its ResponseWriter will return @@ -1180,7 +1476,7 @@ func (h *timeoutHandler) errorBody() string { } func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { - done := make(chan bool) + done := make(chan bool, 1) tw := &timeoutWriter{w: w} go func() { h.handler.ServeHTTP(tw, r) @@ -1232,3 +1528,86 @@ func (tw *timeoutWriter) WriteHeader(code int) { tw.mu.Unlock() tw.w.WriteHeader(code) } + +// globalOptionsHandler responds to "OPTIONS *" requests. +type globalOptionsHandler struct{} + +func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", "0") + if r.ContentLength != 0 { + // Read up to 4KB of OPTIONS body (as mentioned in the + // spec as being reserved for future use), but anything + // over that is considered a waste of server resources + // (or an attack) and we abort and close the connection, + // courtesy of MaxBytesReader's EOF behavior. + mb := MaxBytesReader(w, r.Body, 4<<10) + io.Copy(ioutil.Discard, mb) + } +} + +// eofReader is a non-nil io.ReadCloser that always returns EOF. +var eofReader = ioutil.NopCloser(strings.NewReader("")) + +// initNPNRequest is an HTTP handler that initializes certain +// uninitialized fields in its *Request. Such partially-initialized +// Requests come from NPN protocol handlers. +type initNPNRequest struct { + c *tls.Conn + h serverHandler +} + +func (h initNPNRequest) ServeHTTP(rw ResponseWriter, req *Request) { + if req.TLS == nil { + req.TLS = &tls.ConnectionState{} + *req.TLS = h.c.ConnectionState() + } + if req.Body == nil { + req.Body = eofReader + } + if req.RemoteAddr == "" { + req.RemoteAddr = h.c.RemoteAddr().String() + } + h.h.ServeHTTP(rw, req) +} + +// loggingConn is used for debugging. +type loggingConn struct { + name string + net.Conn +} + +var ( + uniqNameMu sync.Mutex + uniqNameNext = make(map[string]int) +) + +func newLoggingConn(baseName string, c net.Conn) net.Conn { + uniqNameMu.Lock() + defer uniqNameMu.Unlock() + uniqNameNext[baseName]++ + return &loggingConn{ + name: fmt.Sprintf("%s-%d", baseName, uniqNameNext[baseName]), + Conn: c, + } +} + +func (c *loggingConn) Write(p []byte) (n int, err error) { + log.Printf("%s.Write(%d) = ....", c.name, len(p)) + n, err = c.Conn.Write(p) + log.Printf("%s.Write(%d) = %d, %v", c.name, len(p), n, err) + return +} + +func (c *loggingConn) Read(p []byte) (n int, err error) { + log.Printf("%s.Read(%d) = ....", c.name, len(p)) + n, err = c.Conn.Read(p) + log.Printf("%s.Read(%d) = %d, %v", c.name, len(p), n, err) + return +} + +func (c *loggingConn) Close() (err error) { + log.Printf("%s.Close() = ...", c.name) + err = c.Conn.Close() + log.Printf("%s.Close() = %v", c.name, err) + return +} diff --git a/src/pkg/net/http/server_test.go b/src/pkg/net/http/server_test.go new file mode 100644 index 000000000..8b4e8c6d6 --- /dev/null +++ b/src/pkg/net/http/server_test.go @@ -0,0 +1,95 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "net/url" + "testing" +) + +var serveMuxRegister = []struct { + pattern string + h Handler +}{ + {"/dir/", serve(200)}, + {"/search", serve(201)}, + {"codesearch.google.com/search", serve(202)}, + {"codesearch.google.com/", serve(203)}, +} + +// serve returns a handler that sends a response with the given code. +func serve(code int) HandlerFunc { + return func(w ResponseWriter, r *Request) { + w.WriteHeader(code) + } +} + +var serveMuxTests = []struct { + method string + host string + path string + code int + pattern string +}{ + {"GET", "google.com", "/", 404, ""}, + {"GET", "google.com", "/dir", 301, "/dir/"}, + {"GET", "google.com", "/dir/", 200, "/dir/"}, + {"GET", "google.com", "/dir/file", 200, "/dir/"}, + {"GET", "google.com", "/search", 201, "/search"}, + {"GET", "google.com", "/search/", 404, ""}, + {"GET", "google.com", "/search/foo", 404, ""}, + {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"}, + {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"}, + {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"}, + {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"}, + {"GET", "images.google.com", "/search", 201, "/search"}, + {"GET", "images.google.com", "/search/", 404, ""}, + {"GET", "images.google.com", "/search/foo", 404, ""}, + {"GET", "google.com", "/../search", 301, "/search"}, + {"GET", "google.com", "/dir/..", 301, ""}, + {"GET", "google.com", "/dir/..", 301, ""}, + {"GET", "google.com", "/dir/./file", 301, "/dir/"}, + + // The /foo -> /foo/ redirect applies to CONNECT requests + // but the path canonicalization does not. + {"CONNECT", "google.com", "/dir", 301, "/dir/"}, + {"CONNECT", "google.com", "/../search", 404, ""}, + {"CONNECT", "google.com", "/dir/..", 200, "/dir/"}, + {"CONNECT", "google.com", "/dir/..", 200, "/dir/"}, + {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"}, +} + +func TestServeMuxHandler(t *testing.T) { + mux := NewServeMux() + for _, e := range serveMuxRegister { + mux.Handle(e.pattern, e.h) + } + + for _, tt := range serveMuxTests { + r := &Request{ + Method: tt.method, + Host: tt.host, + URL: &url.URL{ + Path: tt.path, + }, + } + h, pattern := mux.Handler(r) + cs := &codeSaver{h: Header{}} + h.ServeHTTP(cs, r) + if pattern != tt.pattern || cs.code != tt.code { + t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, cs.code, pattern, tt.code, tt.pattern) + } + } +} + +// A codeSaver is a ResponseWriter that saves the code passed to WriteHeader. +type codeSaver struct { + h Header + code int +} + +func (cs *codeSaver) Header() Header { return cs.h } +func (cs *codeSaver) Write(p []byte) (int, error) { return len(p), nil } +func (cs *codeSaver) WriteHeader(code int) { cs.code = code } diff --git a/src/pkg/net/http/transfer.go b/src/pkg/net/http/transfer.go index 9e9d84172..43c6023a3 100644 --- a/src/pkg/net/http/transfer.go +++ b/src/pkg/net/http/transfer.go @@ -87,10 +87,8 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) { // Sanitize Body,ContentLength,TransferEncoding if t.ResponseToHEAD { t.Body = nil - t.TransferEncoding = nil - // ContentLength is expected to hold Content-Length - if t.ContentLength < 0 { - return nil, ErrMissingContentLength + if chunked(t.TransferEncoding) { + t.ContentLength = -1 } } else { if !atLeastHTTP11 || t.Body == nil { @@ -122,9 +120,6 @@ func (t *transferWriter) shouldSendContentLength() bool { if t.ContentLength > 0 { return true } - if t.ResponseToHEAD { - return true - } // Many servers expect a Content-Length for these methods if t.Method == "POST" || t.Method == "PUT" { return true @@ -199,10 +194,11 @@ func (t *transferWriter) WriteBody(w io.Writer) (err error) { ncopy, err = io.Copy(w, t.Body) } else { ncopy, err = io.Copy(w, io.LimitReader(t.Body, t.ContentLength)) - nextra, err := io.Copy(ioutil.Discard, t.Body) if err != nil { return err } + var nextra int64 + nextra, err = io.Copy(ioutil.Discard, t.Body) ncopy += nextra } if err != nil { @@ -213,7 +209,7 @@ func (t *transferWriter) WriteBody(w io.Writer) (err error) { } } - if t.ContentLength != -1 && t.ContentLength != ncopy { + if !t.ResponseToHEAD && t.ContentLength != -1 && t.ContentLength != ncopy { return fmt.Errorf("http: Request.ContentLength=%d with Body length %d", t.ContentLength, ncopy) } @@ -294,10 +290,19 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { return err } - t.ContentLength, err = fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.TransferEncoding) + realLength, err := fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.TransferEncoding) if err != nil { return err } + if isResponse && t.RequestMethod == "HEAD" { + if n, err := parseContentLength(t.Header.get("Content-Length")); err != nil { + return err + } else { + t.ContentLength = n + } + } else { + t.ContentLength = realLength + } // Trailer t.Trailer, err = fixTrailer(t.Header, t.TransferEncoding) @@ -310,7 +315,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { // See RFC2616, section 4.4. switch msg.(type) { case *Response: - if t.ContentLength == -1 && + if realLength == -1 && !chunked(t.TransferEncoding) && bodyAllowedForStatus(t.StatusCode) { // Unbounded body. @@ -322,12 +327,16 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { // or close connection when finished, since multipart is not supported yet switch { case chunked(t.TransferEncoding): - t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close} - case t.ContentLength >= 0: + if noBodyExpected(t.RequestMethod) { + t.Body = &body{Reader: io.LimitReader(r, 0), closing: t.Close} + } else { + t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close} + } + case realLength >= 0: // TODO: limit the Content-Length. This is an easy DoS vector. - t.Body = &body{Reader: io.LimitReader(r, t.ContentLength), closing: t.Close} + t.Body = &body{Reader: io.LimitReader(r, realLength), closing: t.Close} default: - // t.ContentLength < 0, i.e. "Content-Length" not mentioned in header + // realLength < 0, i.e. "Content-Length" not mentioned in header if t.Close { // Close semantics (i.e. HTTP/1.0) t.Body = &body{Reader: r, closing: t.Close} @@ -371,12 +380,6 @@ func fixTransferEncoding(requestMethod string, header Header) ([]string, error) delete(header, "Transfer-Encoding") - // Head responses have no bodies, so the transfer encoding - // should be ignored. - if requestMethod == "HEAD" { - return nil, nil - } - encodings := strings.Split(raw[0], ",") te := make([]string, 0, len(encodings)) // TODO: Even though we only support "identity" and "chunked" @@ -432,11 +435,11 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header, } // Logic based on Content-Length - cl := strings.TrimSpace(header.Get("Content-Length")) + cl := strings.TrimSpace(header.get("Content-Length")) if cl != "" { - n, err := strconv.ParseInt(cl, 10, 64) - if err != nil || n < 0 { - return -1, &badStringError{"bad Content-Length", cl} + n, err := parseContentLength(cl) + if err != nil { + return -1, err } return n, nil } else { @@ -451,13 +454,6 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header, return 0, nil } - // Logic based on media type. The purpose of the following code is just - // to detect whether the unsupported "multipart/byteranges" is being - // used. A proper Content-Type parser is needed in the future. - if strings.Contains(strings.ToLower(header.Get("Content-Type")), "multipart/byteranges") { - return -1, ErrNotSupported - } - // Body-EOF logic based on other methods (like closing, or chunked coding) return -1, nil } @@ -469,14 +465,14 @@ func shouldClose(major, minor int, header Header) bool { if major < 1 { return true } else if major == 1 && minor == 0 { - if !strings.Contains(strings.ToLower(header.Get("Connection")), "keep-alive") { + if !strings.Contains(strings.ToLower(header.get("Connection")), "keep-alive") { return true } return false } else { // TODO: Should split on commas, toss surrounding white space, // and check each field. - if strings.ToLower(header.Get("Connection")) == "close" { + if strings.ToLower(header.get("Connection")) == "close" { header.Del("Connection") return true } @@ -486,7 +482,7 @@ func shouldClose(major, minor int, header Header) bool { // Parse the trailer header func fixTrailer(header Header, te []string) (Header, error) { - raw := header.Get("Trailer") + raw := header.get("Trailer") if raw == "" { return nil, nil } @@ -525,11 +521,11 @@ type body struct { res *response // response writer for server requests, else nil } -// ErrBodyReadAfterClose is returned when reading a Request Body after -// the body has been closed. This typically happens when the body is +// ErrBodyReadAfterClose is returned when reading a Request or Response +// Body after the body has been closed. This typically happens when the body is // read after an HTTP Handler calls WriteHeader or Write on its // ResponseWriter. -var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed request Body") +var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed Body") func (b *body) Read(p []byte) (n int, err error) { if b.closed { @@ -567,14 +563,22 @@ func seeUpcomingDoubleCRLF(r *bufio.Reader) bool { return false } +var errTrailerEOF = errors.New("http: unexpected EOF reading trailer") + func (b *body) readTrailer() error { // The common case, since nobody uses trailers. - buf, _ := b.r.Peek(2) + buf, err := b.r.Peek(2) if bytes.Equal(buf, singleCRLF) { b.r.ReadByte() b.r.ReadByte() return nil } + if len(buf) < 2 { + return errTrailerEOF + } + if err != nil { + return err + } // Make sure there's a header terminator coming up, to prevent // a DoS with an unbounded size Trailer. It's not easy to @@ -590,6 +594,9 @@ func (b *body) readTrailer() error { hdr, err := textproto.NewReader(b.r).ReadMIMEHeader() if err != nil { + if err == io.EOF { + return errTrailerEOF + } return err } switch rr := b.hdr.(type) { @@ -630,3 +637,18 @@ func (b *body) Close() error { } return nil } + +// parseContentLength trims whitespace from s and returns -1 if no value +// is set, or the value if it's >= 0. +func parseContentLength(cl string) (int64, error) { + cl = strings.TrimSpace(cl) + if cl == "" { + return -1, nil + } + n, err := strconv.ParseInt(cl, 10, 64) + if err != nil || n < 0 { + return 0, &badStringError{"bad Content-Length", cl} + } + return n, nil + +} diff --git a/src/pkg/net/http/transfer_test.go b/src/pkg/net/http/transfer_test.go new file mode 100644 index 000000000..8627a374c --- /dev/null +++ b/src/pkg/net/http/transfer_test.go @@ -0,0 +1,37 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bufio" + "strings" + "testing" +) + +func TestBodyReadBadTrailer(t *testing.T) { + b := &body{ + Reader: strings.NewReader("foobar"), + hdr: true, // force reading the trailer + r: bufio.NewReader(strings.NewReader("")), + } + buf := make([]byte, 7) + n, err := b.Read(buf[:3]) + got := string(buf[:n]) + if got != "foo" || err != nil { + t.Fatalf(`first Read = %d (%q), %v; want 3 ("foo")`, n, got, err) + } + + n, err = b.Read(buf[:]) + got = string(buf[:n]) + if got != "bar" || err != nil { + t.Fatalf(`second Read = %d (%q), %v; want 3 ("bar")`, n, got, err) + } + + n, err = b.Read(buf[:]) + got = string(buf[:n]) + if err == nil { + t.Errorf("final Read was successful (%q), expected error from trailer read", got) + } +} diff --git a/src/pkg/net/http/transport.go b/src/pkg/net/http/transport.go index 6efe191eb..685d7d56c 100644 --- a/src/pkg/net/http/transport.go +++ b/src/pkg/net/http/transport.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. // HTTP client implementation. See RFC 2616. -// +// // This is the low-level Transport implementation of RoundTripper. // The high-level interface is in client.go. @@ -24,13 +24,14 @@ import ( "os" "strings" "sync" + "time" ) // DefaultTransport is the default implementation of Transport and is -// used by DefaultClient. It establishes a new network connection for -// each call to Do and uses HTTP proxies as directed by the -// $HTTP_PROXY and $NO_PROXY (or $http_proxy and $no_proxy) -// environment variables. +// used by DefaultClient. It establishes network connections as needed +// and caches them for reuse by subsequent calls. It uses HTTP proxies +// as directed by the $HTTP_PROXY and $NO_PROXY (or $http_proxy and +// $no_proxy) environment variables. var DefaultTransport RoundTripper = &Transport{Proxy: ProxyFromEnvironment} // DefaultMaxIdleConnsPerHost is the default value of Transport's @@ -41,8 +42,11 @@ const DefaultMaxIdleConnsPerHost = 2 // https, and http proxies (for either http or https with CONNECT). // Transport can also cache connections for future re-use. type Transport struct { - lk sync.Mutex + idleMu sync.Mutex idleConn map[string][]*persistConn + reqMu sync.Mutex + reqConn map[*Request]*persistConn + altMu sync.RWMutex altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper // TODO: tunable on global max cached connections @@ -68,9 +72,15 @@ type Transport struct { DisableCompression bool // MaxIdleConnsPerHost, if non-zero, controls the maximum idle - // (keep-alive) to keep to keep per-host. If zero, + // (keep-alive) to keep per-host. If zero, // DefaultMaxIdleConnsPerHost is used. MaxIdleConnsPerHost int + + // ResponseHeaderTimeout, if non-zero, specifies the amount of + // time to wait for a server's response headers after fully + // writing the request (including its body, if any). This + // time does not include the time to read the response body. + ResponseHeaderTimeout time.Duration } // ProxyFromEnvironment returns the URL of the proxy to use for a @@ -88,7 +98,7 @@ func ProxyFromEnvironment(req *Request) (*url.URL, error) { return nil, nil } proxyURL, err := url.Parse(proxy) - if err != nil || proxyURL.Scheme == "" { + if err != nil || !strings.HasPrefix(proxyURL.Scheme, "http") { if u, err := url.Parse("http://" + proxy); err == nil { proxyURL = u err = nil @@ -131,17 +141,20 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) { return nil, errors.New("http: nil Request.Header") } if req.URL.Scheme != "http" && req.URL.Scheme != "https" { - t.lk.Lock() + t.altMu.RLock() var rt RoundTripper if t.altProto != nil { rt = t.altProto[req.URL.Scheme] } - t.lk.Unlock() + t.altMu.RUnlock() if rt == nil { return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme} } return rt.RoundTrip(req) } + if req.URL.Host == "" { + return nil, errors.New("http: no Host in request URL") + } treq := &transportRequest{Request: req} cm, err := t.connectMethodForRequest(treq) if err != nil { @@ -170,8 +183,8 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) { if scheme == "http" || scheme == "https" { panic("protocol " + scheme + " already registered") } - t.lk.Lock() - defer t.lk.Unlock() + t.altMu.Lock() + defer t.altMu.Unlock() if t.altProto == nil { t.altProto = make(map[string]RoundTripper) } @@ -186,17 +199,29 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) { // a "keep-alive" state. It does not interrupt any connections currently // in use. func (t *Transport) CloseIdleConnections() { - t.lk.Lock() - defer t.lk.Unlock() - if t.idleConn == nil { + t.idleMu.Lock() + m := t.idleConn + t.idleConn = nil + t.idleMu.Unlock() + if m == nil { return } - for _, conns := range t.idleConn { + for _, conns := range m { for _, pconn := range conns { pconn.close() } } - t.idleConn = make(map[string][]*persistConn) +} + +// CancelRequest cancels an in-flight request by closing its +// connection. +func (t *Transport) CancelRequest(req *Request) { + t.reqMu.Lock() + pc := t.reqConn[req] + t.reqMu.Unlock() + if pc != nil { + pc.conn.Close() + } } // @@ -242,8 +267,6 @@ func (cm *connectMethod) proxyAuth() string { // If pconn is no longer needed or not in a good state, putIdleConn // returns false. func (t *Transport) putIdleConn(pconn *persistConn) bool { - t.lk.Lock() - defer t.lk.Unlock() if t.DisableKeepAlives || t.MaxIdleConnsPerHost < 0 { pconn.close() return false @@ -256,21 +279,32 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { if max == 0 { max = DefaultMaxIdleConnsPerHost } + t.idleMu.Lock() + if t.idleConn == nil { + t.idleConn = make(map[string][]*persistConn) + } if len(t.idleConn[key]) >= max { + t.idleMu.Unlock() pconn.close() return false } + for _, exist := range t.idleConn[key] { + if exist == pconn { + log.Fatalf("dup idle pconn %p in freelist", pconn) + } + } t.idleConn[key] = append(t.idleConn[key], pconn) + t.idleMu.Unlock() return true } func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { - t.lk.Lock() - defer t.lk.Unlock() + key := cm.String() + t.idleMu.Lock() + defer t.idleMu.Unlock() if t.idleConn == nil { - t.idleConn = make(map[string][]*persistConn) + return nil } - key := cm.String() for { pconns, ok := t.idleConn[key] if !ok { @@ -289,7 +323,20 @@ func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { return } } - return + panic("unreachable") +} + +func (t *Transport) setReqConn(r *Request, pc *persistConn) { + t.reqMu.Lock() + defer t.reqMu.Unlock() + if t.reqConn == nil { + t.reqConn = make(map[*Request]*persistConn) + } + if pc != nil { + t.reqConn[r] = pc + } else { + delete(t.reqConn, r) + } } func (t *Transport) dial(network, addr string) (c net.Conn, err error) { @@ -323,6 +370,8 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { cacheKey: cm.String(), conn: conn, reqch: make(chan requestAndChan, 50), + writech: make(chan writeRequest, 50), + closech: make(chan struct{}), } switch { @@ -365,7 +414,18 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { if cm.targetScheme == "https" { // Initiate TLS and check remote host name against certificate. - conn = tls.Client(conn, t.TLSClientConfig) + cfg := t.TLSClientConfig + if cfg == nil || cfg.ServerName == "" { + host := cm.tlsHost() + if cfg == nil { + cfg = &tls.Config{ServerName: host} + } else { + clone := *cfg // shallow clone + clone.ServerName = host + cfg = &clone + } + } + conn = tls.Client(conn, cfg) if err = conn.(*tls.Conn).Handshake(); err != nil { return nil, err } @@ -380,6 +440,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { pconn.br = bufio.NewReader(pconn.conn) pconn.bw = bufio.NewWriter(pconn.conn) go pconn.readLoop() + go pconn.writeLoop() return pconn, nil } @@ -421,7 +482,15 @@ func useProxy(addr string) bool { if hasPort(p) { p = p[:strings.LastIndex(p, ":")] } - if addr == p || (p[0] == '.' && (strings.HasSuffix(addr, p) || addr == p[1:])) { + if addr == p { + return false + } + if p[0] == '.' && (strings.HasSuffix(addr, p) || addr == p[1:]) { + // no_proxy ".foo.com" matches "bar.foo.com" or "foo.com" + return false + } + if p[0] != '.' && strings.HasSuffix(addr, p) && addr[len(addr)-len(p)-1] == '.' { + // no_proxy "foo.com" matches "bar.foo.com" return false } } @@ -484,25 +553,28 @@ type persistConn struct { t *Transport cacheKey string // its connectMethod.String() conn net.Conn + closed bool // whether conn has been closed br *bufio.Reader // from conn bw *bufio.Writer // to conn - reqch chan requestAndChan // written by roundTrip(); read by readLoop() + reqch chan requestAndChan // written by roundTrip; read by readLoop + writech chan writeRequest // written by roundTrip; read by writeLoop + closech chan struct{} // broadcast close when readLoop (TCP connection) closes isProxy bool + lk sync.Mutex // guards following 3 fields + numExpectedResponses int + broken bool // an error has happened on this connection; marked broken so it's not reused. // mutateHeaderFunc is an optional func to modify extra // headers on each outbound request before it's written. (the // original Request given to RoundTrip is not modified) mutateHeaderFunc func(Header) - - lk sync.Mutex // guards numExpectedResponses and broken - numExpectedResponses int - broken bool // an error has happened on this connection; marked broken so it's not reused. } func (pc *persistConn) isBroken() bool { pc.lk.Lock() - defer pc.lk.Unlock() - return pc.broken + b := pc.broken + pc.lk.Unlock() + return b } var remoteSideClosedFunc func(error) bool // or nil to use default @@ -518,6 +590,7 @@ func remoteSideClosed(err error) bool { } func (pc *persistConn) readLoop() { + defer close(pc.closech) alive := true var lastbody io.ReadCloser // last response body, if any, read on this connection @@ -544,12 +617,16 @@ func (pc *persistConn) readLoop() { lastbody.Close() // assumed idempotent lastbody = nil } - resp, err := ReadResponse(pc.br, rc.req) + + var resp *Response + if err == nil { + resp, err = ReadResponse(pc.br, rc.req) + } + hasBody := resp != nil && rc.req.Method != "HEAD" && resp.ContentLength != 0 if err != nil { pc.close() } else { - hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0 if rc.addedGzip && hasBody && resp.Header.Get("Content-Encoding") == "gzip" { resp.Header.Del("Content-Encoding") resp.Header.Del("Content-Length") @@ -569,31 +646,37 @@ func (pc *persistConn) readLoop() { alive = false } - hasBody := resp != nil && resp.ContentLength != 0 var waitForBodyRead chan bool - if alive { - if hasBody { - lastbody = resp.Body - waitForBodyRead = make(chan bool) - resp.Body.(*bodyEOFSignal).fn = func() { - if !pc.t.putIdleConn(pc) { - alive = false - } - waitForBodyRead <- true + if hasBody { + lastbody = resp.Body + waitForBodyRead = make(chan bool, 1) + resp.Body.(*bodyEOFSignal).fn = func(err error) { + alive1 := alive + if err != nil { + alive1 = false } - } else { - // When there's no response body, we immediately - // reuse the TCP connection (putIdleConn), but - // we need to prevent ClientConn.Read from - // closing the Response.Body on the next - // loop, otherwise it might close the body - // before the client code has had a chance to - // read it (even though it'll just be 0, EOF). - lastbody = nil - - if !pc.t.putIdleConn(pc) { - alive = false + if alive1 && !pc.t.putIdleConn(pc) { + alive1 = false + } + if !alive1 || pc.isBroken() { + pc.close() } + waitForBodyRead <- alive1 + } + } + + if alive && !hasBody { + // When there's no response body, we immediately + // reuse the TCP connection (putIdleConn), but + // we need to prevent ClientConn.Read from + // closing the Response.Body on the next + // loop, otherwise it might close the body + // before the client code has had a chance to + // read it (even though it'll just be 0, EOF). + lastbody = nil + + if !pc.t.putIdleConn(pc) { + alive = false } } @@ -602,7 +685,35 @@ func (pc *persistConn) readLoop() { // Wait for the just-returned response body to be fully consumed // before we race and peek on the underlying bufio reader. if waitForBodyRead != nil { - <-waitForBodyRead + alive = <-waitForBodyRead + } + + pc.t.setReqConn(rc.req, nil) + + if !alive { + pc.close() + } + } +} + +func (pc *persistConn) writeLoop() { + for { + select { + case wr := <-pc.writech: + if pc.isBroken() { + wr.ch <- errors.New("http: can't write HTTP request on broken connection") + continue + } + err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra) + if err == nil { + err = pc.bw.Flush() + } + if err != nil { + pc.markBroken() + } + wr.ch <- err + case <-pc.closech: + return } } } @@ -622,9 +733,24 @@ type requestAndChan struct { addedGzip bool } +// A writeRequest is sent by the readLoop's goroutine to the +// writeLoop's goroutine to write a request while the read loop +// concurrently waits on both the write response and the server's +// reply. +type writeRequest struct { + req *transportRequest + ch chan<- error +} + func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { - if pc.mutateHeaderFunc != nil { - pc.mutateHeaderFunc(req.extraHeaders()) + pc.t.setReqConn(req.Request, pc) + pc.lk.Lock() + pc.numExpectedResponses++ + headerFn := pc.mutateHeaderFunc + pc.lk.Unlock() + + if headerFn != nil { + headerFn(req.extraHeaders()) } // Ask for a compressed version if the caller didn't set their @@ -633,34 +759,84 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // requested it. requestedGzip := false if !pc.t.DisableCompression && req.Header.Get("Accept-Encoding") == "" { - // Request gzip only, not deflate. Deflate is ambiguous and + // Request gzip only, not deflate. Deflate is ambiguous and // not as universally supported anyway. // See: http://www.gzip.org/zlib/zlib_faq.html#faq38 requestedGzip = true req.extraHeaders().Set("Accept-Encoding", "gzip") } - pc.lk.Lock() - pc.numExpectedResponses++ - pc.lk.Unlock() + // Write the request concurrently with waiting for a response, + // in case the server decides to reply before reading our full + // request body. + writeErrCh := make(chan error, 1) + pc.writech <- writeRequest{req, writeErrCh} - err = req.Request.write(pc.bw, pc.isProxy, req.extra) - if err != nil { - pc.close() - return + resc := make(chan responseAndError, 1) + pc.reqch <- requestAndChan{req.Request, resc, requestedGzip} + + var re responseAndError + var pconnDeadCh = pc.closech + var failTicker <-chan time.Time + var respHeaderTimer <-chan time.Time +WaitResponse: + for { + select { + case err := <-writeErrCh: + if err != nil { + re = responseAndError{nil, err} + pc.close() + break WaitResponse + } + if d := pc.t.ResponseHeaderTimeout; d > 0 { + respHeaderTimer = time.After(d) + } + case <-pconnDeadCh: + // The persist connection is dead. This shouldn't + // usually happen (only with Connection: close responses + // with no response bodies), but if it does happen it + // means either a) the remote server hung up on us + // prematurely, or b) the readLoop sent us a response & + // closed its closech at roughly the same time, and we + // selected this case first, in which case a response + // might still be coming soon. + // + // We can't avoid the select race in b) by using a unbuffered + // resc channel instead, because then goroutines can + // leak if we exit due to other errors. + pconnDeadCh = nil // avoid spinning + failTicker = time.After(100 * time.Millisecond) // arbitrary time to wait for resc + case <-failTicker: + re = responseAndError{err: errors.New("net/http: transport closed before response was received")} + break WaitResponse + case <-respHeaderTimer: + pc.close() + re = responseAndError{err: errors.New("net/http: timeout awaiting response headers")} + break WaitResponse + case re = <-resc: + break WaitResponse + } } - pc.bw.Flush() - ch := make(chan responseAndError, 1) - pc.reqch <- requestAndChan{req.Request, ch, requestedGzip} - re := <-ch pc.lk.Lock() pc.numExpectedResponses-- pc.lk.Unlock() + if re.err != nil { + pc.t.setReqConn(req.Request, nil) + } return re.res, re.err } +// markBroken marks a connection as broken (so it's not reused). +// It differs from close in that it doesn't close the underlying +// connection for use when it's still being read. +func (pc *persistConn) markBroken() { + pc.lk.Lock() + defer pc.lk.Unlock() + pc.broken = true +} + func (pc *persistConn) close() { pc.lk.Lock() defer pc.lk.Unlock() @@ -669,7 +845,10 @@ func (pc *persistConn) close() { func (pc *persistConn) closeLocked() { pc.broken = true - pc.conn.Close() + if !pc.closed { + pc.conn.Close() + pc.closed = true + } pc.mutateHeaderFunc = nil } @@ -687,43 +866,62 @@ func canonicalAddr(url *url.URL) string { return addr } -func responseIsKeepAlive(res *Response) bool { - // TODO: implement. for now just always shutting down the connection. - return false -} - // bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most -// once, right before the final Read() or Close() call returns, but after -// EOF has been seen. +// once, right before its final (error-producing) Read or Close call +// returns. type bodyEOFSignal struct { - body io.ReadCloser - fn func() - isClosed bool + body io.ReadCloser + mu sync.Mutex // guards closed, rerr and fn + closed bool // whether Close has been called + rerr error // sticky Read error + fn func(error) // error will be nil on Read io.EOF } func (es *bodyEOFSignal) Read(p []byte) (n int, err error) { - n, err = es.body.Read(p) - if es.isClosed && n > 0 { - panic("http: unexpected bodyEOFSignal Read after Close; see issue 1725") + es.mu.Lock() + closed, rerr := es.closed, es.rerr + es.mu.Unlock() + if closed { + return 0, errors.New("http: read on closed response body") + } + if rerr != nil { + return 0, rerr } - if err == io.EOF && es.fn != nil { - es.fn() - es.fn = nil + + n, err = es.body.Read(p) + if err != nil { + es.mu.Lock() + defer es.mu.Unlock() + if es.rerr == nil { + es.rerr = err + } + es.condfn(err) } return } -func (es *bodyEOFSignal) Close() (err error) { - if es.isClosed { +func (es *bodyEOFSignal) Close() error { + es.mu.Lock() + defer es.mu.Unlock() + if es.closed { return nil } - es.isClosed = true - err = es.body.Close() - if err == nil && es.fn != nil { - es.fn() - es.fn = nil + es.closed = true + err := es.body.Close() + es.condfn(err) + return err +} + +// caller must hold es.mu. +func (es *bodyEOFSignal) condfn(err error) { + if es.fn == nil { + return } - return + if err == io.EOF { + err = nil + } + es.fn(err) + es.fn = nil } type readFirstCloseBoth struct { diff --git a/src/pkg/net/http/transport_test.go b/src/pkg/net/http/transport_test.go index a9e401de5..68010e68b 100644 --- a/src/pkg/net/http/transport_test.go +++ b/src/pkg/net/http/transport_test.go @@ -13,6 +13,7 @@ import ( "fmt" "io" "io/ioutil" + "net" . "net/http" "net/http/httptest" "net/url" @@ -20,6 +21,7 @@ import ( "runtime" "strconv" "strings" + "sync" "testing" "time" ) @@ -35,14 +37,78 @@ var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte(r.RemoteAddr)) }) +// testCloseConn is a net.Conn tracked by a testConnSet. +type testCloseConn struct { + net.Conn + set *testConnSet +} + +func (c *testCloseConn) Close() error { + c.set.remove(c) + return c.Conn.Close() +} + +// testConnSet tracks a set of TCP connections and whether they've +// been closed. +type testConnSet struct { + t *testing.T + closed map[net.Conn]bool + list []net.Conn // in order created + mutex sync.Mutex +} + +func (tcs *testConnSet) insert(c net.Conn) { + tcs.mutex.Lock() + defer tcs.mutex.Unlock() + tcs.closed[c] = false + tcs.list = append(tcs.list, c) +} + +func (tcs *testConnSet) remove(c net.Conn) { + tcs.mutex.Lock() + defer tcs.mutex.Unlock() + tcs.closed[c] = true +} + +// some tests use this to manage raw tcp connections for later inspection +func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) { + connSet := &testConnSet{ + t: t, + closed: make(map[net.Conn]bool), + } + dial := func(n, addr string) (net.Conn, error) { + c, err := net.Dial(n, addr) + if err != nil { + return nil, err + } + tc := &testCloseConn{c, connSet} + connSet.insert(tc) + return tc, nil + } + return connSet, dial +} + +func (tcs *testConnSet) check(t *testing.T) { + tcs.mutex.Lock() + defer tcs.mutex.Unlock() + + for i, c := range tcs.list { + if !tcs.closed[c] { + t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list)) + } + } +} + // Two subsequent requests and verify their response is the same. // The response from the server is our own IP:port func TestTransportKeepAlives(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() for _, disableKeepAlive := range []bool{false, true} { tr := &Transport{DisableKeepAlives: disableKeepAlive} + defer tr.CloseIdleConnections() c := &Client{Transport: tr} fetch := func(n int) string { @@ -69,11 +135,16 @@ func TestTransportKeepAlives(t *testing.T) { } func TestTransportConnectionCloseOnResponse(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() + connSet, testDial := makeTestDial(t) + for _, connectionClose := range []bool{false, true} { - tr := &Transport{} + tr := &Transport{ + Dial: testDial, + } c := &Client{Transport: tr} fetch := func(n int) string { @@ -92,8 +163,8 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) { if err != nil { t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err) } - body, err := ioutil.ReadAll(res.Body) defer res.Body.Close() + body, err := ioutil.ReadAll(res.Body) if err != nil { t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) } @@ -107,15 +178,24 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) { t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", connectionClose, bodiesDiffer, body1, body2) } + + tr.CloseIdleConnections() } + + connSet.check(t) } func TestTransportConnectionCloseOnRequest(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() + connSet, testDial := makeTestDial(t) + for _, connectionClose := range []bool{false, true} { - tr := &Transport{} + tr := &Transport{ + Dial: testDial, + } c := &Client{Transport: tr} fetch := func(n int) string { @@ -149,10 +229,15 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) { t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", connectionClose, bodiesDiffer, body1, body2) } + + tr.CloseIdleConnections() } + + connSet.check(t) } func TestTransportIdleCacheKeys(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() @@ -185,6 +270,7 @@ func TestTransportIdleCacheKeys(t *testing.T) { } func TestTransportMaxPerHostIdleConns(t *testing.T) { + defer checkLeakedTransports(t) resch := make(chan string) gotReq := make(chan bool) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -201,7 +287,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { c := &Client{Transport: tr} // Start 3 outstanding requests and wait for the server to get them. - // Their responses will hang until we we write to resch, though. + // Their responses will hang until we write to resch, though. donech := make(chan bool) doReq := func() { resp, err := c.Get(ts.URL) @@ -253,6 +339,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { } func TestTransportServerClosingUnexpectedly(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() @@ -309,9 +396,9 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) { // Test for http://golang.org/issue/2616 (appropriate issue number) // This fails pretty reliably with GOMAXPROCS=100 or something high. func TestStressSurpriseServerCloses(t *testing.T) { + defer checkLeakedTransports(t) if testing.Short() { - t.Logf("skipping test in short mode") - return + t.Skip("skipping test in short mode") } ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "5") @@ -365,6 +452,7 @@ func TestStressSurpriseServerCloses(t *testing.T) { // TestTransportHeadResponses verifies that we deal with Content-Lengths // with no bodies properly func TestTransportHeadResponses(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "HEAD" { panic("expected HEAD; got " + r.Method) @@ -384,7 +472,7 @@ func TestTransportHeadResponses(t *testing.T) { if e, g := "123", res.Header.Get("Content-Length"); e != g { t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g) } - if e, g := int64(0), res.ContentLength; e != g { + if e, g := int64(123), res.ContentLength; e != g { t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g) } } @@ -393,6 +481,7 @@ func TestTransportHeadResponses(t *testing.T) { // TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding // on responses to HEAD requests. func TestTransportHeadChunkedResponse(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "HEAD" { panic("expected HEAD; got " + r.Method) @@ -434,6 +523,7 @@ var roundTripTests = []struct { // Test that the modification made to the Request by the RoundTripper is cleaned up func TestRoundTripGzip(t *testing.T) { + defer checkLeakedTransports(t) const responseBody = "test response body" ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { accept := req.Header.Get("Accept-Encoding") @@ -490,6 +580,7 @@ func TestRoundTripGzip(t *testing.T) { } func TestTransportGzip(t *testing.T) { + defer checkLeakedTransports(t) const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" const nRandBytes = 1024 * 1024 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { @@ -582,6 +673,7 @@ func TestTransportGzip(t *testing.T) { } func TestTransportProxy(t *testing.T) { + defer checkLeakedTransports(t) ch := make(chan string, 1) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ch <- "real server" @@ -610,6 +702,7 @@ func TestTransportProxy(t *testing.T) { // but checks that we don't recurse forever, and checks that // Content-Encoding is removed. func TestTransportGzipRecursive(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", "gzip") w.Write(rgz) @@ -636,6 +729,7 @@ func TestTransportGzipRecursive(t *testing.T) { // tests that persistent goroutine connections shut down when no longer desired. func TestTransportPersistConnLeak(t *testing.T) { + defer checkLeakedTransports(t) gotReqCh := make(chan bool) unblockCh := make(chan bool) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -698,8 +792,49 @@ func TestTransportPersistConnLeak(t *testing.T) { } } +// golang.org/issue/4531: Transport leaks goroutines when +// request.ContentLength is explicitly short +func TestTransportPersistConnLeakShortBody(t *testing.T) { + defer checkLeakedTransports(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + })) + defer ts.Close() + + tr := &Transport{} + c := &Client{Transport: tr} + + n0 := runtime.NumGoroutine() + body := []byte("Hello") + for i := 0; i < 20; i++ { + req, err := NewRequest("POST", ts.URL, bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + req.ContentLength = int64(len(body) - 2) // explicitly short + _, err = c.Do(req) + if err == nil { + t.Fatal("Expect an error from writing too long of a body.") + } + } + nhigh := runtime.NumGoroutine() + tr.CloseIdleConnections() + time.Sleep(50 * time.Millisecond) + runtime.GC() + nfinal := runtime.NumGoroutine() + + growth := nfinal - n0 + + // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. + // Previously we were leaking one per numReq. + t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) + if int(growth) > 5 { + t.Error("too many new goroutines") + } +} + // This used to crash; http://golang.org/issue/3266 func TestTransportIdleConnCrash(t *testing.T) { + defer checkLeakedTransports(t) tr := &Transport{} c := &Client{Transport: tr} @@ -724,6 +859,361 @@ func TestTransportIdleConnCrash(t *testing.T) { <-didreq } +// Test that the transport doesn't close the TCP connection early, +// before the response body has been read. This was a regression +// which sadly lacked a triggering test. The large response body made +// the old race easier to trigger. +func TestIssue3644(t *testing.T) { + defer checkLeakedTransports(t) + const numFoos = 5000 + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Connection", "close") + for i := 0; i < numFoos; i++ { + w.Write([]byte("foo ")) + } + })) + defer ts.Close() + tr := &Transport{} + c := &Client{Transport: tr} + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + bs, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if len(bs) != numFoos*len("foo ") { + t.Errorf("unexpected response length") + } +} + +// Test that a client receives a server's reply, even if the server doesn't read +// the entire request body. +func TestIssue3595(t *testing.T) { + defer checkLeakedTransports(t) + const deniedMsg = "sorry, denied." + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + Error(w, deniedMsg, StatusUnauthorized) + })) + defer ts.Close() + tr := &Transport{} + c := &Client{Transport: tr} + res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a')) + if err != nil { + t.Errorf("Post: %v", err) + return + } + got, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("Body ReadAll: %v", err) + } + if !strings.Contains(string(got), deniedMsg) { + t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg) + } +} + +// From http://golang.org/issue/4454 , +// "client fails to handle requests with no body and chunked encoding" +func TestChunkedNoContent(t *testing.T) { + defer checkLeakedTransports(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.WriteHeader(StatusNoContent) + })) + defer ts.Close() + + for _, closeBody := range []bool{true, false} { + c := &Client{Transport: &Transport{}} + const n = 4 + for i := 1; i <= n; i++ { + res, err := c.Get(ts.URL) + if err != nil { + t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err) + } else { + if closeBody { + res.Body.Close() + } + } + } + } +} + +func TestTransportConcurrency(t *testing.T) { + defer checkLeakedTransports(t) + const maxProcs = 16 + const numReqs = 500 + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "%v", r.FormValue("echo")) + })) + defer ts.Close() + tr := &Transport{} + c := &Client{Transport: tr} + reqs := make(chan string) + defer close(reqs) + + var wg sync.WaitGroup + wg.Add(numReqs) + for i := 0; i < maxProcs*2; i++ { + go func() { + for req := range reqs { + res, err := c.Get(ts.URL + "/?echo=" + req) + if err != nil { + t.Errorf("error on req %s: %v", req, err) + wg.Done() + continue + } + all, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Errorf("read error on req %s: %v", req, err) + wg.Done() + continue + } + if string(all) != req { + t.Errorf("body of req %s = %q; want %q", req, all, req) + } + wg.Done() + res.Body.Close() + } + }() + } + for i := 0; i < numReqs; i++ { + reqs <- fmt.Sprintf("request-%d", i) + } + wg.Wait() +} + +func TestIssue4191_InfiniteGetTimeout(t *testing.T) { + defer checkLeakedTransports(t) + const debug = false + mux := NewServeMux() + mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { + io.Copy(w, neverEnding('a')) + }) + ts := httptest.NewServer(mux) + timeout := 100 * time.Millisecond + + client := &Client{ + Transport: &Transport{ + Dial: func(n, addr string) (net.Conn, error) { + conn, err := net.Dial(n, addr) + if err != nil { + return nil, err + } + conn.SetDeadline(time.Now().Add(timeout)) + if debug { + conn = NewLoggingConn("client", conn) + } + return conn, nil + }, + DisableKeepAlives: true, + }, + } + + getFailed := false + nRuns := 5 + if testing.Short() { + nRuns = 1 + } + for i := 0; i < nRuns; i++ { + if debug { + println("run", i+1, "of", nRuns) + } + sres, err := client.Get(ts.URL + "/get") + if err != nil { + if !getFailed { + // Make the timeout longer, once. + getFailed = true + t.Logf("increasing timeout") + i-- + timeout *= 10 + continue + } + t.Errorf("Error issuing GET: %v", err) + break + } + _, err = io.Copy(ioutil.Discard, sres.Body) + if err == nil { + t.Errorf("Unexpected successful copy") + break + } + } + if debug { + println("tests complete; waiting for handlers to finish") + } + ts.Close() +} + +func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { + defer checkLeakedTransports(t) + const debug = false + mux := NewServeMux() + mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { + io.Copy(w, neverEnding('a')) + }) + mux.HandleFunc("/put", func(w ResponseWriter, r *Request) { + defer r.Body.Close() + io.Copy(ioutil.Discard, r.Body) + }) + ts := httptest.NewServer(mux) + timeout := 100 * time.Millisecond + + client := &Client{ + Transport: &Transport{ + Dial: func(n, addr string) (net.Conn, error) { + conn, err := net.Dial(n, addr) + if err != nil { + return nil, err + } + conn.SetDeadline(time.Now().Add(timeout)) + if debug { + conn = NewLoggingConn("client", conn) + } + return conn, nil + }, + DisableKeepAlives: true, + }, + } + + getFailed := false + nRuns := 5 + if testing.Short() { + nRuns = 1 + } + for i := 0; i < nRuns; i++ { + if debug { + println("run", i+1, "of", nRuns) + } + sres, err := client.Get(ts.URL + "/get") + if err != nil { + if !getFailed { + // Make the timeout longer, once. + getFailed = true + t.Logf("increasing timeout") + i-- + timeout *= 10 + continue + } + t.Errorf("Error issuing GET: %v", err) + break + } + req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body) + _, err = client.Do(req) + if err == nil { + sres.Body.Close() + t.Errorf("Unexpected successful PUT") + break + } + sres.Body.Close() + } + if debug { + println("tests complete; waiting for handlers to finish") + } + ts.Close() +} + +func TestTransportResponseHeaderTimeout(t *testing.T) { + defer checkLeakedTransports(t) + if testing.Short() { + t.Skip("skipping timeout test in -short mode") + } + mux := NewServeMux() + mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {}) + mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) { + time.Sleep(2 * time.Second) + }) + ts := httptest.NewServer(mux) + defer ts.Close() + + tr := &Transport{ + ResponseHeaderTimeout: 500 * time.Millisecond, + } + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + tests := []struct { + path string + want int + wantErr string + }{ + {path: "/fast", want: 200}, + {path: "/slow", wantErr: "timeout awaiting response headers"}, + {path: "/fast", want: 200}, + } + for i, tt := range tests { + res, err := c.Get(ts.URL + tt.path) + if err != nil { + if strings.Contains(err.Error(), tt.wantErr) { + continue + } + t.Errorf("%d. unexpected error: %v", i, err) + continue + } + if tt.wantErr != "" { + t.Errorf("%d. no error. expected error: %v", i, tt.wantErr) + continue + } + if res.StatusCode != tt.want { + t.Errorf("%d for path %q status = %d; want %d", i, tt.path, res.StatusCode, tt.want) + } + } +} + +func TestTransportCancelRequest(t *testing.T) { + defer checkLeakedTransports(t) + if testing.Short() { + t.Skip("skipping test in -short mode") + } + unblockc := make(chan bool) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "Hello") + w.(Flusher).Flush() // send headers and some body + <-unblockc + })) + defer ts.Close() + defer close(unblockc) + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + req, _ := NewRequest("GET", ts.URL, nil) + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + go func() { + time.Sleep(1 * time.Second) + tr.CancelRequest(req) + }() + t0 := time.Now() + body, err := ioutil.ReadAll(res.Body) + d := time.Since(t0) + + if err == nil { + t.Error("expected an error reading the body") + } + if string(body) != "Hello" { + t.Errorf("Body = %q; want Hello", body) + } + if d < 500*time.Millisecond { + t.Errorf("expected ~1 second delay; got %v", d) + } + // Verify no outstanding requests after readLoop/writeLoop + // goroutines shut down. + for tries := 3; tries > 0; tries-- { + n := tr.NumPendingRequestsForTesting() + if n == 0 { + break + } + time.Sleep(100 * time.Millisecond) + if tries == 1 { + t.Errorf("pending requests = %d; want 0", n) + } + } +} + type fooProto struct{} func (fooProto) RoundTrip(req *Request) (*Response, error) { @@ -737,6 +1227,7 @@ func (fooProto) RoundTrip(req *Request) (*Response, error) { } func TestTransportAltProto(t *testing.T) { + defer checkLeakedTransports(t) tr := &Transport{} c := &Client{Transport: tr} tr.RegisterProtocol("foo", fooProto{}) @@ -754,15 +1245,58 @@ func TestTransportAltProto(t *testing.T) { } } -var proxyFromEnvTests = []struct { +func TestTransportNoHost(t *testing.T) { + defer checkLeakedTransports(t) + tr := &Transport{} + _, err := tr.RoundTrip(&Request{ + Header: make(Header), + URL: &url.URL{ + Scheme: "http", + }, + }) + want := "http: no Host in request URL" + if got := fmt.Sprint(err); got != want { + t.Errorf("error = %v; want %q", err, want) + } +} + +type proxyFromEnvTest struct { + req string // URL to fetch; blank means "http://example.com" env string - wanturl string + noenv string + want string wanterr error -}{ - {"127.0.0.1:8080", "http://127.0.0.1:8080", nil}, - {"http://127.0.0.1:8080", "http://127.0.0.1:8080", nil}, - {"https://127.0.0.1:8080", "https://127.0.0.1:8080", nil}, - {"", "<nil>", nil}, +} + +func (t proxyFromEnvTest) String() string { + var buf bytes.Buffer + if t.env != "" { + fmt.Fprintf(&buf, "http_proxy=%q", t.env) + } + if t.noenv != "" { + fmt.Fprintf(&buf, " no_proxy=%q", t.noenv) + } + req := "http://example.com" + if t.req != "" { + req = t.req + } + fmt.Fprintf(&buf, " req=%q", req) + return strings.TrimSpace(buf.String()) +} + +var proxyFromEnvTests = []proxyFromEnvTest{ + {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"}, + {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"}, + {env: "cache.corp.example.com", want: "http://cache.corp.example.com"}, + {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"}, + {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"}, + {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"}, + {want: "<nil>"}, + {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"}, + {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "<nil>"}, + {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, + {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"}, + {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, } func TestProxyFromEnvironment(t *testing.T) { @@ -770,16 +1304,21 @@ func TestProxyFromEnvironment(t *testing.T) { os.Setenv("http_proxy", "") os.Setenv("NO_PROXY", "") os.Setenv("no_proxy", "") - for i, tt := range proxyFromEnvTests { + for _, tt := range proxyFromEnvTests { os.Setenv("HTTP_PROXY", tt.env) - req, _ := NewRequest("GET", "http://example.com", nil) + os.Setenv("NO_PROXY", tt.noenv) + reqURL := tt.req + if reqURL == "" { + reqURL = "http://example.com" + } + req, _ := NewRequest("GET", reqURL, nil) url, err := ProxyFromEnvironment(req) if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e { - t.Errorf("%d. got error = %q, want %q", i, g, e) + t.Errorf("%v: got error = %q, want %q", tt, g, e) continue } - if got := fmt.Sprintf("%s", url); got != tt.wanturl { - t.Errorf("%d. got URL = %q, want %q", i, url, tt.wanturl) + if got := fmt.Sprintf("%s", url); got != tt.want { + t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want) } } } diff --git a/src/pkg/net/http/z_last_test.go b/src/pkg/net/http/z_last_test.go new file mode 100644 index 000000000..44095a8d9 --- /dev/null +++ b/src/pkg/net/http/z_last_test.go @@ -0,0 +1,60 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "net/http" + "runtime" + "strings" + "testing" + "time" +) + +// Verify the other tests didn't leave any goroutines running. +// This is in a file named z_last_test.go so it sorts at the end. +func TestGoroutinesRunning(t *testing.T) { + n := runtime.NumGoroutine() + t.Logf("num goroutines = %d", n) + if n > 20 { + // Currently 14 on Linux (blocked in epoll_wait, + // waiting for on fds that are closed?), but give some + // slop for now. + buf := make([]byte, 1<<20) + buf = buf[:runtime.Stack(buf, true)] + t.Errorf("Too many goroutines:\n%s", buf) + } +} + +func checkLeakedTransports(t *testing.T) { + http.DefaultTransport.(*http.Transport).CloseIdleConnections() + if testing.Short() { + return + } + buf := make([]byte, 1<<20) + var stacks string + var bad string + badSubstring := map[string]string{ + ").readLoop(": "a Transport", + ").writeLoop(": "a Transport", + "created by net/http/httptest.(*Server).Start": "an httptest.Server", + "timeoutHandler": "a TimeoutHandler", + } + for i := 0; i < 4; i++ { + bad = "" + stacks = string(buf[:runtime.Stack(buf, true)]) + for substr, what := range badSubstring { + if strings.Contains(stacks, substr) { + bad = what + } + } + if bad == "" { + return + } + // Bad stuff found, but goroutines might just still be + // shutting down, so give it some time. + time.Sleep(250 * time.Millisecond) + } + t.Errorf("Test appears to have leaked %s:\n%s", bad, stacks) +} |