summaryrefslogtreecommitdiff
path: root/src/pkg/net/http
diff options
context:
space:
mode:
authorMichael Stapelberg <stapelberg@debian.org>2013-03-04 21:27:36 +0100
committerMichael Stapelberg <michael@stapelberg.de>2013-03-04 21:27:36 +0100
commit04b08da9af0c450d645ab7389d1467308cfc2db8 (patch)
treedb247935fa4f2f94408edc3acd5d0d4f997aa0d8 /src/pkg/net/http
parent917c5fb8ec48e22459d77e3849e6d388f93d3260 (diff)
downloadgolang-upstream/1.1_hg20130304.tar.gz
Imported Upstream version 1.1~hg20130304upstream/1.1_hg20130304
Diffstat (limited to 'src/pkg/net/http')
-rw-r--r--src/pkg/net/http/cgi/child.go12
-rw-r--r--src/pkg/net/http/cgi/child_test.go24
-rw-r--r--src/pkg/net/http/cgi/host_test.go86
-rw-r--r--src/pkg/net/http/cgi/plan9_test.go18
-rw-r--r--src/pkg/net/http/cgi/posix_test.go21
-rwxr-xr-xsrc/pkg/net/http/cgi/testdata/test.cgi55
-rw-r--r--src/pkg/net/http/chunked.go59
-rw-r--r--src/pkg/net/http/chunked_test.go54
-rw-r--r--src/pkg/net/http/client.go178
-rw-r--r--src/pkg/net/http/client_test.go250
-rw-r--r--src/pkg/net/http/cookie.go9
-rw-r--r--src/pkg/net/http/cookie_test.go2
-rw-r--r--src/pkg/net/http/cookiejar/jar.go494
-rw-r--r--src/pkg/net/http/cookiejar/jar_test.go1267
-rw-r--r--src/pkg/net/http/cookiejar/punycode.go159
-rw-r--r--src/pkg/net/http/cookiejar/punycode_test.go161
-rw-r--r--src/pkg/net/http/example_test.go2
-rw-r--r--src/pkg/net/http/export_test.go23
-rw-r--r--src/pkg/net/http/filetransport_test.go10
-rw-r--r--src/pkg/net/http/fs.go196
-rw-r--r--src/pkg/net/http/fs_test.go439
-rw-r--r--src/pkg/net/http/header.go135
-rw-r--r--src/pkg/net/http/header_test.go123
-rw-r--r--src/pkg/net/http/httptest/example_test.go50
-rw-r--r--src/pkg/net/http/httptest/recorder.go24
-rw-r--r--src/pkg/net/http/httptest/recorder_test.go90
-rw-r--r--src/pkg/net/http/httptest/server.go73
-rw-r--r--src/pkg/net/http/httputil/chunked.go59
-rw-r--r--src/pkg/net/http/httputil/chunked_test.go54
-rw-r--r--src/pkg/net/http/httputil/dump.go4
-rw-r--r--src/pkg/net/http/httputil/reverseproxy.go58
-rw-r--r--src/pkg/net/http/httputil/reverseproxy_test.go84
-rw-r--r--src/pkg/net/http/jar.go19
-rw-r--r--src/pkg/net/http/lex.go206
-rw-r--r--src/pkg/net/http/lex_test.go65
-rw-r--r--src/pkg/net/http/npn_test.go118
-rw-r--r--src/pkg/net/http/pprof/pprof.go17
-rw-r--r--src/pkg/net/http/proxy_test.go4
-rw-r--r--src/pkg/net/http/range_test.go22
-rw-r--r--src/pkg/net/http/readrequest_test.go48
-rw-r--r--src/pkg/net/http/request.go247
-rw-r--r--src/pkg/net/http/request_test.go189
-rw-r--r--src/pkg/net/http/requestwrite_test.go63
-rw-r--r--src/pkg/net/http/response.go14
-rw-r--r--src/pkg/net/http/response_test.go162
-rw-r--r--src/pkg/net/http/responsewrite_test.go138
-rw-r--r--src/pkg/net/http/serve_test.go578
-rw-r--r--src/pkg/net/http/server.go873
-rw-r--r--src/pkg/net/http/server_test.go95
-rw-r--r--src/pkg/net/http/transfer.go100
-rw-r--r--src/pkg/net/http/transfer_test.go37
-rw-r--r--src/pkg/net/http/transport.go392
-rw-r--r--src/pkg/net/http/transport_test.go577
-rw-r--r--src/pkg/net/http/z_last_test.go60
54 files changed, 7051 insertions, 1246 deletions
diff --git a/src/pkg/net/http/cgi/child.go b/src/pkg/net/http/cgi/child.go
index 1ba7bec5f..100b8b777 100644
--- a/src/pkg/net/http/cgi/child.go
+++ b/src/pkg/net/http/cgi/child.go
@@ -91,10 +91,19 @@ func RequestFromMap(params map[string]string) (*http.Request, error) {
// TODO: cookies. parsing them isn't exported, though.
+ uriStr := params["REQUEST_URI"]
+ if uriStr == "" {
+ // Fallback to SCRIPT_NAME, PATH_INFO and QUERY_STRING.
+ uriStr = params["SCRIPT_NAME"] + params["PATH_INFO"]
+ s := params["QUERY_STRING"]
+ if s != "" {
+ uriStr += "?" + s
+ }
+ }
if r.Host != "" {
// Hostname is provided, so we can reasonably construct a URL,
// even if we have to assume 'http' for the scheme.
- rawurl := "http://" + r.Host + params["REQUEST_URI"]
+ rawurl := "http://" + r.Host + uriStr
url, err := url.Parse(rawurl)
if err != nil {
return nil, errors.New("cgi: failed to parse host and REQUEST_URI into a URL: " + rawurl)
@@ -104,7 +113,6 @@ func RequestFromMap(params map[string]string) (*http.Request, error) {
// Fallback logic if we don't have a Host header or the URL
// failed to parse
if r.URL == nil {
- uriStr := params["REQUEST_URI"]
url, err := url.Parse(uriStr)
if err != nil {
return nil, errors.New("cgi: failed to parse REQUEST_URI into a URL: " + uriStr)
diff --git a/src/pkg/net/http/cgi/child_test.go b/src/pkg/net/http/cgi/child_test.go
index ec53ab851..74e068014 100644
--- a/src/pkg/net/http/cgi/child_test.go
+++ b/src/pkg/net/http/cgi/child_test.go
@@ -82,6 +82,28 @@ func TestRequestWithoutHost(t *testing.T) {
t.Fatalf("unexpected nil URL")
}
if g, e := req.URL.String(), "/path?a=b"; e != g {
- t.Errorf("expected URL %q; got %q", e, g)
+ t.Errorf("URL = %q; want %q", g, e)
+ }
+}
+
+func TestRequestWithoutRequestURI(t *testing.T) {
+ env := map[string]string{
+ "SERVER_PROTOCOL": "HTTP/1.1",
+ "HTTP_HOST": "example.com",
+ "REQUEST_METHOD": "GET",
+ "SCRIPT_NAME": "/dir/scriptname",
+ "PATH_INFO": "/p1/p2",
+ "QUERY_STRING": "a=1&b=2",
+ "CONTENT_LENGTH": "123",
+ }
+ req, err := RequestFromMap(env)
+ if err != nil {
+ t.Fatalf("RequestFromMap: %v", err)
+ }
+ if req.URL == nil {
+ t.Fatalf("unexpected nil URL")
+ }
+ if g, e := req.URL.String(), "http://example.com/dir/scriptname/p1/p2?a=1&b=2"; e != g {
+ t.Errorf("URL = %q; want %q", g, e)
}
}
diff --git a/src/pkg/net/http/cgi/host_test.go b/src/pkg/net/http/cgi/host_test.go
index 4db3d850c..8c16e6897 100644
--- a/src/pkg/net/http/cgi/host_test.go
+++ b/src/pkg/net/http/cgi/host_test.go
@@ -19,7 +19,6 @@ import (
"runtime"
"strconv"
"strings"
- "syscall"
"testing"
"time"
)
@@ -63,17 +62,25 @@ readlines:
}
for key, expected := range expectedMap {
- if got := m[key]; got != expected {
+ got := m[key]
+ if key == "cwd" {
+ // For Windows. golang.org/issue/4645.
+ fi1, _ := os.Stat(got)
+ fi2, _ := os.Stat(expected)
+ if os.SameFile(fi1, fi2) {
+ got = expected
+ }
+ }
+ if got != expected {
t.Errorf("for key %q got %q; expected %q", key, got, expected)
}
}
return rw
}
-var cgiTested = false
-var cgiWorks bool
+var cgiTested, cgiWorks bool
-func skipTest(t *testing.T) bool {
+func check(t *testing.T) {
if !cgiTested {
cgiTested = true
cgiWorks = exec.Command("./testdata/test.cgi").Run() == nil
@@ -81,16 +88,12 @@ func skipTest(t *testing.T) bool {
if !cgiWorks {
// No Perl on Windows, needed by test.cgi
// TODO: make the child process be Go, not Perl.
- t.Logf("Skipping test: test.cgi failed.")
- return true
+ t.Skip("Skipping test: test.cgi failed.")
}
- return false
}
func TestCGIBasicGet(t *testing.T) {
- if skipTest(t) {
- return
- }
+ check(t)
h := &Handler{
Path: "testdata/test.cgi",
Root: "/test.cgi",
@@ -124,9 +127,7 @@ func TestCGIBasicGet(t *testing.T) {
}
func TestCGIBasicGetAbsPath(t *testing.T) {
- if skipTest(t) {
- return
- }
+ check(t)
pwd, err := os.Getwd()
if err != nil {
t.Fatalf("getwd error: %v", err)
@@ -144,9 +145,7 @@ func TestCGIBasicGetAbsPath(t *testing.T) {
}
func TestPathInfo(t *testing.T) {
- if skipTest(t) {
- return
- }
+ check(t)
h := &Handler{
Path: "testdata/test.cgi",
Root: "/test.cgi",
@@ -163,9 +162,7 @@ func TestPathInfo(t *testing.T) {
}
func TestPathInfoDirRoot(t *testing.T) {
- if skipTest(t) {
- return
- }
+ check(t)
h := &Handler{
Path: "testdata/test.cgi",
Root: "/myscript/",
@@ -181,9 +178,7 @@ func TestPathInfoDirRoot(t *testing.T) {
}
func TestDupHeaders(t *testing.T) {
- if skipTest(t) {
- return
- }
+ check(t)
h := &Handler{
Path: "testdata/test.cgi",
}
@@ -203,9 +198,7 @@ func TestDupHeaders(t *testing.T) {
}
func TestPathInfoNoRoot(t *testing.T) {
- if skipTest(t) {
- return
- }
+ check(t)
h := &Handler{
Path: "testdata/test.cgi",
Root: "",
@@ -221,9 +214,7 @@ func TestPathInfoNoRoot(t *testing.T) {
}
func TestCGIBasicPost(t *testing.T) {
- if skipTest(t) {
- return
- }
+ check(t)
postReq := `POST /test.cgi?a=b HTTP/1.0
Host: example.com
Content-Type: application/x-www-form-urlencoded
@@ -250,9 +241,7 @@ func chunk(s string) string {
// The CGI spec doesn't allow chunked requests.
func TestCGIPostChunked(t *testing.T) {
- if skipTest(t) {
- return
- }
+ check(t)
postReq := `POST /test.cgi?a=b HTTP/1.1
Host: example.com
Content-Type: application/x-www-form-urlencoded
@@ -273,9 +262,7 @@ Transfer-Encoding: chunked
}
func TestRedirect(t *testing.T) {
- if skipTest(t) {
- return
- }
+ check(t)
h := &Handler{
Path: "testdata/test.cgi",
Root: "/test.cgi",
@@ -290,9 +277,7 @@ func TestRedirect(t *testing.T) {
}
func TestInternalRedirect(t *testing.T) {
- if skipTest(t) {
- return
- }
+ check(t)
baseHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
fmt.Fprintf(rw, "basepath=%s\n", req.URL.Path)
fmt.Fprintf(rw, "remoteaddr=%s\n", req.RemoteAddr)
@@ -312,8 +297,9 @@ func TestInternalRedirect(t *testing.T) {
// TestCopyError tests that we kill the process if there's an error copying
// its output. (for example, from the client having gone away)
func TestCopyError(t *testing.T) {
- if skipTest(t) || runtime.GOOS == "windows" {
- return
+ check(t)
+ if runtime.GOOS == "windows" {
+ t.Skipf("skipping test on %q", runtime.GOOS)
}
h := &Handler{
Path: "testdata/test.cgi",
@@ -353,11 +339,7 @@ func TestCopyError(t *testing.T) {
}
childRunning := func() bool {
- p, err := os.FindProcess(pid)
- if err != nil {
- return false
- }
- return p.Signal(syscall.Signal(0)) == nil
+ return isProcessRunning(t, pid)
}
if !childRunning() {
@@ -376,10 +358,10 @@ func TestCopyError(t *testing.T) {
}
func TestDirUnix(t *testing.T) {
- if skipTest(t) || runtime.GOOS == "windows" {
- return
+ check(t)
+ if runtime.GOOS == "windows" {
+ t.Skipf("skipping test on %q", runtime.GOOS)
}
-
cwd, _ := os.Getwd()
h := &Handler{
Path: "testdata/test.cgi",
@@ -404,8 +386,8 @@ func TestDirUnix(t *testing.T) {
}
func TestDirWindows(t *testing.T) {
- if skipTest(t) || runtime.GOOS != "windows" {
- return
+ if runtime.GOOS != "windows" {
+ t.Skip("Skipping windows specific test.")
}
cgifile, _ := filepath.Abs("testdata/test.cgi")
@@ -414,7 +396,7 @@ func TestDirWindows(t *testing.T) {
var err error
perl, err = exec.LookPath("perl")
if err != nil {
- return
+ t.Skip("Skipping test: perl not found.")
}
perl, _ = filepath.Abs(perl)
@@ -456,7 +438,7 @@ func TestEnvOverride(t *testing.T) {
var err error
perl, err = exec.LookPath("perl")
if err != nil {
- return
+ t.Skipf("Skipping test: perl not found.")
}
perl, _ = filepath.Abs(perl)
diff --git a/src/pkg/net/http/cgi/plan9_test.go b/src/pkg/net/http/cgi/plan9_test.go
new file mode 100644
index 000000000..c8235831b
--- /dev/null
+++ b/src/pkg/net/http/cgi/plan9_test.go
@@ -0,0 +1,18 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build plan9
+
+package cgi
+
+import (
+ "os"
+ "strconv"
+ "testing"
+)
+
+func isProcessRunning(t *testing.T, pid int) bool {
+ _, err := os.Stat("/proc/" + strconv.Itoa(pid))
+ return err == nil
+}
diff --git a/src/pkg/net/http/cgi/posix_test.go b/src/pkg/net/http/cgi/posix_test.go
new file mode 100644
index 000000000..5ff9e7d5e
--- /dev/null
+++ b/src/pkg/net/http/cgi/posix_test.go
@@ -0,0 +1,21 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !plan9
+
+package cgi
+
+import (
+ "os"
+ "syscall"
+ "testing"
+)
+
+func isProcessRunning(t *testing.T, pid int) bool {
+ p, err := os.FindProcess(pid)
+ if err != nil {
+ return false
+ }
+ return p.Signal(syscall.Signal(0)) == nil
+}
diff --git a/src/pkg/net/http/cgi/testdata/test.cgi b/src/pkg/net/http/cgi/testdata/test.cgi
index b46b1330f..3214df6f0 100755
--- a/src/pkg/net/http/cgi/testdata/test.cgi
+++ b/src/pkg/net/http/cgi/testdata/test.cgi
@@ -8,6 +8,8 @@
use strict;
use Cwd;
+binmode STDOUT;
+
my $q = MiniCGI->new;
my $params = $q->Vars;
@@ -16,51 +18,44 @@ if ($params->{"loc"}) {
exit(0);
}
-my $NL = "\r\n";
-$NL = "\n" if $params->{mode} eq "NL";
-
-my $p = sub {
- print "$_[0]$NL";
-};
-
-# With carriage returns
-$p->("Content-Type: text/html");
-$p->("X-CGI-Pid: $$");
-$p->("X-Test-Header: X-Test-Value");
-$p->("");
+print "Content-Type: text/html\r\n";
+print "X-CGI-Pid: $$\r\n";
+print "X-Test-Header: X-Test-Value\r\n";
+print "\r\n";
if ($params->{"bigresponse"}) {
- for (1..1024) {
- print "A" x 1024, "\n";
+ # 17 MB, for OS X: golang.org/issue/4958
+ for (1..(17 * 1024)) {
+ print "A" x 1024, "\r\n";
}
exit 0;
}
-print "test=Hello CGI\n";
+print "test=Hello CGI\r\n";
foreach my $k (sort keys %$params) {
- print "param-$k=$params->{$k}\n";
+ print "param-$k=$params->{$k}\r\n";
}
foreach my $k (sort keys %ENV) {
- my $clean_env = $ENV{$k};
- $clean_env =~ s/[\n\r]//g;
- print "env-$k=$clean_env\n";
+ my $clean_env = $ENV{$k};
+ $clean_env =~ s/[\n\r]//g;
+ print "env-$k=$clean_env\r\n";
}
-# NOTE: don't call getcwd() for windows.
-# msys return /c/go/src/... not C:\go\...
-my $dir;
+# NOTE: msys perl returns /c/go/src/... not C:\go\....
+my $dir = getcwd();
if ($^O eq 'MSWin32' || $^O eq 'msys') {
- my $cmd = $ENV{'COMSPEC'} || 'c:\\windows\\system32\\cmd.exe';
- $cmd =~ s!\\!/!g;
- $dir = `$cmd /c cd`;
- chomp $dir;
-} else {
- $dir = getcwd();
+ if ($dir =~ /^.:/) {
+ $dir =~ s!/!\\!g;
+ } else {
+ my $cmd = $ENV{'COMSPEC'} || 'c:\\windows\\system32\\cmd.exe';
+ $cmd =~ s!\\!/!g;
+ $dir = `$cmd /c cd`;
+ chomp $dir;
+ }
}
-print "cwd=$dir\n";
-
+print "cwd=$dir\r\n";
# A minimal version of CGI.pm, for people without the perl-modules
# package installed. (CGI.pm used to be part of the Perl core, but
diff --git a/src/pkg/net/http/chunked.go b/src/pkg/net/http/chunked.go
index 60a478fd8..91db01724 100644
--- a/src/pkg/net/http/chunked.go
+++ b/src/pkg/net/http/chunked.go
@@ -11,10 +11,9 @@ package http
import (
"bufio"
- "bytes"
"errors"
+ "fmt"
"io"
- "strconv"
)
const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
@@ -22,7 +21,7 @@ const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
var ErrLineTooLong = errors.New("header line too long")
// newChunkedReader returns a new chunkedReader that translates the data read from r
-// out of HTTP "chunked" format before returning it.
+// out of HTTP "chunked" format before returning it.
// The chunkedReader returns io.EOF when the final 0-length chunk is read.
//
// newChunkedReader is not needed by normal applications. The http package
@@ -39,16 +38,17 @@ type chunkedReader struct {
r *bufio.Reader
n uint64 // unread bytes in chunk
err error
+ buf [2]byte
}
func (cr *chunkedReader) beginChunk() {
// chunk-size CRLF
- var line string
+ var line []byte
line, cr.err = readLine(cr.r)
if cr.err != nil {
return
}
- cr.n, cr.err = strconv.ParseUint(line, 16, 64)
+ cr.n, cr.err = parseHexUint(line)
if cr.err != nil {
return
}
@@ -74,9 +74,8 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
cr.n -= uint64(n)
if cr.n == 0 && cr.err == nil {
// end of chunk (CRLF)
- b := make([]byte, 2)
- if _, cr.err = io.ReadFull(cr.r, b); cr.err == nil {
- if b[0] != '\r' || b[1] != '\n' {
+ if _, cr.err = io.ReadFull(cr.r, cr.buf[:]); cr.err == nil {
+ if cr.buf[0] != '\r' || cr.buf[1] != '\n' {
cr.err = errors.New("malformed chunked encoding")
}
}
@@ -88,7 +87,7 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
// Give up if the line exceeds maxLineLength.
// The returned bytes are a pointer into storage in
// the bufio, so they are only valid until the next bufio read.
-func readLineBytes(b *bufio.Reader) (p []byte, err error) {
+func readLine(b *bufio.Reader) (p []byte, err error) {
if p, err = b.ReadSlice('\n'); err != nil {
// We always know when EOF is coming.
// If the caller asked for a line, there should be a line.
@@ -102,20 +101,18 @@ func readLineBytes(b *bufio.Reader) (p []byte, err error) {
if len(p) >= maxLineLength {
return nil, ErrLineTooLong
}
-
- // Chop off trailing white space.
- p = bytes.TrimRight(p, " \r\t\n")
-
- return p, nil
+ return trimTrailingWhitespace(p), nil
}
-// readLineBytes, but convert the bytes into a string.
-func readLine(b *bufio.Reader) (s string, err error) {
- p, e := readLineBytes(b)
- if e != nil {
- return "", e
+func trimTrailingWhitespace(b []byte) []byte {
+ for len(b) > 0 && isASCIISpace(b[len(b)-1]) {
+ b = b[:len(b)-1]
}
- return string(p), nil
+ return b
+}
+
+func isASCIISpace(b byte) bool {
+ return b == ' ' || b == '\t' || b == '\n' || b == '\r'
}
// newChunkedWriter returns a new chunkedWriter that translates writes into HTTP
@@ -147,9 +144,7 @@ func (cw *chunkedWriter) Write(data []byte) (n int, err error) {
return 0, nil
}
- head := strconv.FormatInt(int64(len(data)), 16) + "\r\n"
-
- if _, err = io.WriteString(cw.Wire, head); err != nil {
+ if _, err = fmt.Fprintf(cw.Wire, "%x\r\n", len(data)); err != nil {
return 0, err
}
if n, err = cw.Wire.Write(data); err != nil {
@@ -168,3 +163,21 @@ func (cw *chunkedWriter) Close() error {
_, err := io.WriteString(cw.Wire, "0\r\n")
return err
}
+
+func parseHexUint(v []byte) (n uint64, err error) {
+ for _, b := range v {
+ n <<= 4
+ switch {
+ case '0' <= b && b <= '9':
+ b = b - '0'
+ case 'a' <= b && b <= 'f':
+ b = b - 'a' + 10
+ case 'A' <= b && b <= 'F':
+ b = b - 'A' + 10
+ default:
+ return 0, errors.New("invalid byte in chunk length")
+ }
+ n |= uint64(b)
+ }
+ return
+}
diff --git a/src/pkg/net/http/chunked_test.go b/src/pkg/net/http/chunked_test.go
index b77ee2ff2..0b18c7b55 100644
--- a/src/pkg/net/http/chunked_test.go
+++ b/src/pkg/net/http/chunked_test.go
@@ -9,7 +9,10 @@ package http
import (
"bytes"
+ "fmt"
+ "io"
"io/ioutil"
+ "runtime"
"testing"
)
@@ -37,3 +40,54 @@ func TestChunk(t *testing.T) {
t.Errorf("chunk reader read %q; want %q", g, e)
}
}
+
+func TestChunkReaderAllocs(t *testing.T) {
+ // temporarily set GOMAXPROCS to 1 as we are testing memory allocations
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1))
+ var buf bytes.Buffer
+ w := newChunkedWriter(&buf)
+ a, b, c := []byte("aaaaaa"), []byte("bbbbbbbbbbbb"), []byte("cccccccccccccccccccccccc")
+ w.Write(a)
+ w.Write(b)
+ w.Write(c)
+ w.Close()
+
+ r := newChunkedReader(&buf)
+ readBuf := make([]byte, len(a)+len(b)+len(c)+1)
+
+ var ms runtime.MemStats
+ runtime.ReadMemStats(&ms)
+ m0 := ms.Mallocs
+
+ n, err := io.ReadFull(r, readBuf)
+
+ runtime.ReadMemStats(&ms)
+ mallocs := ms.Mallocs - m0
+ if mallocs > 1 {
+ t.Errorf("%d mallocs; want <= 1", mallocs)
+ }
+
+ if n != len(readBuf)-1 {
+ t.Errorf("read %d bytes; want %d", n, len(readBuf)-1)
+ }
+ if err != io.ErrUnexpectedEOF {
+ t.Errorf("read error = %v; want ErrUnexpectedEOF", err)
+ }
+}
+
+func TestParseHexUint(t *testing.T) {
+ for i := uint64(0); i <= 1234; i++ {
+ line := []byte(fmt.Sprintf("%x", i))
+ got, err := parseHexUint(line)
+ if err != nil {
+ t.Fatalf("on %d: %v", i, err)
+ }
+ if got != i {
+ t.Errorf("for input %q = %d; want %d", line, got, i)
+ }
+ }
+ _, err := parseHexUint([]byte("bogus"))
+ if err == nil {
+ t.Error("expected error on bogus input")
+ }
+}
diff --git a/src/pkg/net/http/client.go b/src/pkg/net/http/client.go
index 54564e098..5ee0804c7 100644
--- a/src/pkg/net/http/client.go
+++ b/src/pkg/net/http/client.go
@@ -3,7 +3,7 @@
// license that can be found in the LICENSE file.
// HTTP client. See RFC 2616.
-//
+//
// This is the high-level Client interface.
// The low-level implementation is in transport.go.
@@ -14,6 +14,7 @@ import (
"errors"
"fmt"
"io"
+ "log"
"net/url"
"strings"
)
@@ -32,17 +33,19 @@ type Client struct {
// CheckRedirect specifies the policy for handling redirects.
// If CheckRedirect is not nil, the client calls it before
- // following an HTTP redirect. The arguments req and via
- // are the upcoming request and the requests made already,
- // oldest first. If CheckRedirect returns an error, the client
- // returns that error instead of issue the Request req.
+ // following an HTTP redirect. The arguments req and via are
+ // the upcoming request and the requests made already, oldest
+ // first. If CheckRedirect returns an error, the Client's Get
+ // method returns both the previous Response and
+ // CheckRedirect's error (wrapped in a url.Error) instead of
+ // issuing the Request req.
//
// If CheckRedirect is nil, the Client uses its default policy,
// which is to stop after 10 consecutive requests.
CheckRedirect func(req *Request, via []*Request) error
- // Jar specifies the cookie jar.
- // If Jar is nil, cookies are not sent in requests and ignored
+ // Jar specifies the cookie jar.
+ // If Jar is nil, cookies are not sent in requests and ignored
// in responses.
Jar CookieJar
}
@@ -84,10 +87,32 @@ type readClose struct {
io.Closer
}
+func (c *Client) send(req *Request) (*Response, error) {
+ if c.Jar != nil {
+ for _, cookie := range c.Jar.Cookies(req.URL) {
+ req.AddCookie(cookie)
+ }
+ }
+ resp, err := send(req, c.Transport)
+ if err != nil {
+ return nil, err
+ }
+ if c.Jar != nil {
+ if rc := resp.Cookies(); len(rc) > 0 {
+ c.Jar.SetCookies(req.URL, rc)
+ }
+ }
+ return resp, err
+}
+
// Do sends an HTTP request and returns an HTTP response, following
// policy (e.g. redirects, cookies, auth) as configured on the client.
//
-// A non-nil response always contains a non-nil resp.Body.
+// An error is returned if caused by client policy (such as
+// CheckRedirect), or if there was an HTTP protocol error.
+// A non-2xx response doesn't cause an error.
+//
+// When err is nil, resp always contains a non-nil resp.Body.
//
// Callers should close resp.Body when done reading from it. If
// resp.Body is not closed, the Client's underlying RoundTripper
@@ -97,12 +122,16 @@ type readClose struct {
// Generally Get, Post, or PostForm will be used instead of Do.
func (c *Client) Do(req *Request) (resp *Response, err error) {
if req.Method == "GET" || req.Method == "HEAD" {
- return c.doFollowingRedirects(req)
+ return c.doFollowingRedirects(req, shouldRedirectGet)
}
- return send(req, c.Transport)
+ if req.Method == "POST" || req.Method == "PUT" {
+ return c.doFollowingRedirects(req, shouldRedirectPost)
+ }
+ return c.send(req)
}
-// send issues an HTTP request. Caller should close resp.Body when done reading from it.
+// send issues an HTTP request.
+// Caller should close resp.Body when done reading from it.
func send(req *Request, t RoundTripper) (resp *Response, err error) {
if t == nil {
t = DefaultTransport
@@ -130,12 +159,19 @@ func send(req *Request, t RoundTripper) (resp *Response, err error) {
if u := req.URL.User; u != nil {
req.Header.Set("Authorization", "Basic "+base64.URLEncoding.EncodeToString([]byte(u.String())))
}
- return t.RoundTrip(req)
+ resp, err = t.RoundTrip(req)
+ if err != nil {
+ if resp != nil {
+ log.Printf("RoundTripper returned a response & error; ignoring response")
+ }
+ return nil, err
+ }
+ return resp, nil
}
// True if the specified HTTP status code is one for which the Get utility should
// automatically redirect.
-func shouldRedirect(statusCode int) bool {
+func shouldRedirectGet(statusCode int) bool {
switch statusCode {
case StatusMovedPermanently, StatusFound, StatusSeeOther, StatusTemporaryRedirect:
return true
@@ -143,6 +179,16 @@ func shouldRedirect(statusCode int) bool {
return false
}
+// True if the specified HTTP status code is one for which the Post utility should
+// automatically redirect.
+func shouldRedirectPost(statusCode int) bool {
+ switch statusCode {
+ case StatusFound, StatusSeeOther:
+ return true
+ }
+ return false
+}
+
// Get issues a GET to the specified URL. If the response is one of the following
// redirect codes, Get follows the redirect, up to a maximum of 10 redirects:
//
@@ -151,10 +197,15 @@ func shouldRedirect(statusCode int) bool {
// 303 (See Other)
// 307 (Temporary Redirect)
//
-// Caller should close r.Body when done reading from it.
+// An error is returned if there were too many redirects or if there
+// was an HTTP protocol error. A non-2xx response doesn't cause an
+// error.
+//
+// When err is nil, resp always contains a non-nil resp.Body.
+// Caller should close resp.Body when done reading from it.
//
// Get is a wrapper around DefaultClient.Get.
-func Get(url string) (r *Response, err error) {
+func Get(url string) (resp *Response, err error) {
return DefaultClient.Get(url)
}
@@ -167,18 +218,21 @@ func Get(url string) (r *Response, err error) {
// 303 (See Other)
// 307 (Temporary Redirect)
//
-// Caller should close r.Body when done reading from it.
-func (c *Client) Get(url string) (r *Response, err error) {
+// An error is returned if the Client's CheckRedirect function fails
+// or if there was an HTTP protocol error. A non-2xx response doesn't
+// cause an error.
+//
+// When err is nil, resp always contains a non-nil resp.Body.
+// Caller should close resp.Body when done reading from it.
+func (c *Client) Get(url string) (resp *Response, err error) {
req, err := NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
- return c.doFollowingRedirects(req)
+ return c.doFollowingRedirects(req, shouldRedirectGet)
}
-func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err error) {
- // TODO: if/when we add cookie support, the redirected request shouldn't
- // necessarily supply the same cookies as the original.
+func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bool) (resp *Response, err error) {
var base *url.URL
redirectChecker := c.CheckRedirect
if redirectChecker == nil {
@@ -190,17 +244,16 @@ func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err error) {
return nil, errors.New("http: nil Request.URL")
}
- jar := c.Jar
- if jar == nil {
- jar = blackHoleJar{}
- }
-
req := ireq
urlStr := "" // next relative or absolute URL to fetch (after first request)
+ redirectFailed := false
for redirect := 0; ; redirect++ {
if redirect != 0 {
req = new(Request)
req.Method = ireq.Method
+ if ireq.Method == "POST" || ireq.Method == "PUT" {
+ req.Method = "GET"
+ }
req.Header = make(Header)
req.URL, err = base.Parse(urlStr)
if err != nil {
@@ -215,26 +268,21 @@ func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err error) {
err = redirectChecker(req, via)
if err != nil {
+ redirectFailed = true
break
}
}
}
- for _, cookie := range jar.Cookies(req.URL) {
- req.AddCookie(cookie)
- }
urlStr = req.URL.String()
- if r, err = send(req, c.Transport); err != nil {
+ if resp, err = c.send(req); err != nil {
break
}
- if c := r.Cookies(); len(c) > 0 {
- jar.SetCookies(req.URL, c)
- }
- if shouldRedirect(r.StatusCode) {
- r.Body.Close()
- if urlStr = r.Header.Get("Location"); urlStr == "" {
- err = errors.New(fmt.Sprintf("%d response missing Location header", r.StatusCode))
+ if shouldRedirect(resp.StatusCode) {
+ resp.Body.Close()
+ if urlStr = resp.Header.Get("Location"); urlStr == "" {
+ err = errors.New(fmt.Sprintf("%d response missing Location header", resp.StatusCode))
break
}
base = req.URL
@@ -245,12 +293,23 @@ func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err error) {
}
method := ireq.Method
- err = &url.Error{
+ urlErr := &url.Error{
Op: method[0:1] + strings.ToLower(method[1:]),
URL: urlStr,
Err: err,
}
- return
+
+ if redirectFailed {
+ // Special case for Go 1 compatibility: return both the response
+ // and an error if the CheckRedirect function failed.
+ // See http://golang.org/issue/3795
+ return resp, urlErr
+ }
+
+ if resp != nil {
+ resp.Body.Close()
+ }
+ return nil, urlErr
}
func defaultCheckRedirect(req *Request, via []*Request) error {
@@ -262,49 +321,42 @@ func defaultCheckRedirect(req *Request, via []*Request) error {
// Post issues a POST to the specified URL.
//
-// Caller should close r.Body when done reading from it.
+// Caller should close resp.Body when done reading from it.
//
// Post is a wrapper around DefaultClient.Post
-func Post(url string, bodyType string, body io.Reader) (r *Response, err error) {
+func Post(url string, bodyType string, body io.Reader) (resp *Response, err error) {
return DefaultClient.Post(url, bodyType, body)
}
// Post issues a POST to the specified URL.
//
-// Caller should close r.Body when done reading from it.
-func (c *Client) Post(url string, bodyType string, body io.Reader) (r *Response, err error) {
+// Caller should close resp.Body when done reading from it.
+func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Response, err error) {
req, err := NewRequest("POST", url, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", bodyType)
- if c.Jar != nil {
- for _, cookie := range c.Jar.Cookies(req.URL) {
- req.AddCookie(cookie)
- }
- }
- r, err = send(req, c.Transport)
- if err == nil && c.Jar != nil {
- c.Jar.SetCookies(req.URL, r.Cookies())
- }
- return r, err
+ return c.doFollowingRedirects(req, shouldRedirectPost)
}
-// PostForm issues a POST to the specified URL,
-// with data's keys and values urlencoded as the request body.
+// PostForm issues a POST to the specified URL, with data's keys and
+// values URL-encoded as the request body.
//
-// Caller should close r.Body when done reading from it.
+// When err is nil, resp always contains a non-nil resp.Body.
+// Caller should close resp.Body when done reading from it.
//
// PostForm is a wrapper around DefaultClient.PostForm
-func PostForm(url string, data url.Values) (r *Response, err error) {
+func PostForm(url string, data url.Values) (resp *Response, err error) {
return DefaultClient.PostForm(url, data)
}
-// PostForm issues a POST to the specified URL,
+// PostForm issues a POST to the specified URL,
// with data's keys and values urlencoded as the request body.
//
-// Caller should close r.Body when done reading from it.
-func (c *Client) PostForm(url string, data url.Values) (r *Response, err error) {
+// When err is nil, resp always contains a non-nil resp.Body.
+// Caller should close resp.Body when done reading from it.
+func (c *Client) PostForm(url string, data url.Values) (resp *Response, err error) {
return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
}
@@ -318,7 +370,7 @@ func (c *Client) PostForm(url string, data url.Values) (r *Response, err error)
// 307 (Temporary Redirect)
//
// Head is a wrapper around DefaultClient.Head
-func Head(url string) (r *Response, err error) {
+func Head(url string) (resp *Response, err error) {
return DefaultClient.Head(url)
}
@@ -330,10 +382,10 @@ func Head(url string) (r *Response, err error) {
// 302 (Found)
// 303 (See Other)
// 307 (Temporary Redirect)
-func (c *Client) Head(url string) (r *Response, err error) {
+func (c *Client) Head(url string) (resp *Response, err error) {
req, err := NewRequest("HEAD", url, nil)
if err != nil {
return nil, err
}
- return c.doFollowingRedirects(req)
+ return c.doFollowingRedirects(req, shouldRedirectGet)
}
diff --git a/src/pkg/net/http/client_test.go b/src/pkg/net/http/client_test.go
index 9b4261b9f..88649bb16 100644
--- a/src/pkg/net/http/client_test.go
+++ b/src/pkg/net/http/client_test.go
@@ -7,7 +7,9 @@
package http_test
import (
+ "bytes"
"crypto/tls"
+ "crypto/x509"
"errors"
"fmt"
"io"
@@ -53,6 +55,7 @@ func pedanticReadAll(r io.Reader) (b []byte, err error) {
}
func TestClient(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(robotsTxtHandler)
defer ts.Close()
@@ -70,6 +73,7 @@ func TestClient(t *testing.T) {
}
func TestClientHead(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(robotsTxtHandler)
defer ts.Close()
@@ -92,6 +96,7 @@ func (t *recordingTransport) RoundTrip(req *Request) (resp *Response, err error)
}
func TestGetRequestFormat(t *testing.T) {
+ defer checkLeakedTransports(t)
tr := &recordingTransport{}
client := &Client{Transport: tr}
url := "http://dummy.faketld/"
@@ -108,6 +113,7 @@ func TestGetRequestFormat(t *testing.T) {
}
func TestPostRequestFormat(t *testing.T) {
+ defer checkLeakedTransports(t)
tr := &recordingTransport{}
client := &Client{Transport: tr}
@@ -134,6 +140,7 @@ func TestPostRequestFormat(t *testing.T) {
}
func TestPostFormRequestFormat(t *testing.T) {
+ defer checkLeakedTransports(t)
tr := &recordingTransport{}
client := &Client{Transport: tr}
@@ -175,6 +182,7 @@ func TestPostFormRequestFormat(t *testing.T) {
}
func TestRedirects(t *testing.T) {
+ defer checkLeakedTransports(t)
var ts *httptest.Server
ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
n, _ := strconv.Atoi(r.FormValue("n"))
@@ -218,6 +226,10 @@ func TestRedirects(t *testing.T) {
return checkErr
}}
res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatalf("Get error: %v", err)
+ }
+ res.Body.Close()
finalUrl := res.Request.URL.String()
if e, g := "<nil>", fmt.Sprintf("%v", err); e != g {
t.Errorf("with custom client, expected error %q, got %q", e, g)
@@ -231,9 +243,63 @@ func TestRedirects(t *testing.T) {
checkErr = errors.New("no redirects allowed")
res, err = c.Get(ts.URL)
- finalUrl = res.Request.URL.String()
- if e, g := "Get /?n=1: no redirects allowed", fmt.Sprintf("%v", err); e != g {
- t.Errorf("with redirects forbidden, expected error %q, got %q", e, g)
+ if urlError, ok := err.(*url.Error); !ok || urlError.Err != checkErr {
+ t.Errorf("with redirects forbidden, expected a *url.Error with our 'no redirects allowed' error inside; got %#v (%q)", err, err)
+ }
+ if res == nil {
+ t.Fatalf("Expected a non-nil Response on CheckRedirect failure (http://golang.org/issue/3795)")
+ }
+ res.Body.Close()
+ if res.Header.Get("Location") == "" {
+ t.Errorf("no Location header in Response")
+ }
+}
+
+func TestPostRedirects(t *testing.T) {
+ defer checkLeakedTransports(t)
+ var log struct {
+ sync.Mutex
+ bytes.Buffer
+ }
+ var ts *httptest.Server
+ ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ log.Lock()
+ fmt.Fprintf(&log.Buffer, "%s %s ", r.Method, r.RequestURI)
+ log.Unlock()
+ if v := r.URL.Query().Get("code"); v != "" {
+ code, _ := strconv.Atoi(v)
+ if code/100 == 3 {
+ w.Header().Set("Location", ts.URL)
+ }
+ w.WriteHeader(code)
+ }
+ }))
+ defer ts.Close()
+ tests := []struct {
+ suffix string
+ want int // response code
+ }{
+ {"/", 200},
+ {"/?code=301", 301},
+ {"/?code=302", 200},
+ {"/?code=303", 200},
+ {"/?code=404", 404},
+ }
+ for _, tt := range tests {
+ res, err := Post(ts.URL+tt.suffix, "text/plain", strings.NewReader("Some content"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != tt.want {
+ t.Errorf("POST %s: status code = %d; want %d", tt.suffix, res.StatusCode, tt.want)
+ }
+ }
+ log.Lock()
+ got := log.String()
+ log.Unlock()
+ want := "POST / POST /?code=301 POST /?code=302 GET / POST /?code=303 GET / POST /?code=404 "
+ if got != want {
+ t.Errorf("Log differs.\n Got: %q\nWant: %q", got, want)
}
}
@@ -279,6 +345,10 @@ func TestClientSendsCookieFromJar(t *testing.T) {
req, _ := NewRequest("GET", us, nil)
client.Do(req) // Note: doesn't hit network
matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
+
+ req, _ = NewRequest("POST", us, nil)
+ client.Do(req) // Note: doesn't hit network
+ matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
}
// Just enough correctness for our redirect tests. Uses the URL.Host as the
@@ -291,6 +361,9 @@ type TestJar struct {
func (j *TestJar) SetCookies(u *url.URL, cookies []*Cookie) {
j.m.Lock()
defer j.m.Unlock()
+ if j.perURL == nil {
+ j.perURL = make(map[string][]*Cookie)
+ }
j.perURL[u.Host] = cookies
}
@@ -301,6 +374,7 @@ func (j *TestJar) Cookies(u *url.URL) []*Cookie {
}
func TestRedirectCookiesOnRequest(t *testing.T) {
+ defer checkLeakedTransports(t)
var ts *httptest.Server
ts = httptest.NewServer(echoCookiesRedirectHandler)
defer ts.Close()
@@ -318,14 +392,20 @@ func TestRedirectCookiesOnRequest(t *testing.T) {
}
func TestRedirectCookiesJar(t *testing.T) {
+ defer checkLeakedTransports(t)
var ts *httptest.Server
ts = httptest.NewServer(echoCookiesRedirectHandler)
defer ts.Close()
- c := &Client{}
- c.Jar = &TestJar{perURL: make(map[string][]*Cookie)}
+ c := &Client{
+ Jar: new(TestJar),
+ }
u, _ := url.Parse(ts.URL)
c.Jar.SetCookies(u, []*Cookie{expectedCookies[0]})
- resp, _ := c.Get(ts.URL)
+ resp, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ resp.Body.Close()
matchReturnedCookies(t, expectedCookies, resp.Cookies())
}
@@ -348,7 +428,72 @@ func matchReturnedCookies(t *testing.T, expected, given []*Cookie) {
}
}
+func TestJarCalls(t *testing.T) {
+ defer checkLeakedTransports(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ pathSuffix := r.RequestURI[1:]
+ if r.RequestURI == "/nosetcookie" {
+ return // dont set cookies for this path
+ }
+ SetCookie(w, &Cookie{Name: "name" + pathSuffix, Value: "val" + pathSuffix})
+ if r.RequestURI == "/" {
+ Redirect(w, r, "http://secondhost.fake/secondpath", 302)
+ }
+ }))
+ defer ts.Close()
+ jar := new(RecordingJar)
+ c := &Client{
+ Jar: jar,
+ Transport: &Transport{
+ Dial: func(_ string, _ string) (net.Conn, error) {
+ return net.Dial("tcp", ts.Listener.Addr().String())
+ },
+ },
+ }
+ _, err := c.Get("http://firsthost.fake/")
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = c.Get("http://firsthost.fake/nosetcookie")
+ if err != nil {
+ t.Fatal(err)
+ }
+ got := jar.log.String()
+ want := `Cookies("http://firsthost.fake/")
+SetCookie("http://firsthost.fake/", [name=val])
+Cookies("http://secondhost.fake/secondpath")
+SetCookie("http://secondhost.fake/secondpath", [namesecondpath=valsecondpath])
+Cookies("http://firsthost.fake/nosetcookie")
+`
+ if got != want {
+ t.Errorf("Got Jar calls:\n%s\nWant:\n%s", got, want)
+ }
+}
+
+// RecordingJar keeps a log of calls made to it, without
+// tracking any cookies.
+type RecordingJar struct {
+ mu sync.Mutex
+ log bytes.Buffer
+}
+
+func (j *RecordingJar) SetCookies(u *url.URL, cookies []*Cookie) {
+ j.logf("SetCookie(%q, %v)\n", u, cookies)
+}
+
+func (j *RecordingJar) Cookies(u *url.URL) []*Cookie {
+ j.logf("Cookies(%q)\n", u)
+ return nil
+}
+
+func (j *RecordingJar) logf(format string, args ...interface{}) {
+ j.mu.Lock()
+ defer j.mu.Unlock()
+ fmt.Fprintf(&j.log, format, args...)
+}
+
func TestStreamingGet(t *testing.T) {
+ defer checkLeakedTransports(t)
say := make(chan string)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.(Flusher).Flush()
@@ -399,6 +544,7 @@ func (c *writeCountingConn) Write(p []byte) (int, error) {
// TestClientWrites verifies that client requests are buffered and we
// don't send a TCP packet per line of the http request + body.
func TestClientWrites(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
}))
defer ts.Close()
@@ -432,6 +578,7 @@ func TestClientWrites(t *testing.T) {
}
func TestClientInsecureTransport(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Write([]byte("Hello"))
}))
@@ -446,15 +593,20 @@ func TestClientInsecureTransport(t *testing.T) {
InsecureSkipVerify: insecure,
},
}
+ defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
- _, err := c.Get(ts.URL)
+ res, err := c.Get(ts.URL)
if (err == nil) != insecure {
t.Errorf("insecure=%v: got unexpected err=%v", insecure, err)
}
+ if res != nil {
+ res.Body.Close()
+ }
}
}
func TestClientErrorWithRequestURI(t *testing.T) {
+ defer checkLeakedTransports(t)
req, _ := NewRequest("GET", "http://localhost:1234/", nil)
req.RequestURI = "/this/field/is/illegal/and/should/error/"
_, err := DefaultClient.Do(req)
@@ -465,3 +617,87 @@ func TestClientErrorWithRequestURI(t *testing.T) {
t.Errorf("wanted error mentioning RequestURI; got error: %v", err)
}
}
+
+func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport {
+ certs := x509.NewCertPool()
+ for _, c := range ts.TLS.Certificates {
+ roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
+ if err != nil {
+ t.Fatalf("error parsing server's root cert: %v", err)
+ }
+ for _, root := range roots {
+ certs.AddCert(root)
+ }
+ }
+ return &Transport{
+ TLSClientConfig: &tls.Config{RootCAs: certs},
+ }
+}
+
+func TestClientWithCorrectTLSServerName(t *testing.T) {
+ defer checkLeakedTransports(t)
+ ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.TLS.ServerName != "127.0.0.1" {
+ t.Errorf("expected client to set ServerName 127.0.0.1, got: %q", r.TLS.ServerName)
+ }
+ }))
+ defer ts.Close()
+
+ c := &Client{Transport: newTLSTransport(t, ts)}
+ if _, err := c.Get(ts.URL); err != nil {
+ t.Fatalf("expected successful TLS connection, got error: %v", err)
+ }
+}
+
+func TestClientWithIncorrectTLSServerName(t *testing.T) {
+ defer checkLeakedTransports(t)
+ ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
+ defer ts.Close()
+
+ trans := newTLSTransport(t, ts)
+ trans.TLSClientConfig.ServerName = "badserver"
+ c := &Client{Transport: trans}
+ _, err := c.Get(ts.URL)
+ if err == nil {
+ t.Fatalf("expected an error")
+ }
+ if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") {
+ t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err)
+ }
+}
+
+// Verify Response.ContentLength is populated. http://golang.org/issue/4126
+func TestClientHeadContentLength(t *testing.T) {
+ defer checkLeakedTransports(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if v := r.FormValue("cl"); v != "" {
+ w.Header().Set("Content-Length", v)
+ }
+ }))
+ defer ts.Close()
+ tests := []struct {
+ suffix string
+ want int64
+ }{
+ {"/?cl=1234", 1234},
+ {"/?cl=0", 0},
+ {"", -1},
+ }
+ for _, tt := range tests {
+ req, _ := NewRequest("HEAD", ts.URL+tt.suffix, nil)
+ res, err := DefaultClient.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.ContentLength != tt.want {
+ t.Errorf("Content-Length = %d; want %d", res.ContentLength, tt.want)
+ }
+ bs, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(bs) != 0 {
+ t.Errorf("Unexpected content: %q", bs)
+ }
+ }
+}
diff --git a/src/pkg/net/http/cookie.go b/src/pkg/net/http/cookie.go
index 2e30bbff1..155b09223 100644
--- a/src/pkg/net/http/cookie.go
+++ b/src/pkg/net/http/cookie.go
@@ -26,7 +26,7 @@ type Cookie struct {
Expires time.Time
RawExpires string
- // MaxAge=0 means no 'Max-Age' attribute specified.
+ // MaxAge=0 means no 'Max-Age' attribute specified.
// MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0'
// MaxAge>0 means Max-Age attribute present and given in seconds
MaxAge int
@@ -258,10 +258,5 @@ func parseCookieValueUsing(raw string, validByte func(byte) bool) (string, bool)
}
func isCookieNameValid(raw string) bool {
- for _, c := range raw {
- if !isToken(byte(c)) {
- return false
- }
- }
- return true
+ return strings.IndexFunc(raw, isNotToken) < 0
}
diff --git a/src/pkg/net/http/cookie_test.go b/src/pkg/net/http/cookie_test.go
index 1e9186a05..f84f73936 100644
--- a/src/pkg/net/http/cookie_test.go
+++ b/src/pkg/net/http/cookie_test.go
@@ -217,7 +217,7 @@ var readCookiesTests = []struct {
func TestReadCookies(t *testing.T) {
for i, tt := range readCookiesTests {
- for n := 0; n < 2; n++ { // to verify readCookies doesn't mutate its input
+ for n := 0; n < 2; n++ { // to verify readCookies doesn't mutate its input
c := readCookies(tt.Header, tt.Filter)
if !reflect.DeepEqual(c, tt.Cookies) {
t.Errorf("#%d readCookies:\nhave: %s\nwant: %s\n", i, toJSON(c), toJSON(tt.Cookies))
diff --git a/src/pkg/net/http/cookiejar/jar.go b/src/pkg/net/http/cookiejar/jar.go
new file mode 100644
index 000000000..5d1aeb87f
--- /dev/null
+++ b/src/pkg/net/http/cookiejar/jar.go
@@ -0,0 +1,494 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package cookiejar implements an in-memory RFC 6265-compliant http.CookieJar.
+package cookiejar
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "net/http"
+ "net/url"
+ "sort"
+ "strings"
+ "sync"
+ "time"
+)
+
+// PublicSuffixList provides the public suffix of a domain. For example:
+// - the public suffix of "example.com" is "com",
+// - the public suffix of "foo1.foo2.foo3.co.uk" is "co.uk", and
+// - the public suffix of "bar.pvt.k12.ma.us" is "pvt.k12.ma.us".
+//
+// Implementations of PublicSuffixList must be safe for concurrent use by
+// multiple goroutines.
+//
+// An implementation that always returns "" is valid and may be useful for
+// testing but it is not secure: it means that the HTTP server for foo.com can
+// set a cookie for bar.com.
+type PublicSuffixList interface {
+ // PublicSuffix returns the public suffix of domain.
+ //
+ // TODO: specify which of the caller and callee is responsible for IP
+ // addresses, for leading and trailing dots, for case sensitivity, and
+ // for IDN/Punycode.
+ PublicSuffix(domain string) string
+
+ // String returns a description of the source of this public suffix
+ // list. The description will typically contain something like a time
+ // stamp or version number.
+ String() string
+}
+
+// Options are the options for creating a new Jar.
+type Options struct {
+ // PublicSuffixList is the public suffix list that determines whether
+ // an HTTP server can set a cookie for a domain.
+ //
+ // A nil value is valid and may be useful for testing but it is not
+ // secure: it means that the HTTP server for foo.co.uk can set a cookie
+ // for bar.co.uk.
+ PublicSuffixList PublicSuffixList
+}
+
+// Jar implements the http.CookieJar interface from the net/http package.
+type Jar struct {
+ psList PublicSuffixList
+
+ // mu locks the remaining fields.
+ mu sync.Mutex
+
+ // entries is a set of entries, keyed by their eTLD+1 and subkeyed by
+ // their name/domain/path.
+ entries map[string]map[string]entry
+
+ // nextSeqNum is the next sequence number assigned to a new cookie
+ // created SetCookies.
+ nextSeqNum uint64
+}
+
+// New returns a new cookie jar. A nil *Options is equivalent to a zero
+// Options.
+func New(o *Options) (*Jar, error) {
+ jar := &Jar{
+ entries: make(map[string]map[string]entry),
+ }
+ if o != nil {
+ jar.psList = o.PublicSuffixList
+ }
+ return jar, nil
+}
+
+// entry is the internal representation of a cookie.
+//
+// This struct type is not used outside of this package per se, but the exported
+// fields are those of RFC 6265.
+type entry struct {
+ Name string
+ Value string
+ Domain string
+ Path string
+ Secure bool
+ HttpOnly bool
+ Persistent bool
+ HostOnly bool
+ Expires time.Time
+ Creation time.Time
+ LastAccess time.Time
+
+ // seqNum is a sequence number so that Cookies returns cookies in a
+ // deterministic order, even for cookies that have equal Path length and
+ // equal Creation time. This simplifies testing.
+ seqNum uint64
+}
+
+// Id returns the domain;path;name triple of e as an id.
+func (e *entry) id() string {
+ return fmt.Sprintf("%s;%s;%s", e.Domain, e.Path, e.Name)
+}
+
+// shouldSend determines whether e's cookie qualifies to be included in a
+// request to host/path. It is the caller's responsibility to check if the
+// cookie is expired.
+func (e *entry) shouldSend(https bool, host, path string) bool {
+ return e.domainMatch(host) && e.pathMatch(path) && (https || !e.Secure)
+}
+
+// domainMatch implements "domain-match" of RFC 6265 section 5.1.3.
+func (e *entry) domainMatch(host string) bool {
+ if e.Domain == host {
+ return true
+ }
+ return !e.HostOnly && hasDotSuffix(host, e.Domain)
+}
+
+// pathMatch implements "path-match" according to RFC 6265 section 5.1.4.
+func (e *entry) pathMatch(requestPath string) bool {
+ if requestPath == e.Path {
+ return true
+ }
+ if strings.HasPrefix(requestPath, e.Path) {
+ if e.Path[len(e.Path)-1] == '/' {
+ return true // The "/any/" matches "/any/path" case.
+ } else if requestPath[len(e.Path)] == '/' {
+ return true // The "/any" matches "/any/path" case.
+ }
+ }
+ return false
+}
+
+// hasDotSuffix returns whether s ends in "."+suffix.
+func hasDotSuffix(s, suffix string) bool {
+ return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix
+}
+
+// byPathLength is a []entry sort.Interface that sorts according to RFC 6265
+// section 5.4 point 2: by longest path and then by earliest creation time.
+type byPathLength []entry
+
+func (s byPathLength) Len() int { return len(s) }
+
+func (s byPathLength) Less(i, j int) bool {
+ if len(s[i].Path) != len(s[j].Path) {
+ return len(s[i].Path) > len(s[j].Path)
+ }
+ if !s[i].Creation.Equal(s[j].Creation) {
+ return s[i].Creation.Before(s[j].Creation)
+ }
+ return s[i].seqNum < s[j].seqNum
+}
+
+func (s byPathLength) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
+// Cookies implements the Cookies method of the http.CookieJar interface.
+//
+// It returns an empty slice if the URL's scheme is not HTTP or HTTPS.
+func (j *Jar) Cookies(u *url.URL) (cookies []*http.Cookie) {
+ return j.cookies(u, time.Now())
+}
+
+// cookies is like Cookies but takes the current time as a parameter.
+func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) {
+ if u.Scheme != "http" && u.Scheme != "https" {
+ return cookies
+ }
+ host, err := canonicalHost(u.Host)
+ if err != nil {
+ return cookies
+ }
+ key := jarKey(host, j.psList)
+
+ j.mu.Lock()
+ defer j.mu.Unlock()
+
+ submap := j.entries[key]
+ if submap == nil {
+ return cookies
+ }
+
+ https := u.Scheme == "https"
+ path := u.Path
+ if path == "" {
+ path = "/"
+ }
+
+ modified := false
+ var selected []entry
+ for id, e := range submap {
+ if e.Persistent && !e.Expires.After(now) {
+ delete(submap, id)
+ modified = true
+ continue
+ }
+ if !e.shouldSend(https, host, path) {
+ continue
+ }
+ e.LastAccess = now
+ submap[id] = e
+ selected = append(selected, e)
+ modified = true
+ }
+ if modified {
+ if len(submap) == 0 {
+ delete(j.entries, key)
+ } else {
+ j.entries[key] = submap
+ }
+ }
+
+ sort.Sort(byPathLength(selected))
+ for _, e := range selected {
+ cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value})
+ }
+
+ return cookies
+}
+
+// SetCookies implements the SetCookies method of the http.CookieJar interface.
+//
+// It does nothing if the URL's scheme is not HTTP or HTTPS.
+func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) {
+ j.setCookies(u, cookies, time.Now())
+}
+
+// setCookies is like SetCookies but takes the current time as parameter.
+func (j *Jar) setCookies(u *url.URL, cookies []*http.Cookie, now time.Time) {
+ if len(cookies) == 0 {
+ return
+ }
+ if u.Scheme != "http" && u.Scheme != "https" {
+ return
+ }
+ host, err := canonicalHost(u.Host)
+ if err != nil {
+ return
+ }
+ key := jarKey(host, j.psList)
+ defPath := defaultPath(u.Path)
+
+ j.mu.Lock()
+ defer j.mu.Unlock()
+
+ submap := j.entries[key]
+
+ modified := false
+ for _, cookie := range cookies {
+ e, remove, err := j.newEntry(cookie, now, defPath, host)
+ if err != nil {
+ continue
+ }
+ id := e.id()
+ if remove {
+ if submap != nil {
+ if _, ok := submap[id]; ok {
+ delete(submap, id)
+ modified = true
+ }
+ }
+ continue
+ }
+ if submap == nil {
+ submap = make(map[string]entry)
+ }
+
+ if old, ok := submap[id]; ok {
+ e.Creation = old.Creation
+ e.seqNum = old.seqNum
+ } else {
+ e.Creation = now
+ e.seqNum = j.nextSeqNum
+ j.nextSeqNum++
+ }
+ e.LastAccess = now
+ submap[id] = e
+ modified = true
+ }
+
+ if modified {
+ if len(submap) == 0 {
+ delete(j.entries, key)
+ } else {
+ j.entries[key] = submap
+ }
+ }
+}
+
+// canonicalHost strips port from host if present and returns the canonicalized
+// host name.
+func canonicalHost(host string) (string, error) {
+ var err error
+ host = strings.ToLower(host)
+ if hasPort(host) {
+ host, _, err = net.SplitHostPort(host)
+ if err != nil {
+ return "", err
+ }
+ }
+ if strings.HasSuffix(host, ".") {
+ // Strip trailing dot from fully qualified domain names.
+ host = host[:len(host)-1]
+ }
+ return toASCII(host)
+}
+
+// hasPort returns whether host contains a port number. host may be a host
+// name, an IPv4 or an IPv6 address.
+func hasPort(host string) bool {
+ colons := strings.Count(host, ":")
+ if colons == 0 {
+ return false
+ }
+ if colons == 1 {
+ return true
+ }
+ return host[0] == '[' && strings.Contains(host, "]:")
+}
+
+// jarKey returns the key to use for a jar.
+func jarKey(host string, psl PublicSuffixList) string {
+ if isIP(host) {
+ return host
+ }
+
+ var i int
+ if psl == nil {
+ i = strings.LastIndex(host, ".")
+ if i == -1 {
+ return host
+ }
+ } else {
+ suffix := psl.PublicSuffix(host)
+ if suffix == host {
+ return host
+ }
+ i = len(host) - len(suffix)
+ if i <= 0 || host[i-1] != '.' {
+ // The provided public suffix list psl is broken.
+ // Storing cookies under host is a safe stopgap.
+ return host
+ }
+ }
+ prevDot := strings.LastIndex(host[:i-1], ".")
+ return host[prevDot+1:]
+}
+
+// isIP returns whether host is an IP address.
+func isIP(host string) bool {
+ return net.ParseIP(host) != nil
+}
+
+// defaultPath returns the directory part of an URL's path according to
+// RFC 6265 section 5.1.4.
+func defaultPath(path string) string {
+ if len(path) == 0 || path[0] != '/' {
+ return "/" // Path is empty or malformed.
+ }
+
+ i := strings.LastIndex(path, "/") // Path starts with "/", so i != -1.
+ if i == 0 {
+ return "/" // Path has the form "/abc".
+ }
+ return path[:i] // Path is either of form "/abc/xyz" or "/abc/xyz/".
+}
+
+// newEntry creates an entry from a http.Cookie c. now is the current time and
+// is compared to c.Expires to determine deletion of c. defPath and host are the
+// default-path and the canonical host name of the URL c was received from.
+//
+// remove is whether the jar should delete this cookie, as it has already
+// expired with respect to now. In this case, e may be incomplete, but it will
+// be valid to call e.id (which depends on e's Name, Domain and Path).
+//
+// A malformed c.Domain will result in an error.
+func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e entry, remove bool, err error) {
+ e.Name = c.Name
+
+ if c.Path == "" || c.Path[0] != '/' {
+ e.Path = defPath
+ } else {
+ e.Path = c.Path
+ }
+
+ e.Domain, e.HostOnly, err = j.domainAndType(host, c.Domain)
+ if err != nil {
+ return e, false, err
+ }
+
+ // MaxAge takes precedence over Expires.
+ if c.MaxAge < 0 {
+ return e, true, nil
+ } else if c.MaxAge > 0 {
+ e.Expires = now.Add(time.Duration(c.MaxAge) * time.Second)
+ e.Persistent = true
+ } else {
+ if c.Expires.IsZero() {
+ e.Expires = endOfTime
+ e.Persistent = false
+ } else {
+ if !c.Expires.After(now) {
+ return e, true, nil
+ }
+ e.Expires = c.Expires
+ e.Persistent = true
+ }
+ }
+
+ e.Value = c.Value
+ e.Secure = c.Secure
+ e.HttpOnly = c.HttpOnly
+
+ return e, false, nil
+}
+
+var (
+ errIllegalDomain = errors.New("cookiejar: illegal cookie domain attribute")
+ errMalformedDomain = errors.New("cookiejar: malformed cookie domain attribute")
+ errNoHostname = errors.New("cookiejar: no host name available (IP only)")
+)
+
+// endOfTime is the time when session (non-persistent) cookies expire.
+// This instant is representable in most date/time formats (not just
+// Go's time.Time) and should be far enough in the future.
+var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC)
+
+// domainAndType determines the cookie's domain and hostOnly attribute.
+func (j *Jar) domainAndType(host, domain string) (string, bool, error) {
+ if domain == "" {
+ // No domain attribute in the SetCookie header indicates a
+ // host cookie.
+ return host, true, nil
+ }
+
+ if isIP(host) {
+ // According to RFC 6265 domain-matching includes not being
+ // an IP address.
+ // TODO: This might be relaxed as in common browsers.
+ return "", false, errNoHostname
+ }
+
+ // From here on: If the cookie is valid, it is a domain cookie (with
+ // the one exception of a public suffix below).
+ // See RFC 6265 section 5.2.3.
+ if domain[0] == '.' {
+ domain = domain[1:]
+ }
+
+ if len(domain) == 0 || domain[0] == '.' {
+ // Received either "Domain=." or "Domain=..some.thing",
+ // both are illegal.
+ return "", false, errMalformedDomain
+ }
+ domain = strings.ToLower(domain)
+
+ if domain[len(domain)-1] == '.' {
+ // We received stuff like "Domain=www.example.com.".
+ // Browsers do handle such stuff (actually differently) but
+ // RFC 6265 seems to be clear here (e.g. section 4.1.2.3) in
+ // requiring a reject. 4.1.2.3 is not normative, but
+ // "Domain Matching" (5.1.3) and "Canonicalized Host Names"
+ // (5.1.2) are.
+ return "", false, errMalformedDomain
+ }
+
+ // See RFC 6265 section 5.3 #5.
+ if j.psList != nil {
+ if ps := j.psList.PublicSuffix(domain); ps != "" && !hasDotSuffix(domain, ps) {
+ if host == domain {
+ // This is the one exception in which a cookie
+ // with a domain attribute is a host cookie.
+ return host, true, nil
+ }
+ return "", false, errIllegalDomain
+ }
+ }
+
+ // The domain must domain-match host: www.mycompany.com cannot
+ // set cookies for .ourcompetitors.com.
+ if host != domain && !hasDotSuffix(host, domain) {
+ return "", false, errIllegalDomain
+ }
+
+ return domain, false, nil
+}
diff --git a/src/pkg/net/http/cookiejar/jar_test.go b/src/pkg/net/http/cookiejar/jar_test.go
new file mode 100644
index 000000000..3aa601586
--- /dev/null
+++ b/src/pkg/net/http/cookiejar/jar_test.go
@@ -0,0 +1,1267 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cookiejar
+
+import (
+ "fmt"
+ "net/http"
+ "net/url"
+ "sort"
+ "strings"
+ "testing"
+ "time"
+)
+
+// tNow is the synthetic current time used as now during testing.
+var tNow = time.Date(2013, 1, 1, 12, 0, 0, 0, time.UTC)
+
+// testPSL implements PublicSuffixList with just two rules: "co.uk"
+// and the default rule "*".
+type testPSL struct{}
+
+func (testPSL) String() string {
+ return "testPSL"
+}
+func (testPSL) PublicSuffix(d string) string {
+ if d == "co.uk" || strings.HasSuffix(d, ".co.uk") {
+ return "co.uk"
+ }
+ return d[strings.LastIndex(d, ".")+1:]
+}
+
+// newTestJar creates an empty Jar with testPSL as the public suffix list.
+func newTestJar() *Jar {
+ jar, err := New(&Options{PublicSuffixList: testPSL{}})
+ if err != nil {
+ panic(err)
+ }
+ return jar
+}
+
+var hasDotSuffixTests = [...]struct {
+ s, suffix string
+}{
+ {"", ""},
+ {"", "."},
+ {"", "x"},
+ {".", ""},
+ {".", "."},
+ {".", ".."},
+ {".", "x"},
+ {".", "x."},
+ {".", ".x"},
+ {".", ".x."},
+ {"x", ""},
+ {"x", "."},
+ {"x", ".."},
+ {"x", "x"},
+ {"x", "x."},
+ {"x", ".x"},
+ {"x", ".x."},
+ {".x", ""},
+ {".x", "."},
+ {".x", ".."},
+ {".x", "x"},
+ {".x", "x."},
+ {".x", ".x"},
+ {".x", ".x."},
+ {"x.", ""},
+ {"x.", "."},
+ {"x.", ".."},
+ {"x.", "x"},
+ {"x.", "x."},
+ {"x.", ".x"},
+ {"x.", ".x."},
+ {"com", ""},
+ {"com", "m"},
+ {"com", "om"},
+ {"com", "com"},
+ {"com", ".com"},
+ {"com", "x.com"},
+ {"com", "xcom"},
+ {"com", "xorg"},
+ {"com", "org"},
+ {"com", "rg"},
+ {"foo.com", ""},
+ {"foo.com", "m"},
+ {"foo.com", "om"},
+ {"foo.com", "com"},
+ {"foo.com", ".com"},
+ {"foo.com", "o.com"},
+ {"foo.com", "oo.com"},
+ {"foo.com", "foo.com"},
+ {"foo.com", ".foo.com"},
+ {"foo.com", "x.foo.com"},
+ {"foo.com", "xfoo.com"},
+ {"foo.com", "xfoo.org"},
+ {"foo.com", "foo.org"},
+ {"foo.com", "oo.org"},
+ {"foo.com", "o.org"},
+ {"foo.com", ".org"},
+ {"foo.com", "org"},
+ {"foo.com", "rg"},
+}
+
+func TestHasDotSuffix(t *testing.T) {
+ for _, tc := range hasDotSuffixTests {
+ got := hasDotSuffix(tc.s, tc.suffix)
+ want := strings.HasSuffix(tc.s, "."+tc.suffix)
+ if got != want {
+ t.Errorf("s=%q, suffix=%q: got %v, want %v", tc.s, tc.suffix, got, want)
+ }
+ }
+}
+
+var canonicalHostTests = map[string]string{
+ "www.example.com": "www.example.com",
+ "WWW.EXAMPLE.COM": "www.example.com",
+ "wWw.eXAmple.CoM": "www.example.com",
+ "www.example.com:80": "www.example.com",
+ "192.168.0.10": "192.168.0.10",
+ "192.168.0.5:8080": "192.168.0.5",
+ "2001:4860:0:2001::68": "2001:4860:0:2001::68",
+ "[2001:4860:0:::68]:8080": "2001:4860:0:::68",
+ "www.bücher.de": "www.xn--bcher-kva.de",
+ "www.example.com.": "www.example.com",
+ "[bad.unmatched.bracket:": "error",
+}
+
+func TestCanonicalHost(t *testing.T) {
+ for h, want := range canonicalHostTests {
+ got, err := canonicalHost(h)
+ if want == "error" {
+ if err == nil {
+ t.Errorf("%q: got nil error, want non-nil", h)
+ }
+ continue
+ }
+ if err != nil {
+ t.Errorf("%q: %v", h, err)
+ continue
+ }
+ if got != want {
+ t.Errorf("%q: got %q, want %q", h, got, want)
+ continue
+ }
+ }
+}
+
+var hasPortTests = map[string]bool{
+ "www.example.com": false,
+ "www.example.com:80": true,
+ "127.0.0.1": false,
+ "127.0.0.1:8080": true,
+ "2001:4860:0:2001::68": false,
+ "[2001::0:::68]:80": true,
+}
+
+func TestHasPort(t *testing.T) {
+ for host, want := range hasPortTests {
+ if got := hasPort(host); got != want {
+ t.Errorf("%q: got %t, want %t", host, got, want)
+ }
+ }
+}
+
+var jarKeyTests = map[string]string{
+ "foo.www.example.com": "example.com",
+ "www.example.com": "example.com",
+ "example.com": "example.com",
+ "com": "com",
+ "foo.www.bbc.co.uk": "bbc.co.uk",
+ "www.bbc.co.uk": "bbc.co.uk",
+ "bbc.co.uk": "bbc.co.uk",
+ "co.uk": "co.uk",
+ "uk": "uk",
+ "192.168.0.5": "192.168.0.5",
+}
+
+func TestJarKey(t *testing.T) {
+ for host, want := range jarKeyTests {
+ if got := jarKey(host, testPSL{}); got != want {
+ t.Errorf("%q: got %q, want %q", host, got, want)
+ }
+ }
+}
+
+var jarKeyNilPSLTests = map[string]string{
+ "foo.www.example.com": "example.com",
+ "www.example.com": "example.com",
+ "example.com": "example.com",
+ "com": "com",
+ "foo.www.bbc.co.uk": "co.uk",
+ "www.bbc.co.uk": "co.uk",
+ "bbc.co.uk": "co.uk",
+ "co.uk": "co.uk",
+ "uk": "uk",
+ "192.168.0.5": "192.168.0.5",
+}
+
+func TestJarKeyNilPSL(t *testing.T) {
+ for host, want := range jarKeyNilPSLTests {
+ if got := jarKey(host, nil); got != want {
+ t.Errorf("%q: got %q, want %q", host, got, want)
+ }
+ }
+}
+
+var isIPTests = map[string]bool{
+ "127.0.0.1": true,
+ "1.2.3.4": true,
+ "2001:4860:0:2001::68": true,
+ "example.com": false,
+ "1.1.1.300": false,
+ "www.foo.bar.net": false,
+ "123.foo.bar.net": false,
+}
+
+func TestIsIP(t *testing.T) {
+ for host, want := range isIPTests {
+ if got := isIP(host); got != want {
+ t.Errorf("%q: got %t, want %t", host, got, want)
+ }
+ }
+}
+
+var defaultPathTests = map[string]string{
+ "/": "/",
+ "/abc": "/",
+ "/abc/": "/abc",
+ "/abc/xyz": "/abc",
+ "/abc/xyz/": "/abc/xyz",
+ "/a/b/c.html": "/a/b",
+ "": "/",
+ "strange": "/",
+ "//": "/",
+ "/a//b": "/a/",
+ "/a/./b": "/a/.",
+ "/a/../b": "/a/..",
+}
+
+func TestDefaultPath(t *testing.T) {
+ for path, want := range defaultPathTests {
+ if got := defaultPath(path); got != want {
+ t.Errorf("%q: got %q, want %q", path, got, want)
+ }
+ }
+}
+
+var domainAndTypeTests = [...]struct {
+ host string // host Set-Cookie header was received from
+ domain string // domain attribute in Set-Cookie header
+ wantDomain string // expected domain of cookie
+ wantHostOnly bool // expected host-cookie flag
+ wantErr error // expected error
+}{
+ {"www.example.com", "", "www.example.com", true, nil},
+ {"127.0.0.1", "", "127.0.0.1", true, nil},
+ {"2001:4860:0:2001::68", "", "2001:4860:0:2001::68", true, nil},
+ {"www.example.com", "example.com", "example.com", false, nil},
+ {"www.example.com", ".example.com", "example.com", false, nil},
+ {"www.example.com", "www.example.com", "www.example.com", false, nil},
+ {"www.example.com", ".www.example.com", "www.example.com", false, nil},
+ {"foo.sso.example.com", "sso.example.com", "sso.example.com", false, nil},
+ {"bar.co.uk", "bar.co.uk", "bar.co.uk", false, nil},
+ {"foo.bar.co.uk", ".bar.co.uk", "bar.co.uk", false, nil},
+ {"127.0.0.1", "127.0.0.1", "", false, errNoHostname},
+ {"2001:4860:0:2001::68", "2001:4860:0:2001::68", "2001:4860:0:2001::68", false, errNoHostname},
+ {"www.example.com", ".", "", false, errMalformedDomain},
+ {"www.example.com", "..", "", false, errMalformedDomain},
+ {"www.example.com", "other.com", "", false, errIllegalDomain},
+ {"www.example.com", "com", "", false, errIllegalDomain},
+ {"www.example.com", ".com", "", false, errIllegalDomain},
+ {"foo.bar.co.uk", ".co.uk", "", false, errIllegalDomain},
+ {"127.www.0.0.1", "127.0.0.1", "", false, errIllegalDomain},
+ {"com", "", "com", true, nil},
+ {"com", "com", "com", true, nil},
+ {"com", ".com", "com", true, nil},
+ {"co.uk", "", "co.uk", true, nil},
+ {"co.uk", "co.uk", "co.uk", true, nil},
+ {"co.uk", ".co.uk", "co.uk", true, nil},
+}
+
+func TestDomainAndType(t *testing.T) {
+ jar := newTestJar()
+ for _, tc := range domainAndTypeTests {
+ domain, hostOnly, err := jar.domainAndType(tc.host, tc.domain)
+ if err != tc.wantErr {
+ t.Errorf("%q/%q: got %q error, want %q",
+ tc.host, tc.domain, err, tc.wantErr)
+ continue
+ }
+ if err != nil {
+ continue
+ }
+ if domain != tc.wantDomain || hostOnly != tc.wantHostOnly {
+ t.Errorf("%q/%q: got %q/%t want %q/%t",
+ tc.host, tc.domain, domain, hostOnly,
+ tc.wantDomain, tc.wantHostOnly)
+ }
+ }
+}
+
+// expiresIn creates an expires attribute delta seconds from tNow.
+func expiresIn(delta int) string {
+ t := tNow.Add(time.Duration(delta) * time.Second)
+ return "expires=" + t.Format(time.RFC1123)
+}
+
+// mustParseURL parses s to an URL and panics on error.
+func mustParseURL(s string) *url.URL {
+ u, err := url.Parse(s)
+ if err != nil || u.Scheme == "" || u.Host == "" {
+ panic(fmt.Sprintf("Unable to parse URL %s.", s))
+ }
+ return u
+}
+
+// jarTest encapsulates the following actions on a jar:
+// 1. Perform SetCookies with fromURL and the cookies from setCookies.
+// (Done at time tNow + 0 ms.)
+// 2. Check that the entries in the jar matches content.
+// (Done at time tNow + 1001 ms.)
+// 3. For each query in tests: Check that Cookies with toURL yields the
+// cookies in want.
+// (Query n done at tNow + (n+2)*1001 ms.)
+type jarTest struct {
+ description string // The description of what this test is supposed to test
+ fromURL string // The full URL of the request from which Set-Cookie headers where received
+ setCookies []string // All the cookies received from fromURL
+ content string // The whole (non-expired) content of the jar
+ queries []query // Queries to test the Jar.Cookies method
+}
+
+// query contains one test of the cookies returned from Jar.Cookies.
+type query struct {
+ toURL string // the URL in the Cookies call
+ want string // the expected list of cookies (order matters)
+}
+
+// run runs the jarTest.
+func (test jarTest) run(t *testing.T, jar *Jar) {
+ now := tNow
+
+ // Populate jar with cookies.
+ setCookies := make([]*http.Cookie, len(test.setCookies))
+ for i, cs := range test.setCookies {
+ cookies := (&http.Response{Header: http.Header{"Set-Cookie": {cs}}}).Cookies()
+ if len(cookies) != 1 {
+ panic(fmt.Sprintf("Wrong cookie line %q: %#v", cs, cookies))
+ }
+ setCookies[i] = cookies[0]
+ }
+ jar.setCookies(mustParseURL(test.fromURL), setCookies, now)
+ now = now.Add(1001 * time.Millisecond)
+
+ // Serialize non-expired entries in the form "name1=val1 name2=val2".
+ var cs []string
+ for _, submap := range jar.entries {
+ for _, cookie := range submap {
+ if !cookie.Expires.After(now) {
+ continue
+ }
+ cs = append(cs, cookie.Name+"="+cookie.Value)
+ }
+ }
+ sort.Strings(cs)
+ got := strings.Join(cs, " ")
+
+ // Make sure jar content matches our expectations.
+ if got != test.content {
+ t.Errorf("Test %q Content\ngot %q\nwant %q",
+ test.description, got, test.content)
+ }
+
+ // Test different calls to Cookies.
+ for i, query := range test.queries {
+ now = now.Add(1001 * time.Millisecond)
+ var s []string
+ for _, c := range jar.cookies(mustParseURL(query.toURL), now) {
+ s = append(s, c.Name+"="+c.Value)
+ }
+ if got := strings.Join(s, " "); got != query.want {
+ t.Errorf("Test %q #%d\ngot %q\nwant %q", test.description, i, got, query.want)
+ }
+ }
+}
+
+// basicsTests contains fundamental tests. Each jarTest has to be performed on
+// a fresh, empty Jar.
+var basicsTests = [...]jarTest{
+ {
+ "Retrieval of a plain host cookie.",
+ "http://www.host.test/",
+ []string{"A=a"},
+ "A=a",
+ []query{
+ {"http://www.host.test", "A=a"},
+ {"http://www.host.test/", "A=a"},
+ {"http://www.host.test/some/path", "A=a"},
+ {"https://www.host.test", "A=a"},
+ {"https://www.host.test/", "A=a"},
+ {"https://www.host.test/some/path", "A=a"},
+ {"ftp://www.host.test", ""},
+ {"ftp://www.host.test/", ""},
+ {"ftp://www.host.test/some/path", ""},
+ {"http://www.other.org", ""},
+ {"http://sibling.host.test", ""},
+ {"http://deep.www.host.test", ""},
+ },
+ },
+ {
+ "Secure cookies are not returned to http.",
+ "http://www.host.test/",
+ []string{"A=a; secure"},
+ "A=a",
+ []query{
+ {"http://www.host.test", ""},
+ {"http://www.host.test/", ""},
+ {"http://www.host.test/some/path", ""},
+ {"https://www.host.test", "A=a"},
+ {"https://www.host.test/", "A=a"},
+ {"https://www.host.test/some/path", "A=a"},
+ },
+ },
+ {
+ "Explicit path.",
+ "http://www.host.test/",
+ []string{"A=a; path=/some/path"},
+ "A=a",
+ []query{
+ {"http://www.host.test", ""},
+ {"http://www.host.test/", ""},
+ {"http://www.host.test/some", ""},
+ {"http://www.host.test/some/", ""},
+ {"http://www.host.test/some/path", "A=a"},
+ {"http://www.host.test/some/paths", ""},
+ {"http://www.host.test/some/path/foo", "A=a"},
+ {"http://www.host.test/some/path/foo/", "A=a"},
+ },
+ },
+ {
+ "Implicit path #1: path is a directory.",
+ "http://www.host.test/some/path/",
+ []string{"A=a"},
+ "A=a",
+ []query{
+ {"http://www.host.test", ""},
+ {"http://www.host.test/", ""},
+ {"http://www.host.test/some", ""},
+ {"http://www.host.test/some/", ""},
+ {"http://www.host.test/some/path", "A=a"},
+ {"http://www.host.test/some/paths", ""},
+ {"http://www.host.test/some/path/foo", "A=a"},
+ {"http://www.host.test/some/path/foo/", "A=a"},
+ },
+ },
+ {
+ "Implicit path #2: path is not a directory.",
+ "http://www.host.test/some/path/index.html",
+ []string{"A=a"},
+ "A=a",
+ []query{
+ {"http://www.host.test", ""},
+ {"http://www.host.test/", ""},
+ {"http://www.host.test/some", ""},
+ {"http://www.host.test/some/", ""},
+ {"http://www.host.test/some/path", "A=a"},
+ {"http://www.host.test/some/paths", ""},
+ {"http://www.host.test/some/path/foo", "A=a"},
+ {"http://www.host.test/some/path/foo/", "A=a"},
+ },
+ },
+ {
+ "Implicit path #3: no path in URL at all.",
+ "http://www.host.test",
+ []string{"A=a"},
+ "A=a",
+ []query{
+ {"http://www.host.test", "A=a"},
+ {"http://www.host.test/", "A=a"},
+ {"http://www.host.test/some/path", "A=a"},
+ },
+ },
+ {
+ "Cookies are sorted by path length.",
+ "http://www.host.test/",
+ []string{
+ "A=a; path=/foo/bar",
+ "B=b; path=/foo/bar/baz/qux",
+ "C=c; path=/foo/bar/baz",
+ "D=d; path=/foo"},
+ "A=a B=b C=c D=d",
+ []query{
+ {"http://www.host.test/foo/bar/baz/qux", "B=b C=c A=a D=d"},
+ {"http://www.host.test/foo/bar/baz/", "C=c A=a D=d"},
+ {"http://www.host.test/foo/bar", "A=a D=d"},
+ },
+ },
+ {
+ "Creation time determines sorting on same length paths.",
+ "http://www.host.test/",
+ []string{
+ "A=a; path=/foo/bar",
+ "X=x; path=/foo/bar",
+ "Y=y; path=/foo/bar/baz/qux",
+ "B=b; path=/foo/bar/baz/qux",
+ "C=c; path=/foo/bar/baz",
+ "W=w; path=/foo/bar/baz",
+ "Z=z; path=/foo",
+ "D=d; path=/foo"},
+ "A=a B=b C=c D=d W=w X=x Y=y Z=z",
+ []query{
+ {"http://www.host.test/foo/bar/baz/qux", "Y=y B=b C=c W=w A=a X=x Z=z D=d"},
+ {"http://www.host.test/foo/bar/baz/", "C=c W=w A=a X=x Z=z D=d"},
+ {"http://www.host.test/foo/bar", "A=a X=x Z=z D=d"},
+ },
+ },
+ {
+ "Sorting of same-name cookies.",
+ "http://www.host.test/",
+ []string{
+ "A=1; path=/",
+ "A=2; path=/path",
+ "A=3; path=/quux",
+ "A=4; path=/path/foo",
+ "A=5; domain=.host.test; path=/path",
+ "A=6; domain=.host.test; path=/quux",
+ "A=7; domain=.host.test; path=/path/foo",
+ },
+ "A=1 A=2 A=3 A=4 A=5 A=6 A=7",
+ []query{
+ {"http://www.host.test/path", "A=2 A=5 A=1"},
+ {"http://www.host.test/path/foo", "A=4 A=7 A=2 A=5 A=1"},
+ },
+ },
+ {
+ "Disallow domain cookie on public suffix.",
+ "http://www.bbc.co.uk",
+ []string{
+ "a=1",
+ "b=2; domain=co.uk",
+ },
+ "a=1",
+ []query{{"http://www.bbc.co.uk", "a=1"}},
+ },
+ {
+ "Host cookie on IP.",
+ "http://192.168.0.10",
+ []string{"a=1"},
+ "a=1",
+ []query{{"http://192.168.0.10", "a=1"}},
+ },
+ {
+ "Port is ignored #1.",
+ "http://www.host.test/",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://www.host.test", "a=1"},
+ {"http://www.host.test:8080/", "a=1"},
+ },
+ },
+ {
+ "Port is ignored #2.",
+ "http://www.host.test:8080/",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://www.host.test", "a=1"},
+ {"http://www.host.test:8080/", "a=1"},
+ {"http://www.host.test:1234/", "a=1"},
+ },
+ },
+}
+
+func TestBasics(t *testing.T) {
+ for _, test := range basicsTests {
+ jar := newTestJar()
+ test.run(t, jar)
+ }
+}
+
+// updateAndDeleteTests contains jarTests which must be performed on the same
+// Jar.
+var updateAndDeleteTests = [...]jarTest{
+ {
+ "Set initial cookies.",
+ "http://www.host.test",
+ []string{
+ "a=1",
+ "b=2; secure",
+ "c=3; httponly",
+ "d=4; secure; httponly"},
+ "a=1 b=2 c=3 d=4",
+ []query{
+ {"http://www.host.test", "a=1 c=3"},
+ {"https://www.host.test", "a=1 b=2 c=3 d=4"},
+ },
+ },
+ {
+ "Update value via http.",
+ "http://www.host.test",
+ []string{
+ "a=w",
+ "b=x; secure",
+ "c=y; httponly",
+ "d=z; secure; httponly"},
+ "a=w b=x c=y d=z",
+ []query{
+ {"http://www.host.test", "a=w c=y"},
+ {"https://www.host.test", "a=w b=x c=y d=z"},
+ },
+ },
+ {
+ "Clear Secure flag from a http.",
+ "http://www.host.test/",
+ []string{
+ "b=xx",
+ "d=zz; httponly"},
+ "a=w b=xx c=y d=zz",
+ []query{{"http://www.host.test", "a=w b=xx c=y d=zz"}},
+ },
+ {
+ "Delete all.",
+ "http://www.host.test/",
+ []string{
+ "a=1; max-Age=-1", // delete via MaxAge
+ "b=2; " + expiresIn(-10), // delete via Expires
+ "c=2; max-age=-1; " + expiresIn(-10), // delete via both
+ "d=4; max-age=-1; " + expiresIn(10)}, // MaxAge takes precedence
+ "",
+ []query{{"http://www.host.test", ""}},
+ },
+ {
+ "Refill #1.",
+ "http://www.host.test",
+ []string{
+ "A=1",
+ "A=2; path=/foo",
+ "A=3; domain=.host.test",
+ "A=4; path=/foo; domain=.host.test"},
+ "A=1 A=2 A=3 A=4",
+ []query{{"http://www.host.test/foo", "A=2 A=4 A=1 A=3"}},
+ },
+ {
+ "Refill #2.",
+ "http://www.google.com",
+ []string{
+ "A=6",
+ "A=7; path=/foo",
+ "A=8; domain=.google.com",
+ "A=9; path=/foo; domain=.google.com"},
+ "A=1 A=2 A=3 A=4 A=6 A=7 A=8 A=9",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=4 A=1 A=3"},
+ {"http://www.google.com/foo", "A=7 A=9 A=6 A=8"},
+ },
+ },
+ {
+ "Delete A7.",
+ "http://www.google.com",
+ []string{"A=; path=/foo; max-age=-1"},
+ "A=1 A=2 A=3 A=4 A=6 A=8 A=9",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=4 A=1 A=3"},
+ {"http://www.google.com/foo", "A=9 A=6 A=8"},
+ },
+ },
+ {
+ "Delete A4.",
+ "http://www.host.test",
+ []string{"A=; path=/foo; domain=host.test; max-age=-1"},
+ "A=1 A=2 A=3 A=6 A=8 A=9",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=1 A=3"},
+ {"http://www.google.com/foo", "A=9 A=6 A=8"},
+ },
+ },
+ {
+ "Delete A6.",
+ "http://www.google.com",
+ []string{"A=; max-age=-1"},
+ "A=1 A=2 A=3 A=8 A=9",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=1 A=3"},
+ {"http://www.google.com/foo", "A=9 A=8"},
+ },
+ },
+ {
+ "Delete A3.",
+ "http://www.host.test",
+ []string{"A=; domain=host.test; max-age=-1"},
+ "A=1 A=2 A=8 A=9",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=1"},
+ {"http://www.google.com/foo", "A=9 A=8"},
+ },
+ },
+ {
+ "No cross-domain delete.",
+ "http://www.host.test",
+ []string{
+ "A=; domain=google.com; max-age=-1",
+ "A=; path=/foo; domain=google.com; max-age=-1"},
+ "A=1 A=2 A=8 A=9",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=1"},
+ {"http://www.google.com/foo", "A=9 A=8"},
+ },
+ },
+ {
+ "Delete A8 and A9.",
+ "http://www.google.com",
+ []string{
+ "A=; domain=google.com; max-age=-1",
+ "A=; path=/foo; domain=google.com; max-age=-1"},
+ "A=1 A=2",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=1"},
+ {"http://www.google.com/foo", ""},
+ },
+ },
+}
+
+func TestUpdateAndDelete(t *testing.T) {
+ jar := newTestJar()
+ for _, test := range updateAndDeleteTests {
+ test.run(t, jar)
+ }
+}
+
+func TestExpiration(t *testing.T) {
+ jar := newTestJar()
+ jarTest{
+ "Expiration.",
+ "http://www.host.test",
+ []string{
+ "a=1",
+ "b=2; max-age=3",
+ "c=3; " + expiresIn(3),
+ "d=4; max-age=5",
+ "e=5; " + expiresIn(5),
+ "f=6; max-age=100",
+ },
+ "a=1 b=2 c=3 d=4 e=5 f=6", // executed at t0 + 1001 ms
+ []query{
+ {"http://www.host.test", "a=1 b=2 c=3 d=4 e=5 f=6"}, // t0 + 2002 ms
+ {"http://www.host.test", "a=1 d=4 e=5 f=6"}, // t0 + 3003 ms
+ {"http://www.host.test", "a=1 d=4 e=5 f=6"}, // t0 + 4004 ms
+ {"http://www.host.test", "a=1 f=6"}, // t0 + 5005 ms
+ {"http://www.host.test", "a=1 f=6"}, // t0 + 6006 ms
+ },
+ }.run(t, jar)
+}
+
+//
+// Tests derived from Chromium's cookie_store_unittest.h.
+//
+
+// See http://src.chromium.org/viewvc/chrome/trunk/src/net/cookies/cookie_store_unittest.h?revision=159685&content-type=text/plain
+// Some of the original tests are in a bad condition (e.g.
+// DomainWithTrailingDotTest) or are not RFC 6265 conforming (e.g.
+// TestNonDottedAndTLD #1 and #6) and have not been ported.
+
+// chromiumBasicsTests contains fundamental tests. Each jarTest has to be
+// performed on a fresh, empty Jar.
+var chromiumBasicsTests = [...]jarTest{
+ {
+ "DomainWithTrailingDotTest.",
+ "http://www.google.com/",
+ []string{
+ "a=1; domain=.www.google.com.",
+ "b=2; domain=.www.google.com.."},
+ "",
+ []query{
+ {"http://www.google.com", ""},
+ },
+ },
+ {
+ "ValidSubdomainTest #1.",
+ "http://a.b.c.d.com",
+ []string{
+ "a=1; domain=.a.b.c.d.com",
+ "b=2; domain=.b.c.d.com",
+ "c=3; domain=.c.d.com",
+ "d=4; domain=.d.com"},
+ "a=1 b=2 c=3 d=4",
+ []query{
+ {"http://a.b.c.d.com", "a=1 b=2 c=3 d=4"},
+ {"http://b.c.d.com", "b=2 c=3 d=4"},
+ {"http://c.d.com", "c=3 d=4"},
+ {"http://d.com", "d=4"},
+ },
+ },
+ {
+ "ValidSubdomainTest #2.",
+ "http://a.b.c.d.com",
+ []string{
+ "a=1; domain=.a.b.c.d.com",
+ "b=2; domain=.b.c.d.com",
+ "c=3; domain=.c.d.com",
+ "d=4; domain=.d.com",
+ "X=bcd; domain=.b.c.d.com",
+ "X=cd; domain=.c.d.com"},
+ "X=bcd X=cd a=1 b=2 c=3 d=4",
+ []query{
+ {"http://b.c.d.com", "b=2 c=3 d=4 X=bcd X=cd"},
+ {"http://c.d.com", "c=3 d=4 X=cd"},
+ },
+ },
+ {
+ "InvalidDomainTest #1.",
+ "http://foo.bar.com",
+ []string{
+ "a=1; domain=.yo.foo.bar.com",
+ "b=2; domain=.foo.com",
+ "c=3; domain=.bar.foo.com",
+ "d=4; domain=.foo.bar.com.net",
+ "e=5; domain=ar.com",
+ "f=6; domain=.",
+ "g=7; domain=/",
+ "h=8; domain=http://foo.bar.com",
+ "i=9; domain=..foo.bar.com",
+ "j=10; domain=..bar.com",
+ "k=11; domain=.foo.bar.com?blah",
+ "l=12; domain=.foo.bar.com/blah",
+ "m=12; domain=.foo.bar.com:80",
+ "n=14; domain=.foo.bar.com:",
+ "o=15; domain=.foo.bar.com#sup",
+ },
+ "", // Jar is empty.
+ []query{{"http://foo.bar.com", ""}},
+ },
+ {
+ "InvalidDomainTest #2.",
+ "http://foo.com.com",
+ []string{"a=1; domain=.foo.com.com.com"},
+ "",
+ []query{{"http://foo.bar.com", ""}},
+ },
+ {
+ "DomainWithoutLeadingDotTest #1.",
+ "http://manage.hosted.filefront.com",
+ []string{"a=1; domain=filefront.com"},
+ "a=1",
+ []query{{"http://www.filefront.com", "a=1"}},
+ },
+ {
+ "DomainWithoutLeadingDotTest #2.",
+ "http://www.google.com",
+ []string{"a=1; domain=www.google.com"},
+ "a=1",
+ []query{
+ {"http://www.google.com", "a=1"},
+ {"http://sub.www.google.com", "a=1"},
+ {"http://something-else.com", ""},
+ },
+ },
+ {
+ "CaseInsensitiveDomainTest.",
+ "http://www.google.com",
+ []string{
+ "a=1; domain=.GOOGLE.COM",
+ "b=2; domain=.www.gOOgLE.coM"},
+ "a=1 b=2",
+ []query{{"http://www.google.com", "a=1 b=2"}},
+ },
+ {
+ "TestIpAddress #1.",
+ "http://1.2.3.4/foo",
+ []string{"a=1; path=/"},
+ "a=1",
+ []query{{"http://1.2.3.4/foo", "a=1"}},
+ },
+ {
+ "TestIpAddress #2.",
+ "http://1.2.3.4/foo",
+ []string{
+ "a=1; domain=.1.2.3.4",
+ "b=2; domain=.3.4"},
+ "",
+ []query{{"http://1.2.3.4/foo", ""}},
+ },
+ {
+ "TestIpAddress #3.",
+ "http://1.2.3.4/foo",
+ []string{"a=1; domain=1.2.3.4"},
+ "",
+ []query{{"http://1.2.3.4/foo", ""}},
+ },
+ {
+ "TestNonDottedAndTLD #2.",
+ "http://com./index.html",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://com./index.html", "a=1"},
+ {"http://no-cookies.com./index.html", ""},
+ },
+ },
+ {
+ "TestNonDottedAndTLD #3.",
+ "http://a.b",
+ []string{
+ "a=1; domain=.b",
+ "b=2; domain=b"},
+ "",
+ []query{{"http://bar.foo", ""}},
+ },
+ {
+ "TestNonDottedAndTLD #4.",
+ "http://google.com",
+ []string{
+ "a=1; domain=.com",
+ "b=2; domain=com"},
+ "",
+ []query{{"http://google.com", ""}},
+ },
+ {
+ "TestNonDottedAndTLD #5.",
+ "http://google.co.uk",
+ []string{
+ "a=1; domain=.co.uk",
+ "b=2; domain=.uk"},
+ "",
+ []query{
+ {"http://google.co.uk", ""},
+ {"http://else.co.com", ""},
+ {"http://else.uk", ""},
+ },
+ },
+ {
+ "TestHostEndsWithDot.",
+ "http://www.google.com",
+ []string{
+ "a=1",
+ "b=2; domain=.www.google.com."},
+ "a=1",
+ []query{{"http://www.google.com", "a=1"}},
+ },
+ {
+ "PathTest",
+ "http://www.google.izzle",
+ []string{"a=1; path=/wee"},
+ "a=1",
+ []query{
+ {"http://www.google.izzle/wee", "a=1"},
+ {"http://www.google.izzle/wee/", "a=1"},
+ {"http://www.google.izzle/wee/war", "a=1"},
+ {"http://www.google.izzle/wee/war/more/more", "a=1"},
+ {"http://www.google.izzle/weehee", ""},
+ {"http://www.google.izzle/", ""},
+ },
+ },
+}
+
+func TestChromiumBasics(t *testing.T) {
+ for _, test := range chromiumBasicsTests {
+ jar := newTestJar()
+ test.run(t, jar)
+ }
+}
+
+// chromiumDomainTests contains jarTests which must be executed all on the
+// same Jar.
+var chromiumDomainTests = [...]jarTest{
+ {
+ "Fill #1.",
+ "http://www.google.izzle",
+ []string{"A=B"},
+ "A=B",
+ []query{{"http://www.google.izzle", "A=B"}},
+ },
+ {
+ "Fill #2.",
+ "http://www.google.izzle",
+ []string{"C=D; domain=.google.izzle"},
+ "A=B C=D",
+ []query{{"http://www.google.izzle", "A=B C=D"}},
+ },
+ {
+ "Verify A is a host cookie and not accessible from subdomain.",
+ "http://unused.nil",
+ []string{},
+ "A=B C=D",
+ []query{{"http://foo.www.google.izzle", "C=D"}},
+ },
+ {
+ "Verify domain cookies are found on proper domain.",
+ "http://www.google.izzle",
+ []string{"E=F; domain=.www.google.izzle"},
+ "A=B C=D E=F",
+ []query{{"http://www.google.izzle", "A=B C=D E=F"}},
+ },
+ {
+ "Leading dots in domain attributes are optional.",
+ "http://www.google.izzle",
+ []string{"G=H; domain=www.google.izzle"},
+ "A=B C=D E=F G=H",
+ []query{{"http://www.google.izzle", "A=B C=D E=F G=H"}},
+ },
+ {
+ "Verify domain enforcement works #1.",
+ "http://www.google.izzle",
+ []string{"K=L; domain=.bar.www.google.izzle"},
+ "A=B C=D E=F G=H",
+ []query{{"http://bar.www.google.izzle", "C=D E=F G=H"}},
+ },
+ {
+ "Verify domain enforcement works #2.",
+ "http://unused.nil",
+ []string{},
+ "A=B C=D E=F G=H",
+ []query{{"http://www.google.izzle", "A=B C=D E=F G=H"}},
+ },
+}
+
+func TestChromiumDomain(t *testing.T) {
+ jar := newTestJar()
+ for _, test := range chromiumDomainTests {
+ test.run(t, jar)
+ }
+
+}
+
+// chromiumDeletionTests must be performed all on the same Jar.
+var chromiumDeletionTests = [...]jarTest{
+ {
+ "Create session cookie a1.",
+ "http://www.google.com",
+ []string{"a=1"},
+ "a=1",
+ []query{{"http://www.google.com", "a=1"}},
+ },
+ {
+ "Delete sc a1 via MaxAge.",
+ "http://www.google.com",
+ []string{"a=1; max-age=-1"},
+ "",
+ []query{{"http://www.google.com", ""}},
+ },
+ {
+ "Create session cookie b2.",
+ "http://www.google.com",
+ []string{"b=2"},
+ "b=2",
+ []query{{"http://www.google.com", "b=2"}},
+ },
+ {
+ "Delete sc b2 via Expires.",
+ "http://www.google.com",
+ []string{"b=2; " + expiresIn(-10)},
+ "",
+ []query{{"http://www.google.com", ""}},
+ },
+ {
+ "Create persistent cookie c3.",
+ "http://www.google.com",
+ []string{"c=3; max-age=3600"},
+ "c=3",
+ []query{{"http://www.google.com", "c=3"}},
+ },
+ {
+ "Delete pc c3 via MaxAge.",
+ "http://www.google.com",
+ []string{"c=3; max-age=-1"},
+ "",
+ []query{{"http://www.google.com", ""}},
+ },
+ {
+ "Create persistent cookie d4.",
+ "http://www.google.com",
+ []string{"d=4; max-age=3600"},
+ "d=4",
+ []query{{"http://www.google.com", "d=4"}},
+ },
+ {
+ "Delete pc d4 via Expires.",
+ "http://www.google.com",
+ []string{"d=4; " + expiresIn(-10)},
+ "",
+ []query{{"http://www.google.com", ""}},
+ },
+}
+
+func TestChromiumDeletion(t *testing.T) {
+ jar := newTestJar()
+ for _, test := range chromiumDeletionTests {
+ test.run(t, jar)
+ }
+}
+
+// domainHandlingTests tests and documents the rules for domain handling.
+// Each test must be performed on an empty new Jar.
+var domainHandlingTests = [...]jarTest{
+ {
+ "Host cookie",
+ "http://www.host.test",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://www.host.test", "a=1"},
+ {"http://host.test", ""},
+ {"http://bar.host.test", ""},
+ {"http://foo.www.host.test", ""},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Domain cookie #1",
+ "http://www.host.test",
+ []string{"a=1; domain=host.test"},
+ "a=1",
+ []query{
+ {"http://www.host.test", "a=1"},
+ {"http://host.test", "a=1"},
+ {"http://bar.host.test", "a=1"},
+ {"http://foo.www.host.test", "a=1"},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Domain cookie #2",
+ "http://www.host.test",
+ []string{"a=1; domain=.host.test"},
+ "a=1",
+ []query{
+ {"http://www.host.test", "a=1"},
+ {"http://host.test", "a=1"},
+ {"http://bar.host.test", "a=1"},
+ {"http://foo.www.host.test", "a=1"},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Host cookie on IDNA domain #1",
+ "http://www.bücher.test",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://www.bücher.test", "a=1"},
+ {"http://www.xn--bcher-kva.test", "a=1"},
+ {"http://bücher.test", ""},
+ {"http://xn--bcher-kva.test", ""},
+ {"http://bar.bücher.test", ""},
+ {"http://bar.xn--bcher-kva.test", ""},
+ {"http://foo.www.bücher.test", ""},
+ {"http://foo.www.xn--bcher-kva.test", ""},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Host cookie on IDNA domain #2",
+ "http://www.xn--bcher-kva.test",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://www.bücher.test", "a=1"},
+ {"http://www.xn--bcher-kva.test", "a=1"},
+ {"http://bücher.test", ""},
+ {"http://xn--bcher-kva.test", ""},
+ {"http://bar.bücher.test", ""},
+ {"http://bar.xn--bcher-kva.test", ""},
+ {"http://foo.www.bücher.test", ""},
+ {"http://foo.www.xn--bcher-kva.test", ""},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Domain cookie on IDNA domain #1",
+ "http://www.bücher.test",
+ []string{"a=1; domain=xn--bcher-kva.test"},
+ "a=1",
+ []query{
+ {"http://www.bücher.test", "a=1"},
+ {"http://www.xn--bcher-kva.test", "a=1"},
+ {"http://bücher.test", "a=1"},
+ {"http://xn--bcher-kva.test", "a=1"},
+ {"http://bar.bücher.test", "a=1"},
+ {"http://bar.xn--bcher-kva.test", "a=1"},
+ {"http://foo.www.bücher.test", "a=1"},
+ {"http://foo.www.xn--bcher-kva.test", "a=1"},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Domain cookie on IDNA domain #2",
+ "http://www.xn--bcher-kva.test",
+ []string{"a=1; domain=xn--bcher-kva.test"},
+ "a=1",
+ []query{
+ {"http://www.bücher.test", "a=1"},
+ {"http://www.xn--bcher-kva.test", "a=1"},
+ {"http://bücher.test", "a=1"},
+ {"http://xn--bcher-kva.test", "a=1"},
+ {"http://bar.bücher.test", "a=1"},
+ {"http://bar.xn--bcher-kva.test", "a=1"},
+ {"http://foo.www.bücher.test", "a=1"},
+ {"http://foo.www.xn--bcher-kva.test", "a=1"},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Host cookie on TLD.",
+ "http://com",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://com", "a=1"},
+ {"http://any.com", ""},
+ {"http://any.test", ""},
+ },
+ },
+ {
+ "Domain cookie on TLD becomes a host cookie.",
+ "http://com",
+ []string{"a=1; domain=com"},
+ "a=1",
+ []query{
+ {"http://com", "a=1"},
+ {"http://any.com", ""},
+ {"http://any.test", ""},
+ },
+ },
+ {
+ "Host cookie on public suffix.",
+ "http://co.uk",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://co.uk", "a=1"},
+ {"http://uk", ""},
+ {"http://some.co.uk", ""},
+ {"http://foo.some.co.uk", ""},
+ {"http://any.uk", ""},
+ },
+ },
+ {
+ "Domain cookie on public suffix is ignored.",
+ "http://some.co.uk",
+ []string{"a=1; domain=co.uk"},
+ "",
+ []query{
+ {"http://co.uk", ""},
+ {"http://uk", ""},
+ {"http://some.co.uk", ""},
+ {"http://foo.some.co.uk", ""},
+ {"http://any.uk", ""},
+ },
+ },
+}
+
+func TestDomainHandling(t *testing.T) {
+ for _, test := range domainHandlingTests {
+ jar := newTestJar()
+ test.run(t, jar)
+ }
+}
diff --git a/src/pkg/net/http/cookiejar/punycode.go b/src/pkg/net/http/cookiejar/punycode.go
new file mode 100644
index 000000000..ea7ceb5ef
--- /dev/null
+++ b/src/pkg/net/http/cookiejar/punycode.go
@@ -0,0 +1,159 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cookiejar
+
+// This file implements the Punycode algorithm from RFC 3492.
+
+import (
+ "fmt"
+ "strings"
+ "unicode/utf8"
+)
+
+// These parameter values are specified in section 5.
+//
+// All computation is done with int32s, so that overflow behavior is identical
+// regardless of whether int is 32-bit or 64-bit.
+const (
+ base int32 = 36
+ damp int32 = 700
+ initialBias int32 = 72
+ initialN int32 = 128
+ skew int32 = 38
+ tmax int32 = 26
+ tmin int32 = 1
+)
+
+// encode encodes a string as specified in section 6.3 and prepends prefix to
+// the result.
+//
+// The "while h < length(input)" line in the specification becomes "for
+// remaining != 0" in the Go code, because len(s) in Go is in bytes, not runes.
+func encode(prefix, s string) (string, error) {
+ output := make([]byte, len(prefix), len(prefix)+1+2*len(s))
+ copy(output, prefix)
+ delta, n, bias := int32(0), initialN, initialBias
+ b, remaining := int32(0), int32(0)
+ for _, r := range s {
+ if r < 0x80 {
+ b++
+ output = append(output, byte(r))
+ } else {
+ remaining++
+ }
+ }
+ h := b
+ if b > 0 {
+ output = append(output, '-')
+ }
+ for remaining != 0 {
+ m := int32(0x7fffffff)
+ for _, r := range s {
+ if m > r && r >= n {
+ m = r
+ }
+ }
+ delta += (m - n) * (h + 1)
+ if delta < 0 {
+ return "", fmt.Errorf("cookiejar: invalid label %q", s)
+ }
+ n = m
+ for _, r := range s {
+ if r < n {
+ delta++
+ if delta < 0 {
+ return "", fmt.Errorf("cookiejar: invalid label %q", s)
+ }
+ continue
+ }
+ if r > n {
+ continue
+ }
+ q := delta
+ for k := base; ; k += base {
+ t := k - bias
+ if t < tmin {
+ t = tmin
+ } else if t > tmax {
+ t = tmax
+ }
+ if q < t {
+ break
+ }
+ output = append(output, encodeDigit(t+(q-t)%(base-t)))
+ q = (q - t) / (base - t)
+ }
+ output = append(output, encodeDigit(q))
+ bias = adapt(delta, h+1, h == b)
+ delta = 0
+ h++
+ remaining--
+ }
+ delta++
+ n++
+ }
+ return string(output), nil
+}
+
+func encodeDigit(digit int32) byte {
+ switch {
+ case 0 <= digit && digit < 26:
+ return byte(digit + 'a')
+ case 26 <= digit && digit < 36:
+ return byte(digit + ('0' - 26))
+ }
+ panic("cookiejar: internal error in punycode encoding")
+}
+
+// adapt is the bias adaptation function specified in section 6.1.
+func adapt(delta, numPoints int32, firstTime bool) int32 {
+ if firstTime {
+ delta /= damp
+ } else {
+ delta /= 2
+ }
+ delta += delta / numPoints
+ k := int32(0)
+ for delta > ((base-tmin)*tmax)/2 {
+ delta /= base - tmin
+ k += base
+ }
+ return k + (base-tmin+1)*delta/(delta+skew)
+}
+
+// Strictly speaking, the remaining code below deals with IDNA (RFC 5890 and
+// friends) and not Punycode (RFC 3492) per se.
+
+// acePrefix is the ASCII Compatible Encoding prefix.
+const acePrefix = "xn--"
+
+// toASCII converts a domain or domain label to its ASCII form. For example,
+// toASCII("bücher.example.com") is "xn--bcher-kva.example.com", and
+// toASCII("golang") is "golang".
+func toASCII(s string) (string, error) {
+ if ascii(s) {
+ return s, nil
+ }
+ labels := strings.Split(s, ".")
+ for i, label := range labels {
+ if !ascii(label) {
+ a, err := encode(acePrefix, label)
+ if err != nil {
+ return "", err
+ }
+ labels[i] = a
+ }
+ }
+ return strings.Join(labels, "."), nil
+}
+
+func ascii(s string) bool {
+ for i := 0; i < len(s); i++ {
+ if s[i] >= utf8.RuneSelf {
+ return false
+ }
+ }
+ return true
+}
diff --git a/src/pkg/net/http/cookiejar/punycode_test.go b/src/pkg/net/http/cookiejar/punycode_test.go
new file mode 100644
index 000000000..0301de14e
--- /dev/null
+++ b/src/pkg/net/http/cookiejar/punycode_test.go
@@ -0,0 +1,161 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cookiejar
+
+import (
+ "testing"
+)
+
+var punycodeTestCases = [...]struct {
+ s, encoded string
+}{
+ {"", ""},
+ {"-", "--"},
+ {"-a", "-a-"},
+ {"-a-", "-a--"},
+ {"a", "a-"},
+ {"a-", "a--"},
+ {"a-b", "a-b-"},
+ {"books", "books-"},
+ {"bücher", "bcher-kva"},
+ {"Hello世界", "Hello-ck1hg65u"},
+ {"ü", "tda"},
+ {"üý", "tdac"},
+
+ // The test cases below come from RFC 3492 section 7.1 with Errata 3026.
+ {
+ // (A) Arabic (Egyptian).
+ "\u0644\u064A\u0647\u0645\u0627\u0628\u062A\u0643\u0644" +
+ "\u0645\u0648\u0634\u0639\u0631\u0628\u064A\u061F",
+ "egbpdaj6bu4bxfgehfvwxn",
+ },
+ {
+ // (B) Chinese (simplified).
+ "\u4ED6\u4EEC\u4E3A\u4EC0\u4E48\u4E0D\u8BF4\u4E2D\u6587",
+ "ihqwcrb4cv8a8dqg056pqjye",
+ },
+ {
+ // (C) Chinese (traditional).
+ "\u4ED6\u5011\u7232\u4EC0\u9EBD\u4E0D\u8AAA\u4E2D\u6587",
+ "ihqwctvzc91f659drss3x8bo0yb",
+ },
+ {
+ // (D) Czech.
+ "\u0050\u0072\u006F\u010D\u0070\u0072\u006F\u0073\u0074" +
+ "\u011B\u006E\u0065\u006D\u006C\u0075\u0076\u00ED\u010D" +
+ "\u0065\u0073\u006B\u0079",
+ "Proprostnemluvesky-uyb24dma41a",
+ },
+ {
+ // (E) Hebrew.
+ "\u05DC\u05DE\u05D4\u05D4\u05DD\u05E4\u05E9\u05D5\u05D8" +
+ "\u05DC\u05D0\u05DE\u05D3\u05D1\u05E8\u05D9\u05DD\u05E2" +
+ "\u05D1\u05E8\u05D9\u05EA",
+ "4dbcagdahymbxekheh6e0a7fei0b",
+ },
+ {
+ // (F) Hindi (Devanagari).
+ "\u092F\u0939\u0932\u094B\u0917\u0939\u093F\u0928\u094D" +
+ "\u0926\u0940\u0915\u094D\u092F\u094B\u0902\u0928\u0939" +
+ "\u0940\u0902\u092C\u094B\u0932\u0938\u0915\u0924\u0947" +
+ "\u0939\u0948\u0902",
+ "i1baa7eci9glrd9b2ae1bj0hfcgg6iyaf8o0a1dig0cd",
+ },
+ {
+ // (G) Japanese (kanji and hiragana).
+ "\u306A\u305C\u307F\u3093\u306A\u65E5\u672C\u8A9E\u3092" +
+ "\u8A71\u3057\u3066\u304F\u308C\u306A\u3044\u306E\u304B",
+ "n8jok5ay5dzabd5bym9f0cm5685rrjetr6pdxa",
+ },
+ {
+ // (H) Korean (Hangul syllables).
+ "\uC138\uACC4\uC758\uBAA8\uB4E0\uC0AC\uB78C\uB4E4\uC774" +
+ "\uD55C\uAD6D\uC5B4\uB97C\uC774\uD574\uD55C\uB2E4\uBA74" +
+ "\uC5BC\uB9C8\uB098\uC88B\uC744\uAE4C",
+ "989aomsvi5e83db1d2a355cv1e0vak1dwrv93d5xbh15a0dt30a5j" +
+ "psd879ccm6fea98c",
+ },
+ {
+ // (I) Russian (Cyrillic).
+ "\u043F\u043E\u0447\u0435\u043C\u0443\u0436\u0435\u043E" +
+ "\u043D\u0438\u043D\u0435\u0433\u043E\u0432\u043E\u0440" +
+ "\u044F\u0442\u043F\u043E\u0440\u0443\u0441\u0441\u043A" +
+ "\u0438",
+ "b1abfaaepdrnnbgefbadotcwatmq2g4l",
+ },
+ {
+ // (J) Spanish.
+ "\u0050\u006F\u0072\u0071\u0075\u00E9\u006E\u006F\u0070" +
+ "\u0075\u0065\u0064\u0065\u006E\u0073\u0069\u006D\u0070" +
+ "\u006C\u0065\u006D\u0065\u006E\u0074\u0065\u0068\u0061" +
+ "\u0062\u006C\u0061\u0072\u0065\u006E\u0045\u0073\u0070" +
+ "\u0061\u00F1\u006F\u006C",
+ "PorqunopuedensimplementehablarenEspaol-fmd56a",
+ },
+ {
+ // (K) Vietnamese.
+ "\u0054\u1EA1\u0069\u0073\u0061\u006F\u0068\u1ECD\u006B" +
+ "\u0068\u00F4\u006E\u0067\u0074\u0068\u1EC3\u0063\u0068" +
+ "\u1EC9\u006E\u00F3\u0069\u0074\u0069\u1EBF\u006E\u0067" +
+ "\u0056\u0069\u1EC7\u0074",
+ "TisaohkhngthchnitingVit-kjcr8268qyxafd2f1b9g",
+ },
+ {
+ // (L) 3<nen>B<gumi><kinpachi><sensei>.
+ "\u0033\u5E74\u0042\u7D44\u91D1\u516B\u5148\u751F",
+ "3B-ww4c5e180e575a65lsy2b",
+ },
+ {
+ // (M) <amuro><namie>-with-SUPER-MONKEYS.
+ "\u5B89\u5BA4\u5948\u7F8E\u6075\u002D\u0077\u0069\u0074" +
+ "\u0068\u002D\u0053\u0055\u0050\u0045\u0052\u002D\u004D" +
+ "\u004F\u004E\u004B\u0045\u0059\u0053",
+ "-with-SUPER-MONKEYS-pc58ag80a8qai00g7n9n",
+ },
+ {
+ // (N) Hello-Another-Way-<sorezore><no><basho>.
+ "\u0048\u0065\u006C\u006C\u006F\u002D\u0041\u006E\u006F" +
+ "\u0074\u0068\u0065\u0072\u002D\u0057\u0061\u0079\u002D" +
+ "\u305D\u308C\u305E\u308C\u306E\u5834\u6240",
+ "Hello-Another-Way--fc4qua05auwb3674vfr0b",
+ },
+ {
+ // (O) <hitotsu><yane><no><shita>2.
+ "\u3072\u3068\u3064\u5C4B\u6839\u306E\u4E0B\u0032",
+ "2-u9tlzr9756bt3uc0v",
+ },
+ {
+ // (P) Maji<de>Koi<suru>5<byou><mae>
+ "\u004D\u0061\u006A\u0069\u3067\u004B\u006F\u0069\u3059" +
+ "\u308B\u0035\u79D2\u524D",
+ "MajiKoi5-783gue6qz075azm5e",
+ },
+ {
+ // (Q) <pafii>de<runba>
+ "\u30D1\u30D5\u30A3\u30FC\u0064\u0065\u30EB\u30F3\u30D0",
+ "de-jg4avhby1noc0d",
+ },
+ {
+ // (R) <sono><supiido><de>
+ "\u305D\u306E\u30B9\u30D4\u30FC\u30C9\u3067",
+ "d9juau41awczczp",
+ },
+ {
+ // (S) -> $1.00 <-
+ "\u002D\u003E\u0020\u0024\u0031\u002E\u0030\u0030\u0020" +
+ "\u003C\u002D",
+ "-> $1.00 <--",
+ },
+}
+
+func TestPunycode(t *testing.T) {
+ for _, tc := range punycodeTestCases {
+ if got, err := encode("", tc.s); err != nil {
+ t.Errorf(`encode("", %q): %v`, tc.s, err)
+ } else if got != tc.encoded {
+ t.Errorf(`encode("", %q): got %q, want %q`, tc.s, got, tc.encoded)
+ }
+ }
+}
diff --git a/src/pkg/net/http/example_test.go b/src/pkg/net/http/example_test.go
index ec814407d..22073eaf7 100644
--- a/src/pkg/net/http/example_test.go
+++ b/src/pkg/net/http/example_test.go
@@ -43,10 +43,10 @@ func ExampleGet() {
log.Fatal(err)
}
robots, err := ioutil.ReadAll(res.Body)
+ res.Body.Close()
if err != nil {
log.Fatal(err)
}
- res.Body.Close()
fmt.Printf("%s", robots)
}
diff --git a/src/pkg/net/http/export_test.go b/src/pkg/net/http/export_test.go
index 13640ca85..a7bca20a0 100644
--- a/src/pkg/net/http/export_test.go
+++ b/src/pkg/net/http/export_test.go
@@ -7,12 +7,25 @@
package http
-import "time"
+import (
+ "net"
+ "time"
+)
+
+func NewLoggingConn(baseName string, c net.Conn) net.Conn {
+ return newLoggingConn(baseName, c)
+}
+
+func (t *Transport) NumPendingRequestsForTesting() int {
+ t.reqMu.Lock()
+ defer t.reqMu.Unlock()
+ return len(t.reqConn)
+}
func (t *Transport) IdleConnKeysForTesting() (keys []string) {
keys = make([]string, 0)
- t.lk.Lock()
- defer t.lk.Unlock()
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
if t.idleConn == nil {
return
}
@@ -23,8 +36,8 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) {
}
func (t *Transport) IdleConnCountForTesting(cacheKey string) int {
- t.lk.Lock()
- defer t.lk.Unlock()
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
if t.idleConn == nil {
return 0
}
diff --git a/src/pkg/net/http/filetransport_test.go b/src/pkg/net/http/filetransport_test.go
index 039926b53..6f1a537e2 100644
--- a/src/pkg/net/http/filetransport_test.go
+++ b/src/pkg/net/http/filetransport_test.go
@@ -2,11 +2,10 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-package http_test
+package http
import (
"io/ioutil"
- "net/http"
"os"
"path/filepath"
"testing"
@@ -32,9 +31,9 @@ func TestFileTransport(t *testing.T) {
defer os.Remove(dname)
defer os.Remove(fname)
- tr := &http.Transport{}
- tr.RegisterProtocol("file", http.NewFileTransport(http.Dir(dname)))
- c := &http.Client{Transport: tr}
+ tr := &Transport{}
+ tr.RegisterProtocol("file", NewFileTransport(Dir(dname)))
+ c := &Client{Transport: tr}
fooURLs := []string{"file:///foo.txt", "file://../foo.txt"}
for _, urlstr := range fooURLs {
@@ -62,4 +61,5 @@ func TestFileTransport(t *testing.T) {
if res.StatusCode != 404 {
t.Errorf("for %s, StatusCode = %d, want 404", badURL, res.StatusCode)
}
+ res.Body.Close()
}
diff --git a/src/pkg/net/http/fs.go b/src/pkg/net/http/fs.go
index f35dd32c3..b6bea0dfa 100644
--- a/src/pkg/net/http/fs.go
+++ b/src/pkg/net/http/fs.go
@@ -11,6 +11,8 @@ import (
"fmt"
"io"
"mime"
+ "mime/multipart"
+ "net/textproto"
"os"
"path"
"path/filepath"
@@ -26,7 +28,8 @@ import (
type Dir string
func (d Dir) Open(name string) (File, error) {
- if filepath.Separator != '/' && strings.IndexRune(name, filepath.Separator) >= 0 {
+ if filepath.Separator != '/' && strings.IndexRune(name, filepath.Separator) >= 0 ||
+ strings.Contains(name, "\x00") {
return nil, errors.New("http: invalid character in file path")
}
dir := string(d)
@@ -97,6 +100,9 @@ func dirList(w ResponseWriter, f File) {
// The content's Seek method must work: ServeContent uses
// a seek to the end of the content to determine its size.
//
+// If the caller has set w's ETag header, ServeContent uses it to
+// handle requests using If-Range and If-None-Match.
+//
// Note that *os.File implements the io.ReadSeeker interface.
func ServeContent(w ResponseWriter, req *Request, name string, modtime time.Time, content io.ReadSeeker) {
size, err := content.Seek(0, os.SEEK_END)
@@ -119,12 +125,17 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time,
if checkLastModified(w, r, modtime) {
return
}
+ rangeReq, done := checkETag(w, r)
+ if done {
+ return
+ }
code := StatusOK
// If Content-Type isn't set, use the file's extension to find it.
- if w.Header().Get("Content-Type") == "" {
- ctype := mime.TypeByExtension(filepath.Ext(name))
+ ctype := w.Header().Get("Content-Type")
+ if ctype == "" {
+ ctype = mime.TypeByExtension(filepath.Ext(name))
if ctype == "" {
// read a chunk to decide between utf-8 text and binary
var buf [1024]byte
@@ -141,18 +152,34 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time,
}
// handle Content-Range header.
- // TODO(adg): handle multiple ranges
sendSize := size
+ var sendContent io.Reader = content
if size >= 0 {
- ranges, err := parseRange(r.Header.Get("Range"), size)
- if err == nil && len(ranges) > 1 {
- err = errors.New("multiple ranges not supported")
- }
+ ranges, err := parseRange(rangeReq, size)
if err != nil {
Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
return
}
- if len(ranges) == 1 {
+ if sumRangesSize(ranges) >= size {
+ // The total number of bytes in all the ranges
+ // is larger than the size of the file by
+ // itself, so this is probably an attack, or a
+ // dumb client. Ignore the range request.
+ ranges = nil
+ }
+ switch {
+ case len(ranges) == 1:
+ // RFC 2616, Section 14.16:
+ // "When an HTTP message includes the content of a single
+ // range (for example, a response to a request for a
+ // single range, or to a request for a set of ranges
+ // that overlap without any holes), this content is
+ // transmitted with a Content-Range header, and a
+ // Content-Length header showing the number of bytes
+ // actually transferred.
+ // ...
+ // A response to a request for a single range MUST NOT
+ // be sent using the multipart/byteranges media type."
ra := ranges[0]
if _, err := content.Seek(ra.start, os.SEEK_SET); err != nil {
Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
@@ -160,7 +187,41 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time,
}
sendSize = ra.length
code = StatusPartialContent
- w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", ra.start, ra.start+ra.length-1, size))
+ w.Header().Set("Content-Range", ra.contentRange(size))
+ case len(ranges) > 1:
+ for _, ra := range ranges {
+ if ra.start > size {
+ Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
+ return
+ }
+ }
+ sendSize = rangesMIMESize(ranges, ctype, size)
+ code = StatusPartialContent
+
+ pr, pw := io.Pipe()
+ mw := multipart.NewWriter(pw)
+ w.Header().Set("Content-Type", "multipart/byteranges; boundary="+mw.Boundary())
+ sendContent = pr
+ defer pr.Close() // cause writing goroutine to fail and exit if CopyN doesn't finish.
+ go func() {
+ for _, ra := range ranges {
+ part, err := mw.CreatePart(ra.mimeHeader(ctype, size))
+ if err != nil {
+ pw.CloseWithError(err)
+ return
+ }
+ if _, err := content.Seek(ra.start, os.SEEK_SET); err != nil {
+ pw.CloseWithError(err)
+ return
+ }
+ if _, err := io.CopyN(part, content, ra.length); err != nil {
+ pw.CloseWithError(err)
+ return
+ }
+ }
+ mw.Close()
+ pw.Close()
+ }()
}
w.Header().Set("Accept-Ranges", "bytes")
@@ -172,11 +233,7 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time,
w.WriteHeader(code)
if r.Method != "HEAD" {
- if sendSize == -1 {
- io.Copy(w, content)
- } else {
- io.CopyN(w, content, sendSize)
- }
+ io.CopyN(w, sendContent, sendSize)
}
}
@@ -190,6 +247,9 @@ func checkLastModified(w ResponseWriter, r *Request, modtime time.Time) bool {
// The Date-Modified header truncates sub-second precision, so
// use mtime < t+1s instead of mtime <= t to check for unmodified.
if t, err := time.Parse(TimeFormat, r.Header.Get("If-Modified-Since")); err == nil && modtime.Before(t.Add(1*time.Second)) {
+ h := w.Header()
+ delete(h, "Content-Type")
+ delete(h, "Content-Length")
w.WriteHeader(StatusNotModified)
return true
}
@@ -197,6 +257,58 @@ func checkLastModified(w ResponseWriter, r *Request, modtime time.Time) bool {
return false
}
+// checkETag implements If-None-Match and If-Range checks.
+// The ETag must have been previously set in the ResponseWriter's headers.
+//
+// The return value is the effective request "Range" header to use and
+// whether this request is now considered done.
+func checkETag(w ResponseWriter, r *Request) (rangeReq string, done bool) {
+ etag := w.Header().get("Etag")
+ rangeReq = r.Header.get("Range")
+
+ // Invalidate the range request if the entity doesn't match the one
+ // the client was expecting.
+ // "If-Range: version" means "ignore the Range: header unless version matches the
+ // current file."
+ // We only support ETag versions.
+ // The caller must have set the ETag on the response already.
+ if ir := r.Header.get("If-Range"); ir != "" && ir != etag {
+ // TODO(bradfitz): handle If-Range requests with Last-Modified
+ // times instead of ETags? I'd rather not, at least for
+ // now. That seems like a bug/compromise in the RFC 2616, and
+ // I've never heard of anybody caring about that (yet).
+ rangeReq = ""
+ }
+
+ if inm := r.Header.get("If-None-Match"); inm != "" {
+ // Must know ETag.
+ if etag == "" {
+ return rangeReq, false
+ }
+
+ // TODO(bradfitz): non-GET/HEAD requests require more work:
+ // sending a different status code on matches, and
+ // also can't use weak cache validators (those with a "W/
+ // prefix). But most users of ServeContent will be using
+ // it on GET or HEAD, so only support those for now.
+ if r.Method != "GET" && r.Method != "HEAD" {
+ return rangeReq, false
+ }
+
+ // TODO(bradfitz): deal with comma-separated or multiple-valued
+ // list of If-None-match values. For now just handle the common
+ // case of a single item.
+ if inm == etag || inm == "*" {
+ h := w.Header()
+ delete(h, "Content-Type")
+ delete(h, "Content-Length")
+ w.WriteHeader(StatusNotModified)
+ return "", true
+ }
+ }
+ return rangeReq, false
+}
+
// name is '/'-separated, not filepath.Separator.
func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirect bool) {
const indexPage = "/index.html"
@@ -243,9 +355,6 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec
// use contents of index.html for directory, if present
if d.IsDir() {
- if checkLastModified(w, r, d.ModTime()) {
- return
- }
index := name + indexPage
ff, err := fs.Open(index)
if err == nil {
@@ -259,11 +368,16 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec
}
}
+ // Still a directory? (we didn't find an index.html file)
if d.IsDir() {
+ if checkLastModified(w, r, d.ModTime()) {
+ return
+ }
dirList(w, f)
return
}
+ // serverContent will check modification time
serveContent(w, r, d.Name(), d.ModTime(), d.Size(), f)
}
@@ -312,6 +426,17 @@ type httpRange struct {
start, length int64
}
+func (r httpRange) contentRange(size int64) string {
+ return fmt.Sprintf("bytes %d-%d/%d", r.start, r.start+r.length-1, size)
+}
+
+func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHeader {
+ return textproto.MIMEHeader{
+ "Content-Range": {r.contentRange(size)},
+ "Content-Type": {contentType},
+ }
+}
+
// parseRange parses a Range header string as per RFC 2616.
func parseRange(s string, size int64) ([]httpRange, error) {
if s == "" {
@@ -323,11 +448,15 @@ func parseRange(s string, size int64) ([]httpRange, error) {
}
var ranges []httpRange
for _, ra := range strings.Split(s[len(b):], ",") {
+ ra = strings.TrimSpace(ra)
+ if ra == "" {
+ continue
+ }
i := strings.Index(ra, "-")
if i < 0 {
return nil, errors.New("invalid range")
}
- start, end := ra[:i], ra[i+1:]
+ start, end := strings.TrimSpace(ra[:i]), strings.TrimSpace(ra[i+1:])
var r httpRange
if start == "" {
// If no start is specified, end specifies the
@@ -365,3 +494,32 @@ func parseRange(s string, size int64) ([]httpRange, error) {
}
return ranges, nil
}
+
+// countingWriter counts how many bytes have been written to it.
+type countingWriter int64
+
+func (w *countingWriter) Write(p []byte) (n int, err error) {
+ *w += countingWriter(len(p))
+ return len(p), nil
+}
+
+// rangesMIMESize returns the nunber of bytes it takes to encode the
+// provided ranges as a multipart response.
+func rangesMIMESize(ranges []httpRange, contentType string, contentSize int64) (encSize int64) {
+ var w countingWriter
+ mw := multipart.NewWriter(&w)
+ for _, ra := range ranges {
+ mw.CreatePart(ra.mimeHeader(contentType, contentSize))
+ encSize += ra.length
+ }
+ mw.Close()
+ encSize += int64(w)
+ return
+}
+
+func sumRangesSize(ranges []httpRange) (size int64) {
+ for _, ra := range ranges {
+ size += ra.length
+ }
+ return
+}
diff --git a/src/pkg/net/http/fs_test.go b/src/pkg/net/http/fs_test.go
index 5aa93ce58..0dd6d0df9 100644
--- a/src/pkg/net/http/fs_test.go
+++ b/src/pkg/net/http/fs_test.go
@@ -10,12 +10,15 @@ import (
"fmt"
"io"
"io/ioutil"
+ "mime"
+ "mime/multipart"
"net"
. "net/http"
"net/http/httptest"
"net/url"
"os"
"os/exec"
+ "path"
"path/filepath"
"regexp"
"runtime"
@@ -25,24 +28,33 @@ import (
)
const (
- testFile = "testdata/file"
- testFileLength = 11
+ testFile = "testdata/file"
+ testFileLen = 11
)
+type wantRange struct {
+ start, end int64 // range [start,end)
+}
+
var ServeFileRangeTests = []struct {
- start, end int
- r string
- code int
+ r string
+ code int
+ ranges []wantRange
}{
- {0, testFileLength, "", StatusOK},
- {0, 5, "0-4", StatusPartialContent},
- {2, testFileLength, "2-", StatusPartialContent},
- {testFileLength - 5, testFileLength, "-5", StatusPartialContent},
- {3, 8, "3-7", StatusPartialContent},
- {0, 0, "20-", StatusRequestedRangeNotSatisfiable},
+ {r: "", code: StatusOK},
+ {r: "bytes=0-4", code: StatusPartialContent, ranges: []wantRange{{0, 5}}},
+ {r: "bytes=2-", code: StatusPartialContent, ranges: []wantRange{{2, testFileLen}}},
+ {r: "bytes=-5", code: StatusPartialContent, ranges: []wantRange{{testFileLen - 5, testFileLen}}},
+ {r: "bytes=3-7", code: StatusPartialContent, ranges: []wantRange{{3, 8}}},
+ {r: "bytes=20-", code: StatusRequestedRangeNotSatisfiable},
+ {r: "bytes=0-0,-2", code: StatusPartialContent, ranges: []wantRange{{0, 1}, {testFileLen - 2, testFileLen}}},
+ {r: "bytes=0-1,5-8", code: StatusPartialContent, ranges: []wantRange{{0, 2}, {5, 9}}},
+ {r: "bytes=0-1,5-", code: StatusPartialContent, ranges: []wantRange{{0, 2}, {5, testFileLen}}},
+ {r: "bytes=0-,1-,2-,3-,4-", code: StatusOK}, // ignore wasteful range request
}
func TestServeFile(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
ServeFile(w, r, "testdata/file")
}))
@@ -65,33 +77,86 @@ func TestServeFile(t *testing.T) {
// straight GET
_, body := getBody(t, "straight get", req)
- if !equal(body, file) {
+ if !bytes.Equal(body, file) {
t.Fatalf("body mismatch: got %q, want %q", body, file)
}
// Range tests
- for i, rt := range ServeFileRangeTests {
- req.Header.Set("Range", "bytes="+rt.r)
- if rt.r == "" {
- req.Header["Range"] = nil
+Cases:
+ for _, rt := range ServeFileRangeTests {
+ if rt.r != "" {
+ req.Header.Set("Range", rt.r)
}
- r, body := getBody(t, fmt.Sprintf("test %d", i), req)
- if r.StatusCode != rt.code {
- t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, r.StatusCode, rt.code)
+ resp, body := getBody(t, fmt.Sprintf("range test %q", rt.r), req)
+ if resp.StatusCode != rt.code {
+ t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, resp.StatusCode, rt.code)
}
if rt.code == StatusRequestedRangeNotSatisfiable {
continue
}
- h := fmt.Sprintf("bytes %d-%d/%d", rt.start, rt.end-1, testFileLength)
- if rt.r == "" {
- h = ""
+ wantContentRange := ""
+ if len(rt.ranges) == 1 {
+ rng := rt.ranges[0]
+ wantContentRange = fmt.Sprintf("bytes %d-%d/%d", rng.start, rng.end-1, testFileLen)
+ }
+ cr := resp.Header.Get("Content-Range")
+ if cr != wantContentRange {
+ t.Errorf("range=%q: Content-Range = %q, want %q", rt.r, cr, wantContentRange)
}
- cr := r.Header.Get("Content-Range")
- if cr != h {
- t.Errorf("header mismatch: range=%q: got %q, want %q", rt.r, cr, h)
+ ct := resp.Header.Get("Content-Type")
+ if len(rt.ranges) == 1 {
+ rng := rt.ranges[0]
+ wantBody := file[rng.start:rng.end]
+ if !bytes.Equal(body, wantBody) {
+ t.Errorf("range=%q: body = %q, want %q", rt.r, body, wantBody)
+ }
+ if strings.HasPrefix(ct, "multipart/byteranges") {
+ t.Errorf("range=%q content-type = %q; unexpected multipart/byteranges", rt.r, ct)
+ }
}
- if !equal(body, file[rt.start:rt.end]) {
- t.Errorf("body mismatch: range=%q: got %q, want %q", rt.r, body, file[rt.start:rt.end])
+ if len(rt.ranges) > 1 {
+ typ, params, err := mime.ParseMediaType(ct)
+ if err != nil {
+ t.Errorf("range=%q content-type = %q; %v", rt.r, ct, err)
+ continue
+ }
+ if typ != "multipart/byteranges" {
+ t.Errorf("range=%q content-type = %q; want multipart/byteranges", rt.r, typ)
+ continue
+ }
+ if params["boundary"] == "" {
+ t.Errorf("range=%q content-type = %q; lacks boundary", rt.r, ct)
+ continue
+ }
+ if g, w := resp.ContentLength, int64(len(body)); g != w {
+ t.Errorf("range=%q Content-Length = %d; want %d", rt.r, g, w)
+ continue
+ }
+ mr := multipart.NewReader(bytes.NewReader(body), params["boundary"])
+ for ri, rng := range rt.ranges {
+ part, err := mr.NextPart()
+ if err != nil {
+ t.Errorf("range=%q, reading part index %d: %v", rt.r, ri, err)
+ continue Cases
+ }
+ wantContentRange = fmt.Sprintf("bytes %d-%d/%d", rng.start, rng.end-1, testFileLen)
+ if g, w := part.Header.Get("Content-Range"), wantContentRange; g != w {
+ t.Errorf("range=%q: part Content-Range = %q; want %q", rt.r, g, w)
+ }
+ body, err := ioutil.ReadAll(part)
+ if err != nil {
+ t.Errorf("range=%q, reading part index %d body: %v", rt.r, ri, err)
+ continue Cases
+ }
+ wantBody := file[rng.start:rng.end]
+ if !bytes.Equal(body, wantBody) {
+ t.Errorf("range=%q: body = %q, want %q", rt.r, body, wantBody)
+ }
+ }
+ _, err = mr.NextPart()
+ if err != io.EOF {
+ t.Errorf("range=%q; expected final error io.EOF; got %v", rt.r, err)
+ }
}
}
}
@@ -105,6 +170,7 @@ var fsRedirectTestData = []struct {
}
func TestFSRedirect(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(StripPrefix("/test", FileServer(Dir("."))))
defer ts.Close()
@@ -129,6 +195,7 @@ func (fs *testFileSystem) Open(name string) (File, error) {
}
func TestFileServerCleans(t *testing.T) {
+ defer checkLeakedTransports(t)
ch := make(chan string, 1)
fs := FileServer(&testFileSystem{func(name string) (File, error) {
ch <- name
@@ -160,6 +227,7 @@ func mustRemoveAll(dir string) {
}
func TestFileServerImplicitLeadingSlash(t *testing.T) {
+ defer checkLeakedTransports(t)
tempDir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatalf("TempDir: %v", err)
@@ -193,8 +261,7 @@ func TestFileServerImplicitLeadingSlash(t *testing.T) {
func TestDirJoin(t *testing.T) {
wfi, err := os.Stat("/etc/hosts")
if err != nil {
- t.Logf("skipping test; no /etc/hosts file")
- return
+ t.Skip("skipping test; no /etc/hosts file")
}
test := func(d Dir, name string) {
f, err := d.Open(name)
@@ -239,6 +306,7 @@ func TestEmptyDirOpenCWD(t *testing.T) {
}
func TestServeFileContentType(t *testing.T) {
+ defer checkLeakedTransports(t)
const ctype = "icecream/chocolate"
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.FormValue("override") == "1" {
@@ -255,12 +323,14 @@ func TestServeFileContentType(t *testing.T) {
if h := resp.Header.Get("Content-Type"); h != want {
t.Errorf("Content-Type mismatch: got %q, want %q", h, want)
}
+ resp.Body.Close()
}
get("0", "text/plain; charset=utf-8")
get("1", ctype)
}
func TestServeFileMimeType(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
ServeFile(w, r, "testdata/style.css")
}))
@@ -269,6 +339,7 @@ func TestServeFileMimeType(t *testing.T) {
if err != nil {
t.Fatal(err)
}
+ resp.Body.Close()
want := "text/css; charset=utf-8"
if h := resp.Header.Get("Content-Type"); h != want {
t.Errorf("Content-Type mismatch: got %q, want %q", h, want)
@@ -276,6 +347,7 @@ func TestServeFileMimeType(t *testing.T) {
}
func TestServeFileFromCWD(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
ServeFile(w, r, "fs_test.go")
}))
@@ -284,12 +356,14 @@ func TestServeFileFromCWD(t *testing.T) {
if err != nil {
t.Fatal(err)
}
+ r.Body.Close()
if r.StatusCode != 200 {
t.Fatalf("expected 200 OK, got %s", r.Status)
}
}
func TestServeFileWithContentEncoding(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Content-Encoding", "foo")
ServeFile(w, r, "testdata/file")
@@ -299,12 +373,14 @@ func TestServeFileWithContentEncoding(t *testing.T) {
if err != nil {
t.Fatal(err)
}
+ resp.Body.Close()
if g, e := resp.ContentLength, int64(-1); g != e {
t.Errorf("Content-Length mismatch: got %d, want %d", g, e)
}
}
func TestServeIndexHtml(t *testing.T) {
+ defer checkLeakedTransports(t)
const want = "index.html says hello\n"
ts := httptest.NewServer(FileServer(Dir(".")))
defer ts.Close()
@@ -325,64 +401,289 @@ func TestServeIndexHtml(t *testing.T) {
}
}
-func TestServeContent(t *testing.T) {
- type req struct {
- name string
- modtime time.Time
- content io.ReadSeeker
- }
- ch := make(chan req, 1)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
- p := <-ch
- ServeContent(w, r, p.name, p.modtime, p.content)
- }))
+func TestFileServerZeroByte(t *testing.T) {
+ defer checkLeakedTransports(t)
+ ts := httptest.NewServer(FileServer(Dir(".")))
defer ts.Close()
- css, err := os.Open("testdata/style.css")
+ res, err := Get(ts.URL + "/..\x00")
if err != nil {
t.Fatal(err)
}
- defer css.Close()
+ b, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal("reading Body:", err)
+ }
+ if res.StatusCode == 200 {
+ t.Errorf("got status 200; want an error. Body is:\n%s", string(b))
+ }
+}
+
+type fakeFileInfo struct {
+ dir bool
+ basename string
+ modtime time.Time
+ ents []*fakeFileInfo
+ contents string
+}
+
+func (f *fakeFileInfo) Name() string { return f.basename }
+func (f *fakeFileInfo) Sys() interface{} { return nil }
+func (f *fakeFileInfo) ModTime() time.Time { return f.modtime }
+func (f *fakeFileInfo) IsDir() bool { return f.dir }
+func (f *fakeFileInfo) Size() int64 { return int64(len(f.contents)) }
+func (f *fakeFileInfo) Mode() os.FileMode {
+ if f.dir {
+ return 0755 | os.ModeDir
+ }
+ return 0644
+}
+
+type fakeFile struct {
+ io.ReadSeeker
+ fi *fakeFileInfo
+ path string // as opened
+}
+
+func (f *fakeFile) Close() error { return nil }
+func (f *fakeFile) Stat() (os.FileInfo, error) { return f.fi, nil }
+func (f *fakeFile) Readdir(count int) ([]os.FileInfo, error) {
+ if !f.fi.dir {
+ return nil, os.ErrInvalid
+ }
+ var fis []os.FileInfo
+ for _, fi := range f.fi.ents {
+ fis = append(fis, fi)
+ }
+ return fis, nil
+}
+
+type fakeFS map[string]*fakeFileInfo
+
+func (fs fakeFS) Open(name string) (File, error) {
+ name = path.Clean(name)
+ f, ok := fs[name]
+ if !ok {
+ println("fake filesystem didn't find file", name)
+ return nil, os.ErrNotExist
+ }
+ return &fakeFile{ReadSeeker: strings.NewReader(f.contents), fi: f, path: name}, nil
+}
+
+func TestDirectoryIfNotModified(t *testing.T) {
+ defer checkLeakedTransports(t)
+ const indexContents = "I am a fake index.html file"
+ fileMod := time.Unix(1000000000, 0).UTC()
+ fileModStr := fileMod.Format(TimeFormat)
+ dirMod := time.Unix(123, 0).UTC()
+ indexFile := &fakeFileInfo{
+ basename: "index.html",
+ modtime: fileMod,
+ contents: indexContents,
+ }
+ fs := fakeFS{
+ "/": &fakeFileInfo{
+ dir: true,
+ modtime: dirMod,
+ ents: []*fakeFileInfo{indexFile},
+ },
+ "/index.html": indexFile,
+ }
+
+ ts := httptest.NewServer(FileServer(fs))
+ defer ts.Close()
- ch <- req{"style.css", time.Time{}, css}
res, err := Get(ts.URL)
if err != nil {
t.Fatal(err)
}
- if g, e := res.Header.Get("Content-Type"), "text/css; charset=utf-8"; g != e {
- t.Errorf("style.css: content type = %q, want %q", g, e)
+ b, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
}
- if g := res.Header.Get("Last-Modified"); g != "" {
- t.Errorf("want empty Last-Modified; got %q", g)
+ if string(b) != indexContents {
+ t.Fatalf("Got body %q; want %q", b, indexContents)
}
+ res.Body.Close()
+
+ lastMod := res.Header.Get("Last-Modified")
+ if lastMod != fileModStr {
+ t.Fatalf("initial Last-Modified = %q; want %q", lastMod, fileModStr)
+ }
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ req.Header.Set("If-Modified-Since", lastMod)
- fi, err := css.Stat()
+ res, err = DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
- ch <- req{"style.html", fi.ModTime(), css}
- res, err = Get(ts.URL)
+ if res.StatusCode != 304 {
+ t.Fatalf("Code after If-Modified-Since request = %v; want 304", res.StatusCode)
+ }
+ res.Body.Close()
+
+ // Advance the index.html file's modtime, but not the directory's.
+ indexFile.modtime = indexFile.modtime.Add(1 * time.Hour)
+
+ res, err = DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
- if g, e := res.Header.Get("Content-Type"), "text/html; charset=utf-8"; g != e {
- t.Errorf("style.html: content type = %q, want %q", g, e)
+ if res.StatusCode != 200 {
+ t.Fatalf("Code after second If-Modified-Since request = %v; want 200; res is %#v", res.StatusCode, res)
}
- if g := res.Header.Get("Last-Modified"); g == "" {
- t.Errorf("want non-empty last-modified")
+ res.Body.Close()
+}
+
+func mustStat(t *testing.T, fileName string) os.FileInfo {
+ fi, err := os.Stat(fileName)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return fi
+}
+
+func TestServeContent(t *testing.T) {
+ defer checkLeakedTransports(t)
+ type serveParam struct {
+ name string
+ modtime time.Time
+ content io.ReadSeeker
+ contentType string
+ etag string
+ }
+ servec := make(chan serveParam, 1)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ p := <-servec
+ if p.etag != "" {
+ w.Header().Set("ETag", p.etag)
+ }
+ if p.contentType != "" {
+ w.Header().Set("Content-Type", p.contentType)
+ }
+ ServeContent(w, r, p.name, p.modtime, p.content)
+ }))
+ defer ts.Close()
+
+ type testCase struct {
+ file string
+ modtime time.Time
+ serveETag string // optional
+ serveContentType string // optional
+ reqHeader map[string]string
+ wantLastMod string
+ wantContentType string
+ wantStatus int
+ }
+ htmlModTime := mustStat(t, "testdata/index.html").ModTime()
+ tests := map[string]testCase{
+ "no_last_modified": {
+ file: "testdata/style.css",
+ wantContentType: "text/css; charset=utf-8",
+ wantStatus: 200,
+ },
+ "with_last_modified": {
+ file: "testdata/index.html",
+ wantContentType: "text/html; charset=utf-8",
+ modtime: htmlModTime,
+ wantLastMod: htmlModTime.UTC().Format(TimeFormat),
+ wantStatus: 200,
+ },
+ "not_modified_modtime": {
+ file: "testdata/style.css",
+ modtime: htmlModTime,
+ reqHeader: map[string]string{
+ "If-Modified-Since": htmlModTime.UTC().Format(TimeFormat),
+ },
+ wantStatus: 304,
+ },
+ "not_modified_modtime_with_contenttype": {
+ file: "testdata/style.css",
+ serveContentType: "text/css", // explicit content type
+ modtime: htmlModTime,
+ reqHeader: map[string]string{
+ "If-Modified-Since": htmlModTime.UTC().Format(TimeFormat),
+ },
+ wantStatus: 304,
+ },
+ "not_modified_etag": {
+ file: "testdata/style.css",
+ serveETag: `"foo"`,
+ reqHeader: map[string]string{
+ "If-None-Match": `"foo"`,
+ },
+ wantStatus: 304,
+ },
+ "range_good": {
+ file: "testdata/style.css",
+ serveETag: `"A"`,
+ reqHeader: map[string]string{
+ "Range": "bytes=0-4",
+ },
+ wantStatus: StatusPartialContent,
+ wantContentType: "text/css; charset=utf-8",
+ },
+ // An If-Range resource for entity "A", but entity "B" is now current.
+ // The Range request should be ignored.
+ "range_no_match": {
+ file: "testdata/style.css",
+ serveETag: `"A"`,
+ reqHeader: map[string]string{
+ "Range": "bytes=0-4",
+ "If-Range": `"B"`,
+ },
+ wantStatus: 200,
+ wantContentType: "text/css; charset=utf-8",
+ },
+ }
+ for testName, tt := range tests {
+ f, err := os.Open(tt.file)
+ if err != nil {
+ t.Fatalf("test %q: %v", testName, err)
+ }
+ defer f.Close()
+
+ servec <- serveParam{
+ name: filepath.Base(tt.file),
+ content: f,
+ modtime: tt.modtime,
+ etag: tt.serveETag,
+ contentType: tt.serveContentType,
+ }
+ req, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for k, v := range tt.reqHeader {
+ req.Header.Set(k, v)
+ }
+ res, err := DefaultClient.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ io.Copy(ioutil.Discard, res.Body)
+ res.Body.Close()
+ if res.StatusCode != tt.wantStatus {
+ t.Errorf("test %q: status = %d; want %d", testName, res.StatusCode, tt.wantStatus)
+ }
+ if g, e := res.Header.Get("Content-Type"), tt.wantContentType; g != e {
+ t.Errorf("test %q: content-type = %q, want %q", testName, g, e)
+ }
+ if g, e := res.Header.Get("Last-Modified"), tt.wantLastMod; g != e {
+ t.Errorf("test %q: last-modified = %q, want %q", testName, g, e)
+ }
}
}
// verifies that sendfile is being used on Linux
func TestLinuxSendfile(t *testing.T) {
+ defer checkLeakedTransports(t)
if runtime.GOOS != "linux" {
- t.Logf("skipping; linux-only test")
- return
+ t.Skip("skipping; linux-only test")
}
- _, err := exec.LookPath("strace")
- if err != nil {
- t.Logf("skipping; strace not found in path")
- return
+ if _, err := exec.LookPath("strace"); err != nil {
+ t.Skip("skipping; strace not found in path")
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
@@ -401,10 +702,8 @@ func TestLinuxSendfile(t *testing.T) {
child.Env = append([]string{"GO_WANT_HELPER_PROCESS=1"}, os.Environ()...)
child.Stdout = &buf
child.Stderr = &buf
- err = child.Start()
- if err != nil {
- t.Logf("skipping; failed to start straced child: %v", err)
- return
+ if err := child.Start(); err != nil {
+ t.Skipf("skipping; failed to start straced child: %v", err)
}
res, err := Get(fmt.Sprintf("http://%s/", ln.Addr()))
@@ -464,15 +763,3 @@ func TestLinuxSendfileChild(*testing.T) {
panic(err)
}
}
-
-func equal(a, b []byte) bool {
- if len(a) != len(b) {
- return false
- }
- for i := range a {
- if a[i] != b[i] {
- return false
- }
- }
- return true
-}
diff --git a/src/pkg/net/http/header.go b/src/pkg/net/http/header.go
index b107c312d..f479b7b4e 100644
--- a/src/pkg/net/http/header.go
+++ b/src/pkg/net/http/header.go
@@ -5,11 +5,11 @@
package http
import (
- "fmt"
"io"
"net/textproto"
"sort"
"strings"
+ "time"
)
// A Header represents the key-value pairs in an HTTP header.
@@ -36,6 +36,14 @@ func (h Header) Get(key string) string {
return textproto.MIMEHeader(h).Get(key)
}
+// get is like Get, but key must already be in CanonicalHeaderKey form.
+func (h Header) get(key string) string {
+ if v := h[key]; len(v) > 0 {
+ return v[0]
+ }
+ return ""
+}
+
// Del deletes the values associated with key.
func (h Header) Del(key string) {
textproto.MIMEHeader(h).Del(key)
@@ -46,24 +54,87 @@ func (h Header) Write(w io.Writer) error {
return h.WriteSubset(w, nil)
}
+func (h Header) clone() Header {
+ h2 := make(Header, len(h))
+ for k, vv := range h {
+ vv2 := make([]string, len(vv))
+ copy(vv2, vv)
+ h2[k] = vv2
+ }
+ return h2
+}
+
+var timeFormats = []string{
+ TimeFormat,
+ time.RFC850,
+ time.ANSIC,
+}
+
+// ParseTime parses a time header (such as the Date: header),
+// trying each of the three formats allowed by HTTP/1.1:
+// TimeFormat, time.RFC850, and time.ANSIC.
+func ParseTime(text string) (t time.Time, err error) {
+ for _, layout := range timeFormats {
+ t, err = time.Parse(layout, text)
+ if err == nil {
+ return
+ }
+ }
+ return
+}
+
var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ")
+type writeStringer interface {
+ WriteString(string) (int, error)
+}
+
+// stringWriter implements WriteString on a Writer.
+type stringWriter struct {
+ w io.Writer
+}
+
+func (w stringWriter) WriteString(s string) (n int, err error) {
+ return w.w.Write([]byte(s))
+}
+
+type keyValues struct {
+ key string
+ values []string
+}
+
+type byKey []keyValues
+
+func (s byKey) Len() int { return len(s) }
+func (s byKey) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+func (s byKey) Less(i, j int) bool { return s[i].key < s[j].key }
+
+func (h Header) sortedKeyValues(exclude map[string]bool) []keyValues {
+ kvs := make([]keyValues, 0, len(h))
+ for k, vv := range h {
+ if !exclude[k] {
+ kvs = append(kvs, keyValues{k, vv})
+ }
+ }
+ sort.Sort(byKey(kvs))
+ return kvs
+}
+
// WriteSubset writes a header in wire format.
// If exclude is not nil, keys where exclude[key] == true are not written.
func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error {
- keys := make([]string, 0, len(h))
- for k := range h {
- if exclude == nil || !exclude[k] {
- keys = append(keys, k)
- }
+ ws, ok := w.(writeStringer)
+ if !ok {
+ ws = stringWriter{w}
}
- sort.Strings(keys)
- for _, k := range keys {
- for _, v := range h[k] {
+ for _, kv := range h.sortedKeyValues(exclude) {
+ for _, v := range kv.values {
v = headerNewlineToSpace.Replace(v)
- v = strings.TrimSpace(v)
- if _, err := fmt.Fprintf(w, "%s: %s\r\n", k, v); err != nil {
- return err
+ v = textproto.TrimString(v)
+ for _, s := range []string{kv.key, ": ", v, "\r\n"} {
+ if _, err := ws.WriteString(s); err != nil {
+ return err
+ }
}
}
}
@@ -76,3 +147,43 @@ func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error {
// the rest are converted to lowercase. For example, the
// canonical key for "accept-encoding" is "Accept-Encoding".
func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) }
+
+// hasToken returns whether token appears with v, ASCII
+// case-insensitive, with space or comma boundaries.
+// token must be all lowercase.
+// v may contain mixed cased.
+func hasToken(v, token string) bool {
+ if len(token) > len(v) || token == "" {
+ return false
+ }
+ if v == token {
+ return true
+ }
+ for sp := 0; sp <= len(v)-len(token); sp++ {
+ // Check that first character is good.
+ // The token is ASCII, so checking only a single byte
+ // is sufficient. We skip this potential starting
+ // position if both the first byte and its potential
+ // ASCII uppercase equivalent (b|0x20) don't match.
+ // False positives ('^' => '~') are caught by EqualFold.
+ if b := v[sp]; b != token[0] && b|0x20 != token[0] {
+ continue
+ }
+ // Check that start pos is on a valid token boundary.
+ if sp > 0 && !isTokenBoundary(v[sp-1]) {
+ continue
+ }
+ // Check that end pos is on a valid token boundary.
+ if endPos := sp + len(token); endPos != len(v) && !isTokenBoundary(v[endPos]) {
+ continue
+ }
+ if strings.EqualFold(v[sp:sp+len(token)], token) {
+ return true
+ }
+ }
+ return false
+}
+
+func isTokenBoundary(b byte) bool {
+ return b == ' ' || b == ',' || b == '\t'
+}
diff --git a/src/pkg/net/http/header_test.go b/src/pkg/net/http/header_test.go
index ccdee8a97..2313b5549 100644
--- a/src/pkg/net/http/header_test.go
+++ b/src/pkg/net/http/header_test.go
@@ -7,6 +7,7 @@ package http
import (
"bytes"
"testing"
+ "time"
)
var headerWriteTests = []struct {
@@ -67,6 +68,24 @@ var headerWriteTests = []struct {
nil,
"Blank: \r\nDouble-Blank: \r\nDouble-Blank: \r\n",
},
+ // Tests header sorting when over the insertion sort threshold side:
+ {
+ Header{
+ "k1": {"1a", "1b"},
+ "k2": {"2a", "2b"},
+ "k3": {"3a", "3b"},
+ "k4": {"4a", "4b"},
+ "k5": {"5a", "5b"},
+ "k6": {"6a", "6b"},
+ "k7": {"7a", "7b"},
+ "k8": {"8a", "8b"},
+ "k9": {"9a", "9b"},
+ },
+ map[string]bool{"k5": true},
+ "k1: 1a\r\nk1: 1b\r\nk2: 2a\r\nk2: 2b\r\nk3: 3a\r\nk3: 3b\r\n" +
+ "k4: 4a\r\nk4: 4b\r\nk6: 6a\r\nk6: 6b\r\n" +
+ "k7: 7a\r\nk7: 7b\r\nk8: 8a\r\nk8: 8b\r\nk9: 9a\r\nk9: 9b\r\n",
+ },
}
func TestHeaderWrite(t *testing.T) {
@@ -79,3 +98,107 @@ func TestHeaderWrite(t *testing.T) {
buf.Reset()
}
}
+
+var parseTimeTests = []struct {
+ h Header
+ err bool
+}{
+ {Header{"Date": {""}}, true},
+ {Header{"Date": {"invalid"}}, true},
+ {Header{"Date": {"1994-11-06T08:49:37Z00:00"}}, true},
+ {Header{"Date": {"Sun, 06 Nov 1994 08:49:37 GMT"}}, false},
+ {Header{"Date": {"Sunday, 06-Nov-94 08:49:37 GMT"}}, false},
+ {Header{"Date": {"Sun Nov 6 08:49:37 1994"}}, false},
+}
+
+func TestParseTime(t *testing.T) {
+ expect := time.Date(1994, 11, 6, 8, 49, 37, 0, time.UTC)
+ for i, test := range parseTimeTests {
+ d, err := ParseTime(test.h.Get("Date"))
+ if err != nil {
+ if !test.err {
+ t.Errorf("#%d:\n got err: %v", i, err)
+ }
+ continue
+ }
+ if test.err {
+ t.Errorf("#%d:\n should err", i)
+ continue
+ }
+ if !expect.Equal(d) {
+ t.Errorf("#%d:\n got: %v\nwant: %v", i, d, expect)
+ }
+ }
+}
+
+type hasTokenTest struct {
+ header string
+ token string
+ want bool
+}
+
+var hasTokenTests = []hasTokenTest{
+ {"", "", false},
+ {"", "foo", false},
+ {"foo", "foo", true},
+ {"foo ", "foo", true},
+ {" foo", "foo", true},
+ {" foo ", "foo", true},
+ {"foo,bar", "foo", true},
+ {"bar,foo", "foo", true},
+ {"bar, foo", "foo", true},
+ {"bar,foo, baz", "foo", true},
+ {"bar, foo,baz", "foo", true},
+ {"bar,foo, baz", "foo", true},
+ {"bar, foo, baz", "foo", true},
+ {"FOO", "foo", true},
+ {"FOO ", "foo", true},
+ {" FOO", "foo", true},
+ {" FOO ", "foo", true},
+ {"FOO,BAR", "foo", true},
+ {"BAR,FOO", "foo", true},
+ {"BAR, FOO", "foo", true},
+ {"BAR,FOO, baz", "foo", true},
+ {"BAR, FOO,BAZ", "foo", true},
+ {"BAR,FOO, BAZ", "foo", true},
+ {"BAR, FOO, BAZ", "foo", true},
+ {"foobar", "foo", false},
+ {"barfoo ", "foo", false},
+}
+
+func TestHasToken(t *testing.T) {
+ for _, tt := range hasTokenTests {
+ if hasToken(tt.header, tt.token) != tt.want {
+ t.Errorf("hasToken(%q, %q) = %v; want %v", tt.header, tt.token, !tt.want, tt.want)
+ }
+ }
+}
+
+var testHeader = Header{
+ "Content-Length": {"123"},
+ "Content-Type": {"text/plain"},
+ "Date": {"some date at some time Z"},
+ "Server": {"Go http package"},
+}
+
+var buf bytes.Buffer
+
+func BenchmarkHeaderWriteSubset(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ buf.Reset()
+ testHeader.WriteSubset(&buf, nil)
+ }
+}
+
+func TestHeaderWriteSubsetMallocs(t *testing.T) {
+ n := testing.AllocsPerRun(100, func() {
+ buf.Reset()
+ testHeader.WriteSubset(&buf, nil)
+ })
+ if n > 1 {
+ // TODO(bradfitz,rsc): once we can sort without allocating,
+ // make this an error. See http://golang.org/issue/3761
+ // t.Errorf("got %v allocs, want <= %v", n, 1)
+ }
+}
diff --git a/src/pkg/net/http/httptest/example_test.go b/src/pkg/net/http/httptest/example_test.go
new file mode 100644
index 000000000..239470d97
--- /dev/null
+++ b/src/pkg/net/http/httptest/example_test.go
@@ -0,0 +1,50 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httptest_test
+
+import (
+ "fmt"
+ "io/ioutil"
+ "log"
+ "net/http"
+ "net/http/httptest"
+)
+
+func ExampleRecorder() {
+ handler := func(w http.ResponseWriter, r *http.Request) {
+ http.Error(w, "something failed", http.StatusInternalServerError)
+ }
+
+ req, err := http.NewRequest("GET", "http://example.com/foo", nil)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ w := httptest.NewRecorder()
+ handler(w, req)
+
+ fmt.Printf("%d - %s", w.Code, w.Body.String())
+ // Output: 500 - something failed
+}
+
+func ExampleServer() {
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintln(w, "Hello, client")
+ }))
+ defer ts.Close()
+
+ res, err := http.Get(ts.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+ greeting, err := ioutil.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Printf("%s", greeting)
+ // Output: Hello, client
+}
diff --git a/src/pkg/net/http/httptest/recorder.go b/src/pkg/net/http/httptest/recorder.go
index 9aa0d510b..5451f5423 100644
--- a/src/pkg/net/http/httptest/recorder.go
+++ b/src/pkg/net/http/httptest/recorder.go
@@ -17,6 +17,8 @@ type ResponseRecorder struct {
HeaderMap http.Header // the HTTP response headers
Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
Flushed bool
+
+ wroteHeader bool
}
// NewRecorder returns an initialized ResponseRecorder.
@@ -24,6 +26,7 @@ func NewRecorder() *ResponseRecorder {
return &ResponseRecorder{
HeaderMap: make(http.Header),
Body: new(bytes.Buffer),
+ Code: 200,
}
}
@@ -33,26 +36,37 @@ const DefaultRemoteAddr = "1.2.3.4"
// Header returns the response headers.
func (rw *ResponseRecorder) Header() http.Header {
- return rw.HeaderMap
+ m := rw.HeaderMap
+ if m == nil {
+ m = make(http.Header)
+ rw.HeaderMap = m
+ }
+ return m
}
// Write always succeeds and writes to rw.Body, if not nil.
func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
+ if !rw.wroteHeader {
+ rw.WriteHeader(200)
+ }
if rw.Body != nil {
rw.Body.Write(buf)
}
- if rw.Code == 0 {
- rw.Code = http.StatusOK
- }
return len(buf), nil
}
// WriteHeader sets rw.Code.
func (rw *ResponseRecorder) WriteHeader(code int) {
- rw.Code = code
+ if !rw.wroteHeader {
+ rw.Code = code
+ }
+ rw.wroteHeader = true
}
// Flush sets rw.Flushed to true.
func (rw *ResponseRecorder) Flush() {
+ if !rw.wroteHeader {
+ rw.WriteHeader(200)
+ }
rw.Flushed = true
}
diff --git a/src/pkg/net/http/httptest/recorder_test.go b/src/pkg/net/http/httptest/recorder_test.go
new file mode 100644
index 000000000..2b563260c
--- /dev/null
+++ b/src/pkg/net/http/httptest/recorder_test.go
@@ -0,0 +1,90 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httptest
+
+import (
+ "fmt"
+ "net/http"
+ "testing"
+)
+
+func TestRecorder(t *testing.T) {
+ type checkFunc func(*ResponseRecorder) error
+ check := func(fns ...checkFunc) []checkFunc { return fns }
+
+ hasStatus := func(wantCode int) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if rec.Code != wantCode {
+ return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode)
+ }
+ return nil
+ }
+ }
+ hasContents := func(want string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if rec.Body.String() != want {
+ return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want)
+ }
+ return nil
+ }
+ }
+ hasFlush := func(want bool) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if rec.Flushed != want {
+ return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want)
+ }
+ return nil
+ }
+ }
+
+ tests := []struct {
+ name string
+ h func(w http.ResponseWriter, r *http.Request)
+ checks []checkFunc
+ }{
+ {
+ "200 default",
+ func(w http.ResponseWriter, r *http.Request) {},
+ check(hasStatus(200), hasContents("")),
+ },
+ {
+ "first code only",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(201)
+ w.WriteHeader(202)
+ w.Write([]byte("hi"))
+ },
+ check(hasStatus(201), hasContents("hi")),
+ },
+ {
+ "write sends 200",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("hi first"))
+ w.WriteHeader(201)
+ w.WriteHeader(202)
+ },
+ check(hasStatus(200), hasContents("hi first"), hasFlush(false)),
+ },
+ {
+ "flush",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.(http.Flusher).Flush() // also sends a 200
+ w.WriteHeader(201)
+ },
+ check(hasStatus(200), hasFlush(true)),
+ },
+ }
+ r, _ := http.NewRequest("GET", "http://foo.com/", nil)
+ for _, tt := range tests {
+ h := http.HandlerFunc(tt.h)
+ rec := NewRecorder()
+ h.ServeHTTP(rec, r)
+ for _, check := range tt.checks {
+ if err := check(rec); err != nil {
+ t.Errorf("%s: %v", tt.name, err)
+ }
+ }
+ }
+}
diff --git a/src/pkg/net/http/httptest/server.go b/src/pkg/net/http/httptest/server.go
index 57cf0c941..7f265552f 100644
--- a/src/pkg/net/http/httptest/server.go
+++ b/src/pkg/net/http/httptest/server.go
@@ -21,7 +21,11 @@ import (
type Server struct {
URL string // base URL of form http://ipaddr:port with no trailing slash
Listener net.Listener
- TLS *tls.Config // nil if not using using TLS
+
+ // TLS is the optional TLS configuration, populated with a new config
+ // after TLS is started. If set on an unstarted server before StartTLS
+ // is called, existing fields are copied into the new config.
+ TLS *tls.Config
// Config may be changed after calling NewUnstartedServer and
// before Start or StartTLS.
@@ -36,13 +40,16 @@ type Server struct {
// accepted.
type historyListener struct {
net.Listener
- history []net.Conn
+ sync.Mutex // protects history
+ history []net.Conn
}
func (hs *historyListener) Accept() (c net.Conn, err error) {
c, err = hs.Listener.Accept()
if err == nil {
+ hs.Lock()
hs.history = append(hs.history, c)
+ hs.Unlock()
}
return
}
@@ -96,7 +103,7 @@ func (s *Server) Start() {
if s.URL != "" {
panic("Server already started")
}
- s.Listener = &historyListener{s.Listener, make([]net.Conn, 0)}
+ s.Listener = &historyListener{Listener: s.Listener}
s.URL = "http://" + s.Listener.Addr().String()
s.wrapHandler()
go s.Config.Serve(s.Listener)
@@ -116,13 +123,20 @@ func (s *Server) StartTLS() {
panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
}
- s.TLS = &tls.Config{
- NextProtos: []string{"http/1.1"},
- Certificates: []tls.Certificate{cert},
+ existingConfig := s.TLS
+ s.TLS = new(tls.Config)
+ if existingConfig != nil {
+ *s.TLS = *existingConfig
+ }
+ if s.TLS.NextProtos == nil {
+ s.TLS.NextProtos = []string{"http/1.1"}
+ }
+ if len(s.TLS.Certificates) == 0 {
+ s.TLS.Certificates = []tls.Certificate{cert}
}
tlsListener := tls.NewListener(s.Listener, s.TLS)
- s.Listener = &historyListener{tlsListener, make([]net.Conn, 0)}
+ s.Listener = &historyListener{Listener: tlsListener}
s.URL = "https://" + s.Listener.Addr().String()
s.wrapHandler()
go s.Config.Serve(s.Listener)
@@ -152,6 +166,10 @@ func NewTLSServer(handler http.Handler) *Server {
func (s *Server) Close() {
s.Listener.Close()
s.wg.Wait()
+ s.CloseClientConnections()
+ if t, ok := http.DefaultTransport.(*http.Transport); ok {
+ t.CloseIdleConnections()
+ }
}
// CloseClientConnections closes any currently open HTTP connections
@@ -161,9 +179,11 @@ func (s *Server) CloseClientConnections() {
if !ok {
return
}
+ hl.Lock()
for _, conn := range hl.history {
conn.Close()
}
+ hl.Unlock()
}
// waitGroupHandler wraps a handler, incrementing and decrementing a
@@ -180,28 +200,29 @@ func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.h.ServeHTTP(w, r)
}
-// localhostCert is a PEM-encoded TLS cert with SAN DNS names
+// localhostCert is a PEM-encoded TLS cert with SAN IPs
// "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end
// of ASN.1 time).
+// generated from src/pkg/crypto/tls:
+// go run generate_cert.go --rsa-bits 512 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
-MIIBOTCB5qADAgECAgEAMAsGCSqGSIb3DQEBBTAAMB4XDTcwMDEwMTAwMDAwMFoX
-DTQ5MTIzMTIzNTk1OVowADBaMAsGCSqGSIb3DQEBAQNLADBIAkEAsuA5mAFMj6Q7
-qoBzcvKzIq4kzuT5epSp2AkcQfyBHm7K13Ws7u+0b5Vb9gqTf5cAiIKcrtrXVqkL
-8i1UQF6AzwIDAQABo08wTTAOBgNVHQ8BAf8EBAMCACQwDQYDVR0OBAYEBAECAwQw
-DwYDVR0jBAgwBoAEAQIDBDAbBgNVHREEFDASggkxMjcuMC4wLjGCBVs6OjFdMAsG
-CSqGSIb3DQEBBQNBAJH30zjLWRztrWpOCgJL8RQWLaKzhK79pVhAx6q/3NrF16C7
-+l1BRZstTwIGdoGId8BRpErK1TXkniFb95ZMynM=
------END CERTIFICATE-----
-`)
+MIIBdzCCASOgAwIBAgIBADALBgkqhkiG9w0BAQUwEjEQMA4GA1UEChMHQWNtZSBD
+bzAeFw03MDAxMDEwMDAwMDBaFw00OTEyMzEyMzU5NTlaMBIxEDAOBgNVBAoTB0Fj
+bWUgQ28wWjALBgkqhkiG9w0BAQEDSwAwSAJBAN55NcYKZeInyTuhcCwFMhDHCmwa
+IUSdtXdcbItRB/yfXGBhiex00IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEA
+AaNoMGYwDgYDVR0PAQH/BAQDAgCkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1Ud
+EwEB/wQFMAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAA
+AAAAAAAAAAAAAAEwCwYJKoZIhvcNAQEFA0EAAoQn/ytgqpiLcZu9XKbCJsJcvkgk
+Se6AbGXgSlq+ZCEVo0qIwSgeBqmsJxUu7NCSOwVJLYNEBO2DtIxoYVk+MA==
+-----END CERTIFICATE-----`)
// localhostKey is the private key for localhostCert.
var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
-MIIBPQIBAAJBALLgOZgBTI+kO6qAc3LysyKuJM7k+XqUqdgJHEH8gR5uytd1rO7v
-tG+VW/YKk3+XAIiCnK7a11apC/ItVEBegM8CAwEAAQJBAI5sxq7naeR9ahyqRkJi
-SIv2iMxLuPEHaezf5CYOPWjSjBPyVhyRevkhtqEjF/WkgL7C2nWpYHsUcBDBQVF0
-3KECIQDtEGB2ulnkZAahl3WuJziXGLB+p8Wgx7wzSM6bHu1c6QIhAMEp++CaS+SJ
-/TrU0zwY/fW4SvQeb49BPZUF3oqR8Xz3AiEA1rAJHBzBgdOQKdE3ksMUPcnvNJSN
-poCcELmz2clVXtkCIQCLytuLV38XHToTipR4yMl6O+6arzAjZ56uq7m7ZRV0TwIh
-AM65XAOw8Dsg9Kq78aYXiOEDc5DL0sbFUu/SlmRcCg93
------END RSA PRIVATE KEY-----
-`)
+MIIBPAIBAAJBAN55NcYKZeInyTuhcCwFMhDHCmwaIUSdtXdcbItRB/yfXGBhiex0
+0IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEAAQJBAQdUx66rfh8sYsgfdcvV
+NoafYpnEcB5s4m/vSVe6SU7dCK6eYec9f9wpT353ljhDUHq3EbmE4foNzJngh35d
+AekCIQDhRQG5Li0Wj8TM4obOnnXUXf1jRv0UkzE9AHWLG5q3AwIhAPzSjpYUDjVW
+MCUXgckTpKCuGwbJk7424Nb8bLzf3kllAiA5mUBgjfr/WtFSJdWcPQ4Zt9KTMNKD
+EUO0ukpTwEIl6wIhAMbGqZK3zAAFdq8DD2jPx+UJXnh0rnOkZBzDtJ6/iN69AiEA
+1Aq8MJgTaYsDQWyU/hDq5YkDJc9e9DSCvUIzqxQWMQE=
+-----END RSA PRIVATE KEY-----`)
diff --git a/src/pkg/net/http/httputil/chunked.go b/src/pkg/net/http/httputil/chunked.go
index 29eaf3475..b66d40951 100644
--- a/src/pkg/net/http/httputil/chunked.go
+++ b/src/pkg/net/http/httputil/chunked.go
@@ -13,10 +13,9 @@ package httputil
import (
"bufio"
- "bytes"
"errors"
+ "fmt"
"io"
- "strconv"
)
const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
@@ -24,7 +23,7 @@ const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
var ErrLineTooLong = errors.New("header line too long")
// NewChunkedReader returns a new chunkedReader that translates the data read from r
-// out of HTTP "chunked" format before returning it.
+// out of HTTP "chunked" format before returning it.
// The chunkedReader returns io.EOF when the final 0-length chunk is read.
//
// NewChunkedReader is not needed by normal applications. The http package
@@ -41,16 +40,17 @@ type chunkedReader struct {
r *bufio.Reader
n uint64 // unread bytes in chunk
err error
+ buf [2]byte
}
func (cr *chunkedReader) beginChunk() {
// chunk-size CRLF
- var line string
+ var line []byte
line, cr.err = readLine(cr.r)
if cr.err != nil {
return
}
- cr.n, cr.err = strconv.ParseUint(line, 16, 64)
+ cr.n, cr.err = parseHexUint(line)
if cr.err != nil {
return
}
@@ -76,9 +76,8 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
cr.n -= uint64(n)
if cr.n == 0 && cr.err == nil {
// end of chunk (CRLF)
- b := make([]byte, 2)
- if _, cr.err = io.ReadFull(cr.r, b); cr.err == nil {
- if b[0] != '\r' || b[1] != '\n' {
+ if _, cr.err = io.ReadFull(cr.r, cr.buf[:]); cr.err == nil {
+ if cr.buf[0] != '\r' || cr.buf[1] != '\n' {
cr.err = errors.New("malformed chunked encoding")
}
}
@@ -90,7 +89,7 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
// Give up if the line exceeds maxLineLength.
// The returned bytes are a pointer into storage in
// the bufio, so they are only valid until the next bufio read.
-func readLineBytes(b *bufio.Reader) (p []byte, err error) {
+func readLine(b *bufio.Reader) (p []byte, err error) {
if p, err = b.ReadSlice('\n'); err != nil {
// We always know when EOF is coming.
// If the caller asked for a line, there should be a line.
@@ -104,20 +103,18 @@ func readLineBytes(b *bufio.Reader) (p []byte, err error) {
if len(p) >= maxLineLength {
return nil, ErrLineTooLong
}
-
- // Chop off trailing white space.
- p = bytes.TrimRight(p, " \r\t\n")
-
- return p, nil
+ return trimTrailingWhitespace(p), nil
}
-// readLineBytes, but convert the bytes into a string.
-func readLine(b *bufio.Reader) (s string, err error) {
- p, e := readLineBytes(b)
- if e != nil {
- return "", e
+func trimTrailingWhitespace(b []byte) []byte {
+ for len(b) > 0 && isASCIISpace(b[len(b)-1]) {
+ b = b[:len(b)-1]
}
- return string(p), nil
+ return b
+}
+
+func isASCIISpace(b byte) bool {
+ return b == ' ' || b == '\t' || b == '\n' || b == '\r'
}
// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP
@@ -149,9 +146,7 @@ func (cw *chunkedWriter) Write(data []byte) (n int, err error) {
return 0, nil
}
- head := strconv.FormatInt(int64(len(data)), 16) + "\r\n"
-
- if _, err = io.WriteString(cw.Wire, head); err != nil {
+ if _, err = fmt.Fprintf(cw.Wire, "%x\r\n", len(data)); err != nil {
return 0, err
}
if n, err = cw.Wire.Write(data); err != nil {
@@ -170,3 +165,21 @@ func (cw *chunkedWriter) Close() error {
_, err := io.WriteString(cw.Wire, "0\r\n")
return err
}
+
+func parseHexUint(v []byte) (n uint64, err error) {
+ for _, b := range v {
+ n <<= 4
+ switch {
+ case '0' <= b && b <= '9':
+ b = b - '0'
+ case 'a' <= b && b <= 'f':
+ b = b - 'a' + 10
+ case 'A' <= b && b <= 'F':
+ b = b - 'A' + 10
+ default:
+ return 0, errors.New("invalid byte in chunk length")
+ }
+ n |= uint64(b)
+ }
+ return
+}
diff --git a/src/pkg/net/http/httputil/chunked_test.go b/src/pkg/net/http/httputil/chunked_test.go
index 155a32bdf..a06bffad5 100644
--- a/src/pkg/net/http/httputil/chunked_test.go
+++ b/src/pkg/net/http/httputil/chunked_test.go
@@ -11,7 +11,10 @@ package httputil
import (
"bytes"
+ "fmt"
+ "io"
"io/ioutil"
+ "runtime"
"testing"
)
@@ -39,3 +42,54 @@ func TestChunk(t *testing.T) {
t.Errorf("chunk reader read %q; want %q", g, e)
}
}
+
+func TestChunkReaderAllocs(t *testing.T) {
+ // temporarily set GOMAXPROCS to 1 as we are testing memory allocations
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1))
+ var buf bytes.Buffer
+ w := NewChunkedWriter(&buf)
+ a, b, c := []byte("aaaaaa"), []byte("bbbbbbbbbbbb"), []byte("cccccccccccccccccccccccc")
+ w.Write(a)
+ w.Write(b)
+ w.Write(c)
+ w.Close()
+
+ r := NewChunkedReader(&buf)
+ readBuf := make([]byte, len(a)+len(b)+len(c)+1)
+
+ var ms runtime.MemStats
+ runtime.ReadMemStats(&ms)
+ m0 := ms.Mallocs
+
+ n, err := io.ReadFull(r, readBuf)
+
+ runtime.ReadMemStats(&ms)
+ mallocs := ms.Mallocs - m0
+ if mallocs > 1 {
+ t.Errorf("%d mallocs; want <= 1", mallocs)
+ }
+
+ if n != len(readBuf)-1 {
+ t.Errorf("read %d bytes; want %d", n, len(readBuf)-1)
+ }
+ if err != io.ErrUnexpectedEOF {
+ t.Errorf("read error = %v; want ErrUnexpectedEOF", err)
+ }
+}
+
+func TestParseHexUint(t *testing.T) {
+ for i := uint64(0); i <= 1234; i++ {
+ line := []byte(fmt.Sprintf("%x", i))
+ got, err := parseHexUint(line)
+ if err != nil {
+ t.Fatalf("on %d: %v", i, err)
+ }
+ if got != i {
+ t.Errorf("for input %q = %d; want %d", line, got, i)
+ }
+ }
+ _, err := parseHexUint([]byte("bogus"))
+ if err == nil {
+ t.Error("expected error on bogus input")
+ }
+}
diff --git a/src/pkg/net/http/httputil/dump.go b/src/pkg/net/http/httputil/dump.go
index 892ef4ede..0b0035661 100644
--- a/src/pkg/net/http/httputil/dump.go
+++ b/src/pkg/net/http/httputil/dump.go
@@ -75,7 +75,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) {
// Use the actual Transport code to record what we would send
// on the wire, but not using TCP. Use a Transport with a
- // customer dialer that returns a fake net.Conn that waits
+ // custom dialer that returns a fake net.Conn that waits
// for the full input (and recording it), and then responds
// with a dummy response.
var buf bytes.Buffer // records the output
@@ -89,7 +89,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) {
t := &http.Transport{
Dial: func(net, addr string) (net.Conn, error) {
- return &dumpConn{io.MultiWriter(pw, &buf), dr}, nil
+ return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil
},
}
diff --git a/src/pkg/net/http/httputil/reverseproxy.go b/src/pkg/net/http/httputil/reverseproxy.go
index 9c4bd6e09..134c45299 100644
--- a/src/pkg/net/http/httputil/reverseproxy.go
+++ b/src/pkg/net/http/httputil/reverseproxy.go
@@ -17,6 +17,10 @@ import (
"time"
)
+// onExitFlushLoop is a callback set by tests to detect the state of the
+// flushLoop() goroutine.
+var onExitFlushLoop func()
+
// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client.
@@ -102,8 +106,14 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
outreq.Header.Del("Connection")
}
- if clientIp, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
- outreq.Header.Set("X-Forwarded-For", clientIp)
+ if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
+ // If we aren't the first proxy retain prior
+ // X-Forwarded-For information as a comma+space
+ // separated list and fold multiple headers into one.
+ if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
+ clientIP = strings.Join(prior, ", ") + ", " + clientIP
+ }
+ outreq.Header.Set("X-Forwarded-For", clientIP)
}
res, err := transport.RoundTrip(outreq)
@@ -112,20 +122,29 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusInternalServerError)
return
}
+ defer res.Body.Close()
copyHeader(rw.Header(), res.Header)
rw.WriteHeader(res.StatusCode)
+ p.copyResponse(rw, res.Body)
+}
- if res.Body != nil {
- var dst io.Writer = rw
- if p.FlushInterval != 0 {
- if wf, ok := rw.(writeFlusher); ok {
- dst = &maxLatencyWriter{dst: wf, latency: p.FlushInterval}
+func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
+ if p.FlushInterval != 0 {
+ if wf, ok := dst.(writeFlusher); ok {
+ mlw := &maxLatencyWriter{
+ dst: wf,
+ latency: p.FlushInterval,
+ done: make(chan bool),
}
+ go mlw.flushLoop()
+ defer mlw.stop()
+ dst = mlw
}
- io.Copy(dst, res.Body)
}
+
+ io.Copy(dst, src)
}
type writeFlusher interface {
@@ -137,22 +156,14 @@ type maxLatencyWriter struct {
dst writeFlusher
latency time.Duration
- lk sync.Mutex // protects init of done, as well Write + Flush
+ lk sync.Mutex // protects Write + Flush
done chan bool
}
-func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
+func (m *maxLatencyWriter) Write(p []byte) (int, error) {
m.lk.Lock()
defer m.lk.Unlock()
- if m.done == nil {
- m.done = make(chan bool)
- go m.flushLoop()
- }
- n, err = m.dst.Write(p)
- if err != nil {
- m.done <- true
- }
- return
+ return m.dst.Write(p)
}
func (m *maxLatencyWriter) flushLoop() {
@@ -160,13 +171,18 @@ func (m *maxLatencyWriter) flushLoop() {
defer t.Stop()
for {
select {
+ case <-m.done:
+ if onExitFlushLoop != nil {
+ onExitFlushLoop()
+ }
+ return
case <-t.C:
m.lk.Lock()
m.dst.Flush()
m.lk.Unlock()
- case <-m.done:
- return
}
}
panic("unreached")
}
+
+func (m *maxLatencyWriter) stop() { m.done <- true }
diff --git a/src/pkg/net/http/httputil/reverseproxy_test.go b/src/pkg/net/http/httputil/reverseproxy_test.go
index 28e9c90ad..863927162 100644
--- a/src/pkg/net/http/httputil/reverseproxy_test.go
+++ b/src/pkg/net/http/httputil/reverseproxy_test.go
@@ -11,7 +11,9 @@ import (
"net/http"
"net/http/httptest"
"net/url"
+ "strings"
"testing"
+ "time"
)
func TestReverseProxy(t *testing.T) {
@@ -70,6 +72,47 @@ func TestReverseProxy(t *testing.T) {
}
}
+func TestXForwardedFor(t *testing.T) {
+ const prevForwardedFor = "client ip"
+ const backendResponse = "I am the backend"
+ const backendStatus = 404
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Header.Get("X-Forwarded-For") == "" {
+ t.Errorf("didn't get X-Forwarded-For header")
+ }
+ if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
+ t.Errorf("X-Forwarded-For didn't contain prior data")
+ }
+ w.WriteHeader(backendStatus)
+ w.Write([]byte(backendResponse))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Host = "some-name"
+ getReq.Header.Set("Connection", "close")
+ getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
+ getReq.Close = true
+ res, err := http.DefaultClient.Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ if g, e := res.StatusCode, backendStatus; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ bodyBytes, _ := ioutil.ReadAll(res.Body)
+ if g, e := string(bodyBytes), backendResponse; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+}
+
var proxyQueryTests = []struct {
baseSuffix string // suffix to add to backend URL
reqSuffix string // suffix to add to frontend's request URL
@@ -107,3 +150,44 @@ func TestReverseProxyQuery(t *testing.T) {
frontend.Close()
}
}
+
+func TestReverseProxyFlushInterval(t *testing.T) {
+ const expected = "hi"
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(expected))
+ }))
+ defer backend.Close()
+
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.FlushInterval = time.Microsecond
+
+ done := make(chan bool)
+ onExitFlushLoop = func() { done <- true }
+ defer func() { onExitFlushLoop = nil }()
+
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ req, _ := http.NewRequest("GET", frontend.URL, nil)
+ req.Close = true
+ res, err := http.DefaultClient.Do(req)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+ if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected {
+ t.Errorf("got body %q; expected %q", bodyBytes, expected)
+ }
+
+ select {
+ case <-done:
+ // OK
+ case <-time.After(5 * time.Second):
+ t.Error("maxLatencyWriter flushLoop() never exited")
+ }
+}
diff --git a/src/pkg/net/http/jar.go b/src/pkg/net/http/jar.go
index 2c2caa251..5c3de0dad 100644
--- a/src/pkg/net/http/jar.go
+++ b/src/pkg/net/http/jar.go
@@ -8,23 +8,20 @@ import (
"net/url"
)
-// A CookieJar manages storage and use of cookies in HTTP requests.
+// A CookieJar manages storage and use of cookies in HTTP requests.
//
// Implementations of CookieJar must be safe for concurrent use by multiple
// goroutines.
+//
+// The net/http/cookiejar package provides a CookieJar implementation.
type CookieJar interface {
- // SetCookies handles the receipt of the cookies in a reply for the
- // given URL. It may or may not choose to save the cookies, depending
- // on the jar's policy and implementation.
+ // SetCookies handles the receipt of the cookies in a reply for the
+ // given URL. It may or may not choose to save the cookies, depending
+ // on the jar's policy and implementation.
SetCookies(u *url.URL, cookies []*Cookie)
// Cookies returns the cookies to send in a request for the given URL.
- // It is up to the implementation to honor the standard cookie use
- // restrictions such as in RFC 6265.
+ // It is up to the implementation to honor the standard cookie use
+ // restrictions such as in RFC 6265.
Cookies(u *url.URL) []*Cookie
}
-
-type blackHoleJar struct{}
-
-func (blackHoleJar) SetCookies(u *url.URL, cookies []*Cookie) {}
-func (blackHoleJar) Cookies(u *url.URL) []*Cookie { return nil }
diff --git a/src/pkg/net/http/lex.go b/src/pkg/net/http/lex.go
index ffb393ccf..cb33318f4 100644
--- a/src/pkg/net/http/lex.go
+++ b/src/pkg/net/http/lex.go
@@ -6,131 +6,91 @@ package http
// This file deals with lexical matters of HTTP
-func isSeparator(c byte) bool {
- switch c {
- case '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}', ' ', '\t':
- return true
- }
- return false
+var isTokenTable = [127]bool{
+ '!': true,
+ '#': true,
+ '$': true,
+ '%': true,
+ '&': true,
+ '\'': true,
+ '*': true,
+ '+': true,
+ '-': true,
+ '.': true,
+ '0': true,
+ '1': true,
+ '2': true,
+ '3': true,
+ '4': true,
+ '5': true,
+ '6': true,
+ '7': true,
+ '8': true,
+ '9': true,
+ 'A': true,
+ 'B': true,
+ 'C': true,
+ 'D': true,
+ 'E': true,
+ 'F': true,
+ 'G': true,
+ 'H': true,
+ 'I': true,
+ 'J': true,
+ 'K': true,
+ 'L': true,
+ 'M': true,
+ 'N': true,
+ 'O': true,
+ 'P': true,
+ 'Q': true,
+ 'R': true,
+ 'S': true,
+ 'T': true,
+ 'U': true,
+ 'W': true,
+ 'V': true,
+ 'X': true,
+ 'Y': true,
+ 'Z': true,
+ '^': true,
+ '_': true,
+ '`': true,
+ 'a': true,
+ 'b': true,
+ 'c': true,
+ 'd': true,
+ 'e': true,
+ 'f': true,
+ 'g': true,
+ 'h': true,
+ 'i': true,
+ 'j': true,
+ 'k': true,
+ 'l': true,
+ 'm': true,
+ 'n': true,
+ 'o': true,
+ 'p': true,
+ 'q': true,
+ 'r': true,
+ 's': true,
+ 't': true,
+ 'u': true,
+ 'v': true,
+ 'w': true,
+ 'x': true,
+ 'y': true,
+ 'z': true,
+ '|': true,
+ '~': true,
}
-func isCtl(c byte) bool { return (0 <= c && c <= 31) || c == 127 }
-
-func isChar(c byte) bool { return 0 <= c && c <= 127 }
-
-func isAnyText(c byte) bool { return !isCtl(c) }
-
-func isQdText(c byte) bool { return isAnyText(c) && c != '"' }
-
-func isToken(c byte) bool { return isChar(c) && !isCtl(c) && !isSeparator(c) }
-
-// Valid escaped sequences are not specified in RFC 2616, so for now, we assume
-// that they coincide with the common sense ones used by GO. Malformed
-// characters should probably not be treated as errors by a robust (forgiving)
-// parser, so we replace them with the '?' character.
-func httpUnquotePair(b byte) byte {
- // skip the first byte, which should always be '\'
- switch b {
- case 'a':
- return '\a'
- case 'b':
- return '\b'
- case 'f':
- return '\f'
- case 'n':
- return '\n'
- case 'r':
- return '\r'
- case 't':
- return '\t'
- case 'v':
- return '\v'
- case '\\':
- return '\\'
- case '\'':
- return '\''
- case '"':
- return '"'
- }
- return '?'
-}
-
-// raw must begin with a valid quoted string. Only the first quoted string is
-// parsed and is unquoted in result. eaten is the number of bytes parsed, or -1
-// upon failure.
-func httpUnquote(raw []byte) (eaten int, result string) {
- buf := make([]byte, len(raw))
- if raw[0] != '"' {
- return -1, ""
- }
- eaten = 1
- j := 0 // # of bytes written in buf
- for i := 1; i < len(raw); i++ {
- switch b := raw[i]; b {
- case '"':
- eaten++
- buf = buf[0:j]
- return i + 1, string(buf)
- case '\\':
- if len(raw) < i+2 {
- return -1, ""
- }
- buf[j] = httpUnquotePair(raw[i+1])
- eaten += 2
- j++
- i++
- default:
- if isQdText(b) {
- buf[j] = b
- } else {
- buf[j] = '?'
- }
- eaten++
- j++
- }
- }
- return -1, ""
+func isToken(r rune) bool {
+ i := int(r)
+ return i < len(isTokenTable) && isTokenTable[i]
}
-// This is a best effort parse, so errors are not returned, instead not all of
-// the input string might be parsed. result is always non-nil.
-func httpSplitFieldValue(fv string) (eaten int, result []string) {
- result = make([]string, 0, len(fv))
- raw := []byte(fv)
- i := 0
- chunk := ""
- for i < len(raw) {
- b := raw[i]
- switch {
- case b == '"':
- eaten, unq := httpUnquote(raw[i:len(raw)])
- if eaten < 0 {
- return i, result
- } else {
- i += eaten
- chunk += unq
- }
- case isSeparator(b):
- if chunk != "" {
- result = result[0 : len(result)+1]
- result[len(result)-1] = chunk
- chunk = ""
- }
- i++
- case isToken(b):
- chunk += string(b)
- i++
- case b == '\n' || b == '\r':
- i++
- default:
- chunk += "?"
- i++
- }
- }
- if chunk != "" {
- result = result[0 : len(result)+1]
- result[len(result)-1] = chunk
- chunk = ""
- }
- return i, result
+func isNotToken(r rune) bool {
+ return !isToken(r)
}
diff --git a/src/pkg/net/http/lex_test.go b/src/pkg/net/http/lex_test.go
index 5386f7534..6d9d294f7 100644
--- a/src/pkg/net/http/lex_test.go
+++ b/src/pkg/net/http/lex_test.go
@@ -8,63 +8,24 @@ import (
"testing"
)
-type lexTest struct {
- Raw string
- Parsed int // # of parsed characters
- Result []string
-}
+func isChar(c rune) bool { return c <= 127 }
-var lexTests = []lexTest{
- {
- Raw: `"abc"def,:ghi`,
- Parsed: 13,
- Result: []string{"abcdef", "ghi"},
- },
- // My understanding of the RFC is that escape sequences outside of
- // quotes are not interpreted?
- {
- Raw: `"\t"\t"\t"`,
- Parsed: 10,
- Result: []string{"\t", "t\t"},
- },
- {
- Raw: `"\yab"\r\n`,
- Parsed: 10,
- Result: []string{"?ab", "r", "n"},
- },
- {
- Raw: "ab\f",
- Parsed: 3,
- Result: []string{"ab?"},
- },
- {
- Raw: "\"ab \" c,de f, gh, ij\n\t\r",
- Parsed: 23,
- Result: []string{"ab ", "c", "de", "f", "gh", "ij"},
- },
-}
+func isCtl(c rune) bool { return c <= 31 || c == 127 }
-func min(x, y int) int {
- if x <= y {
- return x
+func isSeparator(c rune) bool {
+ switch c {
+ case '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}', ' ', '\t':
+ return true
}
- return y
+ return false
}
-func TestSplitFieldValue(t *testing.T) {
- for k, l := range lexTests {
- parsed, result := httpSplitFieldValue(l.Raw)
- if parsed != l.Parsed {
- t.Errorf("#%d: Parsed %d, expected %d", k, parsed, l.Parsed)
- }
- if len(result) != len(l.Result) {
- t.Errorf("#%d: Result len %d, expected %d", k, len(result), len(l.Result))
- }
- for i := 0; i < min(len(result), len(l.Result)); i++ {
- if result[i] != l.Result[i] {
- t.Errorf("#%d: %d-th entry mismatch. Have {%s}, expect {%s}",
- k, i, result[i], l.Result[i])
- }
+func TestIsToken(t *testing.T) {
+ for i := 0; i <= 130; i++ {
+ r := rune(i)
+ expected := isChar(r) && !isCtl(r) && !isSeparator(r)
+ if isToken(r) != expected {
+ t.Errorf("isToken(0x%x) = %v", r, !expected)
}
}
}
diff --git a/src/pkg/net/http/npn_test.go b/src/pkg/net/http/npn_test.go
new file mode 100644
index 000000000..98b8930d0
--- /dev/null
+++ b/src/pkg/net/http/npn_test.go
@@ -0,0 +1,118 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "bufio"
+ "crypto/tls"
+ "fmt"
+ "io"
+ "io/ioutil"
+ . "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+)
+
+func TestNextProtoUpgrade(t *testing.T) {
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "path=%s,proto=", r.URL.Path)
+ if r.TLS != nil {
+ w.Write([]byte(r.TLS.NegotiatedProtocol))
+ }
+ if r.RemoteAddr == "" {
+ t.Error("request with no RemoteAddr")
+ }
+ if r.Body == nil {
+ t.Errorf("request with nil Body")
+ }
+ }))
+ ts.TLS = &tls.Config{
+ NextProtos: []string{"unhandled-proto", "tls-0.9"},
+ }
+ ts.Config.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){
+ "tls-0.9": handleTLSProtocol09,
+ }
+ ts.StartTLS()
+ defer ts.Close()
+
+ tr := newTLSTransport(t, ts)
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ // Normal request, without NPN.
+ {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if want := "path=/,proto="; string(body) != want {
+ t.Errorf("plain request = %q; want %q", body, want)
+ }
+ }
+
+ // Request to an advertised but unhandled NPN protocol.
+ // Server will hang up.
+ {
+ tr.CloseIdleConnections()
+ tr.TLSClientConfig.NextProtos = []string{"unhandled-proto"}
+ _, err := c.Get(ts.URL)
+ if err == nil {
+ t.Errorf("expected error on unhandled-proto request")
+ }
+ }
+
+ // Request using the "tls-0.9" protocol, which we register here.
+ // It is HTTP/0.9 over TLS.
+ {
+ tlsConfig := newTLSTransport(t, ts).TLSClientConfig
+ tlsConfig.NextProtos = []string{"tls-0.9"}
+ conn, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig)
+ if err != nil {
+ t.Fatal(err)
+ }
+ conn.Write([]byte("GET /foo\n"))
+ body, err := ioutil.ReadAll(conn)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if want := "path=/foo,proto=tls-0.9"; string(body) != want {
+ t.Errorf("plain request = %q; want %q", body, want)
+ }
+ }
+}
+
+// handleTLSProtocol09 implements the HTTP/0.9 protocol over TLS, for the
+// TestNextProtoUpgrade test.
+func handleTLSProtocol09(srv *Server, conn *tls.Conn, h Handler) {
+ br := bufio.NewReader(conn)
+ line, err := br.ReadString('\n')
+ if err != nil {
+ return
+ }
+ line = strings.TrimSpace(line)
+ path := strings.TrimPrefix(line, "GET ")
+ if path == line {
+ return
+ }
+ req, _ := NewRequest("GET", path, nil)
+ req.Proto = "HTTP/0.9"
+ req.ProtoMajor = 0
+ req.ProtoMinor = 9
+ rw := &http09Writer{conn, make(Header)}
+ h.ServeHTTP(rw, req)
+}
+
+type http09Writer struct {
+ io.Writer
+ h Header
+}
+
+func (w http09Writer) Header() Header { return w.h }
+func (w http09Writer) WriteHeader(int) {} // no headers
diff --git a/src/pkg/net/http/pprof/pprof.go b/src/pkg/net/http/pprof/pprof.go
index 06fcde144..0c7548e3e 100644
--- a/src/pkg/net/http/pprof/pprof.go
+++ b/src/pkg/net/http/pprof/pprof.go
@@ -14,6 +14,14 @@
// To use pprof, link this package into your program:
// import _ "net/http/pprof"
//
+// If your application is not already running an http server, you
+// need to start one. Add "net/http" and "log" to your imports and
+// the following code to your main function:
+//
+// go func() {
+// log.Println(http.ListenAndServe("localhost:6060", nil))
+// }()
+//
// Then use the pprof tool to look at the heap profile:
//
// go tool pprof http://localhost:6060/debug/pprof/heap
@@ -22,9 +30,12 @@
//
// go tool pprof http://localhost:6060/debug/pprof/profile
//
-// Or to view all available profiles:
+// Or to look at the goroutine blocking profile:
+//
+// go tool pprof http://localhost:6060/debug/pprof/block
//
-// go tool pprof http://localhost:6060/debug/pprof/
+// To view all available profiles, open http://localhost:6060/debug/pprof/
+// in your browser.
//
// For a study of the facility in action, visit
//
@@ -161,7 +172,7 @@ func (name handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// listing the available profiles.
func Index(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/debug/pprof/") {
- name := r.URL.Path[len("/debug/pprof/"):]
+ name := strings.TrimPrefix(r.URL.Path, "/debug/pprof/")
if name != "" {
handler(name).ServeHTTP(w, r)
return
diff --git a/src/pkg/net/http/proxy_test.go b/src/pkg/net/http/proxy_test.go
index 5ecffafac..449ccaeea 100644
--- a/src/pkg/net/http/proxy_test.go
+++ b/src/pkg/net/http/proxy_test.go
@@ -25,13 +25,13 @@ var UseProxyTests = []struct {
{"[::2]", true}, // not a loopback address
{"barbaz.net", false}, // match as .barbaz.net
- {"foobar.com", false}, // have a port but match
+ {"foobar.com", false}, // have a port but match
{"foofoobar.com", true}, // not match as a part of foobar.com
{"baz.com", true}, // not match as a part of barbaz.com
{"localhost.net", true}, // not match as suffix of address
{"local.localhost", true}, // not match as prefix as address
{"barbarbaz.net", true}, // not match because NO_PROXY have a '.'
- {"www.foobar.com", true}, // not match because NO_PROXY is not .foobar.com
+ {"www.foobar.com", false}, // match because NO_PROXY includes "foobar.com"
}
func TestUseProxy(t *testing.T) {
diff --git a/src/pkg/net/http/range_test.go b/src/pkg/net/http/range_test.go
index 5274a81fa..ef911af7b 100644
--- a/src/pkg/net/http/range_test.go
+++ b/src/pkg/net/http/range_test.go
@@ -14,15 +14,34 @@ var ParseRangeTests = []struct {
r []httpRange
}{
{"", 0, nil},
+ {"", 1000, nil},
{"foo", 0, nil},
{"bytes=", 0, nil},
+ {"bytes=7", 10, nil},
+ {"bytes= 7 ", 10, nil},
+ {"bytes=1-", 0, nil},
{"bytes=5-4", 10, nil},
{"bytes=0-2,5-4", 10, nil},
+ {"bytes=2-5,4-3", 10, nil},
+ {"bytes=--5,4--3", 10, nil},
+ {"bytes=A-", 10, nil},
+ {"bytes=A- ", 10, nil},
+ {"bytes=A-Z", 10, nil},
+ {"bytes= -Z", 10, nil},
+ {"bytes=5-Z", 10, nil},
+ {"bytes=Ran-dom, garbage", 10, nil},
+ {"bytes=0x01-0x02", 10, nil},
+ {"bytes= ", 10, nil},
+ {"bytes= , , , ", 10, nil},
+
{"bytes=0-9", 10, []httpRange{{0, 10}}},
{"bytes=0-", 10, []httpRange{{0, 10}}},
{"bytes=5-", 10, []httpRange{{5, 5}}},
{"bytes=0-20", 10, []httpRange{{0, 10}}},
{"bytes=15-,0-5", 10, nil},
+ {"bytes=1-2,5-", 10, []httpRange{{1, 2}, {5, 5}}},
+ {"bytes=-2 , 7-", 11, []httpRange{{9, 2}, {7, 4}}},
+ {"bytes=0-0 ,2-2, 7-", 11, []httpRange{{0, 1}, {2, 1}, {7, 4}}},
{"bytes=-5", 10, []httpRange{{5, 5}}},
{"bytes=-15", 10, []httpRange{{0, 10}}},
{"bytes=0-499", 10000, []httpRange{{0, 500}}},
@@ -32,6 +51,9 @@ var ParseRangeTests = []struct {
{"bytes=0-0,-1", 10000, []httpRange{{0, 1}, {9999, 1}}},
{"bytes=500-600,601-999", 10000, []httpRange{{500, 101}, {601, 399}}},
{"bytes=500-700,601-999", 10000, []httpRange{{500, 201}, {601, 399}}},
+
+ // Match Apache laxity:
+ {"bytes= 1 -2 , 4- 5, 7 - 8 , ,,", 11, []httpRange{{1, 2}, {4, 2}, {7, 2}}},
}
func TestParseRange(t *testing.T) {
diff --git a/src/pkg/net/http/readrequest_test.go b/src/pkg/net/http/readrequest_test.go
index 2e03c658a..ffdd6a892 100644
--- a/src/pkg/net/http/readrequest_test.go
+++ b/src/pkg/net/http/readrequest_test.go
@@ -247,6 +247,54 @@ var reqTests = []reqTest{
noTrailer,
noError,
},
+
+ // SSDP Notify request. golang.org/issue/3692
+ {
+ "NOTIFY * HTTP/1.1\r\nServer: foo\r\n\r\n",
+ &Request{
+ Method: "NOTIFY",
+ URL: &url.URL{
+ Path: "*",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{
+ "Server": []string{"foo"},
+ },
+ Close: false,
+ ContentLength: 0,
+ RequestURI: "*",
+ },
+
+ noBody,
+ noTrailer,
+ noError,
+ },
+
+ // OPTIONS request. Similar to golang.org/issue/3692
+ {
+ "OPTIONS * HTTP/1.1\r\nServer: foo\r\n\r\n",
+ &Request{
+ Method: "OPTIONS",
+ URL: &url.URL{
+ Path: "*",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{
+ "Server": []string{"foo"},
+ },
+ Close: false,
+ ContentLength: 0,
+ RequestURI: "*",
+ },
+
+ noBody,
+ noTrailer,
+ noError,
+ },
}
func TestReadRequest(t *testing.T) {
diff --git a/src/pkg/net/http/request.go b/src/pkg/net/http/request.go
index f5bc6eb91..217f35b48 100644
--- a/src/pkg/net/http/request.go
+++ b/src/pkg/net/http/request.go
@@ -19,6 +19,7 @@ import (
"mime/multipart"
"net/textproto"
"net/url"
+ "strconv"
"strings"
)
@@ -70,7 +71,13 @@ var reqWriteExcludeHeader = map[string]bool{
// or to be sent by a client.
type Request struct {
Method string // GET, POST, PUT, etc.
- URL *url.URL
+
+ // URL is created from the URI supplied on the Request-Line
+ // as stored in RequestURI.
+ //
+ // For most requests, fields other than Path and RawQuery
+ // will be empty. (See RFC 2616, Section 5.1.2)
+ URL *url.URL
// The protocol version for incoming requests.
// Outgoing requests always use HTTP/1.1.
@@ -123,6 +130,7 @@ type Request struct {
// The host on which the URL is sought.
// Per RFC 2616, this is either the value of the Host: header
// or the host name given in the URL itself.
+ // It may be of the form "host:port".
Host string
// Form contains the parsed form data, including both the URL
@@ -131,6 +139,12 @@ type Request struct {
// The HTTP client ignores Form and uses Body instead.
Form url.Values
+ // PostForm contains the parsed form data from POST or PUT
+ // body parameters.
+ // This field is only available after ParseForm is called.
+ // The HTTP client ignores PostForm and uses Body instead.
+ PostForm url.Values
+
// MultipartForm is the parsed multipart form, including file uploads.
// This field is only available after ParseMultipartForm is called.
// The HTTP client ignores MultipartForm and uses Body instead.
@@ -317,11 +331,20 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
}
// TODO(bradfitz): escape at least newlines in ruri?
- bw := bufio.NewWriter(w)
- fmt.Fprintf(bw, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri)
+ // Wrap the writer in a bufio Writer if it's not already buffered.
+ // Don't always call NewWriter, as that forces a bytes.Buffer
+ // and other small bufio Writers to have a minimum 4k buffer
+ // size.
+ var bw *bufio.Writer
+ if _, ok := w.(io.ByteWriter); !ok {
+ bw = bufio.NewWriter(w)
+ w = bw
+ }
+
+ fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri)
// Header lines
- fmt.Fprintf(bw, "Host: %s\r\n", host)
+ fmt.Fprintf(w, "Host: %s\r\n", host)
// Use the defaultUserAgent unless the Header contains one, which
// may be blank to not send the header.
@@ -332,7 +355,7 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
}
}
if userAgent != "" {
- fmt.Fprintf(bw, "User-Agent: %s\r\n", userAgent)
+ fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent)
}
// Process Body,ContentLength,Close,Trailer
@@ -340,65 +363,61 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
if err != nil {
return err
}
- err = tw.WriteHeader(bw)
+ err = tw.WriteHeader(w)
if err != nil {
return err
}
// TODO: split long values? (If so, should share code with Conn.Write)
- err = req.Header.WriteSubset(bw, reqWriteExcludeHeader)
+ err = req.Header.WriteSubset(w, reqWriteExcludeHeader)
if err != nil {
return err
}
if extraHeaders != nil {
- err = extraHeaders.Write(bw)
+ err = extraHeaders.Write(w)
if err != nil {
return err
}
}
- io.WriteString(bw, "\r\n")
+ io.WriteString(w, "\r\n")
// Write body and trailer
- err = tw.WriteBody(bw)
+ err = tw.WriteBody(w)
if err != nil {
return err
}
- return bw.Flush()
-}
-
-// Convert decimal at s[i:len(s)] to integer,
-// returning value, string position where the digits stopped,
-// and whether there was a valid number (digits, not too big).
-func atoi(s string, i int) (n, i1 int, ok bool) {
- const Big = 1000000
- if i >= len(s) || s[i] < '0' || s[i] > '9' {
- return 0, 0, false
- }
- n = 0
- for ; i < len(s) && '0' <= s[i] && s[i] <= '9'; i++ {
- n = n*10 + int(s[i]-'0')
- if n > Big {
- return 0, 0, false
- }
+ if bw != nil {
+ return bw.Flush()
}
- return n, i, true
+ return nil
}
// ParseHTTPVersion parses a HTTP version string.
// "HTTP/1.0" returns (1, 0, true).
func ParseHTTPVersion(vers string) (major, minor int, ok bool) {
- if len(vers) < 5 || vers[0:5] != "HTTP/" {
+ const Big = 1000000 // arbitrary upper bound
+ switch vers {
+ case "HTTP/1.1":
+ return 1, 1, true
+ case "HTTP/1.0":
+ return 1, 0, true
+ }
+ if !strings.HasPrefix(vers, "HTTP/") {
return 0, 0, false
}
- major, i, ok := atoi(vers, 5)
- if !ok || i >= len(vers) || vers[i] != '.' {
+ dot := strings.Index(vers, ".")
+ if dot < 0 {
return 0, 0, false
}
- minor, i, ok = atoi(vers, i+1)
- if !ok || i != len(vers) {
+ major, err := strconv.Atoi(vers[5:dot])
+ if err != nil || major < 0 || major > Big {
+ return 0, 0, false
+ }
+ minor, err = strconv.Atoi(vers[dot+1:])
+ if err != nil || minor < 0 || minor > Big {
return 0, 0, false
}
return major, minor, true
@@ -426,10 +445,12 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) {
}
if body != nil {
switch v := body.(type) {
- case *strings.Reader:
- req.ContentLength = int64(v.Len())
case *bytes.Buffer:
req.ContentLength = int64(v.Len())
+ case *bytes.Reader:
+ req.ContentLength = int64(v.Len())
+ case *strings.Reader:
+ req.ContentLength = int64(v.Len())
}
}
@@ -513,9 +534,9 @@ func ReadRequest(b *bufio.Reader) (req *Request, err error) {
// the same. In the second case, any Host line is ignored.
req.Host = req.URL.Host
if req.Host == "" {
- req.Host = req.Header.Get("Host")
+ req.Host = req.Header.get("Host")
}
- req.Header.Del("Host")
+ delete(req.Header, "Host")
fixPragmaCacheControl(req.Header)
@@ -594,66 +615,97 @@ func (l *maxBytesReader) Close() error {
return l.r.Close()
}
-// ParseForm parses the raw query from the URL.
+func copyValues(dst, src url.Values) {
+ for k, vs := range src {
+ for _, value := range vs {
+ dst.Add(k, value)
+ }
+ }
+}
+
+func parsePostForm(r *Request) (vs url.Values, err error) {
+ if r.Body == nil {
+ err = errors.New("missing form body")
+ return
+ }
+ ct := r.Header.Get("Content-Type")
+ ct, _, err = mime.ParseMediaType(ct)
+ switch {
+ case ct == "application/x-www-form-urlencoded":
+ var reader io.Reader = r.Body
+ maxFormSize := int64(1<<63 - 1)
+ if _, ok := r.Body.(*maxBytesReader); !ok {
+ maxFormSize = int64(10 << 20) // 10 MB is a lot of text.
+ reader = io.LimitReader(r.Body, maxFormSize+1)
+ }
+ b, e := ioutil.ReadAll(reader)
+ if e != nil {
+ if err == nil {
+ err = e
+ }
+ break
+ }
+ if int64(len(b)) > maxFormSize {
+ err = errors.New("http: POST too large")
+ return
+ }
+ vs, e = url.ParseQuery(string(b))
+ if err == nil {
+ err = e
+ }
+ case ct == "multipart/form-data":
+ // handled by ParseMultipartForm (which is calling us, or should be)
+ // TODO(bradfitz): there are too many possible
+ // orders to call too many functions here.
+ // Clean this up and write more tests.
+ // request_test.go contains the start of this,
+ // in TestRequestMultipartCallOrder.
+ }
+ return
+}
+
+// ParseForm parses the raw query from the URL and updates r.Form.
+//
+// For POST or PUT requests, it also parses the request body as a form and
+// put the results into both r.PostForm and r.Form.
+// POST and PUT body parameters take precedence over URL query string values
+// in r.Form.
//
-// For POST or PUT requests, it also parses the request body as a form.
// If the request Body's size has not already been limited by MaxBytesReader,
// the size is capped at 10MB.
//
// ParseMultipartForm calls ParseForm automatically.
// It is idempotent.
-func (r *Request) ParseForm() (err error) {
- if r.Form != nil {
- return
- }
- if r.URL != nil {
- r.Form, err = url.ParseQuery(r.URL.RawQuery)
+func (r *Request) ParseForm() error {
+ var err error
+ if r.PostForm == nil {
+ if r.Method == "POST" || r.Method == "PUT" {
+ r.PostForm, err = parsePostForm(r)
+ }
+ if r.PostForm == nil {
+ r.PostForm = make(url.Values)
+ }
}
- if r.Method == "POST" || r.Method == "PUT" {
- if r.Body == nil {
- return errors.New("missing form body")
+ if r.Form == nil {
+ if len(r.PostForm) > 0 {
+ r.Form = make(url.Values)
+ copyValues(r.Form, r.PostForm)
}
- ct := r.Header.Get("Content-Type")
- ct, _, err = mime.ParseMediaType(ct)
- switch {
- case ct == "application/x-www-form-urlencoded":
- var reader io.Reader = r.Body
- maxFormSize := int64(1<<63 - 1)
- if _, ok := r.Body.(*maxBytesReader); !ok {
- maxFormSize = int64(10 << 20) // 10 MB is a lot of text.
- reader = io.LimitReader(r.Body, maxFormSize+1)
- }
- b, e := ioutil.ReadAll(reader)
- if e != nil {
- if err == nil {
- err = e
- }
- break
- }
- if int64(len(b)) > maxFormSize {
- return errors.New("http: POST too large")
- }
- var newValues url.Values
- newValues, e = url.ParseQuery(string(b))
+ var newValues url.Values
+ if r.URL != nil {
+ var e error
+ newValues, e = url.ParseQuery(r.URL.RawQuery)
if err == nil {
err = e
}
- if r.Form == nil {
- r.Form = make(url.Values)
- }
- // Copy values into r.Form. TODO: make this smoother.
- for k, vs := range newValues {
- for _, value := range vs {
- r.Form.Add(k, value)
- }
- }
- case ct == "multipart/form-data":
- // handled by ParseMultipartForm (which is calling us, or should be)
- // TODO(bradfitz): there are too many possible
- // orders to call too many functions here.
- // Clean this up and write more tests.
- // request_test.go contains the start of this,
- // in TestRequestMultipartCallOrder.
+ }
+ if newValues == nil {
+ newValues = make(url.Values)
+ }
+ if r.Form == nil {
+ r.Form = newValues
+ } else {
+ copyValues(r.Form, newValues)
}
}
return err
@@ -699,7 +751,9 @@ func (r *Request) ParseMultipartForm(maxMemory int64) error {
}
// FormValue returns the first value for the named component of the query.
+// POST and PUT body parameters take precedence over URL query string values.
// FormValue calls ParseMultipartForm and ParseForm if necessary.
+// To access multiple values of the same key use ParseForm.
func (r *Request) FormValue(key string) string {
if r.Form == nil {
r.ParseMultipartForm(defaultMaxMemory)
@@ -710,6 +764,19 @@ func (r *Request) FormValue(key string) string {
return ""
}
+// PostFormValue returns the first value for the named component of the POST
+// or PUT request body. URL query parameters are ignored.
+// PostFormValue calls ParseMultipartForm and ParseForm if necessary.
+func (r *Request) PostFormValue(key string) string {
+ if r.PostForm == nil {
+ r.ParseMultipartForm(defaultMaxMemory)
+ }
+ if vs := r.PostForm[key]; len(vs) > 0 {
+ return vs[0]
+ }
+ return ""
+}
+
// FormFile returns the first file for the provided form key.
// FormFile calls ParseMultipartForm and ParseForm if necessary.
func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, error) {
@@ -732,12 +799,16 @@ func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, e
}
func (r *Request) expectsContinue() bool {
- return strings.ToLower(r.Header.Get("Expect")) == "100-continue"
+ return hasToken(r.Header.get("Expect"), "100-continue")
}
func (r *Request) wantsHttp10KeepAlive() bool {
if r.ProtoMajor != 1 || r.ProtoMinor != 0 {
return false
}
- return strings.Contains(strings.ToLower(r.Header.Get("Connection")), "keep-alive")
+ return hasToken(r.Header.get("Connection"), "keep-alive")
+}
+
+func (r *Request) wantsClose() bool {
+ return hasToken(r.Header.get("Connection"), "close")
}
diff --git a/src/pkg/net/http/request_test.go b/src/pkg/net/http/request_test.go
index 6e00b9bfd..00ad791de 100644
--- a/src/pkg/net/http/request_test.go
+++ b/src/pkg/net/http/request_test.go
@@ -30,8 +30,8 @@ func TestQuery(t *testing.T) {
}
func TestPostQuery(t *testing.T) {
- req, _ := NewRequest("POST", "http://www.google.com/search?q=foo&q=bar&both=x",
- strings.NewReader("z=post&both=y"))
+ req, _ := NewRequest("POST", "http://www.google.com/search?q=foo&q=bar&both=x&prio=1&empty=not",
+ strings.NewReader("z=post&both=y&prio=2&empty="))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
if q := req.FormValue("q"); q != "foo" {
@@ -40,8 +40,23 @@ func TestPostQuery(t *testing.T) {
if z := req.FormValue("z"); z != "post" {
t.Errorf(`req.FormValue("z") = %q, want "post"`, z)
}
- if both := req.Form["both"]; !reflect.DeepEqual(both, []string{"x", "y"}) {
- t.Errorf(`req.FormValue("both") = %q, want ["x", "y"]`, both)
+ if bq, found := req.PostForm["q"]; found {
+ t.Errorf(`req.PostForm["q"] = %q, want no entry in map`, bq)
+ }
+ if bz := req.PostFormValue("z"); bz != "post" {
+ t.Errorf(`req.PostFormValue("z") = %q, want "post"`, bz)
+ }
+ if qs := req.Form["q"]; !reflect.DeepEqual(qs, []string{"foo", "bar"}) {
+ t.Errorf(`req.Form["q"] = %q, want ["foo", "bar"]`, qs)
+ }
+ if both := req.Form["both"]; !reflect.DeepEqual(both, []string{"y", "x"}) {
+ t.Errorf(`req.Form["both"] = %q, want ["y", "x"]`, both)
+ }
+ if prio := req.FormValue("prio"); prio != "2" {
+ t.Errorf(`req.FormValue("prio") = %q, want "2" (from body)`, prio)
+ }
+ if empty := req.FormValue("empty"); empty != "" {
+ t.Errorf(`req.FormValue("empty") = %q, want "" (from body)`, empty)
}
}
@@ -76,6 +91,23 @@ func TestParseFormUnknownContentType(t *testing.T) {
}
}
+func TestParseFormInitializeOnError(t *testing.T) {
+ nilBody, _ := NewRequest("POST", "http://www.google.com/search?q=foo", nil)
+ tests := []*Request{
+ nilBody,
+ {Method: "GET", URL: nil},
+ }
+ for i, req := range tests {
+ err := req.ParseForm()
+ if req.Form == nil {
+ t.Errorf("%d. Form not initialized, error %v", i, err)
+ }
+ if req.PostForm == nil {
+ t.Errorf("%d. PostForm not initialized, error %v", i, err)
+ }
+ }
+}
+
func TestMultipartReader(t *testing.T) {
req := &Request{
Method: "POST",
@@ -129,7 +161,7 @@ func TestSetBasicAuth(t *testing.T) {
}
func TestMultipartRequest(t *testing.T) {
- // Test that we can read the values and files of a
+ // Test that we can read the values and files of a
// multipart request with FormValue and FormFile,
// and that ParseMultipartForm can be called multiple times.
req := newTestMultipartRequest(t)
@@ -196,6 +228,75 @@ func TestReadRequestErrors(t *testing.T) {
}
}
+func TestNewRequestHost(t *testing.T) {
+ req, err := NewRequest("GET", "http://localhost:1234/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if req.Host != "localhost:1234" {
+ t.Errorf("Host = %q; want localhost:1234", req.Host)
+ }
+}
+
+func TestNewRequestContentLength(t *testing.T) {
+ readByte := func(r io.Reader) io.Reader {
+ var b [1]byte
+ r.Read(b[:])
+ return r
+ }
+ tests := []struct {
+ r io.Reader
+ want int64
+ }{
+ {bytes.NewReader([]byte("123")), 3},
+ {bytes.NewBuffer([]byte("1234")), 4},
+ {strings.NewReader("12345"), 5},
+ // Not detected:
+ {struct{ io.Reader }{strings.NewReader("xyz")}, 0},
+ {io.NewSectionReader(strings.NewReader("x"), 0, 6), 0},
+ {readByte(io.NewSectionReader(strings.NewReader("xy"), 0, 6)), 0},
+ }
+ for _, tt := range tests {
+ req, err := NewRequest("POST", "http://localhost/", tt.r)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if req.ContentLength != tt.want {
+ t.Errorf("ContentLength(%T) = %d; want %d", tt.r, req.ContentLength, tt.want)
+ }
+ }
+}
+
+type logWrites struct {
+ t *testing.T
+ dst *[]string
+}
+
+func (l logWrites) WriteByte(c byte) error {
+ l.t.Fatalf("unexpected WriteByte call")
+ return nil
+}
+
+func (l logWrites) Write(p []byte) (n int, err error) {
+ *l.dst = append(*l.dst, string(p))
+ return len(p), nil
+}
+
+func TestRequestWriteBufferedWriter(t *testing.T) {
+ got := []string{}
+ req, _ := NewRequest("GET", "http://foo.com/", nil)
+ req.Write(logWrites{t, &got})
+ want := []string{
+ "GET / HTTP/1.1\r\n",
+ "Host: foo.com\r\n",
+ "User-Agent: Go http package\r\n",
+ "\r\n",
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("Writes = %q\n Want = %q", got, want)
+ }
+}
+
func testMissingFile(t *testing.T, req *Request) {
f, fh, err := req.FormFile("missing")
if f != nil {
@@ -300,3 +401,81 @@ Content-Disposition: form-data; name="textb"
` + textbValue + `
--MyBoundary--
`
+
+func benchmarkReadRequest(b *testing.B, request string) {
+ request = request + "\n" // final \n
+ request = strings.Replace(request, "\n", "\r\n", -1) // expand \n to \r\n
+ b.SetBytes(int64(len(request)))
+ r := bufio.NewReader(&infiniteReader{buf: []byte(request)})
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _, err := ReadRequest(r)
+ if err != nil {
+ b.Fatalf("failed to read request: %v", err)
+ }
+ }
+}
+
+// infiniteReader satisfies Read requests as if the contents of buf
+// loop indefinitely.
+type infiniteReader struct {
+ buf []byte
+ offset int
+}
+
+func (r *infiniteReader) Read(b []byte) (int, error) {
+ n := copy(b, r.buf[r.offset:])
+ r.offset = (r.offset + n) % len(r.buf)
+ return n, nil
+}
+
+func BenchmarkReadRequestChrome(b *testing.B) {
+ // https://github.com/felixge/node-http-perf/blob/master/fixtures/get.http
+ benchmarkReadRequest(b, `GET / HTTP/1.1
+Host: localhost:8080
+Connection: keep-alive
+Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
+User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
+Accept-Encoding: gzip,deflate,sdch
+Accept-Language: en-US,en;q=0.8
+Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
+Cookie: __utma=1.1978842379.1323102373.1323102373.1323102373.1; EPi:NumberOfVisits=1,2012-02-28T13:42:18; CrmSession=5b707226b9563e1bc69084d07a107c98; plushContainerWidth=100%25; plushNoTopMenu=0; hudson_auto_refresh=false
+`)
+}
+
+func BenchmarkReadRequestCurl(b *testing.B) {
+ // curl http://localhost:8080/
+ benchmarkReadRequest(b, `GET / HTTP/1.1
+User-Agent: curl/7.27.0
+Host: localhost:8080
+Accept: */*
+`)
+}
+
+func BenchmarkReadRequestApachebench(b *testing.B) {
+ // ab -n 1 -c 1 http://localhost:8080/
+ benchmarkReadRequest(b, `GET / HTTP/1.0
+Host: localhost:8080
+User-Agent: ApacheBench/2.3
+Accept: */*
+`)
+}
+
+func BenchmarkReadRequestSiege(b *testing.B) {
+ // siege -r 1 -c 1 http://localhost:8080/
+ benchmarkReadRequest(b, `GET / HTTP/1.1
+Host: localhost:8080
+Accept: */*
+Accept-Encoding: gzip
+User-Agent: JoeDog/1.00 [en] (X11; I; Siege 2.70)
+Connection: keep-alive
+`)
+}
+
+func BenchmarkReadRequestWrk(b *testing.B) {
+ // wrk -t 1 -r 1 -c 1 http://localhost:8080/
+ benchmarkReadRequest(b, `GET / HTTP/1.1
+Host: localhost:8080
+`)
+}
diff --git a/src/pkg/net/http/requestwrite_test.go b/src/pkg/net/http/requestwrite_test.go
index fc3186f0c..bc637f18b 100644
--- a/src/pkg/net/http/requestwrite_test.go
+++ b/src/pkg/net/http/requestwrite_test.go
@@ -328,6 +328,69 @@ var reqWriteTests = []reqWriteTest{
"User-Agent: Go http package\r\n" +
"X-Foo: X-Bar\r\n\r\n",
},
+
+ // If no Request.Host and no Request.URL.Host, we send
+ // an empty Host header, and don't use
+ // Request.Header["Host"]. This is just testing that
+ // we don't change Go 1.0 behavior.
+ {
+ Req: Request{
+ Method: "GET",
+ Host: "",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "",
+ Path: "/search",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{
+ "Host": []string{"bad.example.com"},
+ },
+ },
+
+ WantWrite: "GET /search HTTP/1.1\r\n" +
+ "Host: \r\n" +
+ "User-Agent: Go http package\r\n\r\n",
+ },
+
+ // Opaque test #1 from golang.org/issue/4860
+ {
+ Req: Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Opaque: "/%2F/%2F/",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ },
+
+ WantWrite: "GET /%2F/%2F/ HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go http package\r\n\r\n",
+ },
+
+ // Opaque test #2 from golang.org/issue/4860
+ {
+ Req: Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "x.google.com",
+ Opaque: "//y.google.com/%2F/%2F/",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ },
+
+ WantWrite: "GET http://y.google.com/%2F/%2F/ HTTP/1.1\r\n" +
+ "Host: x.google.com\r\n" +
+ "User-Agent: Go http package\r\n\r\n",
+ },
}
func TestRequestWrite(t *testing.T) {
diff --git a/src/pkg/net/http/response.go b/src/pkg/net/http/response.go
index 945ecd8a4..391ebbf6d 100644
--- a/src/pkg/net/http/response.go
+++ b/src/pkg/net/http/response.go
@@ -49,7 +49,7 @@ type Response struct {
Body io.ReadCloser
// ContentLength records the length of the associated content. The
- // value -1 indicates that the length is unknown. Unless RequestMethod
+ // value -1 indicates that the length is unknown. Unless Request.Method
// is "HEAD", values >= 0 indicate that the given number of bytes may
// be read from Body.
ContentLength int64
@@ -107,7 +107,6 @@ func ReadResponse(r *bufio.Reader, req *Request) (resp *Response, err error) {
resp = new(Response)
resp.Request = req
- resp.Request.Method = strings.ToUpper(resp.Request.Method)
// Parse the first line of the response.
line, err := tp.ReadLine()
@@ -179,7 +178,7 @@ func (r *Response) ProtoAtLeast(major, minor int) bool {
// StatusCode
// ProtoMajor
// ProtoMinor
-// RequestMethod
+// Request.Method
// TransferEncoding
// Trailer
// Body
@@ -188,11 +187,6 @@ func (r *Response) ProtoAtLeast(major, minor int) bool {
//
func (r *Response) Write(w io.Writer) error {
- // RequestMethod should be upper-case
- if r.Request != nil {
- r.Request.Method = strings.ToUpper(r.Request.Method)
- }
-
// Status line
text := r.Status
if text == "" {
@@ -204,9 +198,7 @@ func (r *Response) Write(w io.Writer) error {
}
protoMajor, protoMinor := strconv.Itoa(r.ProtoMajor), strconv.Itoa(r.ProtoMinor)
statusCode := strconv.Itoa(r.StatusCode) + " "
- if strings.HasPrefix(text, statusCode) {
- text = text[len(statusCode):]
- }
+ text = strings.TrimPrefix(text, statusCode)
io.WriteString(w, "HTTP/"+protoMajor+"."+protoMinor+" "+statusCode+text+"\r\n")
// Process Body,ContentLength,Close,Trailer
diff --git a/src/pkg/net/http/response_test.go b/src/pkg/net/http/response_test.go
index 6eed4887d..2f5f77369 100644
--- a/src/pkg/net/http/response_test.go
+++ b/src/pkg/net/http/response_test.go
@@ -124,7 +124,7 @@ var respTests = []respTest{
// Chunked response without Content-Length.
{
- "HTTP/1.0 200 OK\r\n" +
+ "HTTP/1.1 200 OK\r\n" +
"Transfer-Encoding: chunked\r\n" +
"\r\n" +
"0a\r\n" +
@@ -137,12 +137,12 @@ var respTests = []respTest{
Response{
Status: "200 OK",
StatusCode: 200,
- Proto: "HTTP/1.0",
+ Proto: "HTTP/1.1",
ProtoMajor: 1,
- ProtoMinor: 0,
+ ProtoMinor: 1,
Request: dummyReq("GET"),
Header: Header{},
- Close: true,
+ Close: false,
ContentLength: -1,
TransferEncoding: []string{"chunked"},
},
@@ -152,24 +152,24 @@ var respTests = []respTest{
// Chunked response with Content-Length.
{
- "HTTP/1.0 200 OK\r\n" +
+ "HTTP/1.1 200 OK\r\n" +
"Transfer-Encoding: chunked\r\n" +
"Content-Length: 10\r\n" +
"\r\n" +
"0a\r\n" +
- "Body here\n" +
+ "Body here\n\r\n" +
"0\r\n" +
"\r\n",
Response{
Status: "200 OK",
StatusCode: 200,
- Proto: "HTTP/1.0",
+ Proto: "HTTP/1.1",
ProtoMajor: 1,
- ProtoMinor: 0,
+ ProtoMinor: 1,
Request: dummyReq("GET"),
Header: Header{},
- Close: true,
+ Close: false,
ContentLength: -1, // TODO(rsc): Fix?
TransferEncoding: []string{"chunked"},
},
@@ -177,23 +177,88 @@ var respTests = []respTest{
"Body here\n",
},
- // Chunked response in response to a HEAD request (the "chunked" should
- // be ignored, as HEAD responses never have bodies)
+ // Chunked response in response to a HEAD request
{
- "HTTP/1.0 200 OK\r\n" +
+ "HTTP/1.1 200 OK\r\n" +
"Transfer-Encoding: chunked\r\n" +
"\r\n",
Response{
- Status: "200 OK",
- StatusCode: 200,
- Proto: "HTTP/1.0",
- ProtoMajor: 1,
- ProtoMinor: 0,
- Request: dummyReq("HEAD"),
- Header: Header{},
- Close: true,
- ContentLength: 0,
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("HEAD"),
+ Header: Header{},
+ TransferEncoding: []string{"chunked"},
+ Close: false,
+ ContentLength: -1,
+ },
+
+ "",
+ },
+
+ // Content-Length in response to a HEAD request
+ {
+ "HTTP/1.0 200 OK\r\n" +
+ "Content-Length: 256\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("HEAD"),
+ Header: Header{"Content-Length": {"256"}},
+ TransferEncoding: nil,
+ Close: true,
+ ContentLength: 256,
+ },
+
+ "",
+ },
+
+ // Content-Length in response to a HEAD request with HTTP/1.1
+ {
+ "HTTP/1.1 200 OK\r\n" +
+ "Content-Length: 256\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("HEAD"),
+ Header: Header{"Content-Length": {"256"}},
+ TransferEncoding: nil,
+ Close: false,
+ ContentLength: 256,
+ },
+
+ "",
+ },
+
+ // No Content-Length or Chunked in response to a HEAD request
+ {
+ "HTTP/1.0 200 OK\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("HEAD"),
+ Header: Header{},
+ TransferEncoding: nil,
+ Close: true,
+ ContentLength: -1,
},
"",
@@ -259,16 +324,37 @@ var respTests = []respTest{
"",
},
+
+ // golang.org/issue/4767: don't special-case multipart/byteranges responses
+ {
+ `HTTP/1.1 206 Partial Content
+Connection: close
+Content-Type: multipart/byteranges; boundary=18a75608c8f47cef
+
+some body`,
+ Response{
+ Status: "206 Partial Content",
+ StatusCode: 206,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Content-Type": []string{"multipart/byteranges; boundary=18a75608c8f47cef"},
+ },
+ Close: true,
+ ContentLength: -1,
+ },
+
+ "some body",
+ },
}
func TestReadResponse(t *testing.T) {
- for i := range respTests {
- tt := &respTests[i]
- var braw bytes.Buffer
- braw.WriteString(tt.Raw)
- resp, err := ReadResponse(bufio.NewReader(&braw), tt.Resp.Request)
+ for i, tt := range respTests {
+ resp, err := ReadResponse(bufio.NewReader(strings.NewReader(tt.Raw)), tt.Resp.Request)
if err != nil {
- t.Errorf("#%d: %s", i, err)
+ t.Errorf("#%d: %v", i, err)
continue
}
rbody := resp.Body
@@ -276,7 +362,11 @@ func TestReadResponse(t *testing.T) {
diff(t, fmt.Sprintf("#%d Response", i), resp, &tt.Resp)
var bout bytes.Buffer
if rbody != nil {
- io.Copy(&bout, rbody)
+ _, err = io.Copy(&bout, rbody)
+ if err != nil {
+ t.Errorf("#%d: %v", i, err)
+ continue
+ }
rbody.Close()
}
body := bout.String()
@@ -286,6 +376,22 @@ func TestReadResponse(t *testing.T) {
}
}
+func TestWriteResponse(t *testing.T) {
+ for i, tt := range respTests {
+ resp, err := ReadResponse(bufio.NewReader(strings.NewReader(tt.Raw)), tt.Resp.Request)
+ if err != nil {
+ t.Errorf("#%d: %v", i, err)
+ continue
+ }
+ bout := bytes.NewBuffer(nil)
+ err = resp.Write(bout)
+ if err != nil {
+ t.Errorf("#%d: %v", i, err)
+ continue
+ }
+ }
+}
+
var readResponseCloseInMiddleTests = []struct {
chunked, compressed bool
}{
diff --git a/src/pkg/net/http/responsewrite_test.go b/src/pkg/net/http/responsewrite_test.go
index f8e63acf4..5c10e2161 100644
--- a/src/pkg/net/http/responsewrite_test.go
+++ b/src/pkg/net/http/responsewrite_test.go
@@ -15,83 +15,83 @@ type respWriteTest struct {
Raw string
}
-var respWriteTests = []respWriteTest{
- // HTTP/1.0, identity coding; no trailer
- {
- Response{
- StatusCode: 503,
- ProtoMajor: 1,
- ProtoMinor: 0,
- Request: dummyReq("GET"),
- Header: Header{},
- Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")),
- ContentLength: 6,
- },
+func TestResponseWrite(t *testing.T) {
+ respWriteTests := []respWriteTest{
+ // HTTP/1.0, identity coding; no trailer
+ {
+ Response{
+ StatusCode: 503,
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")),
+ ContentLength: 6,
+ },
- "HTTP/1.0 503 Service Unavailable\r\n" +
- "Content-Length: 6\r\n\r\n" +
- "abcdef",
- },
- // Unchunked response without Content-Length.
- {
- Response{
- StatusCode: 200,
- ProtoMajor: 1,
- ProtoMinor: 0,
- Request: dummyReq("GET"),
- Header: Header{},
- Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")),
- ContentLength: -1,
+ "HTTP/1.0 503 Service Unavailable\r\n" +
+ "Content-Length: 6\r\n\r\n" +
+ "abcdef",
},
- "HTTP/1.0 200 OK\r\n" +
- "\r\n" +
- "abcdef",
- },
- // HTTP/1.1, chunked coding; empty trailer; close
- {
- Response{
- StatusCode: 200,
- ProtoMajor: 1,
- ProtoMinor: 1,
- Request: dummyReq("GET"),
- Header: Header{},
- Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")),
- ContentLength: 6,
- TransferEncoding: []string{"chunked"},
- Close: true,
+ // Unchunked response without Content-Length.
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")),
+ ContentLength: -1,
+ },
+ "HTTP/1.0 200 OK\r\n" +
+ "\r\n" +
+ "abcdef",
},
+ // HTTP/1.1, chunked coding; empty trailer; close
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")),
+ ContentLength: 6,
+ TransferEncoding: []string{"chunked"},
+ Close: true,
+ },
- "HTTP/1.1 200 OK\r\n" +
- "Connection: close\r\n" +
- "Transfer-Encoding: chunked\r\n\r\n" +
- "6\r\nabcdef\r\n0\r\n\r\n",
- },
+ "HTTP/1.1 200 OK\r\n" +
+ "Connection: close\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ "6\r\nabcdef\r\n0\r\n\r\n",
+ },
- // Header value with a newline character (Issue 914).
- // Also tests removal of leading and trailing whitespace.
- {
- Response{
- StatusCode: 204,
- ProtoMajor: 1,
- ProtoMinor: 1,
- Request: dummyReq("GET"),
- Header: Header{
- "Foo": []string{" Bar\nBaz "},
+ // Header value with a newline character (Issue 914).
+ // Also tests removal of leading and trailing whitespace.
+ {
+ Response{
+ StatusCode: 204,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Foo": []string{" Bar\nBaz "},
+ },
+ Body: nil,
+ ContentLength: 0,
+ TransferEncoding: []string{"chunked"},
+ Close: true,
},
- Body: nil,
- ContentLength: 0,
- TransferEncoding: []string{"chunked"},
- Close: true,
- },
- "HTTP/1.1 204 No Content\r\n" +
- "Connection: close\r\n" +
- "Foo: Bar Baz\r\n" +
- "\r\n",
- },
-}
+ "HTTP/1.1 204 No Content\r\n" +
+ "Connection: close\r\n" +
+ "Foo: Bar Baz\r\n" +
+ "\r\n",
+ },
+ }
-func TestResponseWrite(t *testing.T) {
for i := range respWriteTests {
tt := &respWriteTests[i]
var braw bytes.Buffer
diff --git a/src/pkg/net/http/serve_test.go b/src/pkg/net/http/serve_test.go
index b6a6b4c77..3300fef59 100644
--- a/src/pkg/net/http/serve_test.go
+++ b/src/pkg/net/http/serve_test.go
@@ -20,8 +20,13 @@ import (
"net/http/httputil"
"net/url"
"os"
+ "os/exec"
"reflect"
+ "runtime"
+ "strconv"
"strings"
+ "sync"
+ "sync/atomic"
"syscall"
"testing"
"time"
@@ -62,6 +67,7 @@ func (a dummyAddr) String() string {
type testConn struct {
readBuf bytes.Buffer
writeBuf bytes.Buffer
+ closec chan bool // if non-nil, send value to it on close
}
func (c *testConn) Read(b []byte) (int, error) {
@@ -73,6 +79,10 @@ func (c *testConn) Write(b []byte) (int, error) {
}
func (c *testConn) Close() error {
+ select {
+ case c.closec <- true:
+ default:
+ }
return nil
}
@@ -168,13 +178,18 @@ var vtests = []struct {
{"http://someHost.com/someDir/apage", "someHost.com/someDir"},
{"http://otherHost.com/someDir/apage", "someDir"},
{"http://otherHost.com/aDir/apage", "Default"},
+ // redirections for trees
+ {"http://localhost/someDir", "/someDir/"},
+ {"http://someHost.com/someDir", "/someDir/"},
}
func TestHostHandlers(t *testing.T) {
+ defer checkLeakedTransports(t)
+ mux := NewServeMux()
for _, h := range handlers {
- Handle(h.pattern, stringHandler(h.msg))
+ mux.Handle(h.pattern, stringHandler(h.msg))
}
- ts := httptest.NewServer(nil)
+ ts := httptest.NewServer(mux)
defer ts.Close()
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
@@ -199,9 +214,19 @@ func TestHostHandlers(t *testing.T) {
t.Errorf("reading response: %v", err)
continue
}
- s := r.Header.Get("Result")
- if s != vt.expected {
- t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
+ switch r.StatusCode {
+ case StatusOK:
+ s := r.Header.Get("Result")
+ if s != vt.expected {
+ t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
+ }
+ case StatusMovedPermanently:
+ s := r.Header.Get("Location")
+ if s != vt.expected {
+ t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
+ }
+ default:
+ t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode)
}
}
}
@@ -232,28 +257,22 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) {
}
func TestServerTimeouts(t *testing.T) {
- // TODO(bradfitz): convert this to use httptest.Server
- l, err := net.Listen("tcp", "127.0.0.1:0")
- if err != nil {
- t.Fatalf("listen error: %v", err)
- }
- addr, _ := l.Addr().(*net.TCPAddr)
-
+ defer checkLeakedTransports(t)
reqNum := 0
- handler := HandlerFunc(func(res ResponseWriter, req *Request) {
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {
reqNum++
fmt.Fprintf(res, "req=%d", reqNum)
- })
-
- server := &Server{Handler: handler, ReadTimeout: 250 * time.Millisecond, WriteTimeout: 250 * time.Millisecond}
- go server.Serve(l)
-
- url := fmt.Sprintf("http://%s/", addr)
+ }))
+ ts.Config.ReadTimeout = 250 * time.Millisecond
+ ts.Config.WriteTimeout = 250 * time.Millisecond
+ ts.Start()
+ defer ts.Close()
// Hit the HTTP server successfully.
tr := &Transport{DisableKeepAlives: true} // they interfere with this test
+ defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
- r, err := c.Get(url)
+ r, err := c.Get(ts.URL)
if err != nil {
t.Fatalf("http Get #1: %v", err)
}
@@ -266,13 +285,13 @@ func TestServerTimeouts(t *testing.T) {
// Slow client that should timeout.
t1 := time.Now()
- conn, err := net.Dial("tcp", addr.String())
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
t.Fatalf("Dial: %v", err)
}
buf := make([]byte, 1)
n, err := conn.Read(buf)
- latency := time.Now().Sub(t1)
+ latency := time.Since(t1)
if n != 0 || err != io.EOF {
t.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF)
}
@@ -283,7 +302,7 @@ func TestServerTimeouts(t *testing.T) {
// Hit the HTTP server successfully again, verifying that the
// previous slow connection didn't run our handler. (that we
// get "req=2", not "req=3")
- r, err = Get(url)
+ r, err = Get(ts.URL)
if err != nil {
t.Fatalf("http Get #2: %v", err)
}
@@ -293,11 +312,87 @@ func TestServerTimeouts(t *testing.T) {
t.Errorf("Get #2 got %q, want %q", string(got), expected)
}
- l.Close()
+ if !testing.Short() {
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer conn.Close()
+ go io.Copy(ioutil.Discard, conn)
+ for i := 0; i < 5; i++ {
+ _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
+ if err != nil {
+ t.Fatalf("on write %d: %v", i, err)
+ }
+ time.Sleep(ts.Config.ReadTimeout / 2)
+ }
+ }
+}
+
+// golang.org/issue/4741 -- setting only a write timeout that triggers
+// shouldn't cause a handler to block forever on reads (next HTTP
+// request) that will never happen.
+func TestOnlyWriteTimeout(t *testing.T) {
+ defer checkLeakedTransports(t)
+ var conn net.Conn
+ var afterTimeoutErrc = make(chan error, 1)
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, req *Request) {
+ buf := make([]byte, 512<<10)
+ _, err := w.Write(buf)
+ if err != nil {
+ t.Errorf("handler Write error: %v", err)
+ return
+ }
+ conn.SetWriteDeadline(time.Now().Add(-30 * time.Second))
+ _, err = w.Write(buf)
+ afterTimeoutErrc <- err
+ }))
+ ts.Listener = trackLastConnListener{ts.Listener, &conn}
+ ts.Start()
+ defer ts.Close()
+
+ tr := &Transport{DisableKeepAlives: false}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ errc := make(chan error)
+ go func() {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ errc <- err
+ return
+ }
+ _, err = io.Copy(ioutil.Discard, res.Body)
+ errc <- err
+ }()
+ select {
+ case err := <-errc:
+ if err == nil {
+ t.Errorf("expected an error from Get request")
+ }
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout waiting for Get error")
+ }
+ if err := <-afterTimeoutErrc; err == nil {
+ t.Error("expected write error after timeout")
+ }
+}
+
+// trackLastConnListener tracks the last net.Conn that was accepted.
+type trackLastConnListener struct {
+ net.Listener
+ last *net.Conn // destination
}
-// TestIdentityResponse verifies that a handler can unset
+func (l trackLastConnListener) Accept() (c net.Conn, err error) {
+ c, err = l.Listener.Accept()
+ *l.last = c
+ return
+}
+
+// TestIdentityResponse verifies that a handler can unset
func TestIdentityResponse(t *testing.T) {
+ defer checkLeakedTransports(t)
handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
rw.Header().Set("Content-Length", "3")
rw.Header().Set("Transfer-Encoding", req.FormValue("te"))
@@ -343,10 +438,12 @@ func TestIdentityResponse(t *testing.T) {
// Verify that ErrContentLength is returned
url := ts.URL + "/?overwrite=1"
- _, err := Get(url)
+ res, err := Get(url)
if err != nil {
t.Fatalf("error with Get of %s: %v", url, err)
}
+ res.Body.Close()
+
// Verify that the connection is closed when the declared Content-Length
// is larger than what the handler wrote.
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
@@ -370,7 +467,8 @@ func TestIdentityResponse(t *testing.T) {
})
}
-func testTcpConnectionCloses(t *testing.T, req string, h Handler) {
+func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
+ defer checkLeakedTransports(t)
s := httptest.NewServer(h)
defer s.Close()
@@ -386,17 +484,18 @@ func testTcpConnectionCloses(t *testing.T, req string, h Handler) {
}
r := bufio.NewReader(conn)
- _, err = ReadResponse(r, &Request{Method: "GET"})
+ res, err := ReadResponse(r, &Request{Method: "GET"})
if err != nil {
t.Fatal("ReadResponse error:", err)
}
- success := make(chan bool)
+ didReadAll := make(chan bool, 1)
go func() {
select {
case <-time.After(5 * time.Second):
- t.Fatal("body not closed after 5s")
- case <-success:
+ t.Error("body not closed after 5s")
+ return
+ case <-didReadAll:
}
}()
@@ -404,32 +503,43 @@ func testTcpConnectionCloses(t *testing.T, req string, h Handler) {
if err != nil {
t.Fatal("read error:", err)
}
+ didReadAll <- true
- success <- true
+ if !res.Close {
+ t.Errorf("Response.Close = false; want true")
+ }
}
// TestServeHTTP10Close verifies that HTTP/1.0 requests won't be kept alive.
func TestServeHTTP10Close(t *testing.T) {
- testTcpConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
+ testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
ServeFile(w, r, "testdata/file")
}))
}
+// TestClientCanClose verifies that clients can also force a connection to close.
+func TestClientCanClose(t *testing.T) {
+ testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
+ // Nothing.
+ }))
+}
+
// TestHandlersCanSetConnectionClose verifies that handlers can force a connection to close,
// even for HTTP/1.1 requests.
func TestHandlersCanSetConnectionClose11(t *testing.T) {
- testTcpConnectionCloses(t, "GET / HTTP/1.1\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
+ testTCPConnectionCloses(t, "GET / HTTP/1.1\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Connection", "close")
}))
}
func TestHandlersCanSetConnectionClose10(t *testing.T) {
- testTcpConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
+ testTCPConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Connection", "close")
}))
}
func TestSetsRemoteAddr(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "%s", r.RemoteAddr)
}))
@@ -450,11 +560,13 @@ func TestSetsRemoteAddr(t *testing.T) {
}
func TestChunkedResponseHeaders(t *testing.T) {
+ defer checkLeakedTransports(t)
log.SetOutput(ioutil.Discard) // is noisy otherwise
defer log.SetOutput(os.Stderr)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted
+ w.(Flusher).Flush()
fmt.Fprintf(w, "I am a chunked response.")
}))
defer ts.Close()
@@ -463,6 +575,7 @@ func TestChunkedResponseHeaders(t *testing.T) {
if err != nil {
t.Fatalf("Get error: %v", err)
}
+ defer res.Body.Close()
if g, e := res.ContentLength, int64(-1); g != e {
t.Errorf("expected ContentLength of %d; got %d", e, g)
}
@@ -478,6 +591,7 @@ func TestChunkedResponseHeaders(t *testing.T) {
// chunking in their response headers and aren't allowed to produce
// output.
func Test304Responses(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.WriteHeader(StatusNotModified)
_, err := w.Write([]byte("illegal body"))
@@ -507,6 +621,7 @@ func Test304Responses(t *testing.T) {
// allowed to produce output, and don't set a Content-Type since
// the real type of the body data cannot be inferred.
func TestHeadResponses(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
_, err := w.Write([]byte("Ignored body"))
if err != ErrBodyNotAllowed {
@@ -541,6 +656,7 @@ func TestHeadResponses(t *testing.T) {
}
func TestTLSHandshakeTimeout(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
ts.Config.ReadTimeout = 250 * time.Millisecond
ts.StartTLS()
@@ -560,6 +676,7 @@ func TestTLSHandshakeTimeout(t *testing.T) {
}
func TestTLSServer(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.TLS != nil {
w.Header().Set("X-TLS-Set", "true")
@@ -642,6 +759,7 @@ var serverExpectTests = []serverExpectTest{
// Tests that the server responds to the "Expect" request header
// correctly.
func TestServerExpect(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
// Note using r.FormValue("readbody") because for POST
// requests that would read from r.Body, which we only
@@ -661,30 +779,51 @@ func TestServerExpect(t *testing.T) {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
- sendf := func(format string, args ...interface{}) {
- _, err := fmt.Fprintf(conn, format, args...)
- if err != nil {
- t.Fatalf("On test %#v, error writing %q: %v", test, format, err)
- }
- }
+
+ // Only send the body immediately if we're acting like an HTTP client
+ // that doesn't send 100-continue expectations.
+ writeBody := test.contentLength > 0 && strings.ToLower(test.expectation) != "100-continue"
+
go func() {
- sendf("POST /?readbody=%v HTTP/1.1\r\n"+
+ _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+
"Connection: close\r\n"+
"Content-Length: %d\r\n"+
"Expect: %s\r\nHost: foo\r\n\r\n",
test.readBody, test.contentLength, test.expectation)
- if test.contentLength > 0 && strings.ToLower(test.expectation) != "100-continue" {
+ if err != nil {
+ t.Errorf("On test %#v, error writing request headers: %v", test, err)
+ return
+ }
+ if writeBody {
body := strings.Repeat("A", test.contentLength)
- sendf(body)
+ _, err = fmt.Fprint(conn, body)
+ if err != nil {
+ if !test.readBody {
+ // Server likely already hung up on us.
+ // See larger comment below.
+ t.Logf("On test %#v, acceptable error writing request body: %v", test, err)
+ return
+ }
+ t.Errorf("On test %#v, error writing request body: %v", test, err)
+ }
}
}()
bufr := bufio.NewReader(conn)
line, err := bufr.ReadString('\n')
if err != nil {
- t.Fatalf("ReadString: %v", err)
+ if writeBody && !test.readBody {
+ // This is an acceptable failure due to a possible TCP race:
+ // We were still writing data and the server hung up on us. A TCP
+ // implementation may send a RST if our request body data was known
+ // to be lost, which may trigger our reads to fail.
+ // See RFC 1122 page 88.
+ t.Logf("On test %#v, acceptable error from ReadString: %v", test, err)
+ return
+ }
+ t.Fatalf("On test %#v, ReadString: %v", test, err)
}
if !strings.Contains(line, test.expectedResponse) {
- t.Errorf("for test %#v got first line=%q", test, line)
+ t.Errorf("On test %#v, got first line = %q; want %q", test, line, test.expectedResponse)
}
}
@@ -714,6 +853,7 @@ func TestServerUnreadRequestBodyLittle(t *testing.T) {
t.Errorf("on request, read buffer length is %d; expected about 100 KB", conn.readBuf.Len())
}
rw.WriteHeader(200)
+ rw.(Flusher).Flush()
if g, e := conn.readBuf.Len(), 0; g != e {
t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e)
}
@@ -736,27 +876,28 @@ func TestServerUnreadRequestBodyLarge(t *testing.T) {
"Content-Length: %d\r\n"+
"\r\n", len(body))))
conn.readBuf.Write([]byte(body))
-
- done := make(chan bool)
+ conn.closec = make(chan bool, 1)
ls := &oneConnListener{conn}
go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
- defer close(done)
if conn.readBuf.Len() < len(body)/2 {
t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
}
rw.WriteHeader(200)
+ rw.(Flusher).Flush()
if conn.readBuf.Len() < len(body)/2 {
t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
}
- if c := rw.Header().Get("Connection"); c != "close" {
- t.Errorf(`Connection header = %q; want "close"`, c)
- }
}))
- <-done
+ <-conn.closec
+
+ if res := conn.writeBuf.String(); !strings.Contains(res, "Connection: close") {
+ t.Errorf("Expected a Connection: close header; got response: %s", res)
+ }
}
func TestTimeoutHandler(t *testing.T) {
+ defer checkLeakedTransports(t)
sendHi := make(chan bool, 1)
writeErrors := make(chan error, 1)
sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -831,6 +972,7 @@ func TestRedirectMunging(t *testing.T) {
// the previous request's body, which is not optimal for zero-lengthed bodies,
// as the client would then see http.ErrBodyReadAfterClose and not 0, io.EOF.
func TestZeroLengthPostAndResponse(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
all, err := ioutil.ReadAll(r.Body)
if err != nil {
@@ -868,15 +1010,20 @@ func TestZeroLengthPostAndResponse(t *testing.T) {
}
}
+func TestHandlerPanicNil(t *testing.T) {
+ testHandlerPanic(t, false, nil)
+}
+
func TestHandlerPanic(t *testing.T) {
- testHandlerPanic(t, false)
+ testHandlerPanic(t, false, "intentional death for testing")
}
func TestHandlerPanicWithHijack(t *testing.T) {
- testHandlerPanic(t, true)
+ testHandlerPanic(t, true, "intentional death for testing")
}
-func testHandlerPanic(t *testing.T, withHijack bool) {
+func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) {
+ defer checkLeakedTransports(t)
// Unlike the other tests that set the log output to ioutil.Discard
// to quiet the output, this test uses a pipe. The pipe serves three
// purposes:
@@ -896,6 +1043,7 @@ func testHandlerPanic(t *testing.T, withHijack bool) {
pr, pw := io.Pipe()
log.SetOutput(pw)
defer log.SetOutput(os.Stderr)
+ defer pw.Close()
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if withHijack {
@@ -905,7 +1053,7 @@ func testHandlerPanic(t *testing.T, withHijack bool) {
}
defer rwc.Close()
}
- panic("intentional death for testing")
+ panic(panicValue)
}))
defer ts.Close()
@@ -917,8 +1065,8 @@ func testHandlerPanic(t *testing.T, withHijack bool) {
buf := make([]byte, 4<<10)
_, err := pr.Read(buf)
pr.Close()
- if err != nil {
- t.Fatal(err)
+ if err != nil && err != io.EOF {
+ t.Error(err)
}
done <- true
}()
@@ -928,6 +1076,10 @@ func testHandlerPanic(t *testing.T, withHijack bool) {
t.Logf("expected an error")
}
+ if panicValue == nil {
+ return
+ }
+
select {
case <-done:
return
@@ -937,6 +1089,7 @@ func testHandlerPanic(t *testing.T, withHijack bool) {
}
func TestNoDate(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header()["Date"] = nil
}))
@@ -952,6 +1105,7 @@ func TestNoDate(t *testing.T) {
}
func TestStripPrefix(t *testing.T) {
+ defer checkLeakedTransports(t)
h := HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("X-Path", r.URL.Path)
})
@@ -965,6 +1119,7 @@ func TestStripPrefix(t *testing.T) {
if g, e := res.Header.Get("X-Path"), "/bar"; g != e {
t.Errorf("test 1: got %s, want %s", g, e)
}
+ res.Body.Close()
res, err = Get(ts.URL + "/bar")
if err != nil {
@@ -973,9 +1128,11 @@ func TestStripPrefix(t *testing.T) {
if g, e := res.StatusCode, 404; g != e {
t.Errorf("test 2: got status %v, want %v", g, e)
}
+ res.Body.Close()
}
func TestRequestLimit(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
t.Fatalf("didn't expect to get request in Handler")
}))
@@ -992,6 +1149,7 @@ func TestRequestLimit(t *testing.T) {
// we do support it (at least currently), so we expect a response below.
t.Fatalf("Do: %v", err)
}
+ defer res.Body.Close()
if res.StatusCode != 413 {
t.Fatalf("expected 413 response status; got: %d %s", res.StatusCode, res.Status)
}
@@ -1013,11 +1171,12 @@ type countReader struct {
func (cr countReader) Read(p []byte) (n int, err error) {
n, err = cr.r.Read(p)
- *cr.n += int64(n)
+ atomic.AddInt64(cr.n, int64(n))
return
}
func TestRequestBodyLimit(t *testing.T) {
+ defer checkLeakedTransports(t)
const limit = 1 << 20
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
r.Body = MaxBytesReader(w, r.Body, limit)
@@ -1031,8 +1190,8 @@ func TestRequestBodyLimit(t *testing.T) {
}))
defer ts.Close()
- nWritten := int64(0)
- req, _ := NewRequest("POST", ts.URL, io.LimitReader(countReader{neverEnding('a'), &nWritten}, limit*200))
+ nWritten := new(int64)
+ req, _ := NewRequest("POST", ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200))
// Send the POST, but don't care it succeeds or not. The
// remote side is going to reply and then close the TCP
@@ -1045,7 +1204,7 @@ func TestRequestBodyLimit(t *testing.T) {
// the remote side hung up on us before we wrote too much.
_, _ = DefaultClient.Do(req)
- if nWritten > limit*100 {
+ if atomic.LoadInt64(nWritten) > limit*100 {
t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d",
limit, nWritten)
}
@@ -1054,6 +1213,7 @@ func TestRequestBodyLimit(t *testing.T) {
// TestClientWriteShutdown tests that if the client shuts down the write
// side of their TCP connection, the server doesn't send a 400 Bad Request.
func TestClientWriteShutdown(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
defer ts.Close()
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
@@ -1086,28 +1246,207 @@ func TestClientWriteShutdown(t *testing.T) {
// Tests that chunked server responses that write 1 byte at a time are
// buffered before chunk headers are added, not after chunk headers.
func TestServerBufferedChunking(t *testing.T) {
- if true {
- t.Logf("Skipping known broken test; see Issue 2357")
- return
- }
conn := new(testConn)
conn.readBuf.Write([]byte("GET / HTTP/1.1\r\n\r\n"))
- done := make(chan bool)
+ conn.closec = make(chan bool, 1)
ls := &oneConnListener{conn}
go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
- defer close(done)
- rw.Header().Set("Content-Type", "text/plain") // prevent sniffing, which buffers
+ rw.(Flusher).Flush() // force the Header to be sent, in chunking mode, not counting the length
rw.Write([]byte{'x'})
rw.Write([]byte{'y'})
rw.Write([]byte{'z'})
}))
- <-done
+ <-conn.closec
if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) {
t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q",
conn.writeBuf.Bytes())
}
}
+// Tests that the server flushes its response headers out when it's
+// ignoring the response body and waits a bit before forcefully
+// closing the TCP connection, causing the client to get a RST.
+// See http://golang.org/issue/3595
+func TestServerGracefulClose(t *testing.T) {
+ defer checkLeakedTransports(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ Error(w, "bye", StatusUnauthorized)
+ }))
+ defer ts.Close()
+
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+ const bodySize = 5 << 20
+ req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize))
+ for i := 0; i < bodySize; i++ {
+ req = append(req, 'x')
+ }
+ writeErr := make(chan error)
+ go func() {
+ _, err := conn.Write(req)
+ writeErr <- err
+ }()
+ br := bufio.NewReader(conn)
+ lineNum := 0
+ for {
+ line, err := br.ReadString('\n')
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ t.Fatalf("ReadLine: %v", err)
+ }
+ lineNum++
+ if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") {
+ t.Errorf("Response line = %q; want a 401", line)
+ }
+ }
+ // Wait for write to finish. This is a broken pipe on both
+ // Darwin and Linux, but checking this isn't the point of
+ // the test.
+ <-writeErr
+}
+
+func TestCaseSensitiveMethod(t *testing.T) {
+ defer checkLeakedTransports(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method != "get" {
+ t.Errorf(`Got method %q; want "get"`, r.Method)
+ }
+ }))
+ defer ts.Close()
+ req, _ := NewRequest("get", ts.URL, nil)
+ res, err := DefaultClient.Do(req)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ res.Body.Close()
+}
+
+// TestContentLengthZero tests that for both an HTTP/1.0 and HTTP/1.1
+// request (both keep-alive), when a Handler never writes any
+// response, the net/http package adds a "Content-Length: 0" response
+// header.
+func TestContentLengthZero(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {}))
+ defer ts.Close()
+
+ for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} {
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("error dialing: %v", err)
+ }
+ _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version)
+ if err != nil {
+ t.Fatalf("error writing: %v", err)
+ }
+ req, _ := NewRequest("GET", "/", nil)
+ res, err := ReadResponse(bufio.NewReader(conn), req)
+ if err != nil {
+ t.Fatalf("error reading response: %v", err)
+ }
+ if te := res.TransferEncoding; len(te) > 0 {
+ t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te)
+ }
+ if cl := res.ContentLength; cl != 0 {
+ t.Errorf("For version %q, Content-Length = %v; want 0", version, cl)
+ }
+ conn.Close()
+ }
+}
+
+func TestCloseNotifier(t *testing.T) {
+ gotReq := make(chan bool, 1)
+ sawClose := make(chan bool, 1)
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ gotReq <- true
+ cc := rw.(CloseNotifier).CloseNotify()
+ <-cc
+ sawClose <- true
+ }))
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("error dialing: %v", err)
+ }
+ diec := make(chan bool)
+ go func() {
+ _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
+ if err != nil {
+ t.Fatal(err)
+ }
+ <-diec
+ conn.Close()
+ }()
+For:
+ for {
+ select {
+ case <-gotReq:
+ diec <- true
+ case <-sawClose:
+ break For
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout")
+ }
+ }
+ ts.Close()
+}
+
+func TestOptions(t *testing.T) {
+ uric := make(chan string, 2) // only expect 1, but leave space for 2
+ mux := NewServeMux()
+ mux.HandleFunc("/", func(w ResponseWriter, r *Request) {
+ uric <- r.RequestURI
+ })
+ ts := httptest.NewServer(mux)
+ defer ts.Close()
+
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+
+ // An OPTIONS * request should succeed.
+ _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ br := bufio.NewReader(conn)
+ res, err := ReadResponse(br, &Request{Method: "OPTIONS"})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 200 {
+ t.Errorf("Got non-200 response to OPTIONS *: %#v", res)
+ }
+
+ // A GET * request on a ServeMux should fail.
+ _, err = conn.Write([]byte("GET * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ res, err = ReadResponse(br, &Request{Method: "GET"})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 400 {
+ t.Errorf("Got non-400 response to GET *: %#v", res)
+ }
+
+ res, err = Get(ts.URL + "/second")
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if got := <-uric; got != "/second" {
+ t.Errorf("Handler saw request for %q; want /second", got)
+ }
+}
+
// goTimeout runs f, failing t if f takes more than ns to complete.
func goTimeout(t *testing.T, d time.Duration, f func()) {
ch := make(chan bool, 2)
@@ -1184,3 +1523,100 @@ func BenchmarkClientServer(b *testing.B) {
b.StopTimer()
}
+
+func BenchmarkClientServerParallel4(b *testing.B) {
+ benchmarkClientServerParallel(b, 4)
+}
+
+func BenchmarkClientServerParallel64(b *testing.B) {
+ benchmarkClientServerParallel(b, 64)
+}
+
+func benchmarkClientServerParallel(b *testing.B, conc int) {
+ b.StopTimer()
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
+ fmt.Fprintf(rw, "Hello world.\n")
+ }))
+ defer ts.Close()
+ b.StartTimer()
+
+ numProcs := runtime.GOMAXPROCS(-1) * conc
+ var wg sync.WaitGroup
+ wg.Add(numProcs)
+ n := int32(b.N)
+ for p := 0; p < numProcs; p++ {
+ go func() {
+ for atomic.AddInt32(&n, -1) >= 0 {
+ res, err := Get(ts.URL)
+ if err != nil {
+ b.Logf("Get: %v", err)
+ continue
+ }
+ all, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ b.Logf("ReadAll: %v", err)
+ continue
+ }
+ body := string(all)
+ if body != "Hello world.\n" {
+ panic("Got body: " + body)
+ }
+ }
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+}
+
+// A benchmark for profiling the server without the HTTP client code.
+// The client code runs in a subprocess.
+//
+// For use like:
+// $ go test -c
+// $ ./http.test -test.run=XX -test.bench=BenchmarkServer -test.benchtime=15s -test.cpuprofile=http.prof
+// $ go tool pprof http.test http.prof
+// (pprof) web
+func BenchmarkServer(b *testing.B) {
+ // Child process mode;
+ if url := os.Getenv("TEST_BENCH_SERVER_URL"); url != "" {
+ n, err := strconv.Atoi(os.Getenv("TEST_BENCH_CLIENT_N"))
+ if err != nil {
+ panic(err)
+ }
+ for i := 0; i < n; i++ {
+ res, err := Get(url)
+ if err != nil {
+ log.Panicf("Get: %v", err)
+ }
+ all, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ log.Panicf("ReadAll: %v", err)
+ }
+ body := string(all)
+ if body != "Hello world.\n" {
+ log.Panicf("Got body: %q", body)
+ }
+ }
+ os.Exit(0)
+ return
+ }
+
+ var res = []byte("Hello world.\n")
+ b.StopTimer()
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
+ rw.Header().Set("Content-Type", "text/html; charset=utf-8")
+ rw.Write(res)
+ }))
+ defer ts.Close()
+ b.StartTimer()
+
+ cmd := exec.Command(os.Args[0], "-test.run=XXXX", "-test.bench=BenchmarkServer")
+ cmd.Env = append([]string{
+ fmt.Sprintf("TEST_BENCH_CLIENT_N=%d", b.N),
+ fmt.Sprintf("TEST_BENCH_SERVER_URL=%s", ts.URL),
+ }, os.Environ()...)
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ b.Errorf("Test failure: %v, with output: %s", err, out)
+ }
+}
diff --git a/src/pkg/net/http/server.go b/src/pkg/net/http/server.go
index 0572b4ae3..b6ab78228 100644
--- a/src/pkg/net/http/server.go
+++ b/src/pkg/net/http/server.go
@@ -11,7 +11,6 @@ package http
import (
"bufio"
- "bytes"
"crypto/tls"
"errors"
"fmt"
@@ -21,7 +20,7 @@ import (
"net"
"net/url"
"path"
- "runtime/debug"
+ "runtime"
"strconv"
"strings"
"sync"
@@ -94,30 +93,188 @@ type Hijacker interface {
Hijack() (net.Conn, *bufio.ReadWriter, error)
}
+// The CloseNotifier interface is implemented by ResponseWriters which
+// allow detecting when the underlying connection has gone away.
+//
+// This mechanism can be used to cancel long operations on the server
+// if the client has disconnected before the response is ready.
+type CloseNotifier interface {
+ // CloseNotify returns a channel that receives a single value
+ // when the client connection has gone away.
+ CloseNotify() <-chan bool
+}
+
// A conn represents the server side of an HTTP connection.
type conn struct {
remoteAddr string // network address of remote side
server *Server // the Server on which the connection arrived
rwc net.Conn // i/o connection
- lr *io.LimitedReader // io.LimitReader(rwc)
- buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->rwc
- hijacked bool // connection has been hijacked by handler
+ sr switchReader // where the LimitReader reads from; usually the rwc
+ lr *io.LimitedReader // io.LimitReader(sr)
+ buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->sr->rwc
tlsState *tls.ConnectionState // or nil when not using TLS
- body []byte
+
+ mu sync.Mutex // guards the following
+ clientGone bool // if client has disconnected mid-request
+ closeNotifyc chan bool // made lazily
+ hijackedv bool // connection has been hijacked by handler
+}
+
+func (c *conn) hijacked() bool {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ return c.hijackedv
+}
+
+func (c *conn) hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.hijackedv {
+ return nil, nil, ErrHijacked
+ }
+ if c.closeNotifyc != nil {
+ return nil, nil, errors.New("http: Hijack is incompatible with use of CloseNotifier")
+ }
+ c.hijackedv = true
+ rwc = c.rwc
+ buf = c.buf
+ c.rwc = nil
+ c.buf = nil
+ return
+}
+
+func (c *conn) closeNotify() <-chan bool {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.closeNotifyc == nil {
+ c.closeNotifyc = make(chan bool)
+ if c.hijackedv {
+ // to obey the function signature, even though
+ // it'll never receive a value.
+ return c.closeNotifyc
+ }
+ pr, pw := io.Pipe()
+
+ readSource := c.sr.r
+ c.sr.Lock()
+ c.sr.r = pr
+ c.sr.Unlock()
+ go func() {
+ _, err := io.Copy(pw, readSource)
+ if err == nil {
+ err = io.EOF
+ }
+ pw.CloseWithError(err)
+ c.noteClientGone()
+ }()
+ }
+ return c.closeNotifyc
+}
+
+func (c *conn) noteClientGone() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.closeNotifyc != nil && !c.clientGone {
+ c.closeNotifyc <- true
+ }
+ c.clientGone = true
+}
+
+type switchReader struct {
+ sync.Mutex
+ r io.Reader
+}
+
+func (sr *switchReader) Read(p []byte) (n int, err error) {
+ sr.Lock()
+ r := sr.r
+ sr.Unlock()
+ return r.Read(p)
+}
+
+// This should be >= 512 bytes for DetectContentType,
+// but otherwise it's somewhat arbitrary.
+const bufferBeforeChunkingSize = 2048
+
+// chunkWriter writes to a response's conn buffer, and is the writer
+// wrapped by the response.bufw buffered writer.
+//
+// chunkWriter also is responsible for finalizing the Header, including
+// conditionally setting the Content-Type and setting a Content-Length
+// in cases where the handler's final output is smaller than the buffer
+// size. It also conditionally adds chunk headers, when in chunking mode.
+//
+// See the comment above (*response).Write for the entire write flow.
+type chunkWriter struct {
+ res *response
+ header Header // a deep copy of r.Header, once WriteHeader is called
+ wroteHeader bool // whether the header's been sent
+
+ // set by the writeHeader method:
+ chunking bool // using chunked transfer encoding for reply body
+}
+
+var crlf = []byte("\r\n")
+
+func (cw *chunkWriter) Write(p []byte) (n int, err error) {
+ if !cw.wroteHeader {
+ cw.writeHeader(p)
+ }
+ if cw.chunking {
+ _, err = fmt.Fprintf(cw.res.conn.buf, "%x\r\n", len(p))
+ if err != nil {
+ cw.res.conn.rwc.Close()
+ return
+ }
+ }
+ n, err = cw.res.conn.buf.Write(p)
+ if cw.chunking && err == nil {
+ _, err = cw.res.conn.buf.Write(crlf)
+ }
+ if err != nil {
+ cw.res.conn.rwc.Close()
+ }
+ return
+}
+
+func (cw *chunkWriter) flush() {
+ if !cw.wroteHeader {
+ cw.writeHeader(nil)
+ }
+ cw.res.conn.buf.Flush()
+}
+
+func (cw *chunkWriter) close() {
+ if !cw.wroteHeader {
+ cw.writeHeader(nil)
+ }
+ if cw.chunking {
+ // zero EOF chunk, trailer key/value pairs (currently
+ // unsupported in Go's server), followed by a blank
+ // line.
+ io.WriteString(cw.res.conn.buf, "0\r\n\r\n")
+ }
}
// A response represents the server side of an HTTP response.
type response struct {
conn *conn
req *Request // request for this response
- chunking bool // using chunked transfer encoding for reply body
- wroteHeader bool // reply header has been written
+ wroteHeader bool // reply header has been (logically) written
wroteContinue bool // 100 Continue response was written
- header Header // reply header parameters
- written int64 // number of bytes written in body
- contentLength int64 // explicitly-declared Content-Length; or -1
- status int // status code passed to WriteHeader
- needSniff bool // need to sniff to find Content-Type
+
+ w *bufio.Writer // buffers output in chunks to chunkWriter
+ cw *chunkWriter
+
+ // handlerHeader is the Header that Handlers get access to,
+ // which may be retained and mutated even after WriteHeader.
+ // handlerHeader is copied into cw.header at WriteHeader
+ // time, and privately mutated thereafter.
+ handlerHeader Header
+
+ written int64 // number of bytes written in body
+ contentLength int64 // explicitly-declared Content-Length; or -1
+ status int // status code passed to WriteHeader
// close connection after this reply. set on request and
// updated after response from handler if there's a
@@ -127,12 +284,14 @@ type response struct {
// requestBodyLimitHit is set by requestTooLarge when
// maxBytesReader hits its max size. It is checked in
- // WriteHeader, to make sure we don't consume the the
+ // WriteHeader, to make sure we don't consume the
// remaining request body to try to advance to the next HTTP
- // request. Instead, when this is set, we stop doing
+ // request. Instead, when this is set, we stop reading
// subsequent requests on this connection and stop reading
// input from it.
requestBodyLimitHit bool
+
+ handlerDone bool // set true when the handler exits
}
// requestTooLarge is called by maxBytesReader when too much input has
@@ -145,42 +304,68 @@ func (w *response) requestTooLarge() {
}
}
+// needsSniff returns whether a Content-Type still needs to be sniffed.
+func (w *response) needsSniff() bool {
+ return !w.cw.wroteHeader && w.handlerHeader.Get("Content-Type") == "" && w.written < sniffLen
+}
+
type writerOnly struct {
io.Writer
}
func (w *response) ReadFrom(src io.Reader) (n int64, err error) {
- // Call WriteHeader before checking w.chunking if it hasn't
- // been called yet, since WriteHeader is what sets w.chunking.
if !w.wroteHeader {
w.WriteHeader(StatusOK)
}
- if !w.chunking && w.bodyAllowed() && !w.needSniff {
- w.Flush()
+
+ if w.needsSniff() {
+ n0, err := io.Copy(writerOnly{w}, io.LimitReader(src, sniffLen))
+ n += n0
+ if err != nil {
+ return n, err
+ }
+ }
+
+ w.w.Flush() // get rid of any previous writes
+ w.cw.flush() // make sure Header is written; flush data to rwc
+
+ // Now that cw has been flushed, its chunking field is guaranteed initialized.
+ if !w.cw.chunking && w.bodyAllowed() {
if rf, ok := w.conn.rwc.(io.ReaderFrom); ok {
- n, err = rf.ReadFrom(src)
- w.written += n
- return
+ n0, err := rf.ReadFrom(src)
+ n += n0
+ w.written += n0
+ return n, err
}
}
+
// Fall back to default io.Copy implementation.
// Use wrapper to hide w.ReadFrom from io.Copy.
- return io.Copy(writerOnly{w}, src)
+ n0, err := io.Copy(writerOnly{w}, src)
+ n += n0
+ return n, err
}
// noLimit is an effective infinite upper bound for io.LimitedReader
const noLimit int64 = (1 << 63) - 1
+// debugServerConnections controls whether all server connections are wrapped
+// with a verbose logging wrapper.
+const debugServerConnections = false
+
// Create new connection from rwc.
func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) {
c = new(conn)
c.remoteAddr = rwc.RemoteAddr().String()
c.server = srv
c.rwc = rwc
- c.body = make([]byte, sniffLen)
- c.lr = io.LimitReader(rwc, noLimit).(*io.LimitedReader)
+ if debugServerConnections {
+ c.rwc = newLoggingConn("server", c.rwc)
+ }
+ c.sr = switchReader{r: c.rwc}
+ c.lr = io.LimitReader(&c.sr, noLimit).(*io.LimitedReader)
br := bufio.NewReader(c.lr)
- bw := bufio.NewWriter(rwc)
+ bw := bufio.NewWriter(c.rwc)
c.buf = bufio.NewReadWriter(br, bw)
return c, nil
}
@@ -207,9 +392,9 @@ type expectContinueReader struct {
func (ecr *expectContinueReader) Read(p []byte) (n int, err error) {
if ecr.closed {
- return 0, errors.New("http: Read after Close on request Body")
+ return 0, ErrBodyReadAfterClose
}
- if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked {
+ if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked() {
ecr.resp.wroteContinue = true
io.WriteString(ecr.resp.conn.buf, "HTTP/1.1 100 Continue\r\n\r\n")
ecr.resp.conn.buf.Flush()
@@ -232,9 +417,19 @@ var errTooLarge = errors.New("http: request too large")
// Read next request from connection.
func (c *conn) readRequest() (w *response, err error) {
- if c.hijacked {
+ if c.hijacked() {
return nil, ErrHijacked
}
+
+ if d := c.server.ReadTimeout; d != 0 {
+ c.rwc.SetReadDeadline(time.Now().Add(d))
+ }
+ if d := c.server.WriteTimeout; d != 0 {
+ defer func() {
+ c.rwc.SetWriteDeadline(time.Now().Add(d))
+ }()
+ }
+
c.lr.N = int64(c.server.maxHeaderBytes()) + 4096 /* bufio slop */
var req *Request
if req, err = ReadRequest(c.buf.Reader); err != nil {
@@ -248,17 +443,20 @@ func (c *conn) readRequest() (w *response, err error) {
req.RemoteAddr = c.remoteAddr
req.TLS = c.tlsState
- w = new(response)
- w.conn = c
- w.req = req
- w.header = make(Header)
- w.contentLength = -1
- c.body = c.body[:0]
+ w = &response{
+ conn: c,
+ req: req,
+ handlerHeader: make(Header),
+ contentLength: -1,
+ cw: new(chunkWriter),
+ }
+ w.cw.res = w
+ w.w = bufio.NewWriterSize(w.cw, bufferBeforeChunkingSize)
return w, nil
}
func (w *response) Header() Header {
- return w.header
+ return w.handlerHeader
}
// maxPostHandlerReadBytes is the max number of Request.Body bytes not
@@ -273,7 +471,7 @@ func (w *response) Header() Header {
const maxPostHandlerReadBytes = 256 << 10
func (w *response) WriteHeader(code int) {
- if w.conn.hijacked {
+ if w.conn.hijacked() {
log.Print("http: response.WriteHeader on hijacked connection")
return
}
@@ -284,31 +482,68 @@ func (w *response) WriteHeader(code int) {
w.wroteHeader = true
w.status = code
- // Check for a explicit (and valid) Content-Length header.
- var hasCL bool
- var contentLength int64
- if clenStr := w.header.Get("Content-Length"); clenStr != "" {
- var err error
- contentLength, err = strconv.ParseInt(clenStr, 10, 64)
- if err == nil {
- hasCL = true
+ w.cw.header = w.handlerHeader.clone()
+
+ if cl := w.cw.header.get("Content-Length"); cl != "" {
+ v, err := strconv.ParseInt(cl, 10, 64)
+ if err == nil && v >= 0 {
+ w.contentLength = v
} else {
- log.Printf("http: invalid Content-Length of %q sent", clenStr)
- w.header.Del("Content-Length")
+ log.Printf("http: invalid Content-Length of %q", cl)
+ w.cw.header.Del("Content-Length")
+ }
+ }
+}
+
+// writeHeader finalizes the header sent to the client and writes it
+// to cw.res.conn.buf.
+//
+// p is not written by writeHeader, but is the first chunk of the body
+// that will be written. It is sniffed for a Content-Type if none is
+// set explicitly. It's also used to set the Content-Length, if the
+// total body size was small and the handler has already finished
+// running.
+func (cw *chunkWriter) writeHeader(p []byte) {
+ if cw.wroteHeader {
+ return
+ }
+ cw.wroteHeader = true
+
+ w := cw.res
+ code := w.status
+ done := w.handlerDone
+
+ // If the handler is done but never sent a Content-Length
+ // response header and this is our first (and last) write, set
+ // it, even to zero. This helps HTTP/1.0 clients keep their
+ // "keep-alive" connections alive.
+ if done && cw.header.get("Content-Length") == "" && w.req.Method != "HEAD" {
+ w.contentLength = int64(len(p))
+ cw.header.Set("Content-Length", strconv.Itoa(len(p)))
+ }
+
+ // If this was an HTTP/1.0 request with keep-alive and we sent a
+ // Content-Length back, we can make this a keep-alive response ...
+ if w.req.wantsHttp10KeepAlive() {
+ sentLength := cw.header.get("Content-Length") != ""
+ if sentLength && cw.header.get("Connection") == "keep-alive" {
+ w.closeAfterReply = false
}
}
+ // Check for a explicit (and valid) Content-Length header.
+ hasCL := w.contentLength != -1
+
if w.req.wantsHttp10KeepAlive() && (w.req.Method == "HEAD" || hasCL) {
- _, connectionHeaderSet := w.header["Connection"]
+ _, connectionHeaderSet := cw.header["Connection"]
if !connectionHeaderSet {
- w.header.Set("Connection", "keep-alive")
+ cw.header.Set("Connection", "keep-alive")
}
- } else if !w.req.ProtoAtLeast(1, 1) {
- // Client did not ask to keep connection alive.
+ } else if !w.req.ProtoAtLeast(1, 1) || w.req.wantsClose() {
w.closeAfterReply = true
}
- if w.header.Get("Connection") == "close" {
+ if cw.header.get("Connection") == "close" {
w.closeAfterReply = true
}
@@ -322,7 +557,7 @@ func (w *response) WriteHeader(code int) {
n, _ := io.CopyN(ioutil.Discard, w.req.Body, maxPostHandlerReadBytes+1)
if n >= maxPostHandlerReadBytes {
w.requestTooLarge()
- w.header.Set("Connection", "close")
+ cw.header.Set("Connection", "close")
} else {
w.req.Body.Close()
}
@@ -332,64 +567,67 @@ func (w *response) WriteHeader(code int) {
if code == StatusNotModified {
// Must not have body.
for _, header := range []string{"Content-Type", "Content-Length", "Transfer-Encoding"} {
- if w.header.Get(header) != "" {
- // TODO: return an error if WriteHeader gets a return parameter
- // or set a flag on w to make future Writes() write an error page?
- // for now just log and drop the header.
- log.Printf("http: StatusNotModified response with header %q defined", header)
- w.header.Del(header)
+ // RFC 2616 section 10.3.5: "the response MUST NOT include other entity-headers"
+ if cw.header.get(header) != "" {
+ cw.header.Del(header)
}
}
} else {
// If no content type, apply sniffing algorithm to body.
- if w.header.Get("Content-Type") == "" && w.req.Method != "HEAD" {
- w.needSniff = true
+ if cw.header.get("Content-Type") == "" && w.req.Method != "HEAD" {
+ cw.header.Set("Content-Type", DetectContentType(p))
}
}
- if _, ok := w.header["Date"]; !ok {
- w.Header().Set("Date", time.Now().UTC().Format(TimeFormat))
+ if _, ok := cw.header["Date"]; !ok {
+ cw.header.Set("Date", time.Now().UTC().Format(TimeFormat))
}
- te := w.header.Get("Transfer-Encoding")
+ te := cw.header.get("Transfer-Encoding")
hasTE := te != ""
if hasCL && hasTE && te != "identity" {
// TODO: return an error if WriteHeader gets a return parameter
// For now just ignore the Content-Length.
log.Printf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d",
- te, contentLength)
- w.header.Del("Content-Length")
+ te, w.contentLength)
+ cw.header.Del("Content-Length")
hasCL = false
}
if w.req.Method == "HEAD" || code == StatusNotModified {
// do nothing
+ } else if code == StatusNoContent {
+ cw.header.Del("Transfer-Encoding")
} else if hasCL {
- w.contentLength = contentLength
- w.header.Del("Transfer-Encoding")
+ cw.header.Del("Transfer-Encoding")
} else if w.req.ProtoAtLeast(1, 1) {
// HTTP/1.1 or greater: use chunked transfer encoding
// to avoid closing the connection at EOF.
// TODO: this blows away any custom or stacked Transfer-Encoding they
// might have set. Deal with that as need arises once we have a valid
// use case.
- w.chunking = true
- w.header.Set("Transfer-Encoding", "chunked")
+ cw.chunking = true
+ cw.header.Set("Transfer-Encoding", "chunked")
} else {
// HTTP version < 1.1: cannot do chunked transfer
// encoding and we don't know the Content-Length so
// signal EOF by closing connection.
w.closeAfterReply = true
- w.header.Del("Transfer-Encoding") // in case already set
+ cw.header.Del("Transfer-Encoding") // in case already set
}
// Cannot use Content-Length with non-identity Transfer-Encoding.
- if w.chunking {
- w.header.Del("Content-Length")
+ if cw.chunking {
+ cw.header.Del("Content-Length")
}
if !w.req.ProtoAtLeast(1, 0) {
return
}
+
+ if w.closeAfterReply && !hasToken(cw.header.get("Connection"), "close") {
+ cw.header.Set("Connection", "close")
+ }
+
proto := "HTTP/1.0"
if w.req.ProtoAtLeast(1, 1) {
proto = "HTTP/1.1"
@@ -400,37 +638,8 @@ func (w *response) WriteHeader(code int) {
text = "status code " + codestring
}
io.WriteString(w.conn.buf, proto+" "+codestring+" "+text+"\r\n")
- w.header.Write(w.conn.buf)
-
- // If we need to sniff the body, leave the header open.
- // Otherwise, end it here.
- if !w.needSniff {
- io.WriteString(w.conn.buf, "\r\n")
- }
-}
-
-// sniff uses the first block of written data,
-// stored in w.conn.body, to decide the Content-Type
-// for the HTTP body.
-func (w *response) sniff() {
- if !w.needSniff {
- return
- }
- w.needSniff = false
-
- data := w.conn.body
- fmt.Fprintf(w.conn.buf, "Content-Type: %s\r\n\r\n", DetectContentType(data))
-
- if len(data) == 0 {
- return
- }
- if w.chunking {
- fmt.Fprintf(w.conn.buf, "%x\r\n", len(data))
- }
- _, err := w.conn.buf.Write(data)
- if w.chunking && err == nil {
- io.WriteString(w.conn.buf, "\r\n")
- }
+ cw.header.Write(w.conn.buf)
+ w.conn.buf.Write(crlf)
}
// bodyAllowed returns true if a Write is allowed for this response type.
@@ -442,8 +651,40 @@ func (w *response) bodyAllowed() bool {
return w.status != StatusNotModified && w.req.Method != "HEAD"
}
+// The Life Of A Write is like this:
+//
+// Handler starts. No header has been sent. The handler can either
+// write a header, or just start writing. Writing before sending a header
+// sends an implicity empty 200 OK header.
+//
+// If the handler didn't declare a Content-Length up front, we either
+// go into chunking mode or, if the handler finishes running before
+// the chunking buffer size, we compute a Content-Length and send that
+// in the header instead.
+//
+// Likewise, if the handler didn't set a Content-Type, we sniff that
+// from the initial chunk of output.
+//
+// The Writers are wired together like:
+//
+// 1. *response (the ResponseWriter) ->
+// 2. (*response).w, a *bufio.Writer of bufferBeforeChunkingSize bytes
+// 3. chunkWriter.Writer (whose writeHeader finalizes Content-Length/Type)
+// and which writes the chunk headers, if needed.
+// 4. conn.buf, a bufio.Writer of default (4kB) bytes
+// 5. the rwc, the net.Conn.
+//
+// TODO(bradfitz): short-circuit some of the buffering when the
+// initial header contains both a Content-Type and Content-Length.
+// Also short-circuit in (1) when the header's been sent and not in
+// chunking mode, writing directly to (4) instead, if (2) has no
+// buffered data. More generally, we could short-circuit from (1) to
+// (3) even in chunking mode if the write size from (1) is over some
+// threshold and nothing is in (2). The answer might be mostly making
+// bufferBeforeChunkingSize smaller and having bufio's fast-paths deal
+// with this instead.
func (w *response) Write(data []byte) (n int, err error) {
- if w.conn.hijacked {
+ if w.conn.hijacked() {
log.Print("http: response.Write on hijacked connection")
return 0, ErrHijacked
}
@@ -461,73 +702,20 @@ func (w *response) Write(data []byte) (n int, err error) {
if w.contentLength != -1 && w.written > w.contentLength {
return 0, ErrContentLength
}
-
- var m int
- if w.needSniff {
- // We need to sniff the beginning of the output to
- // determine the content type. Accumulate the
- // initial writes in w.conn.body.
- // Cap m so that append won't allocate.
- m = cap(w.conn.body) - len(w.conn.body)
- if m > len(data) {
- m = len(data)
- }
- w.conn.body = append(w.conn.body, data[:m]...)
- data = data[m:]
- if len(data) == 0 {
- // Copied everything into the buffer.
- // Wait for next write.
- return m, nil
- }
-
- // Filled the buffer; more data remains.
- // Sniff the content (flushes the buffer)
- // and then proceed with the remainder
- // of the data as a normal Write.
- // Calling sniff clears needSniff.
- w.sniff()
- }
-
- // TODO(rsc): if chunking happened after the buffering,
- // then there would be fewer chunk headers.
- // On the other hand, it would make hijacking more difficult.
- if w.chunking {
- fmt.Fprintf(w.conn.buf, "%x\r\n", len(data)) // TODO(rsc): use strconv not fmt
- }
- n, err = w.conn.buf.Write(data)
- if err == nil && w.chunking {
- if n != len(data) {
- err = io.ErrShortWrite
- }
- if err == nil {
- io.WriteString(w.conn.buf, "\r\n")
- }
- }
-
- return m + n, err
+ return w.w.Write(data)
}
func (w *response) finishRequest() {
- // If this was an HTTP/1.0 request with keep-alive and we sent a Content-Length
- // back, we can make this a keep-alive response ...
- if w.req.wantsHttp10KeepAlive() {
- sentLength := w.header.Get("Content-Length") != ""
- if sentLength && w.header.Get("Connection") == "keep-alive" {
- w.closeAfterReply = false
- }
- }
+ w.handlerDone = true
+
if !w.wroteHeader {
w.WriteHeader(StatusOK)
}
- if w.needSniff {
- w.sniff()
- }
- if w.chunking {
- io.WriteString(w.conn.buf, "0\r\n")
- // trailer key/value pairs, followed by blank line
- io.WriteString(w.conn.buf, "\r\n")
- }
+
+ w.w.Flush()
+ w.cw.close()
w.conn.buf.Flush()
+
// Close the body, unless we're about to close the whole TCP connection
// anyway.
if !w.closeAfterReply {
@@ -537,7 +725,7 @@ func (w *response) finishRequest() {
w.req.MultipartForm.RemoveAll()
}
- if w.contentLength != -1 && w.contentLength != w.written {
+ if w.contentLength != -1 && w.bodyAllowed() && w.contentLength != w.written {
// Did not write enough. Avoid getting out of sync.
w.closeAfterReply = true
}
@@ -547,66 +735,114 @@ func (w *response) Flush() {
if !w.wroteHeader {
w.WriteHeader(StatusOK)
}
- w.sniff()
- w.conn.buf.Flush()
+ w.w.Flush()
+ w.cw.flush()
}
-// Close the connection.
-func (c *conn) close() {
+func (c *conn) finalFlush() {
if c.buf != nil {
c.buf.Flush()
c.buf = nil
}
+}
+
+// Close the connection.
+func (c *conn) close() {
+ c.finalFlush()
if c.rwc != nil {
c.rwc.Close()
c.rwc = nil
}
}
+// rstAvoidanceDelay is the amount of time we sleep after closing the
+// write side of a TCP connection before closing the entire socket.
+// By sleeping, we increase the chances that the client sees our FIN
+// and processes its final data before they process the subsequent RST
+// from closing a connection with known unread data.
+// This RST seems to occur mostly on BSD systems. (And Windows?)
+// This timeout is somewhat arbitrary (~latency around the planet).
+const rstAvoidanceDelay = 500 * time.Millisecond
+
+// closeWrite flushes any outstanding data and sends a FIN packet (if
+// client is connected via TCP), signalling that we're done. We then
+// pause for a bit, hoping the client processes it before `any
+// subsequent RST.
+//
+// See http://golang.org/issue/3595
+func (c *conn) closeWriteAndWait() {
+ c.finalFlush()
+ if tcp, ok := c.rwc.(*net.TCPConn); ok {
+ tcp.CloseWrite()
+ }
+ time.Sleep(rstAvoidanceDelay)
+}
+
+// validNPN returns whether the proto is not a blacklisted Next
+// Protocol Negotiation protocol. Empty and built-in protocol types
+// are blacklisted and can't be overridden with alternate
+// implementations.
+func validNPN(proto string) bool {
+ switch proto {
+ case "", "http/1.1", "http/1.0":
+ return false
+ }
+ return true
+}
+
// Serve a new connection.
func (c *conn) serve() {
defer func() {
- err := recover()
- if err == nil {
- return
+ if err := recover(); err != nil {
+ const size = 4096
+ buf := make([]byte, size)
+ buf = buf[:runtime.Stack(buf, false)]
+ log.Printf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf)
}
-
- var buf bytes.Buffer
- fmt.Fprintf(&buf, "http: panic serving %v: %v\n", c.remoteAddr, err)
- buf.Write(debug.Stack())
- log.Print(buf.String())
-
- if c.rwc != nil { // may be nil if connection hijacked
- c.rwc.Close()
+ if !c.hijacked() {
+ c.close()
}
}()
if tlsConn, ok := c.rwc.(*tls.Conn); ok {
+ if d := c.server.ReadTimeout; d != 0 {
+ c.rwc.SetReadDeadline(time.Now().Add(d))
+ }
+ if d := c.server.WriteTimeout; d != 0 {
+ c.rwc.SetWriteDeadline(time.Now().Add(d))
+ }
if err := tlsConn.Handshake(); err != nil {
- c.close()
return
}
c.tlsState = new(tls.ConnectionState)
*c.tlsState = tlsConn.ConnectionState()
+ if proto := c.tlsState.NegotiatedProtocol; validNPN(proto) {
+ if fn := c.server.TLSNextProto[proto]; fn != nil {
+ h := initNPNRequest{tlsConn, serverHandler{c.server}}
+ fn(c.server, tlsConn, h)
+ }
+ return
+ }
}
for {
w, err := c.readRequest()
if err != nil {
- msg := "400 Bad Request"
if err == errTooLarge {
// Their HTTP client may or may not be
// able to read this if we're
// responding to them and hanging up
// while they're still writing their
// request. Undefined behavior.
- msg = "413 Request Entity Too Large"
+ io.WriteString(c.rwc, "HTTP/1.1 413 Request Entity Too Large\r\n\r\n")
+ c.closeWriteAndWait()
+ break
} else if err == io.EOF {
break // Don't reply
} else if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
break // Don't reply
}
- fmt.Fprintf(c.rwc, "HTTP/1.1 %s\r\n\r\n", msg)
+ io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\n\r\n")
break
}
@@ -624,59 +860,59 @@ func (c *conn) serve() {
break
}
req.Header.Del("Expect")
- } else if req.Header.Get("Expect") != "" {
- // TODO(bradfitz): let ServeHTTP handlers handle
- // requests with non-standard expectation[s]? Seems
- // theoretical at best, and doesn't fit into the
- // current ServeHTTP model anyway. We'd need to
- // make the ResponseWriter an optional
- // "ExpectReplier" interface or something.
- //
- // For now we'll just obey RFC 2616 14.20 which says
- // "If a server receives a request containing an
- // Expect field that includes an expectation-
- // extension that it does not support, it MUST
- // respond with a 417 (Expectation Failed) status."
- w.Header().Set("Connection", "close")
- w.WriteHeader(StatusExpectationFailed)
- w.finishRequest()
+ } else if req.Header.get("Expect") != "" {
+ w.sendExpectationFailed()
break
}
- handler := c.server.Handler
- if handler == nil {
- handler = DefaultServeMux
- }
-
// HTTP cannot have multiple simultaneous active requests.[*]
// Until the server replies to this request, it can't read another,
// so we might as well run the handler in this goroutine.
// [*] Not strictly true: HTTP pipelining. We could let them all process
// in parallel even if their responses need to be serialized.
- handler.ServeHTTP(w, w.req)
- if c.hijacked {
+ serverHandler{c.server}.ServeHTTP(w, w.req)
+ if c.hijacked() {
return
}
w.finishRequest()
if w.closeAfterReply {
+ if w.requestBodyLimitHit {
+ c.closeWriteAndWait()
+ }
break
}
}
- c.close()
+}
+
+func (w *response) sendExpectationFailed() {
+ // TODO(bradfitz): let ServeHTTP handlers handle
+ // requests with non-standard expectation[s]? Seems
+ // theoretical at best, and doesn't fit into the
+ // current ServeHTTP model anyway. We'd need to
+ // make the ResponseWriter an optional
+ // "ExpectReplier" interface or something.
+ //
+ // For now we'll just obey RFC 2616 14.20 which says
+ // "If a server receives a request containing an
+ // Expect field that includes an expectation-
+ // extension that it does not support, it MUST
+ // respond with a 417 (Expectation Failed) status."
+ w.Header().Set("Connection", "close")
+ w.WriteHeader(StatusExpectationFailed)
+ w.finishRequest()
}
// Hijack implements the Hijacker.Hijack method. Our response is both a ResponseWriter
// and a Hijacker.
func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
- if w.conn.hijacked {
- return nil, nil, ErrHijacked
+ if w.wroteHeader {
+ w.cw.flush()
}
- w.conn.hijacked = true
- rwc = w.conn.rwc
- buf = w.conn.buf
- w.conn.rwc = nil
- w.conn.buf = nil
- return
+ return w.conn.hijack()
+}
+
+func (w *response) CloseNotify() <-chan bool {
+ return w.conn.closeNotify()
}
// The HandlerFunc type is an adapter to allow the use of
@@ -817,13 +1053,13 @@ func RedirectHandler(url string, code int) Handler {
// patterns and calls the handler for the pattern that
// most closely matches the URL.
//
-// Patterns named fixed, rooted paths, like "/favicon.ico",
+// Patterns name fixed, rooted paths, like "/favicon.ico",
// or rooted subtrees, like "/images/" (note the trailing slash).
// Longer patterns take precedence over shorter ones, so that
// if there are handlers registered for both "/images/"
// and "/images/thumbnails/", the latter handler will be
// called for paths beginning "/images/thumbnails/" and the
-// former will receiver requests for any other paths in the
+// former will receive requests for any other paths in the
// "/images/" subtree.
//
// Patterns may optionally begin with a host name, restricting matches to
@@ -836,13 +1072,15 @@ func RedirectHandler(url string, code int) Handler {
// redirecting any request containing . or .. elements to an
// equivalent .- and ..-free URL.
type ServeMux struct {
- mu sync.RWMutex
- m map[string]muxEntry
+ mu sync.RWMutex
+ m map[string]muxEntry
+ hosts bool // whether any patterns contain hostnames
}
type muxEntry struct {
explicit bool
h Handler
+ pattern string
}
// NewServeMux allocates and returns a new ServeMux.
@@ -883,8 +1121,7 @@ func cleanPath(p string) string {
// Find a handler on a handler map given a path string
// Most-specific (longest) pattern wins
-func (mux *ServeMux) match(path string) Handler {
- var h Handler
+func (mux *ServeMux) match(path string) (h Handler, pattern string) {
var n = 0
for k, v := range mux.m {
if !pathMatch(k, path) {
@@ -893,37 +1130,64 @@ func (mux *ServeMux) match(path string) Handler {
if h == nil || len(k) > n {
n = len(k)
h = v.h
+ pattern = v.pattern
+ }
+ }
+ return
+}
+
+// Handler returns the handler to use for the given request,
+// consulting r.Method, r.Host, and r.URL.Path. It always returns
+// a non-nil handler. If the path is not in its canonical form, the
+// handler will be an internally-generated handler that redirects
+// to the canonical path.
+//
+// Handler also returns the registered pattern that matches the
+// request or, in the case of internally-generated redirects,
+// the pattern that will match after following the redirect.
+//
+// If there is no registered handler that applies to the request,
+// Handler returns a ``page not found'' handler and an empty pattern.
+func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) {
+ if r.Method != "CONNECT" {
+ if p := cleanPath(r.URL.Path); p != r.URL.Path {
+ _, pattern = mux.handler(r.Host, p)
+ return RedirectHandler(p, StatusMovedPermanently), pattern
}
}
- return h
+
+ return mux.handler(r.Host, r.URL.Path)
}
-// handler returns the handler to use for the request r.
-func (mux *ServeMux) handler(r *Request) Handler {
+// handler is the main implementation of Handler.
+// The path is known to be in canonical form, except for CONNECT methods.
+func (mux *ServeMux) handler(host, path string) (h Handler, pattern string) {
mux.mu.RLock()
defer mux.mu.RUnlock()
// Host-specific pattern takes precedence over generic ones
- h := mux.match(r.Host + r.URL.Path)
+ if mux.hosts {
+ h, pattern = mux.match(host + path)
+ }
if h == nil {
- h = mux.match(r.URL.Path)
+ h, pattern = mux.match(path)
}
if h == nil {
- h = NotFoundHandler()
+ h, pattern = NotFoundHandler(), ""
}
- return h
+ return
}
// ServeHTTP dispatches the request to the handler whose
// pattern most closely matches the request URL.
func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) {
- // Clean path to canonical form and redirect.
- if p := cleanPath(r.URL.Path); p != r.URL.Path {
- w.Header().Set("Location", p)
- w.WriteHeader(StatusMovedPermanently)
+ if r.RequestURI == "*" {
+ w.Header().Set("Connection", "close")
+ w.WriteHeader(StatusBadRequest)
return
}
- mux.handler(r).ServeHTTP(w, r)
+ h, _ := mux.Handler(r)
+ h.ServeHTTP(w, r)
}
// Handle registers the handler for the given pattern.
@@ -942,14 +1206,26 @@ func (mux *ServeMux) Handle(pattern string, handler Handler) {
panic("http: multiple registrations for " + pattern)
}
- mux.m[pattern] = muxEntry{explicit: true, h: handler}
+ mux.m[pattern] = muxEntry{explicit: true, h: handler, pattern: pattern}
+
+ if pattern[0] != '/' {
+ mux.hosts = true
+ }
// Helpful behavior:
// If pattern is /tree/, insert an implicit permanent redirect for /tree.
// It can be overridden by an explicit registration.
n := len(pattern)
if n > 0 && pattern[n-1] == '/' && !mux.m[pattern[0:n-1]].explicit {
- mux.m[pattern[0:n-1]] = muxEntry{h: RedirectHandler(pattern, StatusMovedPermanently)}
+ // If pattern contains a host name, strip it and use remaining
+ // path for redirect.
+ path := pattern
+ if pattern[0] != '/' {
+ // In pattern, at least the last character is a '/', so
+ // strings.Index can't be -1.
+ path = pattern[strings.Index(pattern, "/"):]
+ }
+ mux.m[pattern[0:n-1]] = muxEntry{h: RedirectHandler(path, StatusMovedPermanently), pattern: pattern}
}
}
@@ -971,7 +1247,7 @@ func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) {
}
// Serve accepts incoming HTTP connections on the listener l,
-// creating a new service thread for each. The service threads
+// creating a new service goroutine for each. The service goroutines
// read requests and then call handler to reply to them.
// Handler is typically nil, in which case the DefaultServeMux is used.
func Serve(l net.Listener, handler Handler) error {
@@ -987,6 +1263,32 @@ type Server struct {
WriteTimeout time.Duration // maximum duration before timing out write of the response
MaxHeaderBytes int // maximum size of request headers, DefaultMaxHeaderBytes if 0
TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS
+
+ // TLSNextProto optionally specifies a function to take over
+ // ownership of the provided TLS connection when an NPN
+ // protocol upgrade has occured. The map key is the protocol
+ // name negotiated. The Handler argument should be used to
+ // handle HTTP requests and will initialize the Request's TLS
+ // and RemoteAddr if not already set. The connection is
+ // automatically closed when the function returns.
+ TLSNextProto map[string]func(*Server, *tls.Conn, Handler)
+}
+
+// serverHandler delegates to either the server's Handler or
+// DefaultServeMux and also handles "OPTIONS *" requests.
+type serverHandler struct {
+ srv *Server
+}
+
+func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) {
+ handler := sh.srv.Handler
+ if handler == nil {
+ handler = DefaultServeMux
+ }
+ if req.RequestURI == "*" && req.Method == "OPTIONS" {
+ handler = globalOptionsHandler{}
+ }
+ handler.ServeHTTP(rw, req)
}
// ListenAndServe listens on the TCP network address srv.Addr and then
@@ -1005,7 +1307,7 @@ func (srv *Server) ListenAndServe() error {
}
// Serve accepts incoming connections on the Listener l, creating a
-// new service thread for each. The service threads read requests and
+// new service goroutine for each. The service goroutines read requests and
// then call srv.Handler to reply to them.
func (srv *Server) Serve(l net.Listener) error {
defer l.Close()
@@ -1029,12 +1331,6 @@ func (srv *Server) Serve(l net.Listener) error {
return e
}
tempDelay = 0
- if srv.ReadTimeout != 0 {
- rw.SetReadDeadline(time.Now().Add(srv.ReadTimeout))
- }
- if srv.WriteTimeout != 0 {
- rw.SetWriteDeadline(time.Now().Add(srv.WriteTimeout))
- }
c, err := srv.newConn(rw)
if err != nil {
continue
@@ -1150,7 +1446,7 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error {
// TimeoutHandler returns a Handler that runs h with the given time limit.
//
// The new Handler calls h.ServeHTTP to handle each request, but if a
-// call runs for more than ns nanoseconds, the handler responds with
+// call runs for longer than its time limit, the handler responds with
// a 503 Service Unavailable error and the given message in its body.
// (If msg is empty, a suitable default message will be sent.)
// After such a timeout, writes by h to its ResponseWriter will return
@@ -1180,7 +1476,7 @@ func (h *timeoutHandler) errorBody() string {
}
func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) {
- done := make(chan bool)
+ done := make(chan bool, 1)
tw := &timeoutWriter{w: w}
go func() {
h.handler.ServeHTTP(tw, r)
@@ -1232,3 +1528,86 @@ func (tw *timeoutWriter) WriteHeader(code int) {
tw.mu.Unlock()
tw.w.WriteHeader(code)
}
+
+// globalOptionsHandler responds to "OPTIONS *" requests.
+type globalOptionsHandler struct{}
+
+func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "0")
+ if r.ContentLength != 0 {
+ // Read up to 4KB of OPTIONS body (as mentioned in the
+ // spec as being reserved for future use), but anything
+ // over that is considered a waste of server resources
+ // (or an attack) and we abort and close the connection,
+ // courtesy of MaxBytesReader's EOF behavior.
+ mb := MaxBytesReader(w, r.Body, 4<<10)
+ io.Copy(ioutil.Discard, mb)
+ }
+}
+
+// eofReader is a non-nil io.ReadCloser that always returns EOF.
+var eofReader = ioutil.NopCloser(strings.NewReader(""))
+
+// initNPNRequest is an HTTP handler that initializes certain
+// uninitialized fields in its *Request. Such partially-initialized
+// Requests come from NPN protocol handlers.
+type initNPNRequest struct {
+ c *tls.Conn
+ h serverHandler
+}
+
+func (h initNPNRequest) ServeHTTP(rw ResponseWriter, req *Request) {
+ if req.TLS == nil {
+ req.TLS = &tls.ConnectionState{}
+ *req.TLS = h.c.ConnectionState()
+ }
+ if req.Body == nil {
+ req.Body = eofReader
+ }
+ if req.RemoteAddr == "" {
+ req.RemoteAddr = h.c.RemoteAddr().String()
+ }
+ h.h.ServeHTTP(rw, req)
+}
+
+// loggingConn is used for debugging.
+type loggingConn struct {
+ name string
+ net.Conn
+}
+
+var (
+ uniqNameMu sync.Mutex
+ uniqNameNext = make(map[string]int)
+)
+
+func newLoggingConn(baseName string, c net.Conn) net.Conn {
+ uniqNameMu.Lock()
+ defer uniqNameMu.Unlock()
+ uniqNameNext[baseName]++
+ return &loggingConn{
+ name: fmt.Sprintf("%s-%d", baseName, uniqNameNext[baseName]),
+ Conn: c,
+ }
+}
+
+func (c *loggingConn) Write(p []byte) (n int, err error) {
+ log.Printf("%s.Write(%d) = ....", c.name, len(p))
+ n, err = c.Conn.Write(p)
+ log.Printf("%s.Write(%d) = %d, %v", c.name, len(p), n, err)
+ return
+}
+
+func (c *loggingConn) Read(p []byte) (n int, err error) {
+ log.Printf("%s.Read(%d) = ....", c.name, len(p))
+ n, err = c.Conn.Read(p)
+ log.Printf("%s.Read(%d) = %d, %v", c.name, len(p), n, err)
+ return
+}
+
+func (c *loggingConn) Close() (err error) {
+ log.Printf("%s.Close() = ...", c.name)
+ err = c.Conn.Close()
+ log.Printf("%s.Close() = %v", c.name, err)
+ return
+}
diff --git a/src/pkg/net/http/server_test.go b/src/pkg/net/http/server_test.go
new file mode 100644
index 000000000..8b4e8c6d6
--- /dev/null
+++ b/src/pkg/net/http/server_test.go
@@ -0,0 +1,95 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "net/url"
+ "testing"
+)
+
+var serveMuxRegister = []struct {
+ pattern string
+ h Handler
+}{
+ {"/dir/", serve(200)},
+ {"/search", serve(201)},
+ {"codesearch.google.com/search", serve(202)},
+ {"codesearch.google.com/", serve(203)},
+}
+
+// serve returns a handler that sends a response with the given code.
+func serve(code int) HandlerFunc {
+ return func(w ResponseWriter, r *Request) {
+ w.WriteHeader(code)
+ }
+}
+
+var serveMuxTests = []struct {
+ method string
+ host string
+ path string
+ code int
+ pattern string
+}{
+ {"GET", "google.com", "/", 404, ""},
+ {"GET", "google.com", "/dir", 301, "/dir/"},
+ {"GET", "google.com", "/dir/", 200, "/dir/"},
+ {"GET", "google.com", "/dir/file", 200, "/dir/"},
+ {"GET", "google.com", "/search", 201, "/search"},
+ {"GET", "google.com", "/search/", 404, ""},
+ {"GET", "google.com", "/search/foo", 404, ""},
+ {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"},
+ {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"},
+ {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"},
+ {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"},
+ {"GET", "images.google.com", "/search", 201, "/search"},
+ {"GET", "images.google.com", "/search/", 404, ""},
+ {"GET", "images.google.com", "/search/foo", 404, ""},
+ {"GET", "google.com", "/../search", 301, "/search"},
+ {"GET", "google.com", "/dir/..", 301, ""},
+ {"GET", "google.com", "/dir/..", 301, ""},
+ {"GET", "google.com", "/dir/./file", 301, "/dir/"},
+
+ // The /foo -> /foo/ redirect applies to CONNECT requests
+ // but the path canonicalization does not.
+ {"CONNECT", "google.com", "/dir", 301, "/dir/"},
+ {"CONNECT", "google.com", "/../search", 404, ""},
+ {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
+ {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
+ {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"},
+}
+
+func TestServeMuxHandler(t *testing.T) {
+ mux := NewServeMux()
+ for _, e := range serveMuxRegister {
+ mux.Handle(e.pattern, e.h)
+ }
+
+ for _, tt := range serveMuxTests {
+ r := &Request{
+ Method: tt.method,
+ Host: tt.host,
+ URL: &url.URL{
+ Path: tt.path,
+ },
+ }
+ h, pattern := mux.Handler(r)
+ cs := &codeSaver{h: Header{}}
+ h.ServeHTTP(cs, r)
+ if pattern != tt.pattern || cs.code != tt.code {
+ t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, cs.code, pattern, tt.code, tt.pattern)
+ }
+ }
+}
+
+// A codeSaver is a ResponseWriter that saves the code passed to WriteHeader.
+type codeSaver struct {
+ h Header
+ code int
+}
+
+func (cs *codeSaver) Header() Header { return cs.h }
+func (cs *codeSaver) Write(p []byte) (int, error) { return len(p), nil }
+func (cs *codeSaver) WriteHeader(code int) { cs.code = code }
diff --git a/src/pkg/net/http/transfer.go b/src/pkg/net/http/transfer.go
index 9e9d84172..43c6023a3 100644
--- a/src/pkg/net/http/transfer.go
+++ b/src/pkg/net/http/transfer.go
@@ -87,10 +87,8 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) {
// Sanitize Body,ContentLength,TransferEncoding
if t.ResponseToHEAD {
t.Body = nil
- t.TransferEncoding = nil
- // ContentLength is expected to hold Content-Length
- if t.ContentLength < 0 {
- return nil, ErrMissingContentLength
+ if chunked(t.TransferEncoding) {
+ t.ContentLength = -1
}
} else {
if !atLeastHTTP11 || t.Body == nil {
@@ -122,9 +120,6 @@ func (t *transferWriter) shouldSendContentLength() bool {
if t.ContentLength > 0 {
return true
}
- if t.ResponseToHEAD {
- return true
- }
// Many servers expect a Content-Length for these methods
if t.Method == "POST" || t.Method == "PUT" {
return true
@@ -199,10 +194,11 @@ func (t *transferWriter) WriteBody(w io.Writer) (err error) {
ncopy, err = io.Copy(w, t.Body)
} else {
ncopy, err = io.Copy(w, io.LimitReader(t.Body, t.ContentLength))
- nextra, err := io.Copy(ioutil.Discard, t.Body)
if err != nil {
return err
}
+ var nextra int64
+ nextra, err = io.Copy(ioutil.Discard, t.Body)
ncopy += nextra
}
if err != nil {
@@ -213,7 +209,7 @@ func (t *transferWriter) WriteBody(w io.Writer) (err error) {
}
}
- if t.ContentLength != -1 && t.ContentLength != ncopy {
+ if !t.ResponseToHEAD && t.ContentLength != -1 && t.ContentLength != ncopy {
return fmt.Errorf("http: Request.ContentLength=%d with Body length %d",
t.ContentLength, ncopy)
}
@@ -294,10 +290,19 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
return err
}
- t.ContentLength, err = fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.TransferEncoding)
+ realLength, err := fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.TransferEncoding)
if err != nil {
return err
}
+ if isResponse && t.RequestMethod == "HEAD" {
+ if n, err := parseContentLength(t.Header.get("Content-Length")); err != nil {
+ return err
+ } else {
+ t.ContentLength = n
+ }
+ } else {
+ t.ContentLength = realLength
+ }
// Trailer
t.Trailer, err = fixTrailer(t.Header, t.TransferEncoding)
@@ -310,7 +315,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
// See RFC2616, section 4.4.
switch msg.(type) {
case *Response:
- if t.ContentLength == -1 &&
+ if realLength == -1 &&
!chunked(t.TransferEncoding) &&
bodyAllowedForStatus(t.StatusCode) {
// Unbounded body.
@@ -322,12 +327,16 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
// or close connection when finished, since multipart is not supported yet
switch {
case chunked(t.TransferEncoding):
- t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close}
- case t.ContentLength >= 0:
+ if noBodyExpected(t.RequestMethod) {
+ t.Body = &body{Reader: io.LimitReader(r, 0), closing: t.Close}
+ } else {
+ t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close}
+ }
+ case realLength >= 0:
// TODO: limit the Content-Length. This is an easy DoS vector.
- t.Body = &body{Reader: io.LimitReader(r, t.ContentLength), closing: t.Close}
+ t.Body = &body{Reader: io.LimitReader(r, realLength), closing: t.Close}
default:
- // t.ContentLength < 0, i.e. "Content-Length" not mentioned in header
+ // realLength < 0, i.e. "Content-Length" not mentioned in header
if t.Close {
// Close semantics (i.e. HTTP/1.0)
t.Body = &body{Reader: r, closing: t.Close}
@@ -371,12 +380,6 @@ func fixTransferEncoding(requestMethod string, header Header) ([]string, error)
delete(header, "Transfer-Encoding")
- // Head responses have no bodies, so the transfer encoding
- // should be ignored.
- if requestMethod == "HEAD" {
- return nil, nil
- }
-
encodings := strings.Split(raw[0], ",")
te := make([]string, 0, len(encodings))
// TODO: Even though we only support "identity" and "chunked"
@@ -432,11 +435,11 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header,
}
// Logic based on Content-Length
- cl := strings.TrimSpace(header.Get("Content-Length"))
+ cl := strings.TrimSpace(header.get("Content-Length"))
if cl != "" {
- n, err := strconv.ParseInt(cl, 10, 64)
- if err != nil || n < 0 {
- return -1, &badStringError{"bad Content-Length", cl}
+ n, err := parseContentLength(cl)
+ if err != nil {
+ return -1, err
}
return n, nil
} else {
@@ -451,13 +454,6 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header,
return 0, nil
}
- // Logic based on media type. The purpose of the following code is just
- // to detect whether the unsupported "multipart/byteranges" is being
- // used. A proper Content-Type parser is needed in the future.
- if strings.Contains(strings.ToLower(header.Get("Content-Type")), "multipart/byteranges") {
- return -1, ErrNotSupported
- }
-
// Body-EOF logic based on other methods (like closing, or chunked coding)
return -1, nil
}
@@ -469,14 +465,14 @@ func shouldClose(major, minor int, header Header) bool {
if major < 1 {
return true
} else if major == 1 && minor == 0 {
- if !strings.Contains(strings.ToLower(header.Get("Connection")), "keep-alive") {
+ if !strings.Contains(strings.ToLower(header.get("Connection")), "keep-alive") {
return true
}
return false
} else {
// TODO: Should split on commas, toss surrounding white space,
// and check each field.
- if strings.ToLower(header.Get("Connection")) == "close" {
+ if strings.ToLower(header.get("Connection")) == "close" {
header.Del("Connection")
return true
}
@@ -486,7 +482,7 @@ func shouldClose(major, minor int, header Header) bool {
// Parse the trailer header
func fixTrailer(header Header, te []string) (Header, error) {
- raw := header.Get("Trailer")
+ raw := header.get("Trailer")
if raw == "" {
return nil, nil
}
@@ -525,11 +521,11 @@ type body struct {
res *response // response writer for server requests, else nil
}
-// ErrBodyReadAfterClose is returned when reading a Request Body after
-// the body has been closed. This typically happens when the body is
+// ErrBodyReadAfterClose is returned when reading a Request or Response
+// Body after the body has been closed. This typically happens when the body is
// read after an HTTP Handler calls WriteHeader or Write on its
// ResponseWriter.
-var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed request Body")
+var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed Body")
func (b *body) Read(p []byte) (n int, err error) {
if b.closed {
@@ -567,14 +563,22 @@ func seeUpcomingDoubleCRLF(r *bufio.Reader) bool {
return false
}
+var errTrailerEOF = errors.New("http: unexpected EOF reading trailer")
+
func (b *body) readTrailer() error {
// The common case, since nobody uses trailers.
- buf, _ := b.r.Peek(2)
+ buf, err := b.r.Peek(2)
if bytes.Equal(buf, singleCRLF) {
b.r.ReadByte()
b.r.ReadByte()
return nil
}
+ if len(buf) < 2 {
+ return errTrailerEOF
+ }
+ if err != nil {
+ return err
+ }
// Make sure there's a header terminator coming up, to prevent
// a DoS with an unbounded size Trailer. It's not easy to
@@ -590,6 +594,9 @@ func (b *body) readTrailer() error {
hdr, err := textproto.NewReader(b.r).ReadMIMEHeader()
if err != nil {
+ if err == io.EOF {
+ return errTrailerEOF
+ }
return err
}
switch rr := b.hdr.(type) {
@@ -630,3 +637,18 @@ func (b *body) Close() error {
}
return nil
}
+
+// parseContentLength trims whitespace from s and returns -1 if no value
+// is set, or the value if it's >= 0.
+func parseContentLength(cl string) (int64, error) {
+ cl = strings.TrimSpace(cl)
+ if cl == "" {
+ return -1, nil
+ }
+ n, err := strconv.ParseInt(cl, 10, 64)
+ if err != nil || n < 0 {
+ return 0, &badStringError{"bad Content-Length", cl}
+ }
+ return n, nil
+
+}
diff --git a/src/pkg/net/http/transfer_test.go b/src/pkg/net/http/transfer_test.go
new file mode 100644
index 000000000..8627a374c
--- /dev/null
+++ b/src/pkg/net/http/transfer_test.go
@@ -0,0 +1,37 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bufio"
+ "strings"
+ "testing"
+)
+
+func TestBodyReadBadTrailer(t *testing.T) {
+ b := &body{
+ Reader: strings.NewReader("foobar"),
+ hdr: true, // force reading the trailer
+ r: bufio.NewReader(strings.NewReader("")),
+ }
+ buf := make([]byte, 7)
+ n, err := b.Read(buf[:3])
+ got := string(buf[:n])
+ if got != "foo" || err != nil {
+ t.Fatalf(`first Read = %d (%q), %v; want 3 ("foo")`, n, got, err)
+ }
+
+ n, err = b.Read(buf[:])
+ got = string(buf[:n])
+ if got != "bar" || err != nil {
+ t.Fatalf(`second Read = %d (%q), %v; want 3 ("bar")`, n, got, err)
+ }
+
+ n, err = b.Read(buf[:])
+ got = string(buf[:n])
+ if err == nil {
+ t.Errorf("final Read was successful (%q), expected error from trailer read", got)
+ }
+}
diff --git a/src/pkg/net/http/transport.go b/src/pkg/net/http/transport.go
index 6efe191eb..685d7d56c 100644
--- a/src/pkg/net/http/transport.go
+++ b/src/pkg/net/http/transport.go
@@ -3,7 +3,7 @@
// license that can be found in the LICENSE file.
// HTTP client implementation. See RFC 2616.
-//
+//
// This is the low-level Transport implementation of RoundTripper.
// The high-level interface is in client.go.
@@ -24,13 +24,14 @@ import (
"os"
"strings"
"sync"
+ "time"
)
// DefaultTransport is the default implementation of Transport and is
-// used by DefaultClient. It establishes a new network connection for
-// each call to Do and uses HTTP proxies as directed by the
-// $HTTP_PROXY and $NO_PROXY (or $http_proxy and $no_proxy)
-// environment variables.
+// used by DefaultClient. It establishes network connections as needed
+// and caches them for reuse by subsequent calls. It uses HTTP proxies
+// as directed by the $HTTP_PROXY and $NO_PROXY (or $http_proxy and
+// $no_proxy) environment variables.
var DefaultTransport RoundTripper = &Transport{Proxy: ProxyFromEnvironment}
// DefaultMaxIdleConnsPerHost is the default value of Transport's
@@ -41,8 +42,11 @@ const DefaultMaxIdleConnsPerHost = 2
// https, and http proxies (for either http or https with CONNECT).
// Transport can also cache connections for future re-use.
type Transport struct {
- lk sync.Mutex
+ idleMu sync.Mutex
idleConn map[string][]*persistConn
+ reqMu sync.Mutex
+ reqConn map[*Request]*persistConn
+ altMu sync.RWMutex
altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper
// TODO: tunable on global max cached connections
@@ -68,9 +72,15 @@ type Transport struct {
DisableCompression bool
// MaxIdleConnsPerHost, if non-zero, controls the maximum idle
- // (keep-alive) to keep to keep per-host. If zero,
+ // (keep-alive) to keep per-host. If zero,
// DefaultMaxIdleConnsPerHost is used.
MaxIdleConnsPerHost int
+
+ // ResponseHeaderTimeout, if non-zero, specifies the amount of
+ // time to wait for a server's response headers after fully
+ // writing the request (including its body, if any). This
+ // time does not include the time to read the response body.
+ ResponseHeaderTimeout time.Duration
}
// ProxyFromEnvironment returns the URL of the proxy to use for a
@@ -88,7 +98,7 @@ func ProxyFromEnvironment(req *Request) (*url.URL, error) {
return nil, nil
}
proxyURL, err := url.Parse(proxy)
- if err != nil || proxyURL.Scheme == "" {
+ if err != nil || !strings.HasPrefix(proxyURL.Scheme, "http") {
if u, err := url.Parse("http://" + proxy); err == nil {
proxyURL = u
err = nil
@@ -131,17 +141,20 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) {
return nil, errors.New("http: nil Request.Header")
}
if req.URL.Scheme != "http" && req.URL.Scheme != "https" {
- t.lk.Lock()
+ t.altMu.RLock()
var rt RoundTripper
if t.altProto != nil {
rt = t.altProto[req.URL.Scheme]
}
- t.lk.Unlock()
+ t.altMu.RUnlock()
if rt == nil {
return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme}
}
return rt.RoundTrip(req)
}
+ if req.URL.Host == "" {
+ return nil, errors.New("http: no Host in request URL")
+ }
treq := &transportRequest{Request: req}
cm, err := t.connectMethodForRequest(treq)
if err != nil {
@@ -170,8 +183,8 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) {
if scheme == "http" || scheme == "https" {
panic("protocol " + scheme + " already registered")
}
- t.lk.Lock()
- defer t.lk.Unlock()
+ t.altMu.Lock()
+ defer t.altMu.Unlock()
if t.altProto == nil {
t.altProto = make(map[string]RoundTripper)
}
@@ -186,17 +199,29 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) {
// a "keep-alive" state. It does not interrupt any connections currently
// in use.
func (t *Transport) CloseIdleConnections() {
- t.lk.Lock()
- defer t.lk.Unlock()
- if t.idleConn == nil {
+ t.idleMu.Lock()
+ m := t.idleConn
+ t.idleConn = nil
+ t.idleMu.Unlock()
+ if m == nil {
return
}
- for _, conns := range t.idleConn {
+ for _, conns := range m {
for _, pconn := range conns {
pconn.close()
}
}
- t.idleConn = make(map[string][]*persistConn)
+}
+
+// CancelRequest cancels an in-flight request by closing its
+// connection.
+func (t *Transport) CancelRequest(req *Request) {
+ t.reqMu.Lock()
+ pc := t.reqConn[req]
+ t.reqMu.Unlock()
+ if pc != nil {
+ pc.conn.Close()
+ }
}
//
@@ -242,8 +267,6 @@ func (cm *connectMethod) proxyAuth() string {
// If pconn is no longer needed or not in a good state, putIdleConn
// returns false.
func (t *Transport) putIdleConn(pconn *persistConn) bool {
- t.lk.Lock()
- defer t.lk.Unlock()
if t.DisableKeepAlives || t.MaxIdleConnsPerHost < 0 {
pconn.close()
return false
@@ -256,21 +279,32 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool {
if max == 0 {
max = DefaultMaxIdleConnsPerHost
}
+ t.idleMu.Lock()
+ if t.idleConn == nil {
+ t.idleConn = make(map[string][]*persistConn)
+ }
if len(t.idleConn[key]) >= max {
+ t.idleMu.Unlock()
pconn.close()
return false
}
+ for _, exist := range t.idleConn[key] {
+ if exist == pconn {
+ log.Fatalf("dup idle pconn %p in freelist", pconn)
+ }
+ }
t.idleConn[key] = append(t.idleConn[key], pconn)
+ t.idleMu.Unlock()
return true
}
func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) {
- t.lk.Lock()
- defer t.lk.Unlock()
+ key := cm.String()
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
if t.idleConn == nil {
- t.idleConn = make(map[string][]*persistConn)
+ return nil
}
- key := cm.String()
for {
pconns, ok := t.idleConn[key]
if !ok {
@@ -289,7 +323,20 @@ func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) {
return
}
}
- return
+ panic("unreachable")
+}
+
+func (t *Transport) setReqConn(r *Request, pc *persistConn) {
+ t.reqMu.Lock()
+ defer t.reqMu.Unlock()
+ if t.reqConn == nil {
+ t.reqConn = make(map[*Request]*persistConn)
+ }
+ if pc != nil {
+ t.reqConn[r] = pc
+ } else {
+ delete(t.reqConn, r)
+ }
}
func (t *Transport) dial(network, addr string) (c net.Conn, err error) {
@@ -323,6 +370,8 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) {
cacheKey: cm.String(),
conn: conn,
reqch: make(chan requestAndChan, 50),
+ writech: make(chan writeRequest, 50),
+ closech: make(chan struct{}),
}
switch {
@@ -365,7 +414,18 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) {
if cm.targetScheme == "https" {
// Initiate TLS and check remote host name against certificate.
- conn = tls.Client(conn, t.TLSClientConfig)
+ cfg := t.TLSClientConfig
+ if cfg == nil || cfg.ServerName == "" {
+ host := cm.tlsHost()
+ if cfg == nil {
+ cfg = &tls.Config{ServerName: host}
+ } else {
+ clone := *cfg // shallow clone
+ clone.ServerName = host
+ cfg = &clone
+ }
+ }
+ conn = tls.Client(conn, cfg)
if err = conn.(*tls.Conn).Handshake(); err != nil {
return nil, err
}
@@ -380,6 +440,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) {
pconn.br = bufio.NewReader(pconn.conn)
pconn.bw = bufio.NewWriter(pconn.conn)
go pconn.readLoop()
+ go pconn.writeLoop()
return pconn, nil
}
@@ -421,7 +482,15 @@ func useProxy(addr string) bool {
if hasPort(p) {
p = p[:strings.LastIndex(p, ":")]
}
- if addr == p || (p[0] == '.' && (strings.HasSuffix(addr, p) || addr == p[1:])) {
+ if addr == p {
+ return false
+ }
+ if p[0] == '.' && (strings.HasSuffix(addr, p) || addr == p[1:]) {
+ // no_proxy ".foo.com" matches "bar.foo.com" or "foo.com"
+ return false
+ }
+ if p[0] != '.' && strings.HasSuffix(addr, p) && addr[len(addr)-len(p)-1] == '.' {
+ // no_proxy "foo.com" matches "bar.foo.com"
return false
}
}
@@ -484,25 +553,28 @@ type persistConn struct {
t *Transport
cacheKey string // its connectMethod.String()
conn net.Conn
+ closed bool // whether conn has been closed
br *bufio.Reader // from conn
bw *bufio.Writer // to conn
- reqch chan requestAndChan // written by roundTrip(); read by readLoop()
+ reqch chan requestAndChan // written by roundTrip; read by readLoop
+ writech chan writeRequest // written by roundTrip; read by writeLoop
+ closech chan struct{} // broadcast close when readLoop (TCP connection) closes
isProxy bool
+ lk sync.Mutex // guards following 3 fields
+ numExpectedResponses int
+ broken bool // an error has happened on this connection; marked broken so it's not reused.
// mutateHeaderFunc is an optional func to modify extra
// headers on each outbound request before it's written. (the
// original Request given to RoundTrip is not modified)
mutateHeaderFunc func(Header)
-
- lk sync.Mutex // guards numExpectedResponses and broken
- numExpectedResponses int
- broken bool // an error has happened on this connection; marked broken so it's not reused.
}
func (pc *persistConn) isBroken() bool {
pc.lk.Lock()
- defer pc.lk.Unlock()
- return pc.broken
+ b := pc.broken
+ pc.lk.Unlock()
+ return b
}
var remoteSideClosedFunc func(error) bool // or nil to use default
@@ -518,6 +590,7 @@ func remoteSideClosed(err error) bool {
}
func (pc *persistConn) readLoop() {
+ defer close(pc.closech)
alive := true
var lastbody io.ReadCloser // last response body, if any, read on this connection
@@ -544,12 +617,16 @@ func (pc *persistConn) readLoop() {
lastbody.Close() // assumed idempotent
lastbody = nil
}
- resp, err := ReadResponse(pc.br, rc.req)
+
+ var resp *Response
+ if err == nil {
+ resp, err = ReadResponse(pc.br, rc.req)
+ }
+ hasBody := resp != nil && rc.req.Method != "HEAD" && resp.ContentLength != 0
if err != nil {
pc.close()
} else {
- hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0
if rc.addedGzip && hasBody && resp.Header.Get("Content-Encoding") == "gzip" {
resp.Header.Del("Content-Encoding")
resp.Header.Del("Content-Length")
@@ -569,31 +646,37 @@ func (pc *persistConn) readLoop() {
alive = false
}
- hasBody := resp != nil && resp.ContentLength != 0
var waitForBodyRead chan bool
- if alive {
- if hasBody {
- lastbody = resp.Body
- waitForBodyRead = make(chan bool)
- resp.Body.(*bodyEOFSignal).fn = func() {
- if !pc.t.putIdleConn(pc) {
- alive = false
- }
- waitForBodyRead <- true
+ if hasBody {
+ lastbody = resp.Body
+ waitForBodyRead = make(chan bool, 1)
+ resp.Body.(*bodyEOFSignal).fn = func(err error) {
+ alive1 := alive
+ if err != nil {
+ alive1 = false
}
- } else {
- // When there's no response body, we immediately
- // reuse the TCP connection (putIdleConn), but
- // we need to prevent ClientConn.Read from
- // closing the Response.Body on the next
- // loop, otherwise it might close the body
- // before the client code has had a chance to
- // read it (even though it'll just be 0, EOF).
- lastbody = nil
-
- if !pc.t.putIdleConn(pc) {
- alive = false
+ if alive1 && !pc.t.putIdleConn(pc) {
+ alive1 = false
+ }
+ if !alive1 || pc.isBroken() {
+ pc.close()
}
+ waitForBodyRead <- alive1
+ }
+ }
+
+ if alive && !hasBody {
+ // When there's no response body, we immediately
+ // reuse the TCP connection (putIdleConn), but
+ // we need to prevent ClientConn.Read from
+ // closing the Response.Body on the next
+ // loop, otherwise it might close the body
+ // before the client code has had a chance to
+ // read it (even though it'll just be 0, EOF).
+ lastbody = nil
+
+ if !pc.t.putIdleConn(pc) {
+ alive = false
}
}
@@ -602,7 +685,35 @@ func (pc *persistConn) readLoop() {
// Wait for the just-returned response body to be fully consumed
// before we race and peek on the underlying bufio reader.
if waitForBodyRead != nil {
- <-waitForBodyRead
+ alive = <-waitForBodyRead
+ }
+
+ pc.t.setReqConn(rc.req, nil)
+
+ if !alive {
+ pc.close()
+ }
+ }
+}
+
+func (pc *persistConn) writeLoop() {
+ for {
+ select {
+ case wr := <-pc.writech:
+ if pc.isBroken() {
+ wr.ch <- errors.New("http: can't write HTTP request on broken connection")
+ continue
+ }
+ err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra)
+ if err == nil {
+ err = pc.bw.Flush()
+ }
+ if err != nil {
+ pc.markBroken()
+ }
+ wr.ch <- err
+ case <-pc.closech:
+ return
}
}
}
@@ -622,9 +733,24 @@ type requestAndChan struct {
addedGzip bool
}
+// A writeRequest is sent by the readLoop's goroutine to the
+// writeLoop's goroutine to write a request while the read loop
+// concurrently waits on both the write response and the server's
+// reply.
+type writeRequest struct {
+ req *transportRequest
+ ch chan<- error
+}
+
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
- if pc.mutateHeaderFunc != nil {
- pc.mutateHeaderFunc(req.extraHeaders())
+ pc.t.setReqConn(req.Request, pc)
+ pc.lk.Lock()
+ pc.numExpectedResponses++
+ headerFn := pc.mutateHeaderFunc
+ pc.lk.Unlock()
+
+ if headerFn != nil {
+ headerFn(req.extraHeaders())
}
// Ask for a compressed version if the caller didn't set their
@@ -633,34 +759,84 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
// requested it.
requestedGzip := false
if !pc.t.DisableCompression && req.Header.Get("Accept-Encoding") == "" {
- // Request gzip only, not deflate. Deflate is ambiguous and
+ // Request gzip only, not deflate. Deflate is ambiguous and
// not as universally supported anyway.
// See: http://www.gzip.org/zlib/zlib_faq.html#faq38
requestedGzip = true
req.extraHeaders().Set("Accept-Encoding", "gzip")
}
- pc.lk.Lock()
- pc.numExpectedResponses++
- pc.lk.Unlock()
+ // Write the request concurrently with waiting for a response,
+ // in case the server decides to reply before reading our full
+ // request body.
+ writeErrCh := make(chan error, 1)
+ pc.writech <- writeRequest{req, writeErrCh}
- err = req.Request.write(pc.bw, pc.isProxy, req.extra)
- if err != nil {
- pc.close()
- return
+ resc := make(chan responseAndError, 1)
+ pc.reqch <- requestAndChan{req.Request, resc, requestedGzip}
+
+ var re responseAndError
+ var pconnDeadCh = pc.closech
+ var failTicker <-chan time.Time
+ var respHeaderTimer <-chan time.Time
+WaitResponse:
+ for {
+ select {
+ case err := <-writeErrCh:
+ if err != nil {
+ re = responseAndError{nil, err}
+ pc.close()
+ break WaitResponse
+ }
+ if d := pc.t.ResponseHeaderTimeout; d > 0 {
+ respHeaderTimer = time.After(d)
+ }
+ case <-pconnDeadCh:
+ // The persist connection is dead. This shouldn't
+ // usually happen (only with Connection: close responses
+ // with no response bodies), but if it does happen it
+ // means either a) the remote server hung up on us
+ // prematurely, or b) the readLoop sent us a response &
+ // closed its closech at roughly the same time, and we
+ // selected this case first, in which case a response
+ // might still be coming soon.
+ //
+ // We can't avoid the select race in b) by using a unbuffered
+ // resc channel instead, because then goroutines can
+ // leak if we exit due to other errors.
+ pconnDeadCh = nil // avoid spinning
+ failTicker = time.After(100 * time.Millisecond) // arbitrary time to wait for resc
+ case <-failTicker:
+ re = responseAndError{err: errors.New("net/http: transport closed before response was received")}
+ break WaitResponse
+ case <-respHeaderTimer:
+ pc.close()
+ re = responseAndError{err: errors.New("net/http: timeout awaiting response headers")}
+ break WaitResponse
+ case re = <-resc:
+ break WaitResponse
+ }
}
- pc.bw.Flush()
- ch := make(chan responseAndError, 1)
- pc.reqch <- requestAndChan{req.Request, ch, requestedGzip}
- re := <-ch
pc.lk.Lock()
pc.numExpectedResponses--
pc.lk.Unlock()
+ if re.err != nil {
+ pc.t.setReqConn(req.Request, nil)
+ }
return re.res, re.err
}
+// markBroken marks a connection as broken (so it's not reused).
+// It differs from close in that it doesn't close the underlying
+// connection for use when it's still being read.
+func (pc *persistConn) markBroken() {
+ pc.lk.Lock()
+ defer pc.lk.Unlock()
+ pc.broken = true
+}
+
func (pc *persistConn) close() {
pc.lk.Lock()
defer pc.lk.Unlock()
@@ -669,7 +845,10 @@ func (pc *persistConn) close() {
func (pc *persistConn) closeLocked() {
pc.broken = true
- pc.conn.Close()
+ if !pc.closed {
+ pc.conn.Close()
+ pc.closed = true
+ }
pc.mutateHeaderFunc = nil
}
@@ -687,43 +866,62 @@ func canonicalAddr(url *url.URL) string {
return addr
}
-func responseIsKeepAlive(res *Response) bool {
- // TODO: implement. for now just always shutting down the connection.
- return false
-}
-
// bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most
-// once, right before the final Read() or Close() call returns, but after
-// EOF has been seen.
+// once, right before its final (error-producing) Read or Close call
+// returns.
type bodyEOFSignal struct {
- body io.ReadCloser
- fn func()
- isClosed bool
+ body io.ReadCloser
+ mu sync.Mutex // guards closed, rerr and fn
+ closed bool // whether Close has been called
+ rerr error // sticky Read error
+ fn func(error) // error will be nil on Read io.EOF
}
func (es *bodyEOFSignal) Read(p []byte) (n int, err error) {
- n, err = es.body.Read(p)
- if es.isClosed && n > 0 {
- panic("http: unexpected bodyEOFSignal Read after Close; see issue 1725")
+ es.mu.Lock()
+ closed, rerr := es.closed, es.rerr
+ es.mu.Unlock()
+ if closed {
+ return 0, errors.New("http: read on closed response body")
+ }
+ if rerr != nil {
+ return 0, rerr
}
- if err == io.EOF && es.fn != nil {
- es.fn()
- es.fn = nil
+
+ n, err = es.body.Read(p)
+ if err != nil {
+ es.mu.Lock()
+ defer es.mu.Unlock()
+ if es.rerr == nil {
+ es.rerr = err
+ }
+ es.condfn(err)
}
return
}
-func (es *bodyEOFSignal) Close() (err error) {
- if es.isClosed {
+func (es *bodyEOFSignal) Close() error {
+ es.mu.Lock()
+ defer es.mu.Unlock()
+ if es.closed {
return nil
}
- es.isClosed = true
- err = es.body.Close()
- if err == nil && es.fn != nil {
- es.fn()
- es.fn = nil
+ es.closed = true
+ err := es.body.Close()
+ es.condfn(err)
+ return err
+}
+
+// caller must hold es.mu.
+func (es *bodyEOFSignal) condfn(err error) {
+ if es.fn == nil {
+ return
}
- return
+ if err == io.EOF {
+ err = nil
+ }
+ es.fn(err)
+ es.fn = nil
}
type readFirstCloseBoth struct {
diff --git a/src/pkg/net/http/transport_test.go b/src/pkg/net/http/transport_test.go
index a9e401de5..68010e68b 100644
--- a/src/pkg/net/http/transport_test.go
+++ b/src/pkg/net/http/transport_test.go
@@ -13,6 +13,7 @@ import (
"fmt"
"io"
"io/ioutil"
+ "net"
. "net/http"
"net/http/httptest"
"net/url"
@@ -20,6 +21,7 @@ import (
"runtime"
"strconv"
"strings"
+ "sync"
"testing"
"time"
)
@@ -35,14 +37,78 @@ var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
w.Write([]byte(r.RemoteAddr))
})
+// testCloseConn is a net.Conn tracked by a testConnSet.
+type testCloseConn struct {
+ net.Conn
+ set *testConnSet
+}
+
+func (c *testCloseConn) Close() error {
+ c.set.remove(c)
+ return c.Conn.Close()
+}
+
+// testConnSet tracks a set of TCP connections and whether they've
+// been closed.
+type testConnSet struct {
+ t *testing.T
+ closed map[net.Conn]bool
+ list []net.Conn // in order created
+ mutex sync.Mutex
+}
+
+func (tcs *testConnSet) insert(c net.Conn) {
+ tcs.mutex.Lock()
+ defer tcs.mutex.Unlock()
+ tcs.closed[c] = false
+ tcs.list = append(tcs.list, c)
+}
+
+func (tcs *testConnSet) remove(c net.Conn) {
+ tcs.mutex.Lock()
+ defer tcs.mutex.Unlock()
+ tcs.closed[c] = true
+}
+
+// some tests use this to manage raw tcp connections for later inspection
+func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
+ connSet := &testConnSet{
+ t: t,
+ closed: make(map[net.Conn]bool),
+ }
+ dial := func(n, addr string) (net.Conn, error) {
+ c, err := net.Dial(n, addr)
+ if err != nil {
+ return nil, err
+ }
+ tc := &testCloseConn{c, connSet}
+ connSet.insert(tc)
+ return tc, nil
+ }
+ return connSet, dial
+}
+
+func (tcs *testConnSet) check(t *testing.T) {
+ tcs.mutex.Lock()
+ defer tcs.mutex.Unlock()
+
+ for i, c := range tcs.list {
+ if !tcs.closed[c] {
+ t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
+ }
+ }
+}
+
// Two subsequent requests and verify their response is the same.
// The response from the server is our own IP:port
func TestTransportKeepAlives(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
for _, disableKeepAlive := range []bool{false, true} {
tr := &Transport{DisableKeepAlives: disableKeepAlive}
+ defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
fetch := func(n int) string {
@@ -69,11 +135,16 @@ func TestTransportKeepAlives(t *testing.T) {
}
func TestTransportConnectionCloseOnResponse(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
+ connSet, testDial := makeTestDial(t)
+
for _, connectionClose := range []bool{false, true} {
- tr := &Transport{}
+ tr := &Transport{
+ Dial: testDial,
+ }
c := &Client{Transport: tr}
fetch := func(n int) string {
@@ -92,8 +163,8 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) {
if err != nil {
t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
}
- body, err := ioutil.ReadAll(res.Body)
defer res.Body.Close()
+ body, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
}
@@ -107,15 +178,24 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) {
t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
connectionClose, bodiesDiffer, body1, body2)
}
+
+ tr.CloseIdleConnections()
}
+
+ connSet.check(t)
}
func TestTransportConnectionCloseOnRequest(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
+ connSet, testDial := makeTestDial(t)
+
for _, connectionClose := range []bool{false, true} {
- tr := &Transport{}
+ tr := &Transport{
+ Dial: testDial,
+ }
c := &Client{Transport: tr}
fetch := func(n int) string {
@@ -149,10 +229,15 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) {
t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
connectionClose, bodiesDiffer, body1, body2)
}
+
+ tr.CloseIdleConnections()
}
+
+ connSet.check(t)
}
func TestTransportIdleCacheKeys(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
@@ -185,6 +270,7 @@ func TestTransportIdleCacheKeys(t *testing.T) {
}
func TestTransportMaxPerHostIdleConns(t *testing.T) {
+ defer checkLeakedTransports(t)
resch := make(chan string)
gotReq := make(chan bool)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -201,7 +287,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) {
c := &Client{Transport: tr}
// Start 3 outstanding requests and wait for the server to get them.
- // Their responses will hang until we we write to resch, though.
+ // Their responses will hang until we write to resch, though.
donech := make(chan bool)
doReq := func() {
resp, err := c.Get(ts.URL)
@@ -253,6 +339,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) {
}
func TestTransportServerClosingUnexpectedly(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
@@ -309,9 +396,9 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) {
// Test for http://golang.org/issue/2616 (appropriate issue number)
// This fails pretty reliably with GOMAXPROCS=100 or something high.
func TestStressSurpriseServerCloses(t *testing.T) {
+ defer checkLeakedTransports(t)
if testing.Short() {
- t.Logf("skipping test in short mode")
- return
+ t.Skip("skipping test in short mode")
}
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Content-Length", "5")
@@ -365,6 +452,7 @@ func TestStressSurpriseServerCloses(t *testing.T) {
// TestTransportHeadResponses verifies that we deal with Content-Lengths
// with no bodies properly
func TestTransportHeadResponses(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.Method != "HEAD" {
panic("expected HEAD; got " + r.Method)
@@ -384,7 +472,7 @@ func TestTransportHeadResponses(t *testing.T) {
if e, g := "123", res.Header.Get("Content-Length"); e != g {
t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
}
- if e, g := int64(0), res.ContentLength; e != g {
+ if e, g := int64(123), res.ContentLength; e != g {
t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
}
}
@@ -393,6 +481,7 @@ func TestTransportHeadResponses(t *testing.T) {
// TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding
// on responses to HEAD requests.
func TestTransportHeadChunkedResponse(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.Method != "HEAD" {
panic("expected HEAD; got " + r.Method)
@@ -434,6 +523,7 @@ var roundTripTests = []struct {
// Test that the modification made to the Request by the RoundTripper is cleaned up
func TestRoundTripGzip(t *testing.T) {
+ defer checkLeakedTransports(t)
const responseBody = "test response body"
ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
accept := req.Header.Get("Accept-Encoding")
@@ -490,6 +580,7 @@ func TestRoundTripGzip(t *testing.T) {
}
func TestTransportGzip(t *testing.T) {
+ defer checkLeakedTransports(t)
const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
const nRandBytes = 1024 * 1024
ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
@@ -582,6 +673,7 @@ func TestTransportGzip(t *testing.T) {
}
func TestTransportProxy(t *testing.T) {
+ defer checkLeakedTransports(t)
ch := make(chan string, 1)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
ch <- "real server"
@@ -610,6 +702,7 @@ func TestTransportProxy(t *testing.T) {
// but checks that we don't recurse forever, and checks that
// Content-Encoding is removed.
func TestTransportGzipRecursive(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Content-Encoding", "gzip")
w.Write(rgz)
@@ -636,6 +729,7 @@ func TestTransportGzipRecursive(t *testing.T) {
// tests that persistent goroutine connections shut down when no longer desired.
func TestTransportPersistConnLeak(t *testing.T) {
+ defer checkLeakedTransports(t)
gotReqCh := make(chan bool)
unblockCh := make(chan bool)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -698,8 +792,49 @@ func TestTransportPersistConnLeak(t *testing.T) {
}
}
+// golang.org/issue/4531: Transport leaks goroutines when
+// request.ContentLength is explicitly short
+func TestTransportPersistConnLeakShortBody(t *testing.T) {
+ defer checkLeakedTransports(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ }))
+ defer ts.Close()
+
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+
+ n0 := runtime.NumGoroutine()
+ body := []byte("Hello")
+ for i := 0; i < 20; i++ {
+ req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.ContentLength = int64(len(body) - 2) // explicitly short
+ _, err = c.Do(req)
+ if err == nil {
+ t.Fatal("Expect an error from writing too long of a body.")
+ }
+ }
+ nhigh := runtime.NumGoroutine()
+ tr.CloseIdleConnections()
+ time.Sleep(50 * time.Millisecond)
+ runtime.GC()
+ nfinal := runtime.NumGoroutine()
+
+ growth := nfinal - n0
+
+ // We expect 0 or 1 extra goroutine, empirically. Allow up to 5.
+ // Previously we were leaking one per numReq.
+ t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
+ if int(growth) > 5 {
+ t.Error("too many new goroutines")
+ }
+}
+
// This used to crash; http://golang.org/issue/3266
func TestTransportIdleConnCrash(t *testing.T) {
+ defer checkLeakedTransports(t)
tr := &Transport{}
c := &Client{Transport: tr}
@@ -724,6 +859,361 @@ func TestTransportIdleConnCrash(t *testing.T) {
<-didreq
}
+// Test that the transport doesn't close the TCP connection early,
+// before the response body has been read. This was a regression
+// which sadly lacked a triggering test. The large response body made
+// the old race easier to trigger.
+func TestIssue3644(t *testing.T) {
+ defer checkLeakedTransports(t)
+ const numFoos = 5000
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Connection", "close")
+ for i := 0; i < numFoos; i++ {
+ w.Write([]byte("foo "))
+ }
+ }))
+ defer ts.Close()
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ bs, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(bs) != numFoos*len("foo ") {
+ t.Errorf("unexpected response length")
+ }
+}
+
+// Test that a client receives a server's reply, even if the server doesn't read
+// the entire request body.
+func TestIssue3595(t *testing.T) {
+ defer checkLeakedTransports(t)
+ const deniedMsg = "sorry, denied."
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ Error(w, deniedMsg, StatusUnauthorized)
+ }))
+ defer ts.Close()
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+ res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
+ if err != nil {
+ t.Errorf("Post: %v", err)
+ return
+ }
+ got, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("Body ReadAll: %v", err)
+ }
+ if !strings.Contains(string(got), deniedMsg) {
+ t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
+ }
+}
+
+// From http://golang.org/issue/4454 ,
+// "client fails to handle requests with no body and chunked encoding"
+func TestChunkedNoContent(t *testing.T) {
+ defer checkLeakedTransports(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.WriteHeader(StatusNoContent)
+ }))
+ defer ts.Close()
+
+ for _, closeBody := range []bool{true, false} {
+ c := &Client{Transport: &Transport{}}
+ const n = 4
+ for i := 1; i <= n; i++ {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
+ } else {
+ if closeBody {
+ res.Body.Close()
+ }
+ }
+ }
+ }
+}
+
+func TestTransportConcurrency(t *testing.T) {
+ defer checkLeakedTransports(t)
+ const maxProcs = 16
+ const numReqs = 500
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "%v", r.FormValue("echo"))
+ }))
+ defer ts.Close()
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+ reqs := make(chan string)
+ defer close(reqs)
+
+ var wg sync.WaitGroup
+ wg.Add(numReqs)
+ for i := 0; i < maxProcs*2; i++ {
+ go func() {
+ for req := range reqs {
+ res, err := c.Get(ts.URL + "/?echo=" + req)
+ if err != nil {
+ t.Errorf("error on req %s: %v", req, err)
+ wg.Done()
+ continue
+ }
+ all, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Errorf("read error on req %s: %v", req, err)
+ wg.Done()
+ continue
+ }
+ if string(all) != req {
+ t.Errorf("body of req %s = %q; want %q", req, all, req)
+ }
+ wg.Done()
+ res.Body.Close()
+ }
+ }()
+ }
+ for i := 0; i < numReqs; i++ {
+ reqs <- fmt.Sprintf("request-%d", i)
+ }
+ wg.Wait()
+}
+
+func TestIssue4191_InfiniteGetTimeout(t *testing.T) {
+ defer checkLeakedTransports(t)
+ const debug = false
+ mux := NewServeMux()
+ mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
+ io.Copy(w, neverEnding('a'))
+ })
+ ts := httptest.NewServer(mux)
+ timeout := 100 * time.Millisecond
+
+ client := &Client{
+ Transport: &Transport{
+ Dial: func(n, addr string) (net.Conn, error) {
+ conn, err := net.Dial(n, addr)
+ if err != nil {
+ return nil, err
+ }
+ conn.SetDeadline(time.Now().Add(timeout))
+ if debug {
+ conn = NewLoggingConn("client", conn)
+ }
+ return conn, nil
+ },
+ DisableKeepAlives: true,
+ },
+ }
+
+ getFailed := false
+ nRuns := 5
+ if testing.Short() {
+ nRuns = 1
+ }
+ for i := 0; i < nRuns; i++ {
+ if debug {
+ println("run", i+1, "of", nRuns)
+ }
+ sres, err := client.Get(ts.URL + "/get")
+ if err != nil {
+ if !getFailed {
+ // Make the timeout longer, once.
+ getFailed = true
+ t.Logf("increasing timeout")
+ i--
+ timeout *= 10
+ continue
+ }
+ t.Errorf("Error issuing GET: %v", err)
+ break
+ }
+ _, err = io.Copy(ioutil.Discard, sres.Body)
+ if err == nil {
+ t.Errorf("Unexpected successful copy")
+ break
+ }
+ }
+ if debug {
+ println("tests complete; waiting for handlers to finish")
+ }
+ ts.Close()
+}
+
+func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
+ defer checkLeakedTransports(t)
+ const debug = false
+ mux := NewServeMux()
+ mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
+ io.Copy(w, neverEnding('a'))
+ })
+ mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
+ defer r.Body.Close()
+ io.Copy(ioutil.Discard, r.Body)
+ })
+ ts := httptest.NewServer(mux)
+ timeout := 100 * time.Millisecond
+
+ client := &Client{
+ Transport: &Transport{
+ Dial: func(n, addr string) (net.Conn, error) {
+ conn, err := net.Dial(n, addr)
+ if err != nil {
+ return nil, err
+ }
+ conn.SetDeadline(time.Now().Add(timeout))
+ if debug {
+ conn = NewLoggingConn("client", conn)
+ }
+ return conn, nil
+ },
+ DisableKeepAlives: true,
+ },
+ }
+
+ getFailed := false
+ nRuns := 5
+ if testing.Short() {
+ nRuns = 1
+ }
+ for i := 0; i < nRuns; i++ {
+ if debug {
+ println("run", i+1, "of", nRuns)
+ }
+ sres, err := client.Get(ts.URL + "/get")
+ if err != nil {
+ if !getFailed {
+ // Make the timeout longer, once.
+ getFailed = true
+ t.Logf("increasing timeout")
+ i--
+ timeout *= 10
+ continue
+ }
+ t.Errorf("Error issuing GET: %v", err)
+ break
+ }
+ req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
+ _, err = client.Do(req)
+ if err == nil {
+ sres.Body.Close()
+ t.Errorf("Unexpected successful PUT")
+ break
+ }
+ sres.Body.Close()
+ }
+ if debug {
+ println("tests complete; waiting for handlers to finish")
+ }
+ ts.Close()
+}
+
+func TestTransportResponseHeaderTimeout(t *testing.T) {
+ defer checkLeakedTransports(t)
+ if testing.Short() {
+ t.Skip("skipping timeout test in -short mode")
+ }
+ mux := NewServeMux()
+ mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {})
+ mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
+ time.Sleep(2 * time.Second)
+ })
+ ts := httptest.NewServer(mux)
+ defer ts.Close()
+
+ tr := &Transport{
+ ResponseHeaderTimeout: 500 * time.Millisecond,
+ }
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ tests := []struct {
+ path string
+ want int
+ wantErr string
+ }{
+ {path: "/fast", want: 200},
+ {path: "/slow", wantErr: "timeout awaiting response headers"},
+ {path: "/fast", want: 200},
+ }
+ for i, tt := range tests {
+ res, err := c.Get(ts.URL + tt.path)
+ if err != nil {
+ if strings.Contains(err.Error(), tt.wantErr) {
+ continue
+ }
+ t.Errorf("%d. unexpected error: %v", i, err)
+ continue
+ }
+ if tt.wantErr != "" {
+ t.Errorf("%d. no error. expected error: %v", i, tt.wantErr)
+ continue
+ }
+ if res.StatusCode != tt.want {
+ t.Errorf("%d for path %q status = %d; want %d", i, tt.path, res.StatusCode, tt.want)
+ }
+ }
+}
+
+func TestTransportCancelRequest(t *testing.T) {
+ defer checkLeakedTransports(t)
+ if testing.Short() {
+ t.Skip("skipping test in -short mode")
+ }
+ unblockc := make(chan bool)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "Hello")
+ w.(Flusher).Flush() // send headers and some body
+ <-unblockc
+ }))
+ defer ts.Close()
+ defer close(unblockc)
+
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ go func() {
+ time.Sleep(1 * time.Second)
+ tr.CancelRequest(req)
+ }()
+ t0 := time.Now()
+ body, err := ioutil.ReadAll(res.Body)
+ d := time.Since(t0)
+
+ if err == nil {
+ t.Error("expected an error reading the body")
+ }
+ if string(body) != "Hello" {
+ t.Errorf("Body = %q; want Hello", body)
+ }
+ if d < 500*time.Millisecond {
+ t.Errorf("expected ~1 second delay; got %v", d)
+ }
+ // Verify no outstanding requests after readLoop/writeLoop
+ // goroutines shut down.
+ for tries := 3; tries > 0; tries-- {
+ n := tr.NumPendingRequestsForTesting()
+ if n == 0 {
+ break
+ }
+ time.Sleep(100 * time.Millisecond)
+ if tries == 1 {
+ t.Errorf("pending requests = %d; want 0", n)
+ }
+ }
+}
+
type fooProto struct{}
func (fooProto) RoundTrip(req *Request) (*Response, error) {
@@ -737,6 +1227,7 @@ func (fooProto) RoundTrip(req *Request) (*Response, error) {
}
func TestTransportAltProto(t *testing.T) {
+ defer checkLeakedTransports(t)
tr := &Transport{}
c := &Client{Transport: tr}
tr.RegisterProtocol("foo", fooProto{})
@@ -754,15 +1245,58 @@ func TestTransportAltProto(t *testing.T) {
}
}
-var proxyFromEnvTests = []struct {
+func TestTransportNoHost(t *testing.T) {
+ defer checkLeakedTransports(t)
+ tr := &Transport{}
+ _, err := tr.RoundTrip(&Request{
+ Header: make(Header),
+ URL: &url.URL{
+ Scheme: "http",
+ },
+ })
+ want := "http: no Host in request URL"
+ if got := fmt.Sprint(err); got != want {
+ t.Errorf("error = %v; want %q", err, want)
+ }
+}
+
+type proxyFromEnvTest struct {
+ req string // URL to fetch; blank means "http://example.com"
env string
- wanturl string
+ noenv string
+ want string
wanterr error
-}{
- {"127.0.0.1:8080", "http://127.0.0.1:8080", nil},
- {"http://127.0.0.1:8080", "http://127.0.0.1:8080", nil},
- {"https://127.0.0.1:8080", "https://127.0.0.1:8080", nil},
- {"", "<nil>", nil},
+}
+
+func (t proxyFromEnvTest) String() string {
+ var buf bytes.Buffer
+ if t.env != "" {
+ fmt.Fprintf(&buf, "http_proxy=%q", t.env)
+ }
+ if t.noenv != "" {
+ fmt.Fprintf(&buf, " no_proxy=%q", t.noenv)
+ }
+ req := "http://example.com"
+ if t.req != "" {
+ req = t.req
+ }
+ fmt.Fprintf(&buf, " req=%q", req)
+ return strings.TrimSpace(buf.String())
+}
+
+var proxyFromEnvTests = []proxyFromEnvTest{
+ {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
+ {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
+ {env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
+ {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
+ {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
+ {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
+ {want: "<nil>"},
+ {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
+ {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
+ {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
+ {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
+ {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
}
func TestProxyFromEnvironment(t *testing.T) {
@@ -770,16 +1304,21 @@ func TestProxyFromEnvironment(t *testing.T) {
os.Setenv("http_proxy", "")
os.Setenv("NO_PROXY", "")
os.Setenv("no_proxy", "")
- for i, tt := range proxyFromEnvTests {
+ for _, tt := range proxyFromEnvTests {
os.Setenv("HTTP_PROXY", tt.env)
- req, _ := NewRequest("GET", "http://example.com", nil)
+ os.Setenv("NO_PROXY", tt.noenv)
+ reqURL := tt.req
+ if reqURL == "" {
+ reqURL = "http://example.com"
+ }
+ req, _ := NewRequest("GET", reqURL, nil)
url, err := ProxyFromEnvironment(req)
if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
- t.Errorf("%d. got error = %q, want %q", i, g, e)
+ t.Errorf("%v: got error = %q, want %q", tt, g, e)
continue
}
- if got := fmt.Sprintf("%s", url); got != tt.wanturl {
- t.Errorf("%d. got URL = %q, want %q", i, url, tt.wanturl)
+ if got := fmt.Sprintf("%s", url); got != tt.want {
+ t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
}
}
}
diff --git a/src/pkg/net/http/z_last_test.go b/src/pkg/net/http/z_last_test.go
new file mode 100644
index 000000000..44095a8d9
--- /dev/null
+++ b/src/pkg/net/http/z_last_test.go
@@ -0,0 +1,60 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "net/http"
+ "runtime"
+ "strings"
+ "testing"
+ "time"
+)
+
+// Verify the other tests didn't leave any goroutines running.
+// This is in a file named z_last_test.go so it sorts at the end.
+func TestGoroutinesRunning(t *testing.T) {
+ n := runtime.NumGoroutine()
+ t.Logf("num goroutines = %d", n)
+ if n > 20 {
+ // Currently 14 on Linux (blocked in epoll_wait,
+ // waiting for on fds that are closed?), but give some
+ // slop for now.
+ buf := make([]byte, 1<<20)
+ buf = buf[:runtime.Stack(buf, true)]
+ t.Errorf("Too many goroutines:\n%s", buf)
+ }
+}
+
+func checkLeakedTransports(t *testing.T) {
+ http.DefaultTransport.(*http.Transport).CloseIdleConnections()
+ if testing.Short() {
+ return
+ }
+ buf := make([]byte, 1<<20)
+ var stacks string
+ var bad string
+ badSubstring := map[string]string{
+ ").readLoop(": "a Transport",
+ ").writeLoop(": "a Transport",
+ "created by net/http/httptest.(*Server).Start": "an httptest.Server",
+ "timeoutHandler": "a TimeoutHandler",
+ }
+ for i := 0; i < 4; i++ {
+ bad = ""
+ stacks = string(buf[:runtime.Stack(buf, true)])
+ for substr, what := range badSubstring {
+ if strings.Contains(stacks, substr) {
+ bad = what
+ }
+ }
+ if bad == "" {
+ return
+ }
+ // Bad stuff found, but goroutines might just still be
+ // shutting down, so give it some time.
+ time.Sleep(250 * time.Millisecond)
+ }
+ t.Errorf("Test appears to have leaked %s:\n%s", bad, stacks)
+}