summaryrefslogtreecommitdiff
path: root/src/net/http
diff options
context:
space:
mode:
Diffstat (limited to 'src/net/http')
-rw-r--r--src/net/http/cgi/child.go206
-rw-r--r--src/net/http/cgi/child_test.go131
-rw-r--r--src/net/http/cgi/host.go377
-rw-r--r--src/net/http/cgi/host_test.go461
-rw-r--r--src/net/http/cgi/matryoshka_test.go228
-rw-r--r--src/net/http/cgi/plan9_test.go18
-rw-r--r--src/net/http/cgi/posix_test.go21
-rwxr-xr-xsrc/net/http/cgi/testdata/test.cgi91
-rw-r--r--src/net/http/client.go511
-rw-r--r--src/net/http/client_test.go1075
-rw-r--r--src/net/http/cookie.go363
-rw-r--r--src/net/http/cookie_test.go412
-rw-r--r--src/net/http/cookiejar/jar.go497
-rw-r--r--src/net/http/cookiejar/jar_test.go1267
-rw-r--r--src/net/http/cookiejar/punycode.go159
-rw-r--r--src/net/http/cookiejar/punycode_test.go161
-rw-r--r--src/net/http/doc.go80
-rw-r--r--src/net/http/example_test.go88
-rw-r--r--src/net/http/export_test.go108
-rw-r--r--src/net/http/fcgi/child.go305
-rw-r--r--src/net/http/fcgi/fcgi.go274
-rw-r--r--src/net/http/fcgi/fcgi_test.go150
-rw-r--r--src/net/http/filetransport.go123
-rw-r--r--src/net/http/filetransport_test.go65
-rw-r--r--src/net/http/fs.go556
-rw-r--r--src/net/http/fs_test.go917
-rw-r--r--src/net/http/header.go211
-rw-r--r--src/net/http/header_test.go212
-rw-r--r--src/net/http/httptest/example_test.go50
-rw-r--r--src/net/http/httptest/recorder.go72
-rw-r--r--src/net/http/httptest/recorder_test.go90
-rw-r--r--src/net/http/httptest/server.go228
-rw-r--r--src/net/http/httptest/server_test.go29
-rw-r--r--src/net/http/httputil/dump.go284
-rw-r--r--src/net/http/httputil/dump_test.go291
-rw-r--r--src/net/http/httputil/httputil.go39
-rw-r--r--src/net/http/httputil/persist.go429
-rw-r--r--src/net/http/httputil/reverseproxy.go225
-rw-r--r--src/net/http/httputil/reverseproxy_test.go213
-rw-r--r--src/net/http/internal/chunked.go202
-rw-r--r--src/net/http/internal/chunked_test.go156
-rw-r--r--src/net/http/jar.go27
-rw-r--r--src/net/http/lex.go96
-rw-r--r--src/net/http/lex_test.go31
-rw-r--r--src/net/http/main_test.go109
-rw-r--r--src/net/http/npn_test.go118
-rw-r--r--src/net/http/pprof/pprof.go209
-rw-r--r--src/net/http/proxy_test.go81
-rw-r--r--src/net/http/race.go11
-rw-r--r--src/net/http/range_test.go79
-rw-r--r--src/net/http/readrequest_test.go358
-rw-r--r--src/net/http/request.go921
-rw-r--r--src/net/http/request_test.go680
-rw-r--r--src/net/http/requestwrite_test.go623
-rw-r--r--src/net/http/response.go291
-rw-r--r--src/net/http/response_test.go674
-rw-r--r--src/net/http/responsewrite_test.go226
-rw-r--r--src/net/http/serve_test.go3094
-rw-r--r--src/net/http/server.go2096
-rw-r--r--src/net/http/sniff.go214
-rw-r--r--src/net/http/sniff_test.go171
-rw-r--r--src/net/http/status.go120
-rw-r--r--src/net/http/testdata/file1
-rw-r--r--src/net/http/testdata/index.html1
-rw-r--r--src/net/http/testdata/style.css1
-rw-r--r--src/net/http/transfer.go737
-rw-r--r--src/net/http/transfer_test.go64
-rw-r--r--src/net/http/transport.go1275
-rw-r--r--src/net/http/transport_test.go2324
-rw-r--r--src/net/http/triv.go141
70 files changed, 26148 insertions, 0 deletions
diff --git a/src/net/http/cgi/child.go b/src/net/http/cgi/child.go
new file mode 100644
index 000000000..45fc2e57c
--- /dev/null
+++ b/src/net/http/cgi/child.go
@@ -0,0 +1,206 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file implements CGI from the perspective of a child
+// process.
+
+package cgi
+
+import (
+ "bufio"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net"
+ "net/http"
+ "net/url"
+ "os"
+ "strconv"
+ "strings"
+)
+
+// Request returns the HTTP request as represented in the current
+// environment. This assumes the current program is being run
+// by a web server in a CGI environment.
+// The returned Request's Body is populated, if applicable.
+func Request() (*http.Request, error) {
+ r, err := RequestFromMap(envMap(os.Environ()))
+ if err != nil {
+ return nil, err
+ }
+ if r.ContentLength > 0 {
+ r.Body = ioutil.NopCloser(io.LimitReader(os.Stdin, r.ContentLength))
+ }
+ return r, nil
+}
+
+func envMap(env []string) map[string]string {
+ m := make(map[string]string)
+ for _, kv := range env {
+ if idx := strings.Index(kv, "="); idx != -1 {
+ m[kv[:idx]] = kv[idx+1:]
+ }
+ }
+ return m
+}
+
+// RequestFromMap creates an http.Request from CGI variables.
+// The returned Request's Body field is not populated.
+func RequestFromMap(params map[string]string) (*http.Request, error) {
+ r := new(http.Request)
+ r.Method = params["REQUEST_METHOD"]
+ if r.Method == "" {
+ return nil, errors.New("cgi: no REQUEST_METHOD in environment")
+ }
+
+ r.Proto = params["SERVER_PROTOCOL"]
+ var ok bool
+ r.ProtoMajor, r.ProtoMinor, ok = http.ParseHTTPVersion(r.Proto)
+ if !ok {
+ return nil, errors.New("cgi: invalid SERVER_PROTOCOL version")
+ }
+
+ r.Close = true
+ r.Trailer = http.Header{}
+ r.Header = http.Header{}
+
+ r.Host = params["HTTP_HOST"]
+
+ if lenstr := params["CONTENT_LENGTH"]; lenstr != "" {
+ clen, err := strconv.ParseInt(lenstr, 10, 64)
+ if err != nil {
+ return nil, errors.New("cgi: bad CONTENT_LENGTH in environment: " + lenstr)
+ }
+ r.ContentLength = clen
+ }
+
+ if ct := params["CONTENT_TYPE"]; ct != "" {
+ r.Header.Set("Content-Type", ct)
+ }
+
+ // Copy "HTTP_FOO_BAR" variables to "Foo-Bar" Headers
+ for k, v := range params {
+ if !strings.HasPrefix(k, "HTTP_") || k == "HTTP_HOST" {
+ continue
+ }
+ r.Header.Add(strings.Replace(k[5:], "_", "-", -1), v)
+ }
+
+ // TODO: cookies. parsing them isn't exported, though.
+
+ uriStr := params["REQUEST_URI"]
+ if uriStr == "" {
+ // Fallback to SCRIPT_NAME, PATH_INFO and QUERY_STRING.
+ uriStr = params["SCRIPT_NAME"] + params["PATH_INFO"]
+ s := params["QUERY_STRING"]
+ if s != "" {
+ uriStr += "?" + s
+ }
+ }
+
+ // There's apparently a de-facto standard for this.
+ // http://docstore.mik.ua/orelly/linux/cgi/ch03_02.htm#ch03-35636
+ if s := params["HTTPS"]; s == "on" || s == "ON" || s == "1" {
+ r.TLS = &tls.ConnectionState{HandshakeComplete: true}
+ }
+
+ if r.Host != "" {
+ // Hostname is provided, so we can reasonably construct a URL.
+ rawurl := r.Host + uriStr
+ if r.TLS == nil {
+ rawurl = "http://" + rawurl
+ } else {
+ rawurl = "https://" + rawurl
+ }
+ url, err := url.Parse(rawurl)
+ if err != nil {
+ return nil, errors.New("cgi: failed to parse host and REQUEST_URI into a URL: " + rawurl)
+ }
+ r.URL = url
+ }
+ // Fallback logic if we don't have a Host header or the URL
+ // failed to parse
+ if r.URL == nil {
+ url, err := url.Parse(uriStr)
+ if err != nil {
+ return nil, errors.New("cgi: failed to parse REQUEST_URI into a URL: " + uriStr)
+ }
+ r.URL = url
+ }
+
+ // Request.RemoteAddr has its port set by Go's standard http
+ // server, so we do here too. We don't have one, though, so we
+ // use a dummy one.
+ r.RemoteAddr = net.JoinHostPort(params["REMOTE_ADDR"], "0")
+
+ return r, nil
+}
+
+// Serve executes the provided Handler on the currently active CGI
+// request, if any. If there's no current CGI environment
+// an error is returned. The provided handler may be nil to use
+// http.DefaultServeMux.
+func Serve(handler http.Handler) error {
+ req, err := Request()
+ if err != nil {
+ return err
+ }
+ if handler == nil {
+ handler = http.DefaultServeMux
+ }
+ rw := &response{
+ req: req,
+ header: make(http.Header),
+ bufw: bufio.NewWriter(os.Stdout),
+ }
+ handler.ServeHTTP(rw, req)
+ rw.Write(nil) // make sure a response is sent
+ if err = rw.bufw.Flush(); err != nil {
+ return err
+ }
+ return nil
+}
+
+type response struct {
+ req *http.Request
+ header http.Header
+ bufw *bufio.Writer
+ headerSent bool
+}
+
+func (r *response) Flush() {
+ r.bufw.Flush()
+}
+
+func (r *response) Header() http.Header {
+ return r.header
+}
+
+func (r *response) Write(p []byte) (n int, err error) {
+ if !r.headerSent {
+ r.WriteHeader(http.StatusOK)
+ }
+ return r.bufw.Write(p)
+}
+
+func (r *response) WriteHeader(code int) {
+ if r.headerSent {
+ // Note: explicitly using Stderr, as Stdout is our HTTP output.
+ fmt.Fprintf(os.Stderr, "CGI attempted to write header twice on request for %s", r.req.URL)
+ return
+ }
+ r.headerSent = true
+ fmt.Fprintf(r.bufw, "Status: %d %s\r\n", code, http.StatusText(code))
+
+ // Set a default Content-Type
+ if _, hasType := r.header["Content-Type"]; !hasType {
+ r.header.Add("Content-Type", "text/html; charset=utf-8")
+ }
+
+ r.header.Write(r.bufw)
+ r.bufw.WriteString("\r\n")
+ r.bufw.Flush()
+}
diff --git a/src/net/http/cgi/child_test.go b/src/net/http/cgi/child_test.go
new file mode 100644
index 000000000..075d8411b
--- /dev/null
+++ b/src/net/http/cgi/child_test.go
@@ -0,0 +1,131 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests for CGI (the child process perspective)
+
+package cgi
+
+import (
+ "testing"
+)
+
+func TestRequest(t *testing.T) {
+ env := map[string]string{
+ "SERVER_PROTOCOL": "HTTP/1.1",
+ "REQUEST_METHOD": "GET",
+ "HTTP_HOST": "example.com",
+ "HTTP_REFERER": "elsewhere",
+ "HTTP_USER_AGENT": "goclient",
+ "HTTP_FOO_BAR": "baz",
+ "REQUEST_URI": "/path?a=b",
+ "CONTENT_LENGTH": "123",
+ "CONTENT_TYPE": "text/xml",
+ "REMOTE_ADDR": "5.6.7.8",
+ }
+ req, err := RequestFromMap(env)
+ if err != nil {
+ t.Fatalf("RequestFromMap: %v", err)
+ }
+ if g, e := req.UserAgent(), "goclient"; e != g {
+ t.Errorf("expected UserAgent %q; got %q", e, g)
+ }
+ if g, e := req.Method, "GET"; e != g {
+ t.Errorf("expected Method %q; got %q", e, g)
+ }
+ if g, e := req.Header.Get("Content-Type"), "text/xml"; e != g {
+ t.Errorf("expected Content-Type %q; got %q", e, g)
+ }
+ if g, e := req.ContentLength, int64(123); e != g {
+ t.Errorf("expected ContentLength %d; got %d", e, g)
+ }
+ if g, e := req.Referer(), "elsewhere"; e != g {
+ t.Errorf("expected Referer %q; got %q", e, g)
+ }
+ if req.Header == nil {
+ t.Fatalf("unexpected nil Header")
+ }
+ if g, e := req.Header.Get("Foo-Bar"), "baz"; e != g {
+ t.Errorf("expected Foo-Bar %q; got %q", e, g)
+ }
+ if g, e := req.URL.String(), "http://example.com/path?a=b"; e != g {
+ t.Errorf("expected URL %q; got %q", e, g)
+ }
+ if g, e := req.FormValue("a"), "b"; e != g {
+ t.Errorf("expected FormValue(a) %q; got %q", e, g)
+ }
+ if req.Trailer == nil {
+ t.Errorf("unexpected nil Trailer")
+ }
+ if req.TLS != nil {
+ t.Errorf("expected nil TLS")
+ }
+ if e, g := "5.6.7.8:0", req.RemoteAddr; e != g {
+ t.Errorf("RemoteAddr: got %q; want %q", g, e)
+ }
+}
+
+func TestRequestWithTLS(t *testing.T) {
+ env := map[string]string{
+ "SERVER_PROTOCOL": "HTTP/1.1",
+ "REQUEST_METHOD": "GET",
+ "HTTP_HOST": "example.com",
+ "HTTP_REFERER": "elsewhere",
+ "REQUEST_URI": "/path?a=b",
+ "CONTENT_TYPE": "text/xml",
+ "HTTPS": "1",
+ "REMOTE_ADDR": "5.6.7.8",
+ }
+ req, err := RequestFromMap(env)
+ if err != nil {
+ t.Fatalf("RequestFromMap: %v", err)
+ }
+ if g, e := req.URL.String(), "https://example.com/path?a=b"; e != g {
+ t.Errorf("expected URL %q; got %q", e, g)
+ }
+ if req.TLS == nil {
+ t.Errorf("expected non-nil TLS")
+ }
+}
+
+func TestRequestWithoutHost(t *testing.T) {
+ env := map[string]string{
+ "SERVER_PROTOCOL": "HTTP/1.1",
+ "HTTP_HOST": "",
+ "REQUEST_METHOD": "GET",
+ "REQUEST_URI": "/path?a=b",
+ "CONTENT_LENGTH": "123",
+ }
+ req, err := RequestFromMap(env)
+ if err != nil {
+ t.Fatalf("RequestFromMap: %v", err)
+ }
+ if req.URL == nil {
+ t.Fatalf("unexpected nil URL")
+ }
+ if g, e := req.URL.String(), "/path?a=b"; e != g {
+ t.Errorf("URL = %q; want %q", g, e)
+ }
+}
+
+func TestRequestWithoutRequestURI(t *testing.T) {
+ env := map[string]string{
+ "SERVER_PROTOCOL": "HTTP/1.1",
+ "HTTP_HOST": "example.com",
+ "REQUEST_METHOD": "GET",
+ "SCRIPT_NAME": "/dir/scriptname",
+ "PATH_INFO": "/p1/p2",
+ "QUERY_STRING": "a=1&b=2",
+ "CONTENT_LENGTH": "123",
+ }
+ req, err := RequestFromMap(env)
+ if err != nil {
+ t.Fatalf("RequestFromMap: %v", err)
+ }
+ if req.URL == nil {
+ t.Fatalf("unexpected nil URL")
+ }
+ if g, e := req.URL.String(), "http://example.com/dir/scriptname/p1/p2?a=1&b=2"; e != g {
+ t.Errorf("URL = %q; want %q", g, e)
+ }
+}
diff --git a/src/net/http/cgi/host.go b/src/net/http/cgi/host.go
new file mode 100644
index 000000000..ec95a972c
--- /dev/null
+++ b/src/net/http/cgi/host.go
@@ -0,0 +1,377 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file implements the host side of CGI (being the webserver
+// parent process).
+
+// Package cgi implements CGI (Common Gateway Interface) as specified
+// in RFC 3875.
+//
+// Note that using CGI means starting a new process to handle each
+// request, which is typically less efficient than using a
+// long-running server. This package is intended primarily for
+// compatibility with existing systems.
+package cgi
+
+import (
+ "bufio"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "regexp"
+ "runtime"
+ "strconv"
+ "strings"
+)
+
+var trailingPort = regexp.MustCompile(`:([0-9]+)$`)
+
+var osDefaultInheritEnv = map[string][]string{
+ "darwin": {"DYLD_LIBRARY_PATH"},
+ "freebsd": {"LD_LIBRARY_PATH"},
+ "hpux": {"LD_LIBRARY_PATH", "SHLIB_PATH"},
+ "irix": {"LD_LIBRARY_PATH", "LD_LIBRARYN32_PATH", "LD_LIBRARY64_PATH"},
+ "linux": {"LD_LIBRARY_PATH"},
+ "openbsd": {"LD_LIBRARY_PATH"},
+ "solaris": {"LD_LIBRARY_PATH", "LD_LIBRARY_PATH_32", "LD_LIBRARY_PATH_64"},
+ "windows": {"SystemRoot", "COMSPEC", "PATHEXT", "WINDIR"},
+}
+
+// Handler runs an executable in a subprocess with a CGI environment.
+type Handler struct {
+ Path string // path to the CGI executable
+ Root string // root URI prefix of handler or empty for "/"
+
+ // Dir specifies the CGI executable's working directory.
+ // If Dir is empty, the base directory of Path is used.
+ // If Path has no base directory, the current working
+ // directory is used.
+ Dir string
+
+ Env []string // extra environment variables to set, if any, as "key=value"
+ InheritEnv []string // environment variables to inherit from host, as "key"
+ Logger *log.Logger // optional log for errors or nil to use log.Print
+ Args []string // optional arguments to pass to child process
+
+ // PathLocationHandler specifies the root http Handler that
+ // should handle internal redirects when the CGI process
+ // returns a Location header value starting with a "/", as
+ // specified in RFC 3875 ยง 6.3.2. This will likely be
+ // http.DefaultServeMux.
+ //
+ // If nil, a CGI response with a local URI path is instead sent
+ // back to the client and not redirected internally.
+ PathLocationHandler http.Handler
+}
+
+// removeLeadingDuplicates remove leading duplicate in environments.
+// It's possible to override environment like following.
+// cgi.Handler{
+// ...
+// Env: []string{"SCRIPT_FILENAME=foo.php"},
+// }
+func removeLeadingDuplicates(env []string) (ret []string) {
+ n := len(env)
+ for i := 0; i < n; i++ {
+ e := env[i]
+ s := strings.SplitN(e, "=", 2)[0]
+ found := false
+ for j := i + 1; j < n; j++ {
+ if s == strings.SplitN(env[j], "=", 2)[0] {
+ found = true
+ break
+ }
+ }
+ if !found {
+ ret = append(ret, e)
+ }
+ }
+ return
+}
+
+func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
+ root := h.Root
+ if root == "" {
+ root = "/"
+ }
+
+ if len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked" {
+ rw.WriteHeader(http.StatusBadRequest)
+ rw.Write([]byte("Chunked request bodies are not supported by CGI."))
+ return
+ }
+
+ pathInfo := req.URL.Path
+ if root != "/" && strings.HasPrefix(pathInfo, root) {
+ pathInfo = pathInfo[len(root):]
+ }
+
+ port := "80"
+ if matches := trailingPort.FindStringSubmatch(req.Host); len(matches) != 0 {
+ port = matches[1]
+ }
+
+ env := []string{
+ "SERVER_SOFTWARE=go",
+ "SERVER_NAME=" + req.Host,
+ "SERVER_PROTOCOL=HTTP/1.1",
+ "HTTP_HOST=" + req.Host,
+ "GATEWAY_INTERFACE=CGI/1.1",
+ "REQUEST_METHOD=" + req.Method,
+ "QUERY_STRING=" + req.URL.RawQuery,
+ "REQUEST_URI=" + req.URL.RequestURI(),
+ "PATH_INFO=" + pathInfo,
+ "SCRIPT_NAME=" + root,
+ "SCRIPT_FILENAME=" + h.Path,
+ "REMOTE_ADDR=" + req.RemoteAddr,
+ "REMOTE_HOST=" + req.RemoteAddr,
+ "SERVER_PORT=" + port,
+ }
+
+ if req.TLS != nil {
+ env = append(env, "HTTPS=on")
+ }
+
+ for k, v := range req.Header {
+ k = strings.Map(upperCaseAndUnderscore, k)
+ joinStr := ", "
+ if k == "COOKIE" {
+ joinStr = "; "
+ }
+ env = append(env, "HTTP_"+k+"="+strings.Join(v, joinStr))
+ }
+
+ if req.ContentLength > 0 {
+ env = append(env, fmt.Sprintf("CONTENT_LENGTH=%d", req.ContentLength))
+ }
+ if ctype := req.Header.Get("Content-Type"); ctype != "" {
+ env = append(env, "CONTENT_TYPE="+ctype)
+ }
+
+ if h.Env != nil {
+ env = append(env, h.Env...)
+ }
+
+ envPath := os.Getenv("PATH")
+ if envPath == "" {
+ envPath = "/bin:/usr/bin:/usr/ucb:/usr/bsd:/usr/local/bin"
+ }
+ env = append(env, "PATH="+envPath)
+
+ for _, e := range h.InheritEnv {
+ if v := os.Getenv(e); v != "" {
+ env = append(env, e+"="+v)
+ }
+ }
+
+ for _, e := range osDefaultInheritEnv[runtime.GOOS] {
+ if v := os.Getenv(e); v != "" {
+ env = append(env, e+"="+v)
+ }
+ }
+
+ env = removeLeadingDuplicates(env)
+
+ var cwd, path string
+ if h.Dir != "" {
+ path = h.Path
+ cwd = h.Dir
+ } else {
+ cwd, path = filepath.Split(h.Path)
+ }
+ if cwd == "" {
+ cwd = "."
+ }
+
+ internalError := func(err error) {
+ rw.WriteHeader(http.StatusInternalServerError)
+ h.printf("CGI error: %v", err)
+ }
+
+ cmd := &exec.Cmd{
+ Path: path,
+ Args: append([]string{h.Path}, h.Args...),
+ Dir: cwd,
+ Env: env,
+ Stderr: os.Stderr, // for now
+ }
+ if req.ContentLength != 0 {
+ cmd.Stdin = req.Body
+ }
+ stdoutRead, err := cmd.StdoutPipe()
+ if err != nil {
+ internalError(err)
+ return
+ }
+
+ err = cmd.Start()
+ if err != nil {
+ internalError(err)
+ return
+ }
+ if hook := testHookStartProcess; hook != nil {
+ hook(cmd.Process)
+ }
+ defer cmd.Wait()
+ defer stdoutRead.Close()
+
+ linebody := bufio.NewReaderSize(stdoutRead, 1024)
+ headers := make(http.Header)
+ statusCode := 0
+ headerLines := 0
+ sawBlankLine := false
+ for {
+ line, isPrefix, err := linebody.ReadLine()
+ if isPrefix {
+ rw.WriteHeader(http.StatusInternalServerError)
+ h.printf("cgi: long header line from subprocess.")
+ return
+ }
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ rw.WriteHeader(http.StatusInternalServerError)
+ h.printf("cgi: error reading headers: %v", err)
+ return
+ }
+ if len(line) == 0 {
+ sawBlankLine = true
+ break
+ }
+ headerLines++
+ parts := strings.SplitN(string(line), ":", 2)
+ if len(parts) < 2 {
+ h.printf("cgi: bogus header line: %s", string(line))
+ continue
+ }
+ header, val := parts[0], parts[1]
+ header = strings.TrimSpace(header)
+ val = strings.TrimSpace(val)
+ switch {
+ case header == "Status":
+ if len(val) < 3 {
+ h.printf("cgi: bogus status (short): %q", val)
+ return
+ }
+ code, err := strconv.Atoi(val[0:3])
+ if err != nil {
+ h.printf("cgi: bogus status: %q", val)
+ h.printf("cgi: line was %q", line)
+ return
+ }
+ statusCode = code
+ default:
+ headers.Add(header, val)
+ }
+ }
+ if headerLines == 0 || !sawBlankLine {
+ rw.WriteHeader(http.StatusInternalServerError)
+ h.printf("cgi: no headers")
+ return
+ }
+
+ if loc := headers.Get("Location"); loc != "" {
+ if strings.HasPrefix(loc, "/") && h.PathLocationHandler != nil {
+ h.handleInternalRedirect(rw, req, loc)
+ return
+ }
+ if statusCode == 0 {
+ statusCode = http.StatusFound
+ }
+ }
+
+ if statusCode == 0 && headers.Get("Content-Type") == "" {
+ rw.WriteHeader(http.StatusInternalServerError)
+ h.printf("cgi: missing required Content-Type in headers")
+ return
+ }
+
+ if statusCode == 0 {
+ statusCode = http.StatusOK
+ }
+
+ // Copy headers to rw's headers, after we've decided not to
+ // go into handleInternalRedirect, which won't want its rw
+ // headers to have been touched.
+ for k, vv := range headers {
+ for _, v := range vv {
+ rw.Header().Add(k, v)
+ }
+ }
+
+ rw.WriteHeader(statusCode)
+
+ _, err = io.Copy(rw, linebody)
+ if err != nil {
+ h.printf("cgi: copy error: %v", err)
+ // And kill the child CGI process so we don't hang on
+ // the deferred cmd.Wait above if the error was just
+ // the client (rw) going away. If it was a read error
+ // (because the child died itself), then the extra
+ // kill of an already-dead process is harmless (the PID
+ // won't be reused until the Wait above).
+ cmd.Process.Kill()
+ }
+}
+
+func (h *Handler) printf(format string, v ...interface{}) {
+ if h.Logger != nil {
+ h.Logger.Printf(format, v...)
+ } else {
+ log.Printf(format, v...)
+ }
+}
+
+func (h *Handler) handleInternalRedirect(rw http.ResponseWriter, req *http.Request, path string) {
+ url, err := req.URL.Parse(path)
+ if err != nil {
+ rw.WriteHeader(http.StatusInternalServerError)
+ h.printf("cgi: error resolving local URI path %q: %v", path, err)
+ return
+ }
+ // TODO: RFC 3875 isn't clear if only GET is supported, but it
+ // suggests so: "Note that any message-body attached to the
+ // request (such as for a POST request) may not be available
+ // to the resource that is the target of the redirect." We
+ // should do some tests against Apache to see how it handles
+ // POST, HEAD, etc. Does the internal redirect get the same
+ // method or just GET? What about incoming headers?
+ // (e.g. Cookies) Which headers, if any, are copied into the
+ // second request?
+ newReq := &http.Request{
+ Method: "GET",
+ URL: url,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: make(http.Header),
+ Host: url.Host,
+ RemoteAddr: req.RemoteAddr,
+ TLS: req.TLS,
+ }
+ h.PathLocationHandler.ServeHTTP(rw, newReq)
+}
+
+func upperCaseAndUnderscore(r rune) rune {
+ switch {
+ case r >= 'a' && r <= 'z':
+ return r - ('a' - 'A')
+ case r == '-':
+ return '_'
+ case r == '=':
+ // Maybe not part of the CGI 'spec' but would mess up
+ // the environment in any case, as Go represents the
+ // environment as a slice of "key=value" strings.
+ return '_'
+ }
+ // TODO: other transformations in spec or practice?
+ return r
+}
+
+var testHookStartProcess func(*os.Process) // nil except for some tests
diff --git a/src/net/http/cgi/host_test.go b/src/net/http/cgi/host_test.go
new file mode 100644
index 000000000..8c16e6897
--- /dev/null
+++ b/src/net/http/cgi/host_test.go
@@ -0,0 +1,461 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests for package cgi
+
+package cgi
+
+import (
+ "bufio"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "runtime"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+)
+
+func newRequest(httpreq string) *http.Request {
+ buf := bufio.NewReader(strings.NewReader(httpreq))
+ req, err := http.ReadRequest(buf)
+ if err != nil {
+ panic("cgi: bogus http request in test: " + httpreq)
+ }
+ req.RemoteAddr = "1.2.3.4"
+ return req
+}
+
+func runCgiTest(t *testing.T, h *Handler, httpreq string, expectedMap map[string]string) *httptest.ResponseRecorder {
+ rw := httptest.NewRecorder()
+ req := newRequest(httpreq)
+ h.ServeHTTP(rw, req)
+
+ // Make a map to hold the test map that the CGI returns.
+ m := make(map[string]string)
+ m["_body"] = rw.Body.String()
+ linesRead := 0
+readlines:
+ for {
+ line, err := rw.Body.ReadString('\n')
+ switch {
+ case err == io.EOF:
+ break readlines
+ case err != nil:
+ t.Fatalf("unexpected error reading from CGI: %v", err)
+ }
+ linesRead++
+ trimmedLine := strings.TrimRight(line, "\r\n")
+ split := strings.SplitN(trimmedLine, "=", 2)
+ if len(split) != 2 {
+ t.Fatalf("Unexpected %d parts from invalid line number %v: %q; existing map=%v",
+ len(split), linesRead, line, m)
+ }
+ m[split[0]] = split[1]
+ }
+
+ for key, expected := range expectedMap {
+ got := m[key]
+ if key == "cwd" {
+ // For Windows. golang.org/issue/4645.
+ fi1, _ := os.Stat(got)
+ fi2, _ := os.Stat(expected)
+ if os.SameFile(fi1, fi2) {
+ got = expected
+ }
+ }
+ if got != expected {
+ t.Errorf("for key %q got %q; expected %q", key, got, expected)
+ }
+ }
+ return rw
+}
+
+var cgiTested, cgiWorks bool
+
+func check(t *testing.T) {
+ if !cgiTested {
+ cgiTested = true
+ cgiWorks = exec.Command("./testdata/test.cgi").Run() == nil
+ }
+ if !cgiWorks {
+ // No Perl on Windows, needed by test.cgi
+ // TODO: make the child process be Go, not Perl.
+ t.Skip("Skipping test: test.cgi failed.")
+ }
+}
+
+func TestCGIBasicGet(t *testing.T) {
+ check(t)
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ expectedMap := map[string]string{
+ "test": "Hello CGI",
+ "param-a": "b",
+ "param-foo": "bar",
+ "env-GATEWAY_INTERFACE": "CGI/1.1",
+ "env-HTTP_HOST": "example.com",
+ "env-PATH_INFO": "",
+ "env-QUERY_STRING": "foo=bar&a=b",
+ "env-REMOTE_ADDR": "1.2.3.4",
+ "env-REMOTE_HOST": "1.2.3.4",
+ "env-REQUEST_METHOD": "GET",
+ "env-REQUEST_URI": "/test.cgi?foo=bar&a=b",
+ "env-SCRIPT_FILENAME": "testdata/test.cgi",
+ "env-SCRIPT_NAME": "/test.cgi",
+ "env-SERVER_NAME": "example.com",
+ "env-SERVER_PORT": "80",
+ "env-SERVER_SOFTWARE": "go",
+ }
+ replay := runCgiTest(t, h, "GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
+
+ if expected, got := "text/html", replay.Header().Get("Content-Type"); got != expected {
+ t.Errorf("got a Content-Type of %q; expected %q", got, expected)
+ }
+ if expected, got := "X-Test-Value", replay.Header().Get("X-Test-Header"); got != expected {
+ t.Errorf("got a X-Test-Header of %q; expected %q", got, expected)
+ }
+}
+
+func TestCGIBasicGetAbsPath(t *testing.T) {
+ check(t)
+ pwd, err := os.Getwd()
+ if err != nil {
+ t.Fatalf("getwd error: %v", err)
+ }
+ h := &Handler{
+ Path: pwd + "/testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ expectedMap := map[string]string{
+ "env-REQUEST_URI": "/test.cgi?foo=bar&a=b",
+ "env-SCRIPT_FILENAME": pwd + "/testdata/test.cgi",
+ "env-SCRIPT_NAME": "/test.cgi",
+ }
+ runCgiTest(t, h, "GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+func TestPathInfo(t *testing.T) {
+ check(t)
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ expectedMap := map[string]string{
+ "param-a": "b",
+ "env-PATH_INFO": "/extrapath",
+ "env-QUERY_STRING": "a=b",
+ "env-REQUEST_URI": "/test.cgi/extrapath?a=b",
+ "env-SCRIPT_FILENAME": "testdata/test.cgi",
+ "env-SCRIPT_NAME": "/test.cgi",
+ }
+ runCgiTest(t, h, "GET /test.cgi/extrapath?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+func TestPathInfoDirRoot(t *testing.T) {
+ check(t)
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/myscript/",
+ }
+ expectedMap := map[string]string{
+ "env-PATH_INFO": "bar",
+ "env-QUERY_STRING": "a=b",
+ "env-REQUEST_URI": "/myscript/bar?a=b",
+ "env-SCRIPT_FILENAME": "testdata/test.cgi",
+ "env-SCRIPT_NAME": "/myscript/",
+ }
+ runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+func TestDupHeaders(t *testing.T) {
+ check(t)
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ }
+ expectedMap := map[string]string{
+ "env-REQUEST_URI": "/myscript/bar?a=b",
+ "env-SCRIPT_FILENAME": "testdata/test.cgi",
+ "env-HTTP_COOKIE": "nom=NOM; yum=YUM",
+ "env-HTTP_X_FOO": "val1, val2",
+ }
+ runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\n"+
+ "Cookie: nom=NOM\n"+
+ "Cookie: yum=YUM\n"+
+ "X-Foo: val1\n"+
+ "X-Foo: val2\n"+
+ "Host: example.com\n\n",
+ expectedMap)
+}
+
+func TestPathInfoNoRoot(t *testing.T) {
+ check(t)
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "",
+ }
+ expectedMap := map[string]string{
+ "env-PATH_INFO": "/bar",
+ "env-QUERY_STRING": "a=b",
+ "env-REQUEST_URI": "/bar?a=b",
+ "env-SCRIPT_FILENAME": "testdata/test.cgi",
+ "env-SCRIPT_NAME": "/",
+ }
+ runCgiTest(t, h, "GET /bar?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+func TestCGIBasicPost(t *testing.T) {
+ check(t)
+ postReq := `POST /test.cgi?a=b HTTP/1.0
+Host: example.com
+Content-Type: application/x-www-form-urlencoded
+Content-Length: 15
+
+postfoo=postbar`
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ expectedMap := map[string]string{
+ "test": "Hello CGI",
+ "param-postfoo": "postbar",
+ "env-REQUEST_METHOD": "POST",
+ "env-CONTENT_LENGTH": "15",
+ "env-REQUEST_URI": "/test.cgi?a=b",
+ }
+ runCgiTest(t, h, postReq, expectedMap)
+}
+
+func chunk(s string) string {
+ return fmt.Sprintf("%x\r\n%s\r\n", len(s), s)
+}
+
+// The CGI spec doesn't allow chunked requests.
+func TestCGIPostChunked(t *testing.T) {
+ check(t)
+ postReq := `POST /test.cgi?a=b HTTP/1.1
+Host: example.com
+Content-Type: application/x-www-form-urlencoded
+Transfer-Encoding: chunked
+
+` + chunk("postfoo") + chunk("=") + chunk("postbar") + chunk("")
+
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ expectedMap := map[string]string{}
+ resp := runCgiTest(t, h, postReq, expectedMap)
+ if got, expected := resp.Code, http.StatusBadRequest; got != expected {
+ t.Fatalf("Expected %v response code from chunked request body; got %d",
+ expected, got)
+ }
+}
+
+func TestRedirect(t *testing.T) {
+ check(t)
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ rec := runCgiTest(t, h, "GET /test.cgi?loc=http://foo.com/ HTTP/1.0\nHost: example.com\n\n", nil)
+ if e, g := 302, rec.Code; e != g {
+ t.Errorf("expected status code %d; got %d", e, g)
+ }
+ if e, g := "http://foo.com/", rec.Header().Get("Location"); e != g {
+ t.Errorf("expected Location header of %q; got %q", e, g)
+ }
+}
+
+func TestInternalRedirect(t *testing.T) {
+ check(t)
+ baseHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
+ fmt.Fprintf(rw, "basepath=%s\n", req.URL.Path)
+ fmt.Fprintf(rw, "remoteaddr=%s\n", req.RemoteAddr)
+ })
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ PathLocationHandler: baseHandler,
+ }
+ expectedMap := map[string]string{
+ "basepath": "/foo",
+ "remoteaddr": "1.2.3.4",
+ }
+ runCgiTest(t, h, "GET /test.cgi?loc=/foo HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+// TestCopyError tests that we kill the process if there's an error copying
+// its output. (for example, from the client having gone away)
+func TestCopyError(t *testing.T) {
+ check(t)
+ if runtime.GOOS == "windows" {
+ t.Skipf("skipping test on %q", runtime.GOOS)
+ }
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ ts := httptest.NewServer(h)
+ defer ts.Close()
+
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ req, _ := http.NewRequest("GET", "http://example.com/test.cgi?bigresponse=1", nil)
+ err = req.Write(conn)
+ if err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+
+ res, err := http.ReadResponse(bufio.NewReader(conn), req)
+ if err != nil {
+ t.Fatalf("ReadResponse: %v", err)
+ }
+
+ pidstr := res.Header.Get("X-CGI-Pid")
+ if pidstr == "" {
+ t.Fatalf("expected an X-CGI-Pid header in response")
+ }
+ pid, err := strconv.Atoi(pidstr)
+ if err != nil {
+ t.Fatalf("invalid X-CGI-Pid value")
+ }
+
+ var buf [5000]byte
+ n, err := io.ReadFull(res.Body, buf[:])
+ if err != nil {
+ t.Fatalf("ReadFull: %d bytes, %v", n, err)
+ }
+
+ childRunning := func() bool {
+ return isProcessRunning(t, pid)
+ }
+
+ if !childRunning() {
+ t.Fatalf("pre-conn.Close, expected child to be running")
+ }
+ conn.Close()
+
+ tries := 0
+ for tries < 25 && childRunning() {
+ time.Sleep(50 * time.Millisecond * time.Duration(tries))
+ tries++
+ }
+ if childRunning() {
+ t.Fatalf("post-conn.Close, expected child to be gone")
+ }
+}
+
+func TestDirUnix(t *testing.T) {
+ check(t)
+ if runtime.GOOS == "windows" {
+ t.Skipf("skipping test on %q", runtime.GOOS)
+ }
+ cwd, _ := os.Getwd()
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ Dir: cwd,
+ }
+ expectedMap := map[string]string{
+ "cwd": cwd,
+ }
+ runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap)
+
+ cwd, _ = os.Getwd()
+ cwd = filepath.Join(cwd, "testdata")
+ h = &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ expectedMap = map[string]string{
+ "cwd": cwd,
+ }
+ runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+func TestDirWindows(t *testing.T) {
+ if runtime.GOOS != "windows" {
+ t.Skip("Skipping windows specific test.")
+ }
+
+ cgifile, _ := filepath.Abs("testdata/test.cgi")
+
+ var perl string
+ var err error
+ perl, err = exec.LookPath("perl")
+ if err != nil {
+ t.Skip("Skipping test: perl not found.")
+ }
+ perl, _ = filepath.Abs(perl)
+
+ cwd, _ := os.Getwd()
+ h := &Handler{
+ Path: perl,
+ Root: "/test.cgi",
+ Dir: cwd,
+ Args: []string{cgifile},
+ Env: []string{"SCRIPT_FILENAME=" + cgifile},
+ }
+ expectedMap := map[string]string{
+ "cwd": cwd,
+ }
+ runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap)
+
+ // If not specify Dir on windows, working directory should be
+ // base directory of perl.
+ cwd, _ = filepath.Split(perl)
+ if cwd != "" && cwd[len(cwd)-1] == filepath.Separator {
+ cwd = cwd[:len(cwd)-1]
+ }
+ h = &Handler{
+ Path: perl,
+ Root: "/test.cgi",
+ Args: []string{cgifile},
+ Env: []string{"SCRIPT_FILENAME=" + cgifile},
+ }
+ expectedMap = map[string]string{
+ "cwd": cwd,
+ }
+ runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+func TestEnvOverride(t *testing.T) {
+ cgifile, _ := filepath.Abs("testdata/test.cgi")
+
+ var perl string
+ var err error
+ perl, err = exec.LookPath("perl")
+ if err != nil {
+ t.Skipf("Skipping test: perl not found.")
+ }
+ perl, _ = filepath.Abs(perl)
+
+ cwd, _ := os.Getwd()
+ h := &Handler{
+ Path: perl,
+ Root: "/test.cgi",
+ Dir: cwd,
+ Args: []string{cgifile},
+ Env: []string{
+ "SCRIPT_FILENAME=" + cgifile,
+ "REQUEST_URI=/foo/bar"},
+ }
+ expectedMap := map[string]string{
+ "cwd": cwd,
+ "env-SCRIPT_FILENAME": cgifile,
+ "env-REQUEST_URI": "/foo/bar",
+ }
+ runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
diff --git a/src/net/http/cgi/matryoshka_test.go b/src/net/http/cgi/matryoshka_test.go
new file mode 100644
index 000000000..18c4803e7
--- /dev/null
+++ b/src/net/http/cgi/matryoshka_test.go
@@ -0,0 +1,228 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests a Go CGI program running under a Go CGI host process.
+// Further, the two programs are the same binary, just checking
+// their environment to figure out what mode to run in.
+
+package cgi
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "runtime"
+ "testing"
+ "time"
+)
+
+// This test is a CGI host (testing host.go) that runs its own binary
+// as a child process testing the other half of CGI (child.go).
+func TestHostingOurselves(t *testing.T) {
+ if runtime.GOOS == "nacl" {
+ t.Skip("skipping on nacl")
+ }
+
+ h := &Handler{
+ Path: os.Args[0],
+ Root: "/test.go",
+ Args: []string{"-test.run=TestBeChildCGIProcess"},
+ }
+ expectedMap := map[string]string{
+ "test": "Hello CGI-in-CGI",
+ "param-a": "b",
+ "param-foo": "bar",
+ "env-GATEWAY_INTERFACE": "CGI/1.1",
+ "env-HTTP_HOST": "example.com",
+ "env-PATH_INFO": "",
+ "env-QUERY_STRING": "foo=bar&a=b",
+ "env-REMOTE_ADDR": "1.2.3.4",
+ "env-REMOTE_HOST": "1.2.3.4",
+ "env-REQUEST_METHOD": "GET",
+ "env-REQUEST_URI": "/test.go?foo=bar&a=b",
+ "env-SCRIPT_FILENAME": os.Args[0],
+ "env-SCRIPT_NAME": "/test.go",
+ "env-SERVER_NAME": "example.com",
+ "env-SERVER_PORT": "80",
+ "env-SERVER_SOFTWARE": "go",
+ }
+ replay := runCgiTest(t, h, "GET /test.go?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
+
+ if expected, got := "text/html; charset=utf-8", replay.Header().Get("Content-Type"); got != expected {
+ t.Errorf("got a Content-Type of %q; expected %q", got, expected)
+ }
+ if expected, got := "X-Test-Value", replay.Header().Get("X-Test-Header"); got != expected {
+ t.Errorf("got a X-Test-Header of %q; expected %q", got, expected)
+ }
+}
+
+type customWriterRecorder struct {
+ w io.Writer
+ *httptest.ResponseRecorder
+}
+
+func (r *customWriterRecorder) Write(p []byte) (n int, err error) {
+ return r.w.Write(p)
+}
+
+type limitWriter struct {
+ w io.Writer
+ n int
+}
+
+func (w *limitWriter) Write(p []byte) (n int, err error) {
+ if len(p) > w.n {
+ p = p[:w.n]
+ }
+ if len(p) > 0 {
+ n, err = w.w.Write(p)
+ w.n -= n
+ }
+ if w.n == 0 {
+ err = errors.New("past write limit")
+ }
+ return
+}
+
+// If there's an error copying the child's output to the parent, test
+// that we kill the child.
+func TestKillChildAfterCopyError(t *testing.T) {
+ if runtime.GOOS == "nacl" {
+ t.Skip("skipping on nacl")
+ }
+
+ defer func() { testHookStartProcess = nil }()
+ proc := make(chan *os.Process, 1)
+ testHookStartProcess = func(p *os.Process) {
+ proc <- p
+ }
+
+ h := &Handler{
+ Path: os.Args[0],
+ Root: "/test.go",
+ Args: []string{"-test.run=TestBeChildCGIProcess"},
+ }
+ req, _ := http.NewRequest("GET", "http://example.com/test.cgi?write-forever=1", nil)
+ rec := httptest.NewRecorder()
+ var out bytes.Buffer
+ const writeLen = 50 << 10
+ rw := &customWriterRecorder{&limitWriter{&out, writeLen}, rec}
+
+ donec := make(chan bool, 1)
+ go func() {
+ h.ServeHTTP(rw, req)
+ donec <- true
+ }()
+
+ select {
+ case <-donec:
+ if out.Len() != writeLen || out.Bytes()[0] != 'a' {
+ t.Errorf("unexpected output: %q", out.Bytes())
+ }
+ case <-time.After(5 * time.Second):
+ t.Errorf("timeout. ServeHTTP hung and didn't kill the child process?")
+ select {
+ case p := <-proc:
+ p.Kill()
+ t.Logf("killed process")
+ default:
+ t.Logf("didn't kill process")
+ }
+ }
+}
+
+// Test that a child handler writing only headers works.
+// golang.org/issue/7196
+func TestChildOnlyHeaders(t *testing.T) {
+ if runtime.GOOS == "nacl" {
+ t.Skip("skipping on nacl")
+ }
+
+ h := &Handler{
+ Path: os.Args[0],
+ Root: "/test.go",
+ Args: []string{"-test.run=TestBeChildCGIProcess"},
+ }
+ expectedMap := map[string]string{
+ "_body": "",
+ }
+ replay := runCgiTest(t, h, "GET /test.go?no-body=1 HTTP/1.0\nHost: example.com\n\n", expectedMap)
+ if expected, got := "X-Test-Value", replay.Header().Get("X-Test-Header"); got != expected {
+ t.Errorf("got a X-Test-Header of %q; expected %q", got, expected)
+ }
+}
+
+// golang.org/issue/7198
+func Test500WithNoHeaders(t *testing.T) { want500Test(t, "/immediate-disconnect") }
+func Test500WithNoContentType(t *testing.T) { want500Test(t, "/no-content-type") }
+func Test500WithEmptyHeaders(t *testing.T) { want500Test(t, "/empty-headers") }
+
+func want500Test(t *testing.T, path string) {
+ h := &Handler{
+ Path: os.Args[0],
+ Root: "/test.go",
+ Args: []string{"-test.run=TestBeChildCGIProcess"},
+ }
+ expectedMap := map[string]string{
+ "_body": "",
+ }
+ replay := runCgiTest(t, h, "GET "+path+" HTTP/1.0\nHost: example.com\n\n", expectedMap)
+ if replay.Code != 500 {
+ t.Errorf("Got code %d; want 500", replay.Code)
+ }
+}
+
+type neverEnding byte
+
+func (b neverEnding) Read(p []byte) (n int, err error) {
+ for i := range p {
+ p[i] = byte(b)
+ }
+ return len(p), nil
+}
+
+// Note: not actually a test.
+func TestBeChildCGIProcess(t *testing.T) {
+ if os.Getenv("REQUEST_METHOD") == "" {
+ // Not in a CGI environment; skipping test.
+ return
+ }
+ switch os.Getenv("REQUEST_URI") {
+ case "/immediate-disconnect":
+ os.Exit(0)
+ case "/no-content-type":
+ fmt.Printf("Content-Length: 6\n\nHello\n")
+ os.Exit(0)
+ case "/empty-headers":
+ fmt.Printf("\nHello")
+ os.Exit(0)
+ }
+ Serve(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
+ rw.Header().Set("X-Test-Header", "X-Test-Value")
+ req.ParseForm()
+ if req.FormValue("no-body") == "1" {
+ return
+ }
+ if req.FormValue("write-forever") == "1" {
+ io.Copy(rw, neverEnding('a'))
+ for {
+ time.Sleep(5 * time.Second) // hang forever, until killed
+ }
+ }
+ fmt.Fprintf(rw, "test=Hello CGI-in-CGI\n")
+ for k, vv := range req.Form {
+ for _, v := range vv {
+ fmt.Fprintf(rw, "param-%s=%s\n", k, v)
+ }
+ }
+ for _, kv := range os.Environ() {
+ fmt.Fprintf(rw, "env-%s\n", kv)
+ }
+ }))
+ os.Exit(0)
+}
diff --git a/src/net/http/cgi/plan9_test.go b/src/net/http/cgi/plan9_test.go
new file mode 100644
index 000000000..c8235831b
--- /dev/null
+++ b/src/net/http/cgi/plan9_test.go
@@ -0,0 +1,18 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build plan9
+
+package cgi
+
+import (
+ "os"
+ "strconv"
+ "testing"
+)
+
+func isProcessRunning(t *testing.T, pid int) bool {
+ _, err := os.Stat("/proc/" + strconv.Itoa(pid))
+ return err == nil
+}
diff --git a/src/net/http/cgi/posix_test.go b/src/net/http/cgi/posix_test.go
new file mode 100644
index 000000000..5ff9e7d5e
--- /dev/null
+++ b/src/net/http/cgi/posix_test.go
@@ -0,0 +1,21 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !plan9
+
+package cgi
+
+import (
+ "os"
+ "syscall"
+ "testing"
+)
+
+func isProcessRunning(t *testing.T, pid int) bool {
+ p, err := os.FindProcess(pid)
+ if err != nil {
+ return false
+ }
+ return p.Signal(syscall.Signal(0)) == nil
+}
diff --git a/src/net/http/cgi/testdata/test.cgi b/src/net/http/cgi/testdata/test.cgi
new file mode 100755
index 000000000..3214df6f0
--- /dev/null
+++ b/src/net/http/cgi/testdata/test.cgi
@@ -0,0 +1,91 @@
+#!/usr/bin/perl
+# Copyright 2011 The Go Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+#
+# Test script run as a child process under cgi_test.go
+
+use strict;
+use Cwd;
+
+binmode STDOUT;
+
+my $q = MiniCGI->new;
+my $params = $q->Vars;
+
+if ($params->{"loc"}) {
+ print "Location: $params->{loc}\r\n\r\n";
+ exit(0);
+}
+
+print "Content-Type: text/html\r\n";
+print "X-CGI-Pid: $$\r\n";
+print "X-Test-Header: X-Test-Value\r\n";
+print "\r\n";
+
+if ($params->{"bigresponse"}) {
+ # 17 MB, for OS X: golang.org/issue/4958
+ for (1..(17 * 1024)) {
+ print "A" x 1024, "\r\n";
+ }
+ exit 0;
+}
+
+print "test=Hello CGI\r\n";
+
+foreach my $k (sort keys %$params) {
+ print "param-$k=$params->{$k}\r\n";
+}
+
+foreach my $k (sort keys %ENV) {
+ my $clean_env = $ENV{$k};
+ $clean_env =~ s/[\n\r]//g;
+ print "env-$k=$clean_env\r\n";
+}
+
+# NOTE: msys perl returns /c/go/src/... not C:\go\....
+my $dir = getcwd();
+if ($^O eq 'MSWin32' || $^O eq 'msys') {
+ if ($dir =~ /^.:/) {
+ $dir =~ s!/!\\!g;
+ } else {
+ my $cmd = $ENV{'COMSPEC'} || 'c:\\windows\\system32\\cmd.exe';
+ $cmd =~ s!\\!/!g;
+ $dir = `$cmd /c cd`;
+ chomp $dir;
+ }
+}
+print "cwd=$dir\r\n";
+
+# A minimal version of CGI.pm, for people without the perl-modules
+# package installed. (CGI.pm used to be part of the Perl core, but
+# some distros now bundle perl-base and perl-modules separately...)
+package MiniCGI;
+
+sub new {
+ my $class = shift;
+ return bless {}, $class;
+}
+
+sub Vars {
+ my $self = shift;
+ my $pairs;
+ if ($ENV{CONTENT_LENGTH}) {
+ $pairs = do { local $/; <STDIN> };
+ } else {
+ $pairs = $ENV{QUERY_STRING};
+ }
+ my $vars = {};
+ foreach my $kv (split(/&/, $pairs)) {
+ my ($k, $v) = split(/=/, $kv, 2);
+ $vars->{_urldecode($k)} = _urldecode($v);
+ }
+ return $vars;
+}
+
+sub _urldecode {
+ my $v = shift;
+ $v =~ tr/+/ /;
+ $v =~ s/%([a-fA-F0-9][a-fA-F0-9])/pack("C", hex($1))/eg;
+ return $v;
+}
diff --git a/src/net/http/client.go b/src/net/http/client.go
new file mode 100644
index 000000000..ce884d1f0
--- /dev/null
+++ b/src/net/http/client.go
@@ -0,0 +1,511 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP client. See RFC 2616.
+//
+// This is the high-level Client interface.
+// The low-level implementation is in transport.go.
+
+package http
+
+import (
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "net/url"
+ "strings"
+ "sync"
+ "time"
+)
+
+// A Client is an HTTP client. Its zero value (DefaultClient) is a
+// usable client that uses DefaultTransport.
+//
+// The Client's Transport typically has internal state (cached TCP
+// connections), so Clients should be reused instead of created as
+// needed. Clients are safe for concurrent use by multiple goroutines.
+//
+// A Client is higher-level than a RoundTripper (such as Transport)
+// and additionally handles HTTP details such as cookies and
+// redirects.
+type Client struct {
+ // Transport specifies the mechanism by which individual
+ // HTTP requests are made.
+ // If nil, DefaultTransport is used.
+ Transport RoundTripper
+
+ // CheckRedirect specifies the policy for handling redirects.
+ // If CheckRedirect is not nil, the client calls it before
+ // following an HTTP redirect. The arguments req and via are
+ // the upcoming request and the requests made already, oldest
+ // first. If CheckRedirect returns an error, the Client's Get
+ // method returns both the previous Response and
+ // CheckRedirect's error (wrapped in a url.Error) instead of
+ // issuing the Request req.
+ //
+ // If CheckRedirect is nil, the Client uses its default policy,
+ // which is to stop after 10 consecutive requests.
+ CheckRedirect func(req *Request, via []*Request) error
+
+ // Jar specifies the cookie jar.
+ // If Jar is nil, cookies are not sent in requests and ignored
+ // in responses.
+ Jar CookieJar
+
+ // Timeout specifies a time limit for requests made by this
+ // Client. The timeout includes connection time, any
+ // redirects, and reading the response body. The timer remains
+ // running after Get, Head, Post, or Do return and will
+ // interrupt reading of the Response.Body.
+ //
+ // A Timeout of zero means no timeout.
+ //
+ // The Client's Transport must support the CancelRequest
+ // method or Client will return errors when attempting to make
+ // a request with Get, Head, Post, or Do. Client's default
+ // Transport (DefaultTransport) supports CancelRequest.
+ Timeout time.Duration
+}
+
+// DefaultClient is the default Client and is used by Get, Head, and Post.
+var DefaultClient = &Client{}
+
+// RoundTripper is an interface representing the ability to execute a
+// single HTTP transaction, obtaining the Response for a given Request.
+//
+// A RoundTripper must be safe for concurrent use by multiple
+// goroutines.
+type RoundTripper interface {
+ // RoundTrip executes a single HTTP transaction, returning
+ // the Response for the request req. RoundTrip should not
+ // attempt to interpret the response. In particular,
+ // RoundTrip must return err == nil if it obtained a response,
+ // regardless of the response's HTTP status code. A non-nil
+ // err should be reserved for failure to obtain a response.
+ // Similarly, RoundTrip should not attempt to handle
+ // higher-level protocol details such as redirects,
+ // authentication, or cookies.
+ //
+ // RoundTrip should not modify the request, except for
+ // consuming and closing the Body, including on errors. The
+ // request's URL and Header fields are guaranteed to be
+ // initialized.
+ RoundTrip(*Request) (*Response, error)
+}
+
+// Given a string of the form "host", "host:port", or "[ipv6::address]:port",
+// return true if the string includes a port.
+func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") }
+
+// refererForURL returns a referer without any authentication info or
+// an empty string if lastReq scheme is https and newReq scheme is http.
+func refererForURL(lastReq, newReq *url.URL) string {
+ // https://tools.ietf.org/html/rfc7231#section-5.5.2
+ // "Clients SHOULD NOT include a Referer header field in a
+ // (non-secure) HTTP request if the referring page was
+ // transferred with a secure protocol."
+ if lastReq.Scheme == "https" && newReq.Scheme == "http" {
+ return ""
+ }
+ referer := lastReq.String()
+ if lastReq.User != nil {
+ // This is not very efficient, but is the best we can
+ // do without:
+ // - introducing a new method on URL
+ // - creating a race condition
+ // - copying the URL struct manually, which would cause
+ // maintenance problems down the line
+ auth := lastReq.User.String() + "@"
+ referer = strings.Replace(referer, auth, "", 1)
+ }
+ return referer
+}
+
+// Used in Send to implement io.ReadCloser by bundling together the
+// bufio.Reader through which we read the response, and the underlying
+// network connection.
+type readClose struct {
+ io.Reader
+ io.Closer
+}
+
+func (c *Client) send(req *Request) (*Response, error) {
+ if c.Jar != nil {
+ for _, cookie := range c.Jar.Cookies(req.URL) {
+ req.AddCookie(cookie)
+ }
+ }
+ resp, err := send(req, c.transport())
+ if err != nil {
+ return nil, err
+ }
+ if c.Jar != nil {
+ if rc := resp.Cookies(); len(rc) > 0 {
+ c.Jar.SetCookies(req.URL, rc)
+ }
+ }
+ return resp, err
+}
+
+// Do sends an HTTP request and returns an HTTP response, following
+// policy (e.g. redirects, cookies, auth) as configured on the client.
+//
+// An error is returned if caused by client policy (such as
+// CheckRedirect), or if there was an HTTP protocol error.
+// A non-2xx response doesn't cause an error.
+//
+// When err is nil, resp always contains a non-nil resp.Body.
+//
+// Callers should close resp.Body when done reading from it. If
+// resp.Body is not closed, the Client's underlying RoundTripper
+// (typically Transport) may not be able to re-use a persistent TCP
+// connection to the server for a subsequent "keep-alive" request.
+//
+// The request Body, if non-nil, will be closed by the underlying
+// Transport, even on errors.
+//
+// Generally Get, Post, or PostForm will be used instead of Do.
+func (c *Client) Do(req *Request) (resp *Response, err error) {
+ if req.Method == "GET" || req.Method == "HEAD" {
+ return c.doFollowingRedirects(req, shouldRedirectGet)
+ }
+ if req.Method == "POST" || req.Method == "PUT" {
+ return c.doFollowingRedirects(req, shouldRedirectPost)
+ }
+ return c.send(req)
+}
+
+func (c *Client) transport() RoundTripper {
+ if c.Transport != nil {
+ return c.Transport
+ }
+ return DefaultTransport
+}
+
+// send issues an HTTP request.
+// Caller should close resp.Body when done reading from it.
+func send(req *Request, t RoundTripper) (resp *Response, err error) {
+ if t == nil {
+ req.closeBody()
+ return nil, errors.New("http: no Client.Transport or DefaultTransport")
+ }
+
+ if req.URL == nil {
+ req.closeBody()
+ return nil, errors.New("http: nil Request.URL")
+ }
+
+ if req.RequestURI != "" {
+ req.closeBody()
+ return nil, errors.New("http: Request.RequestURI can't be set in client requests.")
+ }
+
+ // Most the callers of send (Get, Post, et al) don't need
+ // Headers, leaving it uninitialized. We guarantee to the
+ // Transport that this has been initialized, though.
+ if req.Header == nil {
+ req.Header = make(Header)
+ }
+
+ if u := req.URL.User; u != nil {
+ username := u.Username()
+ password, _ := u.Password()
+ req.Header.Set("Authorization", "Basic "+basicAuth(username, password))
+ }
+ resp, err = t.RoundTrip(req)
+ if err != nil {
+ if resp != nil {
+ log.Printf("RoundTripper returned a response & error; ignoring response")
+ }
+ return nil, err
+ }
+ return resp, nil
+}
+
+// See 2 (end of page 4) http://www.ietf.org/rfc/rfc2617.txt
+// "To receive authorization, the client sends the userid and password,
+// separated by a single colon (":") character, within a base64
+// encoded string in the credentials."
+// It is not meant to be urlencoded.
+func basicAuth(username, password string) string {
+ auth := username + ":" + password
+ return base64.StdEncoding.EncodeToString([]byte(auth))
+}
+
+// True if the specified HTTP status code is one for which the Get utility should
+// automatically redirect.
+func shouldRedirectGet(statusCode int) bool {
+ switch statusCode {
+ case StatusMovedPermanently, StatusFound, StatusSeeOther, StatusTemporaryRedirect:
+ return true
+ }
+ return false
+}
+
+// True if the specified HTTP status code is one for which the Post utility should
+// automatically redirect.
+func shouldRedirectPost(statusCode int) bool {
+ switch statusCode {
+ case StatusFound, StatusSeeOther:
+ return true
+ }
+ return false
+}
+
+// Get issues a GET to the specified URL. If the response is one of the following
+// redirect codes, Get follows the redirect, up to a maximum of 10 redirects:
+//
+// 301 (Moved Permanently)
+// 302 (Found)
+// 303 (See Other)
+// 307 (Temporary Redirect)
+//
+// An error is returned if there were too many redirects or if there
+// was an HTTP protocol error. A non-2xx response doesn't cause an
+// error.
+//
+// When err is nil, resp always contains a non-nil resp.Body.
+// Caller should close resp.Body when done reading from it.
+//
+// Get is a wrapper around DefaultClient.Get.
+func Get(url string) (resp *Response, err error) {
+ return DefaultClient.Get(url)
+}
+
+// Get issues a GET to the specified URL. If the response is one of the
+// following redirect codes, Get follows the redirect after calling the
+// Client's CheckRedirect function.
+//
+// 301 (Moved Permanently)
+// 302 (Found)
+// 303 (See Other)
+// 307 (Temporary Redirect)
+//
+// An error is returned if the Client's CheckRedirect function fails
+// or if there was an HTTP protocol error. A non-2xx response doesn't
+// cause an error.
+//
+// When err is nil, resp always contains a non-nil resp.Body.
+// Caller should close resp.Body when done reading from it.
+func (c *Client) Get(url string) (resp *Response, err error) {
+ req, err := NewRequest("GET", url, nil)
+ if err != nil {
+ return nil, err
+ }
+ return c.doFollowingRedirects(req, shouldRedirectGet)
+}
+
+func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bool) (resp *Response, err error) {
+ var base *url.URL
+ redirectChecker := c.CheckRedirect
+ if redirectChecker == nil {
+ redirectChecker = defaultCheckRedirect
+ }
+ var via []*Request
+
+ if ireq.URL == nil {
+ ireq.closeBody()
+ return nil, errors.New("http: nil Request.URL")
+ }
+
+ var reqmu sync.Mutex // guards req
+ req := ireq
+
+ var timer *time.Timer
+ if c.Timeout > 0 {
+ type canceler interface {
+ CancelRequest(*Request)
+ }
+ tr, ok := c.transport().(canceler)
+ if !ok {
+ return nil, fmt.Errorf("net/http: Client Transport of type %T doesn't support CancelRequest; Timeout not supported", c.transport())
+ }
+ timer = time.AfterFunc(c.Timeout, func() {
+ reqmu.Lock()
+ defer reqmu.Unlock()
+ tr.CancelRequest(req)
+ })
+ }
+
+ urlStr := "" // next relative or absolute URL to fetch (after first request)
+ redirectFailed := false
+ for redirect := 0; ; redirect++ {
+ if redirect != 0 {
+ nreq := new(Request)
+ nreq.Method = ireq.Method
+ if ireq.Method == "POST" || ireq.Method == "PUT" {
+ nreq.Method = "GET"
+ }
+ nreq.Header = make(Header)
+ nreq.URL, err = base.Parse(urlStr)
+ if err != nil {
+ break
+ }
+ if len(via) > 0 {
+ // Add the Referer header.
+ lastReq := via[len(via)-1]
+ if ref := refererForURL(lastReq.URL, nreq.URL); ref != "" {
+ nreq.Header.Set("Referer", ref)
+ }
+
+ err = redirectChecker(nreq, via)
+ if err != nil {
+ redirectFailed = true
+ break
+ }
+ }
+ reqmu.Lock()
+ req = nreq
+ reqmu.Unlock()
+ }
+
+ urlStr = req.URL.String()
+ if resp, err = c.send(req); err != nil {
+ break
+ }
+
+ if shouldRedirect(resp.StatusCode) {
+ // Read the body if small so underlying TCP connection will be re-used.
+ // No need to check for errors: if it fails, Transport won't reuse it anyway.
+ const maxBodySlurpSize = 2 << 10
+ if resp.ContentLength == -1 || resp.ContentLength <= maxBodySlurpSize {
+ io.CopyN(ioutil.Discard, resp.Body, maxBodySlurpSize)
+ }
+ resp.Body.Close()
+ if urlStr = resp.Header.Get("Location"); urlStr == "" {
+ err = errors.New(fmt.Sprintf("%d response missing Location header", resp.StatusCode))
+ break
+ }
+ base = req.URL
+ via = append(via, req)
+ continue
+ }
+ if timer != nil {
+ resp.Body = &cancelTimerBody{timer, resp.Body}
+ }
+ return resp, nil
+ }
+
+ method := ireq.Method
+ urlErr := &url.Error{
+ Op: method[0:1] + strings.ToLower(method[1:]),
+ URL: urlStr,
+ Err: err,
+ }
+
+ if redirectFailed {
+ // Special case for Go 1 compatibility: return both the response
+ // and an error if the CheckRedirect function failed.
+ // See http://golang.org/issue/3795
+ return resp, urlErr
+ }
+
+ if resp != nil {
+ resp.Body.Close()
+ }
+ return nil, urlErr
+}
+
+func defaultCheckRedirect(req *Request, via []*Request) error {
+ if len(via) >= 10 {
+ return errors.New("stopped after 10 redirects")
+ }
+ return nil
+}
+
+// Post issues a POST to the specified URL.
+//
+// Caller should close resp.Body when done reading from it.
+//
+// Post is a wrapper around DefaultClient.Post
+func Post(url string, bodyType string, body io.Reader) (resp *Response, err error) {
+ return DefaultClient.Post(url, bodyType, body)
+}
+
+// Post issues a POST to the specified URL.
+//
+// Caller should close resp.Body when done reading from it.
+//
+// If the provided body is also an io.Closer, it is closed after the
+// request.
+func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Response, err error) {
+ req, err := NewRequest("POST", url, body)
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", bodyType)
+ return c.doFollowingRedirects(req, shouldRedirectPost)
+}
+
+// PostForm issues a POST to the specified URL, with data's keys and
+// values URL-encoded as the request body.
+//
+// When err is nil, resp always contains a non-nil resp.Body.
+// Caller should close resp.Body when done reading from it.
+//
+// PostForm is a wrapper around DefaultClient.PostForm
+func PostForm(url string, data url.Values) (resp *Response, err error) {
+ return DefaultClient.PostForm(url, data)
+}
+
+// PostForm issues a POST to the specified URL,
+// with data's keys and values urlencoded as the request body.
+//
+// When err is nil, resp always contains a non-nil resp.Body.
+// Caller should close resp.Body when done reading from it.
+func (c *Client) PostForm(url string, data url.Values) (resp *Response, err error) {
+ return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
+}
+
+// Head issues a HEAD to the specified URL. If the response is one of the
+// following redirect codes, Head follows the redirect after calling the
+// Client's CheckRedirect function.
+//
+// 301 (Moved Permanently)
+// 302 (Found)
+// 303 (See Other)
+// 307 (Temporary Redirect)
+//
+// Head is a wrapper around DefaultClient.Head
+func Head(url string) (resp *Response, err error) {
+ return DefaultClient.Head(url)
+}
+
+// Head issues a HEAD to the specified URL. If the response is one of the
+// following redirect codes, Head follows the redirect after calling the
+// Client's CheckRedirect function.
+//
+// 301 (Moved Permanently)
+// 302 (Found)
+// 303 (See Other)
+// 307 (Temporary Redirect)
+func (c *Client) Head(url string) (resp *Response, err error) {
+ req, err := NewRequest("HEAD", url, nil)
+ if err != nil {
+ return nil, err
+ }
+ return c.doFollowingRedirects(req, shouldRedirectGet)
+}
+
+type cancelTimerBody struct {
+ t *time.Timer
+ rc io.ReadCloser
+}
+
+func (b *cancelTimerBody) Read(p []byte) (n int, err error) {
+ n, err = b.rc.Read(p)
+ if err == io.EOF {
+ b.t.Stop()
+ }
+ return
+}
+
+func (b *cancelTimerBody) Close() error {
+ err := b.rc.Close()
+ b.t.Stop()
+ return err
+}
diff --git a/src/net/http/client_test.go b/src/net/http/client_test.go
new file mode 100644
index 000000000..56b6563c4
--- /dev/null
+++ b/src/net/http/client_test.go
@@ -0,0 +1,1075 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests for client.go
+
+package http_test
+
+import (
+ "bytes"
+ "crypto/tls"
+ "crypto/x509"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "net"
+ . "net/http"
+ "net/http/httptest"
+ "net/url"
+ "reflect"
+ "sort"
+ "strconv"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+)
+
+var robotsTxtHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Last-Modified", "sometime")
+ fmt.Fprintf(w, "User-agent: go\nDisallow: /something/")
+})
+
+// pedanticReadAll works like ioutil.ReadAll but additionally
+// verifies that r obeys the documented io.Reader contract.
+func pedanticReadAll(r io.Reader) (b []byte, err error) {
+ var bufa [64]byte
+ buf := bufa[:]
+ for {
+ n, err := r.Read(buf)
+ if n == 0 && err == nil {
+ return nil, fmt.Errorf("Read: n=0 with err=nil")
+ }
+ b = append(b, buf[:n]...)
+ if err == io.EOF {
+ n, err := r.Read(buf)
+ if n != 0 || err != io.EOF {
+ return nil, fmt.Errorf("Read: n=%d err=%#v after EOF", n, err)
+ }
+ return b, nil
+ }
+ if err != nil {
+ return b, err
+ }
+ }
+}
+
+type chanWriter chan string
+
+func (w chanWriter) Write(p []byte) (n int, err error) {
+ w <- string(p)
+ return len(p), nil
+}
+
+func TestClient(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(robotsTxtHandler)
+ defer ts.Close()
+
+ r, err := Get(ts.URL)
+ var b []byte
+ if err == nil {
+ b, err = pedanticReadAll(r.Body)
+ r.Body.Close()
+ }
+ if err != nil {
+ t.Error(err)
+ } else if s := string(b); !strings.HasPrefix(s, "User-agent:") {
+ t.Errorf("Incorrect page body (did not begin with User-agent): %q", s)
+ }
+}
+
+func TestClientHead(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(robotsTxtHandler)
+ defer ts.Close()
+
+ r, err := Head(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, ok := r.Header["Last-Modified"]; !ok {
+ t.Error("Last-Modified header not found.")
+ }
+}
+
+type recordingTransport struct {
+ req *Request
+}
+
+func (t *recordingTransport) RoundTrip(req *Request) (resp *Response, err error) {
+ t.req = req
+ return nil, errors.New("dummy impl")
+}
+
+func TestGetRequestFormat(t *testing.T) {
+ defer afterTest(t)
+ tr := &recordingTransport{}
+ client := &Client{Transport: tr}
+ url := "http://dummy.faketld/"
+ client.Get(url) // Note: doesn't hit network
+ if tr.req.Method != "GET" {
+ t.Errorf("expected method %q; got %q", "GET", tr.req.Method)
+ }
+ if tr.req.URL.String() != url {
+ t.Errorf("expected URL %q; got %q", url, tr.req.URL.String())
+ }
+ if tr.req.Header == nil {
+ t.Errorf("expected non-nil request Header")
+ }
+}
+
+func TestPostRequestFormat(t *testing.T) {
+ defer afterTest(t)
+ tr := &recordingTransport{}
+ client := &Client{Transport: tr}
+
+ url := "http://dummy.faketld/"
+ json := `{"key":"value"}`
+ b := strings.NewReader(json)
+ client.Post(url, "application/json", b) // Note: doesn't hit network
+
+ if tr.req.Method != "POST" {
+ t.Errorf("got method %q, want %q", tr.req.Method, "POST")
+ }
+ if tr.req.URL.String() != url {
+ t.Errorf("got URL %q, want %q", tr.req.URL.String(), url)
+ }
+ if tr.req.Header == nil {
+ t.Fatalf("expected non-nil request Header")
+ }
+ if tr.req.Close {
+ t.Error("got Close true, want false")
+ }
+ if g, e := tr.req.ContentLength, int64(len(json)); g != e {
+ t.Errorf("got ContentLength %d, want %d", g, e)
+ }
+}
+
+func TestPostFormRequestFormat(t *testing.T) {
+ defer afterTest(t)
+ tr := &recordingTransport{}
+ client := &Client{Transport: tr}
+
+ urlStr := "http://dummy.faketld/"
+ form := make(url.Values)
+ form.Set("foo", "bar")
+ form.Add("foo", "bar2")
+ form.Set("bar", "baz")
+ client.PostForm(urlStr, form) // Note: doesn't hit network
+
+ if tr.req.Method != "POST" {
+ t.Errorf("got method %q, want %q", tr.req.Method, "POST")
+ }
+ if tr.req.URL.String() != urlStr {
+ t.Errorf("got URL %q, want %q", tr.req.URL.String(), urlStr)
+ }
+ if tr.req.Header == nil {
+ t.Fatalf("expected non-nil request Header")
+ }
+ if g, e := tr.req.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; g != e {
+ t.Errorf("got Content-Type %q, want %q", g, e)
+ }
+ if tr.req.Close {
+ t.Error("got Close true, want false")
+ }
+ // Depending on map iteration, body can be either of these.
+ expectedBody := "foo=bar&foo=bar2&bar=baz"
+ expectedBody1 := "bar=baz&foo=bar&foo=bar2"
+ if g, e := tr.req.ContentLength, int64(len(expectedBody)); g != e {
+ t.Errorf("got ContentLength %d, want %d", g, e)
+ }
+ bodyb, err := ioutil.ReadAll(tr.req.Body)
+ if err != nil {
+ t.Fatalf("ReadAll on req.Body: %v", err)
+ }
+ if g := string(bodyb); g != expectedBody && g != expectedBody1 {
+ t.Errorf("got body %q, want %q or %q", g, expectedBody, expectedBody1)
+ }
+}
+
+func TestClientRedirects(t *testing.T) {
+ defer afterTest(t)
+ var ts *httptest.Server
+ ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ n, _ := strconv.Atoi(r.FormValue("n"))
+ // Test Referer header. (7 is arbitrary position to test at)
+ if n == 7 {
+ if g, e := r.Referer(), ts.URL+"/?n=6"; e != g {
+ t.Errorf("on request ?n=7, expected referer of %q; got %q", e, g)
+ }
+ }
+ if n < 15 {
+ Redirect(w, r, fmt.Sprintf("/?n=%d", n+1), StatusFound)
+ return
+ }
+ fmt.Fprintf(w, "n=%d", n)
+ }))
+ defer ts.Close()
+
+ c := &Client{}
+ _, err := c.Get(ts.URL)
+ if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g {
+ t.Errorf("with default client Get, expected error %q, got %q", e, g)
+ }
+
+ // HEAD request should also have the ability to follow redirects.
+ _, err = c.Head(ts.URL)
+ if e, g := "Head /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g {
+ t.Errorf("with default client Head, expected error %q, got %q", e, g)
+ }
+
+ // Do should also follow redirects.
+ greq, _ := NewRequest("GET", ts.URL, nil)
+ _, err = c.Do(greq)
+ if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g {
+ t.Errorf("with default client Do, expected error %q, got %q", e, g)
+ }
+
+ var checkErr error
+ var lastVia []*Request
+ c = &Client{CheckRedirect: func(_ *Request, via []*Request) error {
+ lastVia = via
+ return checkErr
+ }}
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatalf("Get error: %v", err)
+ }
+ res.Body.Close()
+ finalUrl := res.Request.URL.String()
+ if e, g := "<nil>", fmt.Sprintf("%v", err); e != g {
+ t.Errorf("with custom client, expected error %q, got %q", e, g)
+ }
+ if !strings.HasSuffix(finalUrl, "/?n=15") {
+ t.Errorf("expected final url to end in /?n=15; got url %q", finalUrl)
+ }
+ if e, g := 15, len(lastVia); e != g {
+ t.Errorf("expected lastVia to have contained %d elements; got %d", e, g)
+ }
+
+ checkErr = errors.New("no redirects allowed")
+ res, err = c.Get(ts.URL)
+ if urlError, ok := err.(*url.Error); !ok || urlError.Err != checkErr {
+ t.Errorf("with redirects forbidden, expected a *url.Error with our 'no redirects allowed' error inside; got %#v (%q)", err, err)
+ }
+ if res == nil {
+ t.Fatalf("Expected a non-nil Response on CheckRedirect failure (http://golang.org/issue/3795)")
+ }
+ res.Body.Close()
+ if res.Header.Get("Location") == "" {
+ t.Errorf("no Location header in Response")
+ }
+}
+
+func TestPostRedirects(t *testing.T) {
+ defer afterTest(t)
+ var log struct {
+ sync.Mutex
+ bytes.Buffer
+ }
+ var ts *httptest.Server
+ ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ log.Lock()
+ fmt.Fprintf(&log.Buffer, "%s %s ", r.Method, r.RequestURI)
+ log.Unlock()
+ if v := r.URL.Query().Get("code"); v != "" {
+ code, _ := strconv.Atoi(v)
+ if code/100 == 3 {
+ w.Header().Set("Location", ts.URL)
+ }
+ w.WriteHeader(code)
+ }
+ }))
+ defer ts.Close()
+ tests := []struct {
+ suffix string
+ want int // response code
+ }{
+ {"/", 200},
+ {"/?code=301", 301},
+ {"/?code=302", 200},
+ {"/?code=303", 200},
+ {"/?code=404", 404},
+ }
+ for _, tt := range tests {
+ res, err := Post(ts.URL+tt.suffix, "text/plain", strings.NewReader("Some content"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != tt.want {
+ t.Errorf("POST %s: status code = %d; want %d", tt.suffix, res.StatusCode, tt.want)
+ }
+ }
+ log.Lock()
+ got := log.String()
+ log.Unlock()
+ want := "POST / POST /?code=301 POST /?code=302 GET / POST /?code=303 GET / POST /?code=404 "
+ if got != want {
+ t.Errorf("Log differs.\n Got: %q\nWant: %q", got, want)
+ }
+}
+
+var expectedCookies = []*Cookie{
+ {Name: "ChocolateChip", Value: "tasty"},
+ {Name: "First", Value: "Hit"},
+ {Name: "Second", Value: "Hit"},
+}
+
+var echoCookiesRedirectHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
+ for _, cookie := range r.Cookies() {
+ SetCookie(w, cookie)
+ }
+ if r.URL.Path == "/" {
+ SetCookie(w, expectedCookies[1])
+ Redirect(w, r, "/second", StatusMovedPermanently)
+ } else {
+ SetCookie(w, expectedCookies[2])
+ w.Write([]byte("hello"))
+ }
+})
+
+func TestClientSendsCookieFromJar(t *testing.T) {
+ tr := &recordingTransport{}
+ client := &Client{Transport: tr}
+ client.Jar = &TestJar{perURL: make(map[string][]*Cookie)}
+ us := "http://dummy.faketld/"
+ u, _ := url.Parse(us)
+ client.Jar.SetCookies(u, expectedCookies)
+
+ client.Get(us) // Note: doesn't hit network
+ matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
+
+ client.Head(us) // Note: doesn't hit network
+ matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
+
+ client.Post(us, "text/plain", strings.NewReader("body")) // Note: doesn't hit network
+ matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
+
+ client.PostForm(us, url.Values{}) // Note: doesn't hit network
+ matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
+
+ req, _ := NewRequest("GET", us, nil)
+ client.Do(req) // Note: doesn't hit network
+ matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
+
+ req, _ = NewRequest("POST", us, nil)
+ client.Do(req) // Note: doesn't hit network
+ matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
+}
+
+// Just enough correctness for our redirect tests. Uses the URL.Host as the
+// scope of all cookies.
+type TestJar struct {
+ m sync.Mutex
+ perURL map[string][]*Cookie
+}
+
+func (j *TestJar) SetCookies(u *url.URL, cookies []*Cookie) {
+ j.m.Lock()
+ defer j.m.Unlock()
+ if j.perURL == nil {
+ j.perURL = make(map[string][]*Cookie)
+ }
+ j.perURL[u.Host] = cookies
+}
+
+func (j *TestJar) Cookies(u *url.URL) []*Cookie {
+ j.m.Lock()
+ defer j.m.Unlock()
+ return j.perURL[u.Host]
+}
+
+func TestRedirectCookiesJar(t *testing.T) {
+ defer afterTest(t)
+ var ts *httptest.Server
+ ts = httptest.NewServer(echoCookiesRedirectHandler)
+ defer ts.Close()
+ c := &Client{
+ Jar: new(TestJar),
+ }
+ u, _ := url.Parse(ts.URL)
+ c.Jar.SetCookies(u, []*Cookie{expectedCookies[0]})
+ resp, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ resp.Body.Close()
+ matchReturnedCookies(t, expectedCookies, resp.Cookies())
+}
+
+func matchReturnedCookies(t *testing.T, expected, given []*Cookie) {
+ if len(given) != len(expected) {
+ t.Logf("Received cookies: %v", given)
+ t.Errorf("Expected %d cookies, got %d", len(expected), len(given))
+ }
+ for _, ec := range expected {
+ foundC := false
+ for _, c := range given {
+ if ec.Name == c.Name && ec.Value == c.Value {
+ foundC = true
+ break
+ }
+ }
+ if !foundC {
+ t.Errorf("Missing cookie %v", ec)
+ }
+ }
+}
+
+func TestJarCalls(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ pathSuffix := r.RequestURI[1:]
+ if r.RequestURI == "/nosetcookie" {
+ return // dont set cookies for this path
+ }
+ SetCookie(w, &Cookie{Name: "name" + pathSuffix, Value: "val" + pathSuffix})
+ if r.RequestURI == "/" {
+ Redirect(w, r, "http://secondhost.fake/secondpath", 302)
+ }
+ }))
+ defer ts.Close()
+ jar := new(RecordingJar)
+ c := &Client{
+ Jar: jar,
+ Transport: &Transport{
+ Dial: func(_ string, _ string) (net.Conn, error) {
+ return net.Dial("tcp", ts.Listener.Addr().String())
+ },
+ },
+ }
+ _, err := c.Get("http://firsthost.fake/")
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = c.Get("http://firsthost.fake/nosetcookie")
+ if err != nil {
+ t.Fatal(err)
+ }
+ got := jar.log.String()
+ want := `Cookies("http://firsthost.fake/")
+SetCookie("http://firsthost.fake/", [name=val])
+Cookies("http://secondhost.fake/secondpath")
+SetCookie("http://secondhost.fake/secondpath", [namesecondpath=valsecondpath])
+Cookies("http://firsthost.fake/nosetcookie")
+`
+ if got != want {
+ t.Errorf("Got Jar calls:\n%s\nWant:\n%s", got, want)
+ }
+}
+
+// RecordingJar keeps a log of calls made to it, without
+// tracking any cookies.
+type RecordingJar struct {
+ mu sync.Mutex
+ log bytes.Buffer
+}
+
+func (j *RecordingJar) SetCookies(u *url.URL, cookies []*Cookie) {
+ j.logf("SetCookie(%q, %v)\n", u, cookies)
+}
+
+func (j *RecordingJar) Cookies(u *url.URL) []*Cookie {
+ j.logf("Cookies(%q)\n", u)
+ return nil
+}
+
+func (j *RecordingJar) logf(format string, args ...interface{}) {
+ j.mu.Lock()
+ defer j.mu.Unlock()
+ fmt.Fprintf(&j.log, format, args...)
+}
+
+func TestStreamingGet(t *testing.T) {
+ defer afterTest(t)
+ say := make(chan string)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.(Flusher).Flush()
+ for str := range say {
+ w.Write([]byte(str))
+ w.(Flusher).Flush()
+ }
+ }))
+ defer ts.Close()
+
+ c := &Client{}
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ var buf [10]byte
+ for _, str := range []string{"i", "am", "also", "known", "as", "comet"} {
+ say <- str
+ n, err := io.ReadFull(res.Body, buf[0:len(str)])
+ if err != nil {
+ t.Fatalf("ReadFull on %q: %v", str, err)
+ }
+ if n != len(str) {
+ t.Fatalf("Receiving %q, only read %d bytes", str, n)
+ }
+ got := string(buf[0:n])
+ if got != str {
+ t.Fatalf("Expected %q, got %q", str, got)
+ }
+ }
+ close(say)
+ _, err = io.ReadFull(res.Body, buf[0:1])
+ if err != io.EOF {
+ t.Fatalf("at end expected EOF, got %v", err)
+ }
+}
+
+type writeCountingConn struct {
+ net.Conn
+ count *int
+}
+
+func (c *writeCountingConn) Write(p []byte) (int, error) {
+ *c.count++
+ return c.Conn.Write(p)
+}
+
+// TestClientWrites verifies that client requests are buffered and we
+// don't send a TCP packet per line of the http request + body.
+func TestClientWrites(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ }))
+ defer ts.Close()
+
+ writes := 0
+ dialer := func(netz string, addr string) (net.Conn, error) {
+ c, err := net.Dial(netz, addr)
+ if err == nil {
+ c = &writeCountingConn{c, &writes}
+ }
+ return c, err
+ }
+ c := &Client{Transport: &Transport{Dial: dialer}}
+
+ _, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if writes != 1 {
+ t.Errorf("Get request did %d Write calls, want 1", writes)
+ }
+
+ writes = 0
+ _, err = c.PostForm(ts.URL, url.Values{"foo": {"bar"}})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if writes != 1 {
+ t.Errorf("Post request did %d Write calls, want 1", writes)
+ }
+}
+
+func TestClientInsecureTransport(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Write([]byte("Hello"))
+ }))
+ errc := make(chanWriter, 10) // but only expecting 1
+ ts.Config.ErrorLog = log.New(errc, "", 0)
+ defer ts.Close()
+
+ // TODO(bradfitz): add tests for skipping hostname checks too?
+ // would require a new cert for testing, and probably
+ // redundant with these tests.
+ for _, insecure := range []bool{true, false} {
+ tr := &Transport{
+ TLSClientConfig: &tls.Config{
+ InsecureSkipVerify: insecure,
+ },
+ }
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+ res, err := c.Get(ts.URL)
+ if (err == nil) != insecure {
+ t.Errorf("insecure=%v: got unexpected err=%v", insecure, err)
+ }
+ if res != nil {
+ res.Body.Close()
+ }
+ }
+
+ select {
+ case v := <-errc:
+ if !strings.Contains(v, "TLS handshake error") {
+ t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v)
+ }
+ case <-time.After(5 * time.Second):
+ t.Errorf("timeout waiting for logged error")
+ }
+
+}
+
+func TestClientErrorWithRequestURI(t *testing.T) {
+ defer afterTest(t)
+ req, _ := NewRequest("GET", "http://localhost:1234/", nil)
+ req.RequestURI = "/this/field/is/illegal/and/should/error/"
+ _, err := DefaultClient.Do(req)
+ if err == nil {
+ t.Fatalf("expected an error")
+ }
+ if !strings.Contains(err.Error(), "RequestURI") {
+ t.Errorf("wanted error mentioning RequestURI; got error: %v", err)
+ }
+}
+
+func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport {
+ certs := x509.NewCertPool()
+ for _, c := range ts.TLS.Certificates {
+ roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
+ if err != nil {
+ t.Fatalf("error parsing server's root cert: %v", err)
+ }
+ for _, root := range roots {
+ certs.AddCert(root)
+ }
+ }
+ return &Transport{
+ TLSClientConfig: &tls.Config{RootCAs: certs},
+ }
+}
+
+func TestClientWithCorrectTLSServerName(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.TLS.ServerName != "127.0.0.1" {
+ t.Errorf("expected client to set ServerName 127.0.0.1, got: %q", r.TLS.ServerName)
+ }
+ }))
+ defer ts.Close()
+
+ c := &Client{Transport: newTLSTransport(t, ts)}
+ if _, err := c.Get(ts.URL); err != nil {
+ t.Fatalf("expected successful TLS connection, got error: %v", err)
+ }
+}
+
+func TestClientWithIncorrectTLSServerName(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
+ defer ts.Close()
+ errc := make(chanWriter, 10) // but only expecting 1
+ ts.Config.ErrorLog = log.New(errc, "", 0)
+
+ trans := newTLSTransport(t, ts)
+ trans.TLSClientConfig.ServerName = "badserver"
+ c := &Client{Transport: trans}
+ _, err := c.Get(ts.URL)
+ if err == nil {
+ t.Fatalf("expected an error")
+ }
+ if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") {
+ t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err)
+ }
+ select {
+ case v := <-errc:
+ if !strings.Contains(v, "TLS handshake error") {
+ t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v)
+ }
+ case <-time.After(5 * time.Second):
+ t.Errorf("timeout waiting for logged error")
+ }
+}
+
+// Test for golang.org/issue/5829; the Transport should respect TLSClientConfig.ServerName
+// when not empty.
+//
+// tls.Config.ServerName (non-empty, set to "example.com") takes
+// precedence over "some-other-host.tld" which previously incorrectly
+// took precedence. We don't actually connect to (or even resolve)
+// "some-other-host.tld", though, because of the Transport.Dial hook.
+//
+// The httptest.Server has a cert with "example.com" as its name.
+func TestTransportUsesTLSConfigServerName(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Write([]byte("Hello"))
+ }))
+ defer ts.Close()
+
+ tr := newTLSTransport(t, ts)
+ tr.TLSClientConfig.ServerName = "example.com" // one of httptest's Server cert names
+ tr.Dial = func(netw, addr string) (net.Conn, error) {
+ return net.Dial(netw, ts.Listener.Addr().String())
+ }
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+ res, err := c.Get("https://some-other-host.tld/")
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+}
+
+func TestResponseSetsTLSConnectionState(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Write([]byte("Hello"))
+ }))
+ defer ts.Close()
+
+ tr := newTLSTransport(t, ts)
+ tr.TLSClientConfig.CipherSuites = []uint16{tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA}
+ tr.Dial = func(netw, addr string) (net.Conn, error) {
+ return net.Dial(netw, ts.Listener.Addr().String())
+ }
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+ res, err := c.Get("https://example.com/")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.TLS == nil {
+ t.Fatal("Response didn't set TLS Connection State.")
+ }
+ if got, want := res.TLS.CipherSuite, tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA; got != want {
+ t.Errorf("TLS Cipher Suite = %d; want %d", got, want)
+ }
+}
+
+// Verify Response.ContentLength is populated. http://golang.org/issue/4126
+func TestClientHeadContentLength(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if v := r.FormValue("cl"); v != "" {
+ w.Header().Set("Content-Length", v)
+ }
+ }))
+ defer ts.Close()
+ tests := []struct {
+ suffix string
+ want int64
+ }{
+ {"/?cl=1234", 1234},
+ {"/?cl=0", 0},
+ {"", -1},
+ }
+ for _, tt := range tests {
+ req, _ := NewRequest("HEAD", ts.URL+tt.suffix, nil)
+ res, err := DefaultClient.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.ContentLength != tt.want {
+ t.Errorf("Content-Length = %d; want %d", res.ContentLength, tt.want)
+ }
+ bs, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(bs) != 0 {
+ t.Errorf("Unexpected content: %q", bs)
+ }
+ }
+}
+
+func TestEmptyPasswordAuth(t *testing.T) {
+ defer afterTest(t)
+ gopher := "gopher"
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ auth := r.Header.Get("Authorization")
+ if strings.HasPrefix(auth, "Basic ") {
+ encoded := auth[6:]
+ decoded, err := base64.StdEncoding.DecodeString(encoded)
+ if err != nil {
+ t.Fatal(err)
+ }
+ expected := gopher + ":"
+ s := string(decoded)
+ if expected != s {
+ t.Errorf("Invalid Authorization header. Got %q, wanted %q", s, expected)
+ }
+ } else {
+ t.Errorf("Invalid auth %q", auth)
+ }
+ }))
+ defer ts.Close()
+ c := &Client{}
+ req, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.URL.User = url.User(gopher)
+ resp, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer resp.Body.Close()
+}
+
+func TestBasicAuth(t *testing.T) {
+ defer afterTest(t)
+ tr := &recordingTransport{}
+ client := &Client{Transport: tr}
+
+ url := "http://My%20User:My%20Pass@dummy.faketld/"
+ expected := "My User:My Pass"
+ client.Get(url)
+
+ if tr.req.Method != "GET" {
+ t.Errorf("got method %q, want %q", tr.req.Method, "GET")
+ }
+ if tr.req.URL.String() != url {
+ t.Errorf("got URL %q, want %q", tr.req.URL.String(), url)
+ }
+ if tr.req.Header == nil {
+ t.Fatalf("expected non-nil request Header")
+ }
+ auth := tr.req.Header.Get("Authorization")
+ if strings.HasPrefix(auth, "Basic ") {
+ encoded := auth[6:]
+ decoded, err := base64.StdEncoding.DecodeString(encoded)
+ if err != nil {
+ t.Fatal(err)
+ }
+ s := string(decoded)
+ if expected != s {
+ t.Errorf("Invalid Authorization header. Got %q, wanted %q", s, expected)
+ }
+ } else {
+ t.Errorf("Invalid auth %q", auth)
+ }
+}
+
+func TestClientTimeout(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ defer afterTest(t)
+ sawRoot := make(chan bool, 1)
+ sawSlow := make(chan bool, 1)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.URL.Path == "/" {
+ sawRoot <- true
+ Redirect(w, r, "/slow", StatusFound)
+ return
+ }
+ if r.URL.Path == "/slow" {
+ w.Write([]byte("Hello"))
+ w.(Flusher).Flush()
+ sawSlow <- true
+ time.Sleep(2 * time.Second)
+ return
+ }
+ }))
+ defer ts.Close()
+ const timeout = 500 * time.Millisecond
+ c := &Client{
+ Timeout: timeout,
+ }
+
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ select {
+ case <-sawRoot:
+ // good.
+ default:
+ t.Fatal("handler never got / request")
+ }
+
+ select {
+ case <-sawSlow:
+ // good.
+ default:
+ t.Fatal("handler never got /slow request")
+ }
+
+ errc := make(chan error, 1)
+ go func() {
+ _, err := ioutil.ReadAll(res.Body)
+ errc <- err
+ res.Body.Close()
+ }()
+
+ const failTime = timeout * 2
+ select {
+ case err := <-errc:
+ if err == nil {
+ t.Error("expected error from ReadAll")
+ }
+ // Expected error.
+ case <-time.After(failTime):
+ t.Errorf("timeout after %v waiting for timeout of %v", failTime, timeout)
+ }
+}
+
+func TestClientRedirectEatsBody(t *testing.T) {
+ defer afterTest(t)
+ saw := make(chan string, 2)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ saw <- r.RemoteAddr
+ if r.URL.Path == "/" {
+ Redirect(w, r, "/foo", StatusFound) // which includes a body
+ }
+ }))
+ defer ts.Close()
+
+ res, err := Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+
+ var first string
+ select {
+ case first = <-saw:
+ default:
+ t.Fatal("server didn't see a request")
+ }
+
+ var second string
+ select {
+ case second = <-saw:
+ default:
+ t.Fatal("server didn't see a second request")
+ }
+
+ if first != second {
+ t.Fatal("server saw different client ports before & after the redirect")
+ }
+}
+
+// eofReaderFunc is an io.Reader that runs itself, and then returns io.EOF.
+type eofReaderFunc func()
+
+func (f eofReaderFunc) Read(p []byte) (n int, err error) {
+ f()
+ return 0, io.EOF
+}
+
+func TestClientTrailers(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Connection", "close")
+ w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
+ w.Header().Add("Trailer", "Server-Trailer-C")
+
+ var decl []string
+ for k := range r.Trailer {
+ decl = append(decl, k)
+ }
+ sort.Strings(decl)
+
+ slurp, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ t.Errorf("Server reading request body: %v", err)
+ }
+ if string(slurp) != "foo" {
+ t.Errorf("Server read request body %q; want foo", slurp)
+ }
+ if r.Trailer == nil {
+ io.WriteString(w, "nil Trailer")
+ } else {
+ fmt.Fprintf(w, "decl: %v, vals: %s, %s",
+ decl,
+ r.Trailer.Get("Client-Trailer-A"),
+ r.Trailer.Get("Client-Trailer-B"))
+ }
+
+ // TODO: golang.org/issue/7759: there's no way yet for
+ // the server to set trailers without hijacking, so do
+ // that for now, just to test the client. Later, in
+ // Go 1.4, it should be implicit that any mutations
+ // to w.Header() after the initial write are the
+ // trailers to be sent, if and only if they were
+ // previously declared with w.Header().Set("Trailer",
+ // ..keys..)
+ w.(Flusher).Flush()
+ conn, buf, _ := w.(Hijacker).Hijack()
+ t := Header{}
+ t.Set("Server-Trailer-A", "valuea")
+ t.Set("Server-Trailer-C", "valuec") // skipping B
+ buf.WriteString("0\r\n") // eof
+ t.Write(buf)
+ buf.WriteString("\r\n") // end of trailers
+ buf.Flush()
+ conn.Close()
+ }))
+ defer ts.Close()
+
+ var req *Request
+ req, _ = NewRequest("POST", ts.URL, io.MultiReader(
+ eofReaderFunc(func() {
+ req.Trailer["Client-Trailer-A"] = []string{"valuea"}
+ }),
+ strings.NewReader("foo"),
+ eofReaderFunc(func() {
+ req.Trailer["Client-Trailer-B"] = []string{"valueb"}
+ }),
+ ))
+ req.Trailer = Header{
+ "Client-Trailer-A": nil, // to be set later
+ "Client-Trailer-B": nil, // to be set later
+ }
+ req.ContentLength = -1
+ res, err := DefaultClient.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
+ t.Error(err)
+ }
+ want := Header{
+ "Server-Trailer-A": []string{"valuea"},
+ "Server-Trailer-B": nil,
+ "Server-Trailer-C": []string{"valuec"},
+ }
+ if !reflect.DeepEqual(res.Trailer, want) {
+ t.Errorf("Response trailers = %#v; want %#v", res.Trailer, want)
+ }
+}
+
+func TestReferer(t *testing.T) {
+ tests := []struct {
+ lastReq, newReq string // from -> to URLs
+ want string
+ }{
+ // don't send user:
+ {"http://gopher@test.com", "http://link.com", "http://test.com"},
+ {"https://gopher@test.com", "https://link.com", "https://test.com"},
+
+ // don't send a user and password:
+ {"http://gopher:go@test.com", "http://link.com", "http://test.com"},
+ {"https://gopher:go@test.com", "https://link.com", "https://test.com"},
+
+ // nothing to do:
+ {"http://test.com", "http://link.com", "http://test.com"},
+ {"https://test.com", "https://link.com", "https://test.com"},
+
+ // https to http doesn't send a referer:
+ {"https://test.com", "http://link.com", ""},
+ {"https://gopher:go@test.com", "http://link.com", ""},
+ }
+ for _, tt := range tests {
+ l, err := url.Parse(tt.lastReq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ n, err := url.Parse(tt.newReq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ r := ExportRefererForURL(l, n)
+ if r != tt.want {
+ t.Errorf("refererForURL(%q, %q) = %q; want %q", tt.lastReq, tt.newReq, r, tt.want)
+ }
+ }
+}
diff --git a/src/net/http/cookie.go b/src/net/http/cookie.go
new file mode 100644
index 000000000..a0d0fdbbd
--- /dev/null
+++ b/src/net/http/cookie.go
@@ -0,0 +1,363 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bytes"
+ "fmt"
+ "log"
+ "net"
+ "strconv"
+ "strings"
+ "time"
+)
+
+// This implementation is done according to RFC 6265:
+//
+// http://tools.ietf.org/html/rfc6265
+
+// A Cookie represents an HTTP cookie as sent in the Set-Cookie header of an
+// HTTP response or the Cookie header of an HTTP request.
+type Cookie struct {
+ Name string
+ Value string
+ Path string
+ Domain string
+ Expires time.Time
+ RawExpires string
+
+ // MaxAge=0 means no 'Max-Age' attribute specified.
+ // MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0'
+ // MaxAge>0 means Max-Age attribute present and given in seconds
+ MaxAge int
+ Secure bool
+ HttpOnly bool
+ Raw string
+ Unparsed []string // Raw text of unparsed attribute-value pairs
+}
+
+// readSetCookies parses all "Set-Cookie" values from
+// the header h and returns the successfully parsed Cookies.
+func readSetCookies(h Header) []*Cookie {
+ cookies := []*Cookie{}
+ for _, line := range h["Set-Cookie"] {
+ parts := strings.Split(strings.TrimSpace(line), ";")
+ if len(parts) == 1 && parts[0] == "" {
+ continue
+ }
+ parts[0] = strings.TrimSpace(parts[0])
+ j := strings.Index(parts[0], "=")
+ if j < 0 {
+ continue
+ }
+ name, value := parts[0][:j], parts[0][j+1:]
+ if !isCookieNameValid(name) {
+ continue
+ }
+ value, success := parseCookieValue(value, true)
+ if !success {
+ continue
+ }
+ c := &Cookie{
+ Name: name,
+ Value: value,
+ Raw: line,
+ }
+ for i := 1; i < len(parts); i++ {
+ parts[i] = strings.TrimSpace(parts[i])
+ if len(parts[i]) == 0 {
+ continue
+ }
+
+ attr, val := parts[i], ""
+ if j := strings.Index(attr, "="); j >= 0 {
+ attr, val = attr[:j], attr[j+1:]
+ }
+ lowerAttr := strings.ToLower(attr)
+ val, success = parseCookieValue(val, false)
+ if !success {
+ c.Unparsed = append(c.Unparsed, parts[i])
+ continue
+ }
+ switch lowerAttr {
+ case "secure":
+ c.Secure = true
+ continue
+ case "httponly":
+ c.HttpOnly = true
+ continue
+ case "domain":
+ c.Domain = val
+ continue
+ case "max-age":
+ secs, err := strconv.Atoi(val)
+ if err != nil || secs != 0 && val[0] == '0' {
+ break
+ }
+ if secs <= 0 {
+ c.MaxAge = -1
+ } else {
+ c.MaxAge = secs
+ }
+ continue
+ case "expires":
+ c.RawExpires = val
+ exptime, err := time.Parse(time.RFC1123, val)
+ if err != nil {
+ exptime, err = time.Parse("Mon, 02-Jan-2006 15:04:05 MST", val)
+ if err != nil {
+ c.Expires = time.Time{}
+ break
+ }
+ }
+ c.Expires = exptime.UTC()
+ continue
+ case "path":
+ c.Path = val
+ continue
+ }
+ c.Unparsed = append(c.Unparsed, parts[i])
+ }
+ cookies = append(cookies, c)
+ }
+ return cookies
+}
+
+// SetCookie adds a Set-Cookie header to the provided ResponseWriter's headers.
+func SetCookie(w ResponseWriter, cookie *Cookie) {
+ w.Header().Add("Set-Cookie", cookie.String())
+}
+
+// String returns the serialization of the cookie for use in a Cookie
+// header (if only Name and Value are set) or a Set-Cookie response
+// header (if other fields are set).
+func (c *Cookie) String() string {
+ var b bytes.Buffer
+ fmt.Fprintf(&b, "%s=%s", sanitizeCookieName(c.Name), sanitizeCookieValue(c.Value))
+ if len(c.Path) > 0 {
+ fmt.Fprintf(&b, "; Path=%s", sanitizeCookiePath(c.Path))
+ }
+ if len(c.Domain) > 0 {
+ if validCookieDomain(c.Domain) {
+ // A c.Domain containing illegal characters is not
+ // sanitized but simply dropped which turns the cookie
+ // into a host-only cookie. A leading dot is okay
+ // but won't be sent.
+ d := c.Domain
+ if d[0] == '.' {
+ d = d[1:]
+ }
+ fmt.Fprintf(&b, "; Domain=%s", d)
+ } else {
+ log.Printf("net/http: invalid Cookie.Domain %q; dropping domain attribute",
+ c.Domain)
+ }
+ }
+ if c.Expires.Unix() > 0 {
+ fmt.Fprintf(&b, "; Expires=%s", c.Expires.UTC().Format(time.RFC1123))
+ }
+ if c.MaxAge > 0 {
+ fmt.Fprintf(&b, "; Max-Age=%d", c.MaxAge)
+ } else if c.MaxAge < 0 {
+ fmt.Fprintf(&b, "; Max-Age=0")
+ }
+ if c.HttpOnly {
+ fmt.Fprintf(&b, "; HttpOnly")
+ }
+ if c.Secure {
+ fmt.Fprintf(&b, "; Secure")
+ }
+ return b.String()
+}
+
+// readCookies parses all "Cookie" values from the header h and
+// returns the successfully parsed Cookies.
+//
+// if filter isn't empty, only cookies of that name are returned
+func readCookies(h Header, filter string) []*Cookie {
+ cookies := []*Cookie{}
+ lines, ok := h["Cookie"]
+ if !ok {
+ return cookies
+ }
+
+ for _, line := range lines {
+ parts := strings.Split(strings.TrimSpace(line), ";")
+ if len(parts) == 1 && parts[0] == "" {
+ continue
+ }
+ // Per-line attributes
+ parsedPairs := 0
+ for i := 0; i < len(parts); i++ {
+ parts[i] = strings.TrimSpace(parts[i])
+ if len(parts[i]) == 0 {
+ continue
+ }
+ name, val := parts[i], ""
+ if j := strings.Index(name, "="); j >= 0 {
+ name, val = name[:j], name[j+1:]
+ }
+ if !isCookieNameValid(name) {
+ continue
+ }
+ if filter != "" && filter != name {
+ continue
+ }
+ val, success := parseCookieValue(val, true)
+ if !success {
+ continue
+ }
+ cookies = append(cookies, &Cookie{Name: name, Value: val})
+ parsedPairs++
+ }
+ }
+ return cookies
+}
+
+// validCookieDomain returns wheter v is a valid cookie domain-value.
+func validCookieDomain(v string) bool {
+ if isCookieDomainName(v) {
+ return true
+ }
+ if net.ParseIP(v) != nil && !strings.Contains(v, ":") {
+ return true
+ }
+ return false
+}
+
+// isCookieDomainName returns whether s is a valid domain name or a valid
+// domain name with a leading dot '.'. It is almost a direct copy of
+// package net's isDomainName.
+func isCookieDomainName(s string) bool {
+ if len(s) == 0 {
+ return false
+ }
+ if len(s) > 255 {
+ return false
+ }
+
+ if s[0] == '.' {
+ // A cookie a domain attribute may start with a leading dot.
+ s = s[1:]
+ }
+ last := byte('.')
+ ok := false // Ok once we've seen a letter.
+ partlen := 0
+ for i := 0; i < len(s); i++ {
+ c := s[i]
+ switch {
+ default:
+ return false
+ case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z':
+ // No '_' allowed here (in contrast to package net).
+ ok = true
+ partlen++
+ case '0' <= c && c <= '9':
+ // fine
+ partlen++
+ case c == '-':
+ // Byte before dash cannot be dot.
+ if last == '.' {
+ return false
+ }
+ partlen++
+ case c == '.':
+ // Byte before dot cannot be dot, dash.
+ if last == '.' || last == '-' {
+ return false
+ }
+ if partlen > 63 || partlen == 0 {
+ return false
+ }
+ partlen = 0
+ }
+ last = c
+ }
+ if last == '-' || partlen > 63 {
+ return false
+ }
+
+ return ok
+}
+
+var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-")
+
+func sanitizeCookieName(n string) string {
+ return cookieNameSanitizer.Replace(n)
+}
+
+// http://tools.ietf.org/html/rfc6265#section-4.1.1
+// cookie-value = *cookie-octet / ( DQUOTE *cookie-octet DQUOTE )
+// cookie-octet = %x21 / %x23-2B / %x2D-3A / %x3C-5B / %x5D-7E
+// ; US-ASCII characters excluding CTLs,
+// ; whitespace DQUOTE, comma, semicolon,
+// ; and backslash
+// We loosen this as spaces and commas are common in cookie values
+// but we produce a quoted cookie-value in when value starts or ends
+// with a comma or space.
+// See http://golang.org/issue/7243 for the discussion.
+func sanitizeCookieValue(v string) string {
+ v = sanitizeOrWarn("Cookie.Value", validCookieValueByte, v)
+ if len(v) == 0 {
+ return v
+ }
+ if v[0] == ' ' || v[0] == ',' || v[len(v)-1] == ' ' || v[len(v)-1] == ',' {
+ return `"` + v + `"`
+ }
+ return v
+}
+
+func validCookieValueByte(b byte) bool {
+ return 0x20 <= b && b < 0x7f && b != '"' && b != ';' && b != '\\'
+}
+
+// path-av = "Path=" path-value
+// path-value = <any CHAR except CTLs or ";">
+func sanitizeCookiePath(v string) string {
+ return sanitizeOrWarn("Cookie.Path", validCookiePathByte, v)
+}
+
+func validCookiePathByte(b byte) bool {
+ return 0x20 <= b && b < 0x7f && b != ';'
+}
+
+func sanitizeOrWarn(fieldName string, valid func(byte) bool, v string) string {
+ ok := true
+ for i := 0; i < len(v); i++ {
+ if valid(v[i]) {
+ continue
+ }
+ log.Printf("net/http: invalid byte %q in %s; dropping invalid bytes", v[i], fieldName)
+ ok = false
+ break
+ }
+ if ok {
+ return v
+ }
+ buf := make([]byte, 0, len(v))
+ for i := 0; i < len(v); i++ {
+ if b := v[i]; valid(b) {
+ buf = append(buf, b)
+ }
+ }
+ return string(buf)
+}
+
+func parseCookieValue(raw string, allowDoubleQuote bool) (string, bool) {
+ // Strip the quotes, if present.
+ if allowDoubleQuote && len(raw) > 1 && raw[0] == '"' && raw[len(raw)-1] == '"' {
+ raw = raw[1 : len(raw)-1]
+ }
+ for i := 0; i < len(raw); i++ {
+ if !validCookieValueByte(raw[i]) {
+ return "", false
+ }
+ }
+ return raw, true
+}
+
+func isCookieNameValid(raw string) bool {
+ return strings.IndexFunc(raw, isNotToken) < 0
+}
diff --git a/src/net/http/cookie_test.go b/src/net/http/cookie_test.go
new file mode 100644
index 000000000..98dc2fade
--- /dev/null
+++ b/src/net/http/cookie_test.go
@@ -0,0 +1,412 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "log"
+ "os"
+ "reflect"
+ "strings"
+ "testing"
+ "time"
+)
+
+var writeSetCookiesTests = []struct {
+ Cookie *Cookie
+ Raw string
+}{
+ {
+ &Cookie{Name: "cookie-1", Value: "v$1"},
+ "cookie-1=v$1",
+ },
+ {
+ &Cookie{Name: "cookie-2", Value: "two", MaxAge: 3600},
+ "cookie-2=two; Max-Age=3600",
+ },
+ {
+ &Cookie{Name: "cookie-3", Value: "three", Domain: ".example.com"},
+ "cookie-3=three; Domain=example.com",
+ },
+ {
+ &Cookie{Name: "cookie-4", Value: "four", Path: "/restricted/"},
+ "cookie-4=four; Path=/restricted/",
+ },
+ {
+ &Cookie{Name: "cookie-5", Value: "five", Domain: "wrong;bad.abc"},
+ "cookie-5=five",
+ },
+ {
+ &Cookie{Name: "cookie-6", Value: "six", Domain: "bad-.abc"},
+ "cookie-6=six",
+ },
+ {
+ &Cookie{Name: "cookie-7", Value: "seven", Domain: "127.0.0.1"},
+ "cookie-7=seven; Domain=127.0.0.1",
+ },
+ {
+ &Cookie{Name: "cookie-8", Value: "eight", Domain: "::1"},
+ "cookie-8=eight",
+ },
+ // The "special" cookies have values containing commas or spaces which
+ // are disallowed by RFC 6265 but are common in the wild.
+ {
+ &Cookie{Name: "special-1", Value: "a z"},
+ `special-1=a z`,
+ },
+ {
+ &Cookie{Name: "special-2", Value: " z"},
+ `special-2=" z"`,
+ },
+ {
+ &Cookie{Name: "special-3", Value: "a "},
+ `special-3="a "`,
+ },
+ {
+ &Cookie{Name: "special-4", Value: " "},
+ `special-4=" "`,
+ },
+ {
+ &Cookie{Name: "special-5", Value: "a,z"},
+ `special-5=a,z`,
+ },
+ {
+ &Cookie{Name: "special-6", Value: ",z"},
+ `special-6=",z"`,
+ },
+ {
+ &Cookie{Name: "special-7", Value: "a,"},
+ `special-7="a,"`,
+ },
+ {
+ &Cookie{Name: "special-8", Value: ","},
+ `special-8=","`,
+ },
+ {
+ &Cookie{Name: "empty-value", Value: ""},
+ `empty-value=`,
+ },
+}
+
+func TestWriteSetCookies(t *testing.T) {
+ defer log.SetOutput(os.Stderr)
+ var logbuf bytes.Buffer
+ log.SetOutput(&logbuf)
+
+ for i, tt := range writeSetCookiesTests {
+ if g, e := tt.Cookie.String(), tt.Raw; g != e {
+ t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, e, g)
+ continue
+ }
+ }
+
+ if got, sub := logbuf.String(), "dropping domain attribute"; !strings.Contains(got, sub) {
+ t.Errorf("Expected substring %q in log output. Got:\n%s", sub, got)
+ }
+}
+
+type headerOnlyResponseWriter Header
+
+func (ho headerOnlyResponseWriter) Header() Header {
+ return Header(ho)
+}
+
+func (ho headerOnlyResponseWriter) Write([]byte) (int, error) {
+ panic("NOIMPL")
+}
+
+func (ho headerOnlyResponseWriter) WriteHeader(int) {
+ panic("NOIMPL")
+}
+
+func TestSetCookie(t *testing.T) {
+ m := make(Header)
+ SetCookie(headerOnlyResponseWriter(m), &Cookie{Name: "cookie-1", Value: "one", Path: "/restricted/"})
+ SetCookie(headerOnlyResponseWriter(m), &Cookie{Name: "cookie-2", Value: "two", MaxAge: 3600})
+ if l := len(m["Set-Cookie"]); l != 2 {
+ t.Fatalf("expected %d cookies, got %d", 2, l)
+ }
+ if g, e := m["Set-Cookie"][0], "cookie-1=one; Path=/restricted/"; g != e {
+ t.Errorf("cookie #1: want %q, got %q", e, g)
+ }
+ if g, e := m["Set-Cookie"][1], "cookie-2=two; Max-Age=3600"; g != e {
+ t.Errorf("cookie #2: want %q, got %q", e, g)
+ }
+}
+
+var addCookieTests = []struct {
+ Cookies []*Cookie
+ Raw string
+}{
+ {
+ []*Cookie{},
+ "",
+ },
+ {
+ []*Cookie{{Name: "cookie-1", Value: "v$1"}},
+ "cookie-1=v$1",
+ },
+ {
+ []*Cookie{
+ {Name: "cookie-1", Value: "v$1"},
+ {Name: "cookie-2", Value: "v$2"},
+ {Name: "cookie-3", Value: "v$3"},
+ },
+ "cookie-1=v$1; cookie-2=v$2; cookie-3=v$3",
+ },
+}
+
+func TestAddCookie(t *testing.T) {
+ for i, tt := range addCookieTests {
+ req, _ := NewRequest("GET", "http://example.com/", nil)
+ for _, c := range tt.Cookies {
+ req.AddCookie(c)
+ }
+ if g := req.Header.Get("Cookie"); g != tt.Raw {
+ t.Errorf("Test %d:\nwant: %s\n got: %s\n", i, tt.Raw, g)
+ continue
+ }
+ }
+}
+
+var readSetCookiesTests = []struct {
+ Header Header
+ Cookies []*Cookie
+}{
+ {
+ Header{"Set-Cookie": {"Cookie-1=v$1"}},
+ []*Cookie{{Name: "Cookie-1", Value: "v$1", Raw: "Cookie-1=v$1"}},
+ },
+ {
+ Header{"Set-Cookie": {"NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly"}},
+ []*Cookie{{
+ Name: "NID",
+ Value: "99=YsDT5i3E-CXax-",
+ Path: "/",
+ Domain: ".google.ch",
+ HttpOnly: true,
+ Expires: time.Date(2011, 11, 23, 1, 5, 3, 0, time.UTC),
+ RawExpires: "Wed, 23-Nov-2011 01:05:03 GMT",
+ Raw: "NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly",
+ }},
+ },
+ {
+ Header{"Set-Cookie": {".ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly"}},
+ []*Cookie{{
+ Name: ".ASPXAUTH",
+ Value: "7E3AA",
+ Path: "/",
+ Expires: time.Date(2012, 3, 7, 14, 25, 6, 0, time.UTC),
+ RawExpires: "Wed, 07-Mar-2012 14:25:06 GMT",
+ HttpOnly: true,
+ Raw: ".ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly",
+ }},
+ },
+ {
+ Header{"Set-Cookie": {"ASP.NET_SessionId=foo; path=/; HttpOnly"}},
+ []*Cookie{{
+ Name: "ASP.NET_SessionId",
+ Value: "foo",
+ Path: "/",
+ HttpOnly: true,
+ Raw: "ASP.NET_SessionId=foo; path=/; HttpOnly",
+ }},
+ },
+ // Make sure we can properly read back the Set-Cookie headers we create
+ // for values containing spaces or commas:
+ {
+ Header{"Set-Cookie": {`special-1=a z`}},
+ []*Cookie{{Name: "special-1", Value: "a z", Raw: `special-1=a z`}},
+ },
+ {
+ Header{"Set-Cookie": {`special-2=" z"`}},
+ []*Cookie{{Name: "special-2", Value: " z", Raw: `special-2=" z"`}},
+ },
+ {
+ Header{"Set-Cookie": {`special-3="a "`}},
+ []*Cookie{{Name: "special-3", Value: "a ", Raw: `special-3="a "`}},
+ },
+ {
+ Header{"Set-Cookie": {`special-4=" "`}},
+ []*Cookie{{Name: "special-4", Value: " ", Raw: `special-4=" "`}},
+ },
+ {
+ Header{"Set-Cookie": {`special-5=a,z`}},
+ []*Cookie{{Name: "special-5", Value: "a,z", Raw: `special-5=a,z`}},
+ },
+ {
+ Header{"Set-Cookie": {`special-6=",z"`}},
+ []*Cookie{{Name: "special-6", Value: ",z", Raw: `special-6=",z"`}},
+ },
+ {
+ Header{"Set-Cookie": {`special-7=a,`}},
+ []*Cookie{{Name: "special-7", Value: "a,", Raw: `special-7=a,`}},
+ },
+ {
+ Header{"Set-Cookie": {`special-8=","`}},
+ []*Cookie{{Name: "special-8", Value: ",", Raw: `special-8=","`}},
+ },
+
+ // TODO(bradfitz): users have reported seeing this in the
+ // wild, but do browsers handle it? RFC 6265 just says "don't
+ // do that" (section 3) and then never mentions header folding
+ // again.
+ // Header{"Set-Cookie": {"ASP.NET_SessionId=foo; path=/; HttpOnly, .ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly"}},
+}
+
+func toJSON(v interface{}) string {
+ b, err := json.Marshal(v)
+ if err != nil {
+ return fmt.Sprintf("%#v", v)
+ }
+ return string(b)
+}
+
+func TestReadSetCookies(t *testing.T) {
+ for i, tt := range readSetCookiesTests {
+ for n := 0; n < 2; n++ { // to verify readSetCookies doesn't mutate its input
+ c := readSetCookies(tt.Header)
+ if !reflect.DeepEqual(c, tt.Cookies) {
+ t.Errorf("#%d readSetCookies: have\n%s\nwant\n%s\n", i, toJSON(c), toJSON(tt.Cookies))
+ continue
+ }
+ }
+ }
+}
+
+var readCookiesTests = []struct {
+ Header Header
+ Filter string
+ Cookies []*Cookie
+}{
+ {
+ Header{"Cookie": {"Cookie-1=v$1", "c2=v2"}},
+ "",
+ []*Cookie{
+ {Name: "Cookie-1", Value: "v$1"},
+ {Name: "c2", Value: "v2"},
+ },
+ },
+ {
+ Header{"Cookie": {"Cookie-1=v$1", "c2=v2"}},
+ "c2",
+ []*Cookie{
+ {Name: "c2", Value: "v2"},
+ },
+ },
+ {
+ Header{"Cookie": {"Cookie-1=v$1; c2=v2"}},
+ "",
+ []*Cookie{
+ {Name: "Cookie-1", Value: "v$1"},
+ {Name: "c2", Value: "v2"},
+ },
+ },
+ {
+ Header{"Cookie": {"Cookie-1=v$1; c2=v2"}},
+ "c2",
+ []*Cookie{
+ {Name: "c2", Value: "v2"},
+ },
+ },
+ {
+ Header{"Cookie": {`Cookie-1="v$1"; c2="v2"`}},
+ "",
+ []*Cookie{
+ {Name: "Cookie-1", Value: "v$1"},
+ {Name: "c2", Value: "v2"},
+ },
+ },
+}
+
+func TestReadCookies(t *testing.T) {
+ for i, tt := range readCookiesTests {
+ for n := 0; n < 2; n++ { // to verify readCookies doesn't mutate its input
+ c := readCookies(tt.Header, tt.Filter)
+ if !reflect.DeepEqual(c, tt.Cookies) {
+ t.Errorf("#%d readCookies:\nhave: %s\nwant: %s\n", i, toJSON(c), toJSON(tt.Cookies))
+ continue
+ }
+ }
+ }
+}
+
+func TestSetCookieDoubleQuotes(t *testing.T) {
+ res := &Response{Header: Header{}}
+ res.Header.Add("Set-Cookie", `quoted0=none; max-age=30`)
+ res.Header.Add("Set-Cookie", `quoted1="cookieValue"; max-age=31`)
+ res.Header.Add("Set-Cookie", `quoted2=cookieAV; max-age="32"`)
+ res.Header.Add("Set-Cookie", `quoted3="both"; max-age="33"`)
+ got := res.Cookies()
+ want := []*Cookie{
+ {Name: "quoted0", Value: "none", MaxAge: 30},
+ {Name: "quoted1", Value: "cookieValue", MaxAge: 31},
+ {Name: "quoted2", Value: "cookieAV"},
+ {Name: "quoted3", Value: "both"},
+ }
+ if len(got) != len(want) {
+ t.Fatal("got %d cookies, want %d", len(got), len(want))
+ }
+ for i, w := range want {
+ g := got[i]
+ if g.Name != w.Name || g.Value != w.Value || g.MaxAge != w.MaxAge {
+ t.Errorf("cookie #%d:\ngot %v\nwant %v", i, g, w)
+ }
+ }
+}
+
+func TestCookieSanitizeValue(t *testing.T) {
+ defer log.SetOutput(os.Stderr)
+ var logbuf bytes.Buffer
+ log.SetOutput(&logbuf)
+
+ tests := []struct {
+ in, want string
+ }{
+ {"foo", "foo"},
+ {"foo;bar", "foobar"},
+ {"foo\\bar", "foobar"},
+ {"foo\"bar", "foobar"},
+ {"\x00\x7e\x7f\x80", "\x7e"},
+ {`"withquotes"`, "withquotes"},
+ {"a z", "a z"},
+ {" z", `" z"`},
+ {"a ", `"a "`},
+ }
+ for _, tt := range tests {
+ if got := sanitizeCookieValue(tt.in); got != tt.want {
+ t.Errorf("sanitizeCookieValue(%q) = %q; want %q", tt.in, got, tt.want)
+ }
+ }
+
+ if got, sub := logbuf.String(), "dropping invalid bytes"; !strings.Contains(got, sub) {
+ t.Errorf("Expected substring %q in log output. Got:\n%s", sub, got)
+ }
+}
+
+func TestCookieSanitizePath(t *testing.T) {
+ defer log.SetOutput(os.Stderr)
+ var logbuf bytes.Buffer
+ log.SetOutput(&logbuf)
+
+ tests := []struct {
+ in, want string
+ }{
+ {"/path", "/path"},
+ {"/path with space/", "/path with space/"},
+ {"/just;no;semicolon\x00orstuff/", "/justnosemicolonorstuff/"},
+ }
+ for _, tt := range tests {
+ if got := sanitizeCookiePath(tt.in); got != tt.want {
+ t.Errorf("sanitizeCookiePath(%q) = %q; want %q", tt.in, got, tt.want)
+ }
+ }
+
+ if got, sub := logbuf.String(), "dropping invalid bytes"; !strings.Contains(got, sub) {
+ t.Errorf("Expected substring %q in log output. Got:\n%s", sub, got)
+ }
+}
diff --git a/src/net/http/cookiejar/jar.go b/src/net/http/cookiejar/jar.go
new file mode 100644
index 000000000..0e0fac928
--- /dev/null
+++ b/src/net/http/cookiejar/jar.go
@@ -0,0 +1,497 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package cookiejar implements an in-memory RFC 6265-compliant http.CookieJar.
+package cookiejar
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "net/http"
+ "net/url"
+ "sort"
+ "strings"
+ "sync"
+ "time"
+)
+
+// PublicSuffixList provides the public suffix of a domain. For example:
+// - the public suffix of "example.com" is "com",
+// - the public suffix of "foo1.foo2.foo3.co.uk" is "co.uk", and
+// - the public suffix of "bar.pvt.k12.ma.us" is "pvt.k12.ma.us".
+//
+// Implementations of PublicSuffixList must be safe for concurrent use by
+// multiple goroutines.
+//
+// An implementation that always returns "" is valid and may be useful for
+// testing but it is not secure: it means that the HTTP server for foo.com can
+// set a cookie for bar.com.
+//
+// A public suffix list implementation is in the package
+// golang.org/x/net/publicsuffix.
+type PublicSuffixList interface {
+ // PublicSuffix returns the public suffix of domain.
+ //
+ // TODO: specify which of the caller and callee is responsible for IP
+ // addresses, for leading and trailing dots, for case sensitivity, and
+ // for IDN/Punycode.
+ PublicSuffix(domain string) string
+
+ // String returns a description of the source of this public suffix
+ // list. The description will typically contain something like a time
+ // stamp or version number.
+ String() string
+}
+
+// Options are the options for creating a new Jar.
+type Options struct {
+ // PublicSuffixList is the public suffix list that determines whether
+ // an HTTP server can set a cookie for a domain.
+ //
+ // A nil value is valid and may be useful for testing but it is not
+ // secure: it means that the HTTP server for foo.co.uk can set a cookie
+ // for bar.co.uk.
+ PublicSuffixList PublicSuffixList
+}
+
+// Jar implements the http.CookieJar interface from the net/http package.
+type Jar struct {
+ psList PublicSuffixList
+
+ // mu locks the remaining fields.
+ mu sync.Mutex
+
+ // entries is a set of entries, keyed by their eTLD+1 and subkeyed by
+ // their name/domain/path.
+ entries map[string]map[string]entry
+
+ // nextSeqNum is the next sequence number assigned to a new cookie
+ // created SetCookies.
+ nextSeqNum uint64
+}
+
+// New returns a new cookie jar. A nil *Options is equivalent to a zero
+// Options.
+func New(o *Options) (*Jar, error) {
+ jar := &Jar{
+ entries: make(map[string]map[string]entry),
+ }
+ if o != nil {
+ jar.psList = o.PublicSuffixList
+ }
+ return jar, nil
+}
+
+// entry is the internal representation of a cookie.
+//
+// This struct type is not used outside of this package per se, but the exported
+// fields are those of RFC 6265.
+type entry struct {
+ Name string
+ Value string
+ Domain string
+ Path string
+ Secure bool
+ HttpOnly bool
+ Persistent bool
+ HostOnly bool
+ Expires time.Time
+ Creation time.Time
+ LastAccess time.Time
+
+ // seqNum is a sequence number so that Cookies returns cookies in a
+ // deterministic order, even for cookies that have equal Path length and
+ // equal Creation time. This simplifies testing.
+ seqNum uint64
+}
+
+// Id returns the domain;path;name triple of e as an id.
+func (e *entry) id() string {
+ return fmt.Sprintf("%s;%s;%s", e.Domain, e.Path, e.Name)
+}
+
+// shouldSend determines whether e's cookie qualifies to be included in a
+// request to host/path. It is the caller's responsibility to check if the
+// cookie is expired.
+func (e *entry) shouldSend(https bool, host, path string) bool {
+ return e.domainMatch(host) && e.pathMatch(path) && (https || !e.Secure)
+}
+
+// domainMatch implements "domain-match" of RFC 6265 section 5.1.3.
+func (e *entry) domainMatch(host string) bool {
+ if e.Domain == host {
+ return true
+ }
+ return !e.HostOnly && hasDotSuffix(host, e.Domain)
+}
+
+// pathMatch implements "path-match" according to RFC 6265 section 5.1.4.
+func (e *entry) pathMatch(requestPath string) bool {
+ if requestPath == e.Path {
+ return true
+ }
+ if strings.HasPrefix(requestPath, e.Path) {
+ if e.Path[len(e.Path)-1] == '/' {
+ return true // The "/any/" matches "/any/path" case.
+ } else if requestPath[len(e.Path)] == '/' {
+ return true // The "/any" matches "/any/path" case.
+ }
+ }
+ return false
+}
+
+// hasDotSuffix reports whether s ends in "."+suffix.
+func hasDotSuffix(s, suffix string) bool {
+ return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix
+}
+
+// byPathLength is a []entry sort.Interface that sorts according to RFC 6265
+// section 5.4 point 2: by longest path and then by earliest creation time.
+type byPathLength []entry
+
+func (s byPathLength) Len() int { return len(s) }
+
+func (s byPathLength) Less(i, j int) bool {
+ if len(s[i].Path) != len(s[j].Path) {
+ return len(s[i].Path) > len(s[j].Path)
+ }
+ if !s[i].Creation.Equal(s[j].Creation) {
+ return s[i].Creation.Before(s[j].Creation)
+ }
+ return s[i].seqNum < s[j].seqNum
+}
+
+func (s byPathLength) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
+// Cookies implements the Cookies method of the http.CookieJar interface.
+//
+// It returns an empty slice if the URL's scheme is not HTTP or HTTPS.
+func (j *Jar) Cookies(u *url.URL) (cookies []*http.Cookie) {
+ return j.cookies(u, time.Now())
+}
+
+// cookies is like Cookies but takes the current time as a parameter.
+func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) {
+ if u.Scheme != "http" && u.Scheme != "https" {
+ return cookies
+ }
+ host, err := canonicalHost(u.Host)
+ if err != nil {
+ return cookies
+ }
+ key := jarKey(host, j.psList)
+
+ j.mu.Lock()
+ defer j.mu.Unlock()
+
+ submap := j.entries[key]
+ if submap == nil {
+ return cookies
+ }
+
+ https := u.Scheme == "https"
+ path := u.Path
+ if path == "" {
+ path = "/"
+ }
+
+ modified := false
+ var selected []entry
+ for id, e := range submap {
+ if e.Persistent && !e.Expires.After(now) {
+ delete(submap, id)
+ modified = true
+ continue
+ }
+ if !e.shouldSend(https, host, path) {
+ continue
+ }
+ e.LastAccess = now
+ submap[id] = e
+ selected = append(selected, e)
+ modified = true
+ }
+ if modified {
+ if len(submap) == 0 {
+ delete(j.entries, key)
+ } else {
+ j.entries[key] = submap
+ }
+ }
+
+ sort.Sort(byPathLength(selected))
+ for _, e := range selected {
+ cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value})
+ }
+
+ return cookies
+}
+
+// SetCookies implements the SetCookies method of the http.CookieJar interface.
+//
+// It does nothing if the URL's scheme is not HTTP or HTTPS.
+func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) {
+ j.setCookies(u, cookies, time.Now())
+}
+
+// setCookies is like SetCookies but takes the current time as parameter.
+func (j *Jar) setCookies(u *url.URL, cookies []*http.Cookie, now time.Time) {
+ if len(cookies) == 0 {
+ return
+ }
+ if u.Scheme != "http" && u.Scheme != "https" {
+ return
+ }
+ host, err := canonicalHost(u.Host)
+ if err != nil {
+ return
+ }
+ key := jarKey(host, j.psList)
+ defPath := defaultPath(u.Path)
+
+ j.mu.Lock()
+ defer j.mu.Unlock()
+
+ submap := j.entries[key]
+
+ modified := false
+ for _, cookie := range cookies {
+ e, remove, err := j.newEntry(cookie, now, defPath, host)
+ if err != nil {
+ continue
+ }
+ id := e.id()
+ if remove {
+ if submap != nil {
+ if _, ok := submap[id]; ok {
+ delete(submap, id)
+ modified = true
+ }
+ }
+ continue
+ }
+ if submap == nil {
+ submap = make(map[string]entry)
+ }
+
+ if old, ok := submap[id]; ok {
+ e.Creation = old.Creation
+ e.seqNum = old.seqNum
+ } else {
+ e.Creation = now
+ e.seqNum = j.nextSeqNum
+ j.nextSeqNum++
+ }
+ e.LastAccess = now
+ submap[id] = e
+ modified = true
+ }
+
+ if modified {
+ if len(submap) == 0 {
+ delete(j.entries, key)
+ } else {
+ j.entries[key] = submap
+ }
+ }
+}
+
+// canonicalHost strips port from host if present and returns the canonicalized
+// host name.
+func canonicalHost(host string) (string, error) {
+ var err error
+ host = strings.ToLower(host)
+ if hasPort(host) {
+ host, _, err = net.SplitHostPort(host)
+ if err != nil {
+ return "", err
+ }
+ }
+ if strings.HasSuffix(host, ".") {
+ // Strip trailing dot from fully qualified domain names.
+ host = host[:len(host)-1]
+ }
+ return toASCII(host)
+}
+
+// hasPort reports whether host contains a port number. host may be a host
+// name, an IPv4 or an IPv6 address.
+func hasPort(host string) bool {
+ colons := strings.Count(host, ":")
+ if colons == 0 {
+ return false
+ }
+ if colons == 1 {
+ return true
+ }
+ return host[0] == '[' && strings.Contains(host, "]:")
+}
+
+// jarKey returns the key to use for a jar.
+func jarKey(host string, psl PublicSuffixList) string {
+ if isIP(host) {
+ return host
+ }
+
+ var i int
+ if psl == nil {
+ i = strings.LastIndex(host, ".")
+ if i == -1 {
+ return host
+ }
+ } else {
+ suffix := psl.PublicSuffix(host)
+ if suffix == host {
+ return host
+ }
+ i = len(host) - len(suffix)
+ if i <= 0 || host[i-1] != '.' {
+ // The provided public suffix list psl is broken.
+ // Storing cookies under host is a safe stopgap.
+ return host
+ }
+ }
+ prevDot := strings.LastIndex(host[:i-1], ".")
+ return host[prevDot+1:]
+}
+
+// isIP reports whether host is an IP address.
+func isIP(host string) bool {
+ return net.ParseIP(host) != nil
+}
+
+// defaultPath returns the directory part of an URL's path according to
+// RFC 6265 section 5.1.4.
+func defaultPath(path string) string {
+ if len(path) == 0 || path[0] != '/' {
+ return "/" // Path is empty or malformed.
+ }
+
+ i := strings.LastIndex(path, "/") // Path starts with "/", so i != -1.
+ if i == 0 {
+ return "/" // Path has the form "/abc".
+ }
+ return path[:i] // Path is either of form "/abc/xyz" or "/abc/xyz/".
+}
+
+// newEntry creates an entry from a http.Cookie c. now is the current time and
+// is compared to c.Expires to determine deletion of c. defPath and host are the
+// default-path and the canonical host name of the URL c was received from.
+//
+// remove records whether the jar should delete this cookie, as it has already
+// expired with respect to now. In this case, e may be incomplete, but it will
+// be valid to call e.id (which depends on e's Name, Domain and Path).
+//
+// A malformed c.Domain will result in an error.
+func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e entry, remove bool, err error) {
+ e.Name = c.Name
+
+ if c.Path == "" || c.Path[0] != '/' {
+ e.Path = defPath
+ } else {
+ e.Path = c.Path
+ }
+
+ e.Domain, e.HostOnly, err = j.domainAndType(host, c.Domain)
+ if err != nil {
+ return e, false, err
+ }
+
+ // MaxAge takes precedence over Expires.
+ if c.MaxAge < 0 {
+ return e, true, nil
+ } else if c.MaxAge > 0 {
+ e.Expires = now.Add(time.Duration(c.MaxAge) * time.Second)
+ e.Persistent = true
+ } else {
+ if c.Expires.IsZero() {
+ e.Expires = endOfTime
+ e.Persistent = false
+ } else {
+ if !c.Expires.After(now) {
+ return e, true, nil
+ }
+ e.Expires = c.Expires
+ e.Persistent = true
+ }
+ }
+
+ e.Value = c.Value
+ e.Secure = c.Secure
+ e.HttpOnly = c.HttpOnly
+
+ return e, false, nil
+}
+
+var (
+ errIllegalDomain = errors.New("cookiejar: illegal cookie domain attribute")
+ errMalformedDomain = errors.New("cookiejar: malformed cookie domain attribute")
+ errNoHostname = errors.New("cookiejar: no host name available (IP only)")
+)
+
+// endOfTime is the time when session (non-persistent) cookies expire.
+// This instant is representable in most date/time formats (not just
+// Go's time.Time) and should be far enough in the future.
+var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC)
+
+// domainAndType determines the cookie's domain and hostOnly attribute.
+func (j *Jar) domainAndType(host, domain string) (string, bool, error) {
+ if domain == "" {
+ // No domain attribute in the SetCookie header indicates a
+ // host cookie.
+ return host, true, nil
+ }
+
+ if isIP(host) {
+ // According to RFC 6265 domain-matching includes not being
+ // an IP address.
+ // TODO: This might be relaxed as in common browsers.
+ return "", false, errNoHostname
+ }
+
+ // From here on: If the cookie is valid, it is a domain cookie (with
+ // the one exception of a public suffix below).
+ // See RFC 6265 section 5.2.3.
+ if domain[0] == '.' {
+ domain = domain[1:]
+ }
+
+ if len(domain) == 0 || domain[0] == '.' {
+ // Received either "Domain=." or "Domain=..some.thing",
+ // both are illegal.
+ return "", false, errMalformedDomain
+ }
+ domain = strings.ToLower(domain)
+
+ if domain[len(domain)-1] == '.' {
+ // We received stuff like "Domain=www.example.com.".
+ // Browsers do handle such stuff (actually differently) but
+ // RFC 6265 seems to be clear here (e.g. section 4.1.2.3) in
+ // requiring a reject. 4.1.2.3 is not normative, but
+ // "Domain Matching" (5.1.3) and "Canonicalized Host Names"
+ // (5.1.2) are.
+ return "", false, errMalformedDomain
+ }
+
+ // See RFC 6265 section 5.3 #5.
+ if j.psList != nil {
+ if ps := j.psList.PublicSuffix(domain); ps != "" && !hasDotSuffix(domain, ps) {
+ if host == domain {
+ // This is the one exception in which a cookie
+ // with a domain attribute is a host cookie.
+ return host, true, nil
+ }
+ return "", false, errIllegalDomain
+ }
+ }
+
+ // The domain must domain-match host: www.mycompany.com cannot
+ // set cookies for .ourcompetitors.com.
+ if host != domain && !hasDotSuffix(host, domain) {
+ return "", false, errIllegalDomain
+ }
+
+ return domain, false, nil
+}
diff --git a/src/net/http/cookiejar/jar_test.go b/src/net/http/cookiejar/jar_test.go
new file mode 100644
index 000000000..3aa601586
--- /dev/null
+++ b/src/net/http/cookiejar/jar_test.go
@@ -0,0 +1,1267 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cookiejar
+
+import (
+ "fmt"
+ "net/http"
+ "net/url"
+ "sort"
+ "strings"
+ "testing"
+ "time"
+)
+
+// tNow is the synthetic current time used as now during testing.
+var tNow = time.Date(2013, 1, 1, 12, 0, 0, 0, time.UTC)
+
+// testPSL implements PublicSuffixList with just two rules: "co.uk"
+// and the default rule "*".
+type testPSL struct{}
+
+func (testPSL) String() string {
+ return "testPSL"
+}
+func (testPSL) PublicSuffix(d string) string {
+ if d == "co.uk" || strings.HasSuffix(d, ".co.uk") {
+ return "co.uk"
+ }
+ return d[strings.LastIndex(d, ".")+1:]
+}
+
+// newTestJar creates an empty Jar with testPSL as the public suffix list.
+func newTestJar() *Jar {
+ jar, err := New(&Options{PublicSuffixList: testPSL{}})
+ if err != nil {
+ panic(err)
+ }
+ return jar
+}
+
+var hasDotSuffixTests = [...]struct {
+ s, suffix string
+}{
+ {"", ""},
+ {"", "."},
+ {"", "x"},
+ {".", ""},
+ {".", "."},
+ {".", ".."},
+ {".", "x"},
+ {".", "x."},
+ {".", ".x"},
+ {".", ".x."},
+ {"x", ""},
+ {"x", "."},
+ {"x", ".."},
+ {"x", "x"},
+ {"x", "x."},
+ {"x", ".x"},
+ {"x", ".x."},
+ {".x", ""},
+ {".x", "."},
+ {".x", ".."},
+ {".x", "x"},
+ {".x", "x."},
+ {".x", ".x"},
+ {".x", ".x."},
+ {"x.", ""},
+ {"x.", "."},
+ {"x.", ".."},
+ {"x.", "x"},
+ {"x.", "x."},
+ {"x.", ".x"},
+ {"x.", ".x."},
+ {"com", ""},
+ {"com", "m"},
+ {"com", "om"},
+ {"com", "com"},
+ {"com", ".com"},
+ {"com", "x.com"},
+ {"com", "xcom"},
+ {"com", "xorg"},
+ {"com", "org"},
+ {"com", "rg"},
+ {"foo.com", ""},
+ {"foo.com", "m"},
+ {"foo.com", "om"},
+ {"foo.com", "com"},
+ {"foo.com", ".com"},
+ {"foo.com", "o.com"},
+ {"foo.com", "oo.com"},
+ {"foo.com", "foo.com"},
+ {"foo.com", ".foo.com"},
+ {"foo.com", "x.foo.com"},
+ {"foo.com", "xfoo.com"},
+ {"foo.com", "xfoo.org"},
+ {"foo.com", "foo.org"},
+ {"foo.com", "oo.org"},
+ {"foo.com", "o.org"},
+ {"foo.com", ".org"},
+ {"foo.com", "org"},
+ {"foo.com", "rg"},
+}
+
+func TestHasDotSuffix(t *testing.T) {
+ for _, tc := range hasDotSuffixTests {
+ got := hasDotSuffix(tc.s, tc.suffix)
+ want := strings.HasSuffix(tc.s, "."+tc.suffix)
+ if got != want {
+ t.Errorf("s=%q, suffix=%q: got %v, want %v", tc.s, tc.suffix, got, want)
+ }
+ }
+}
+
+var canonicalHostTests = map[string]string{
+ "www.example.com": "www.example.com",
+ "WWW.EXAMPLE.COM": "www.example.com",
+ "wWw.eXAmple.CoM": "www.example.com",
+ "www.example.com:80": "www.example.com",
+ "192.168.0.10": "192.168.0.10",
+ "192.168.0.5:8080": "192.168.0.5",
+ "2001:4860:0:2001::68": "2001:4860:0:2001::68",
+ "[2001:4860:0:::68]:8080": "2001:4860:0:::68",
+ "www.bรผcher.de": "www.xn--bcher-kva.de",
+ "www.example.com.": "www.example.com",
+ "[bad.unmatched.bracket:": "error",
+}
+
+func TestCanonicalHost(t *testing.T) {
+ for h, want := range canonicalHostTests {
+ got, err := canonicalHost(h)
+ if want == "error" {
+ if err == nil {
+ t.Errorf("%q: got nil error, want non-nil", h)
+ }
+ continue
+ }
+ if err != nil {
+ t.Errorf("%q: %v", h, err)
+ continue
+ }
+ if got != want {
+ t.Errorf("%q: got %q, want %q", h, got, want)
+ continue
+ }
+ }
+}
+
+var hasPortTests = map[string]bool{
+ "www.example.com": false,
+ "www.example.com:80": true,
+ "127.0.0.1": false,
+ "127.0.0.1:8080": true,
+ "2001:4860:0:2001::68": false,
+ "[2001::0:::68]:80": true,
+}
+
+func TestHasPort(t *testing.T) {
+ for host, want := range hasPortTests {
+ if got := hasPort(host); got != want {
+ t.Errorf("%q: got %t, want %t", host, got, want)
+ }
+ }
+}
+
+var jarKeyTests = map[string]string{
+ "foo.www.example.com": "example.com",
+ "www.example.com": "example.com",
+ "example.com": "example.com",
+ "com": "com",
+ "foo.www.bbc.co.uk": "bbc.co.uk",
+ "www.bbc.co.uk": "bbc.co.uk",
+ "bbc.co.uk": "bbc.co.uk",
+ "co.uk": "co.uk",
+ "uk": "uk",
+ "192.168.0.5": "192.168.0.5",
+}
+
+func TestJarKey(t *testing.T) {
+ for host, want := range jarKeyTests {
+ if got := jarKey(host, testPSL{}); got != want {
+ t.Errorf("%q: got %q, want %q", host, got, want)
+ }
+ }
+}
+
+var jarKeyNilPSLTests = map[string]string{
+ "foo.www.example.com": "example.com",
+ "www.example.com": "example.com",
+ "example.com": "example.com",
+ "com": "com",
+ "foo.www.bbc.co.uk": "co.uk",
+ "www.bbc.co.uk": "co.uk",
+ "bbc.co.uk": "co.uk",
+ "co.uk": "co.uk",
+ "uk": "uk",
+ "192.168.0.5": "192.168.0.5",
+}
+
+func TestJarKeyNilPSL(t *testing.T) {
+ for host, want := range jarKeyNilPSLTests {
+ if got := jarKey(host, nil); got != want {
+ t.Errorf("%q: got %q, want %q", host, got, want)
+ }
+ }
+}
+
+var isIPTests = map[string]bool{
+ "127.0.0.1": true,
+ "1.2.3.4": true,
+ "2001:4860:0:2001::68": true,
+ "example.com": false,
+ "1.1.1.300": false,
+ "www.foo.bar.net": false,
+ "123.foo.bar.net": false,
+}
+
+func TestIsIP(t *testing.T) {
+ for host, want := range isIPTests {
+ if got := isIP(host); got != want {
+ t.Errorf("%q: got %t, want %t", host, got, want)
+ }
+ }
+}
+
+var defaultPathTests = map[string]string{
+ "/": "/",
+ "/abc": "/",
+ "/abc/": "/abc",
+ "/abc/xyz": "/abc",
+ "/abc/xyz/": "/abc/xyz",
+ "/a/b/c.html": "/a/b",
+ "": "/",
+ "strange": "/",
+ "//": "/",
+ "/a//b": "/a/",
+ "/a/./b": "/a/.",
+ "/a/../b": "/a/..",
+}
+
+func TestDefaultPath(t *testing.T) {
+ for path, want := range defaultPathTests {
+ if got := defaultPath(path); got != want {
+ t.Errorf("%q: got %q, want %q", path, got, want)
+ }
+ }
+}
+
+var domainAndTypeTests = [...]struct {
+ host string // host Set-Cookie header was received from
+ domain string // domain attribute in Set-Cookie header
+ wantDomain string // expected domain of cookie
+ wantHostOnly bool // expected host-cookie flag
+ wantErr error // expected error
+}{
+ {"www.example.com", "", "www.example.com", true, nil},
+ {"127.0.0.1", "", "127.0.0.1", true, nil},
+ {"2001:4860:0:2001::68", "", "2001:4860:0:2001::68", true, nil},
+ {"www.example.com", "example.com", "example.com", false, nil},
+ {"www.example.com", ".example.com", "example.com", false, nil},
+ {"www.example.com", "www.example.com", "www.example.com", false, nil},
+ {"www.example.com", ".www.example.com", "www.example.com", false, nil},
+ {"foo.sso.example.com", "sso.example.com", "sso.example.com", false, nil},
+ {"bar.co.uk", "bar.co.uk", "bar.co.uk", false, nil},
+ {"foo.bar.co.uk", ".bar.co.uk", "bar.co.uk", false, nil},
+ {"127.0.0.1", "127.0.0.1", "", false, errNoHostname},
+ {"2001:4860:0:2001::68", "2001:4860:0:2001::68", "2001:4860:0:2001::68", false, errNoHostname},
+ {"www.example.com", ".", "", false, errMalformedDomain},
+ {"www.example.com", "..", "", false, errMalformedDomain},
+ {"www.example.com", "other.com", "", false, errIllegalDomain},
+ {"www.example.com", "com", "", false, errIllegalDomain},
+ {"www.example.com", ".com", "", false, errIllegalDomain},
+ {"foo.bar.co.uk", ".co.uk", "", false, errIllegalDomain},
+ {"127.www.0.0.1", "127.0.0.1", "", false, errIllegalDomain},
+ {"com", "", "com", true, nil},
+ {"com", "com", "com", true, nil},
+ {"com", ".com", "com", true, nil},
+ {"co.uk", "", "co.uk", true, nil},
+ {"co.uk", "co.uk", "co.uk", true, nil},
+ {"co.uk", ".co.uk", "co.uk", true, nil},
+}
+
+func TestDomainAndType(t *testing.T) {
+ jar := newTestJar()
+ for _, tc := range domainAndTypeTests {
+ domain, hostOnly, err := jar.domainAndType(tc.host, tc.domain)
+ if err != tc.wantErr {
+ t.Errorf("%q/%q: got %q error, want %q",
+ tc.host, tc.domain, err, tc.wantErr)
+ continue
+ }
+ if err != nil {
+ continue
+ }
+ if domain != tc.wantDomain || hostOnly != tc.wantHostOnly {
+ t.Errorf("%q/%q: got %q/%t want %q/%t",
+ tc.host, tc.domain, domain, hostOnly,
+ tc.wantDomain, tc.wantHostOnly)
+ }
+ }
+}
+
+// expiresIn creates an expires attribute delta seconds from tNow.
+func expiresIn(delta int) string {
+ t := tNow.Add(time.Duration(delta) * time.Second)
+ return "expires=" + t.Format(time.RFC1123)
+}
+
+// mustParseURL parses s to an URL and panics on error.
+func mustParseURL(s string) *url.URL {
+ u, err := url.Parse(s)
+ if err != nil || u.Scheme == "" || u.Host == "" {
+ panic(fmt.Sprintf("Unable to parse URL %s.", s))
+ }
+ return u
+}
+
+// jarTest encapsulates the following actions on a jar:
+// 1. Perform SetCookies with fromURL and the cookies from setCookies.
+// (Done at time tNow + 0 ms.)
+// 2. Check that the entries in the jar matches content.
+// (Done at time tNow + 1001 ms.)
+// 3. For each query in tests: Check that Cookies with toURL yields the
+// cookies in want.
+// (Query n done at tNow + (n+2)*1001 ms.)
+type jarTest struct {
+ description string // The description of what this test is supposed to test
+ fromURL string // The full URL of the request from which Set-Cookie headers where received
+ setCookies []string // All the cookies received from fromURL
+ content string // The whole (non-expired) content of the jar
+ queries []query // Queries to test the Jar.Cookies method
+}
+
+// query contains one test of the cookies returned from Jar.Cookies.
+type query struct {
+ toURL string // the URL in the Cookies call
+ want string // the expected list of cookies (order matters)
+}
+
+// run runs the jarTest.
+func (test jarTest) run(t *testing.T, jar *Jar) {
+ now := tNow
+
+ // Populate jar with cookies.
+ setCookies := make([]*http.Cookie, len(test.setCookies))
+ for i, cs := range test.setCookies {
+ cookies := (&http.Response{Header: http.Header{"Set-Cookie": {cs}}}).Cookies()
+ if len(cookies) != 1 {
+ panic(fmt.Sprintf("Wrong cookie line %q: %#v", cs, cookies))
+ }
+ setCookies[i] = cookies[0]
+ }
+ jar.setCookies(mustParseURL(test.fromURL), setCookies, now)
+ now = now.Add(1001 * time.Millisecond)
+
+ // Serialize non-expired entries in the form "name1=val1 name2=val2".
+ var cs []string
+ for _, submap := range jar.entries {
+ for _, cookie := range submap {
+ if !cookie.Expires.After(now) {
+ continue
+ }
+ cs = append(cs, cookie.Name+"="+cookie.Value)
+ }
+ }
+ sort.Strings(cs)
+ got := strings.Join(cs, " ")
+
+ // Make sure jar content matches our expectations.
+ if got != test.content {
+ t.Errorf("Test %q Content\ngot %q\nwant %q",
+ test.description, got, test.content)
+ }
+
+ // Test different calls to Cookies.
+ for i, query := range test.queries {
+ now = now.Add(1001 * time.Millisecond)
+ var s []string
+ for _, c := range jar.cookies(mustParseURL(query.toURL), now) {
+ s = append(s, c.Name+"="+c.Value)
+ }
+ if got := strings.Join(s, " "); got != query.want {
+ t.Errorf("Test %q #%d\ngot %q\nwant %q", test.description, i, got, query.want)
+ }
+ }
+}
+
+// basicsTests contains fundamental tests. Each jarTest has to be performed on
+// a fresh, empty Jar.
+var basicsTests = [...]jarTest{
+ {
+ "Retrieval of a plain host cookie.",
+ "http://www.host.test/",
+ []string{"A=a"},
+ "A=a",
+ []query{
+ {"http://www.host.test", "A=a"},
+ {"http://www.host.test/", "A=a"},
+ {"http://www.host.test/some/path", "A=a"},
+ {"https://www.host.test", "A=a"},
+ {"https://www.host.test/", "A=a"},
+ {"https://www.host.test/some/path", "A=a"},
+ {"ftp://www.host.test", ""},
+ {"ftp://www.host.test/", ""},
+ {"ftp://www.host.test/some/path", ""},
+ {"http://www.other.org", ""},
+ {"http://sibling.host.test", ""},
+ {"http://deep.www.host.test", ""},
+ },
+ },
+ {
+ "Secure cookies are not returned to http.",
+ "http://www.host.test/",
+ []string{"A=a; secure"},
+ "A=a",
+ []query{
+ {"http://www.host.test", ""},
+ {"http://www.host.test/", ""},
+ {"http://www.host.test/some/path", ""},
+ {"https://www.host.test", "A=a"},
+ {"https://www.host.test/", "A=a"},
+ {"https://www.host.test/some/path", "A=a"},
+ },
+ },
+ {
+ "Explicit path.",
+ "http://www.host.test/",
+ []string{"A=a; path=/some/path"},
+ "A=a",
+ []query{
+ {"http://www.host.test", ""},
+ {"http://www.host.test/", ""},
+ {"http://www.host.test/some", ""},
+ {"http://www.host.test/some/", ""},
+ {"http://www.host.test/some/path", "A=a"},
+ {"http://www.host.test/some/paths", ""},
+ {"http://www.host.test/some/path/foo", "A=a"},
+ {"http://www.host.test/some/path/foo/", "A=a"},
+ },
+ },
+ {
+ "Implicit path #1: path is a directory.",
+ "http://www.host.test/some/path/",
+ []string{"A=a"},
+ "A=a",
+ []query{
+ {"http://www.host.test", ""},
+ {"http://www.host.test/", ""},
+ {"http://www.host.test/some", ""},
+ {"http://www.host.test/some/", ""},
+ {"http://www.host.test/some/path", "A=a"},
+ {"http://www.host.test/some/paths", ""},
+ {"http://www.host.test/some/path/foo", "A=a"},
+ {"http://www.host.test/some/path/foo/", "A=a"},
+ },
+ },
+ {
+ "Implicit path #2: path is not a directory.",
+ "http://www.host.test/some/path/index.html",
+ []string{"A=a"},
+ "A=a",
+ []query{
+ {"http://www.host.test", ""},
+ {"http://www.host.test/", ""},
+ {"http://www.host.test/some", ""},
+ {"http://www.host.test/some/", ""},
+ {"http://www.host.test/some/path", "A=a"},
+ {"http://www.host.test/some/paths", ""},
+ {"http://www.host.test/some/path/foo", "A=a"},
+ {"http://www.host.test/some/path/foo/", "A=a"},
+ },
+ },
+ {
+ "Implicit path #3: no path in URL at all.",
+ "http://www.host.test",
+ []string{"A=a"},
+ "A=a",
+ []query{
+ {"http://www.host.test", "A=a"},
+ {"http://www.host.test/", "A=a"},
+ {"http://www.host.test/some/path", "A=a"},
+ },
+ },
+ {
+ "Cookies are sorted by path length.",
+ "http://www.host.test/",
+ []string{
+ "A=a; path=/foo/bar",
+ "B=b; path=/foo/bar/baz/qux",
+ "C=c; path=/foo/bar/baz",
+ "D=d; path=/foo"},
+ "A=a B=b C=c D=d",
+ []query{
+ {"http://www.host.test/foo/bar/baz/qux", "B=b C=c A=a D=d"},
+ {"http://www.host.test/foo/bar/baz/", "C=c A=a D=d"},
+ {"http://www.host.test/foo/bar", "A=a D=d"},
+ },
+ },
+ {
+ "Creation time determines sorting on same length paths.",
+ "http://www.host.test/",
+ []string{
+ "A=a; path=/foo/bar",
+ "X=x; path=/foo/bar",
+ "Y=y; path=/foo/bar/baz/qux",
+ "B=b; path=/foo/bar/baz/qux",
+ "C=c; path=/foo/bar/baz",
+ "W=w; path=/foo/bar/baz",
+ "Z=z; path=/foo",
+ "D=d; path=/foo"},
+ "A=a B=b C=c D=d W=w X=x Y=y Z=z",
+ []query{
+ {"http://www.host.test/foo/bar/baz/qux", "Y=y B=b C=c W=w A=a X=x Z=z D=d"},
+ {"http://www.host.test/foo/bar/baz/", "C=c W=w A=a X=x Z=z D=d"},
+ {"http://www.host.test/foo/bar", "A=a X=x Z=z D=d"},
+ },
+ },
+ {
+ "Sorting of same-name cookies.",
+ "http://www.host.test/",
+ []string{
+ "A=1; path=/",
+ "A=2; path=/path",
+ "A=3; path=/quux",
+ "A=4; path=/path/foo",
+ "A=5; domain=.host.test; path=/path",
+ "A=6; domain=.host.test; path=/quux",
+ "A=7; domain=.host.test; path=/path/foo",
+ },
+ "A=1 A=2 A=3 A=4 A=5 A=6 A=7",
+ []query{
+ {"http://www.host.test/path", "A=2 A=5 A=1"},
+ {"http://www.host.test/path/foo", "A=4 A=7 A=2 A=5 A=1"},
+ },
+ },
+ {
+ "Disallow domain cookie on public suffix.",
+ "http://www.bbc.co.uk",
+ []string{
+ "a=1",
+ "b=2; domain=co.uk",
+ },
+ "a=1",
+ []query{{"http://www.bbc.co.uk", "a=1"}},
+ },
+ {
+ "Host cookie on IP.",
+ "http://192.168.0.10",
+ []string{"a=1"},
+ "a=1",
+ []query{{"http://192.168.0.10", "a=1"}},
+ },
+ {
+ "Port is ignored #1.",
+ "http://www.host.test/",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://www.host.test", "a=1"},
+ {"http://www.host.test:8080/", "a=1"},
+ },
+ },
+ {
+ "Port is ignored #2.",
+ "http://www.host.test:8080/",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://www.host.test", "a=1"},
+ {"http://www.host.test:8080/", "a=1"},
+ {"http://www.host.test:1234/", "a=1"},
+ },
+ },
+}
+
+func TestBasics(t *testing.T) {
+ for _, test := range basicsTests {
+ jar := newTestJar()
+ test.run(t, jar)
+ }
+}
+
+// updateAndDeleteTests contains jarTests which must be performed on the same
+// Jar.
+var updateAndDeleteTests = [...]jarTest{
+ {
+ "Set initial cookies.",
+ "http://www.host.test",
+ []string{
+ "a=1",
+ "b=2; secure",
+ "c=3; httponly",
+ "d=4; secure; httponly"},
+ "a=1 b=2 c=3 d=4",
+ []query{
+ {"http://www.host.test", "a=1 c=3"},
+ {"https://www.host.test", "a=1 b=2 c=3 d=4"},
+ },
+ },
+ {
+ "Update value via http.",
+ "http://www.host.test",
+ []string{
+ "a=w",
+ "b=x; secure",
+ "c=y; httponly",
+ "d=z; secure; httponly"},
+ "a=w b=x c=y d=z",
+ []query{
+ {"http://www.host.test", "a=w c=y"},
+ {"https://www.host.test", "a=w b=x c=y d=z"},
+ },
+ },
+ {
+ "Clear Secure flag from a http.",
+ "http://www.host.test/",
+ []string{
+ "b=xx",
+ "d=zz; httponly"},
+ "a=w b=xx c=y d=zz",
+ []query{{"http://www.host.test", "a=w b=xx c=y d=zz"}},
+ },
+ {
+ "Delete all.",
+ "http://www.host.test/",
+ []string{
+ "a=1; max-Age=-1", // delete via MaxAge
+ "b=2; " + expiresIn(-10), // delete via Expires
+ "c=2; max-age=-1; " + expiresIn(-10), // delete via both
+ "d=4; max-age=-1; " + expiresIn(10)}, // MaxAge takes precedence
+ "",
+ []query{{"http://www.host.test", ""}},
+ },
+ {
+ "Refill #1.",
+ "http://www.host.test",
+ []string{
+ "A=1",
+ "A=2; path=/foo",
+ "A=3; domain=.host.test",
+ "A=4; path=/foo; domain=.host.test"},
+ "A=1 A=2 A=3 A=4",
+ []query{{"http://www.host.test/foo", "A=2 A=4 A=1 A=3"}},
+ },
+ {
+ "Refill #2.",
+ "http://www.google.com",
+ []string{
+ "A=6",
+ "A=7; path=/foo",
+ "A=8; domain=.google.com",
+ "A=9; path=/foo; domain=.google.com"},
+ "A=1 A=2 A=3 A=4 A=6 A=7 A=8 A=9",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=4 A=1 A=3"},
+ {"http://www.google.com/foo", "A=7 A=9 A=6 A=8"},
+ },
+ },
+ {
+ "Delete A7.",
+ "http://www.google.com",
+ []string{"A=; path=/foo; max-age=-1"},
+ "A=1 A=2 A=3 A=4 A=6 A=8 A=9",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=4 A=1 A=3"},
+ {"http://www.google.com/foo", "A=9 A=6 A=8"},
+ },
+ },
+ {
+ "Delete A4.",
+ "http://www.host.test",
+ []string{"A=; path=/foo; domain=host.test; max-age=-1"},
+ "A=1 A=2 A=3 A=6 A=8 A=9",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=1 A=3"},
+ {"http://www.google.com/foo", "A=9 A=6 A=8"},
+ },
+ },
+ {
+ "Delete A6.",
+ "http://www.google.com",
+ []string{"A=; max-age=-1"},
+ "A=1 A=2 A=3 A=8 A=9",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=1 A=3"},
+ {"http://www.google.com/foo", "A=9 A=8"},
+ },
+ },
+ {
+ "Delete A3.",
+ "http://www.host.test",
+ []string{"A=; domain=host.test; max-age=-1"},
+ "A=1 A=2 A=8 A=9",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=1"},
+ {"http://www.google.com/foo", "A=9 A=8"},
+ },
+ },
+ {
+ "No cross-domain delete.",
+ "http://www.host.test",
+ []string{
+ "A=; domain=google.com; max-age=-1",
+ "A=; path=/foo; domain=google.com; max-age=-1"},
+ "A=1 A=2 A=8 A=9",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=1"},
+ {"http://www.google.com/foo", "A=9 A=8"},
+ },
+ },
+ {
+ "Delete A8 and A9.",
+ "http://www.google.com",
+ []string{
+ "A=; domain=google.com; max-age=-1",
+ "A=; path=/foo; domain=google.com; max-age=-1"},
+ "A=1 A=2",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=1"},
+ {"http://www.google.com/foo", ""},
+ },
+ },
+}
+
+func TestUpdateAndDelete(t *testing.T) {
+ jar := newTestJar()
+ for _, test := range updateAndDeleteTests {
+ test.run(t, jar)
+ }
+}
+
+func TestExpiration(t *testing.T) {
+ jar := newTestJar()
+ jarTest{
+ "Expiration.",
+ "http://www.host.test",
+ []string{
+ "a=1",
+ "b=2; max-age=3",
+ "c=3; " + expiresIn(3),
+ "d=4; max-age=5",
+ "e=5; " + expiresIn(5),
+ "f=6; max-age=100",
+ },
+ "a=1 b=2 c=3 d=4 e=5 f=6", // executed at t0 + 1001 ms
+ []query{
+ {"http://www.host.test", "a=1 b=2 c=3 d=4 e=5 f=6"}, // t0 + 2002 ms
+ {"http://www.host.test", "a=1 d=4 e=5 f=6"}, // t0 + 3003 ms
+ {"http://www.host.test", "a=1 d=4 e=5 f=6"}, // t0 + 4004 ms
+ {"http://www.host.test", "a=1 f=6"}, // t0 + 5005 ms
+ {"http://www.host.test", "a=1 f=6"}, // t0 + 6006 ms
+ },
+ }.run(t, jar)
+}
+
+//
+// Tests derived from Chromium's cookie_store_unittest.h.
+//
+
+// See http://src.chromium.org/viewvc/chrome/trunk/src/net/cookies/cookie_store_unittest.h?revision=159685&content-type=text/plain
+// Some of the original tests are in a bad condition (e.g.
+// DomainWithTrailingDotTest) or are not RFC 6265 conforming (e.g.
+// TestNonDottedAndTLD #1 and #6) and have not been ported.
+
+// chromiumBasicsTests contains fundamental tests. Each jarTest has to be
+// performed on a fresh, empty Jar.
+var chromiumBasicsTests = [...]jarTest{
+ {
+ "DomainWithTrailingDotTest.",
+ "http://www.google.com/",
+ []string{
+ "a=1; domain=.www.google.com.",
+ "b=2; domain=.www.google.com.."},
+ "",
+ []query{
+ {"http://www.google.com", ""},
+ },
+ },
+ {
+ "ValidSubdomainTest #1.",
+ "http://a.b.c.d.com",
+ []string{
+ "a=1; domain=.a.b.c.d.com",
+ "b=2; domain=.b.c.d.com",
+ "c=3; domain=.c.d.com",
+ "d=4; domain=.d.com"},
+ "a=1 b=2 c=3 d=4",
+ []query{
+ {"http://a.b.c.d.com", "a=1 b=2 c=3 d=4"},
+ {"http://b.c.d.com", "b=2 c=3 d=4"},
+ {"http://c.d.com", "c=3 d=4"},
+ {"http://d.com", "d=4"},
+ },
+ },
+ {
+ "ValidSubdomainTest #2.",
+ "http://a.b.c.d.com",
+ []string{
+ "a=1; domain=.a.b.c.d.com",
+ "b=2; domain=.b.c.d.com",
+ "c=3; domain=.c.d.com",
+ "d=4; domain=.d.com",
+ "X=bcd; domain=.b.c.d.com",
+ "X=cd; domain=.c.d.com"},
+ "X=bcd X=cd a=1 b=2 c=3 d=4",
+ []query{
+ {"http://b.c.d.com", "b=2 c=3 d=4 X=bcd X=cd"},
+ {"http://c.d.com", "c=3 d=4 X=cd"},
+ },
+ },
+ {
+ "InvalidDomainTest #1.",
+ "http://foo.bar.com",
+ []string{
+ "a=1; domain=.yo.foo.bar.com",
+ "b=2; domain=.foo.com",
+ "c=3; domain=.bar.foo.com",
+ "d=4; domain=.foo.bar.com.net",
+ "e=5; domain=ar.com",
+ "f=6; domain=.",
+ "g=7; domain=/",
+ "h=8; domain=http://foo.bar.com",
+ "i=9; domain=..foo.bar.com",
+ "j=10; domain=..bar.com",
+ "k=11; domain=.foo.bar.com?blah",
+ "l=12; domain=.foo.bar.com/blah",
+ "m=12; domain=.foo.bar.com:80",
+ "n=14; domain=.foo.bar.com:",
+ "o=15; domain=.foo.bar.com#sup",
+ },
+ "", // Jar is empty.
+ []query{{"http://foo.bar.com", ""}},
+ },
+ {
+ "InvalidDomainTest #2.",
+ "http://foo.com.com",
+ []string{"a=1; domain=.foo.com.com.com"},
+ "",
+ []query{{"http://foo.bar.com", ""}},
+ },
+ {
+ "DomainWithoutLeadingDotTest #1.",
+ "http://manage.hosted.filefront.com",
+ []string{"a=1; domain=filefront.com"},
+ "a=1",
+ []query{{"http://www.filefront.com", "a=1"}},
+ },
+ {
+ "DomainWithoutLeadingDotTest #2.",
+ "http://www.google.com",
+ []string{"a=1; domain=www.google.com"},
+ "a=1",
+ []query{
+ {"http://www.google.com", "a=1"},
+ {"http://sub.www.google.com", "a=1"},
+ {"http://something-else.com", ""},
+ },
+ },
+ {
+ "CaseInsensitiveDomainTest.",
+ "http://www.google.com",
+ []string{
+ "a=1; domain=.GOOGLE.COM",
+ "b=2; domain=.www.gOOgLE.coM"},
+ "a=1 b=2",
+ []query{{"http://www.google.com", "a=1 b=2"}},
+ },
+ {
+ "TestIpAddress #1.",
+ "http://1.2.3.4/foo",
+ []string{"a=1; path=/"},
+ "a=1",
+ []query{{"http://1.2.3.4/foo", "a=1"}},
+ },
+ {
+ "TestIpAddress #2.",
+ "http://1.2.3.4/foo",
+ []string{
+ "a=1; domain=.1.2.3.4",
+ "b=2; domain=.3.4"},
+ "",
+ []query{{"http://1.2.3.4/foo", ""}},
+ },
+ {
+ "TestIpAddress #3.",
+ "http://1.2.3.4/foo",
+ []string{"a=1; domain=1.2.3.4"},
+ "",
+ []query{{"http://1.2.3.4/foo", ""}},
+ },
+ {
+ "TestNonDottedAndTLD #2.",
+ "http://com./index.html",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://com./index.html", "a=1"},
+ {"http://no-cookies.com./index.html", ""},
+ },
+ },
+ {
+ "TestNonDottedAndTLD #3.",
+ "http://a.b",
+ []string{
+ "a=1; domain=.b",
+ "b=2; domain=b"},
+ "",
+ []query{{"http://bar.foo", ""}},
+ },
+ {
+ "TestNonDottedAndTLD #4.",
+ "http://google.com",
+ []string{
+ "a=1; domain=.com",
+ "b=2; domain=com"},
+ "",
+ []query{{"http://google.com", ""}},
+ },
+ {
+ "TestNonDottedAndTLD #5.",
+ "http://google.co.uk",
+ []string{
+ "a=1; domain=.co.uk",
+ "b=2; domain=.uk"},
+ "",
+ []query{
+ {"http://google.co.uk", ""},
+ {"http://else.co.com", ""},
+ {"http://else.uk", ""},
+ },
+ },
+ {
+ "TestHostEndsWithDot.",
+ "http://www.google.com",
+ []string{
+ "a=1",
+ "b=2; domain=.www.google.com."},
+ "a=1",
+ []query{{"http://www.google.com", "a=1"}},
+ },
+ {
+ "PathTest",
+ "http://www.google.izzle",
+ []string{"a=1; path=/wee"},
+ "a=1",
+ []query{
+ {"http://www.google.izzle/wee", "a=1"},
+ {"http://www.google.izzle/wee/", "a=1"},
+ {"http://www.google.izzle/wee/war", "a=1"},
+ {"http://www.google.izzle/wee/war/more/more", "a=1"},
+ {"http://www.google.izzle/weehee", ""},
+ {"http://www.google.izzle/", ""},
+ },
+ },
+}
+
+func TestChromiumBasics(t *testing.T) {
+ for _, test := range chromiumBasicsTests {
+ jar := newTestJar()
+ test.run(t, jar)
+ }
+}
+
+// chromiumDomainTests contains jarTests which must be executed all on the
+// same Jar.
+var chromiumDomainTests = [...]jarTest{
+ {
+ "Fill #1.",
+ "http://www.google.izzle",
+ []string{"A=B"},
+ "A=B",
+ []query{{"http://www.google.izzle", "A=B"}},
+ },
+ {
+ "Fill #2.",
+ "http://www.google.izzle",
+ []string{"C=D; domain=.google.izzle"},
+ "A=B C=D",
+ []query{{"http://www.google.izzle", "A=B C=D"}},
+ },
+ {
+ "Verify A is a host cookie and not accessible from subdomain.",
+ "http://unused.nil",
+ []string{},
+ "A=B C=D",
+ []query{{"http://foo.www.google.izzle", "C=D"}},
+ },
+ {
+ "Verify domain cookies are found on proper domain.",
+ "http://www.google.izzle",
+ []string{"E=F; domain=.www.google.izzle"},
+ "A=B C=D E=F",
+ []query{{"http://www.google.izzle", "A=B C=D E=F"}},
+ },
+ {
+ "Leading dots in domain attributes are optional.",
+ "http://www.google.izzle",
+ []string{"G=H; domain=www.google.izzle"},
+ "A=B C=D E=F G=H",
+ []query{{"http://www.google.izzle", "A=B C=D E=F G=H"}},
+ },
+ {
+ "Verify domain enforcement works #1.",
+ "http://www.google.izzle",
+ []string{"K=L; domain=.bar.www.google.izzle"},
+ "A=B C=D E=F G=H",
+ []query{{"http://bar.www.google.izzle", "C=D E=F G=H"}},
+ },
+ {
+ "Verify domain enforcement works #2.",
+ "http://unused.nil",
+ []string{},
+ "A=B C=D E=F G=H",
+ []query{{"http://www.google.izzle", "A=B C=D E=F G=H"}},
+ },
+}
+
+func TestChromiumDomain(t *testing.T) {
+ jar := newTestJar()
+ for _, test := range chromiumDomainTests {
+ test.run(t, jar)
+ }
+
+}
+
+// chromiumDeletionTests must be performed all on the same Jar.
+var chromiumDeletionTests = [...]jarTest{
+ {
+ "Create session cookie a1.",
+ "http://www.google.com",
+ []string{"a=1"},
+ "a=1",
+ []query{{"http://www.google.com", "a=1"}},
+ },
+ {
+ "Delete sc a1 via MaxAge.",
+ "http://www.google.com",
+ []string{"a=1; max-age=-1"},
+ "",
+ []query{{"http://www.google.com", ""}},
+ },
+ {
+ "Create session cookie b2.",
+ "http://www.google.com",
+ []string{"b=2"},
+ "b=2",
+ []query{{"http://www.google.com", "b=2"}},
+ },
+ {
+ "Delete sc b2 via Expires.",
+ "http://www.google.com",
+ []string{"b=2; " + expiresIn(-10)},
+ "",
+ []query{{"http://www.google.com", ""}},
+ },
+ {
+ "Create persistent cookie c3.",
+ "http://www.google.com",
+ []string{"c=3; max-age=3600"},
+ "c=3",
+ []query{{"http://www.google.com", "c=3"}},
+ },
+ {
+ "Delete pc c3 via MaxAge.",
+ "http://www.google.com",
+ []string{"c=3; max-age=-1"},
+ "",
+ []query{{"http://www.google.com", ""}},
+ },
+ {
+ "Create persistent cookie d4.",
+ "http://www.google.com",
+ []string{"d=4; max-age=3600"},
+ "d=4",
+ []query{{"http://www.google.com", "d=4"}},
+ },
+ {
+ "Delete pc d4 via Expires.",
+ "http://www.google.com",
+ []string{"d=4; " + expiresIn(-10)},
+ "",
+ []query{{"http://www.google.com", ""}},
+ },
+}
+
+func TestChromiumDeletion(t *testing.T) {
+ jar := newTestJar()
+ for _, test := range chromiumDeletionTests {
+ test.run(t, jar)
+ }
+}
+
+// domainHandlingTests tests and documents the rules for domain handling.
+// Each test must be performed on an empty new Jar.
+var domainHandlingTests = [...]jarTest{
+ {
+ "Host cookie",
+ "http://www.host.test",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://www.host.test", "a=1"},
+ {"http://host.test", ""},
+ {"http://bar.host.test", ""},
+ {"http://foo.www.host.test", ""},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Domain cookie #1",
+ "http://www.host.test",
+ []string{"a=1; domain=host.test"},
+ "a=1",
+ []query{
+ {"http://www.host.test", "a=1"},
+ {"http://host.test", "a=1"},
+ {"http://bar.host.test", "a=1"},
+ {"http://foo.www.host.test", "a=1"},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Domain cookie #2",
+ "http://www.host.test",
+ []string{"a=1; domain=.host.test"},
+ "a=1",
+ []query{
+ {"http://www.host.test", "a=1"},
+ {"http://host.test", "a=1"},
+ {"http://bar.host.test", "a=1"},
+ {"http://foo.www.host.test", "a=1"},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Host cookie on IDNA domain #1",
+ "http://www.bรผcher.test",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://www.bรผcher.test", "a=1"},
+ {"http://www.xn--bcher-kva.test", "a=1"},
+ {"http://bรผcher.test", ""},
+ {"http://xn--bcher-kva.test", ""},
+ {"http://bar.bรผcher.test", ""},
+ {"http://bar.xn--bcher-kva.test", ""},
+ {"http://foo.www.bรผcher.test", ""},
+ {"http://foo.www.xn--bcher-kva.test", ""},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Host cookie on IDNA domain #2",
+ "http://www.xn--bcher-kva.test",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://www.bรผcher.test", "a=1"},
+ {"http://www.xn--bcher-kva.test", "a=1"},
+ {"http://bรผcher.test", ""},
+ {"http://xn--bcher-kva.test", ""},
+ {"http://bar.bรผcher.test", ""},
+ {"http://bar.xn--bcher-kva.test", ""},
+ {"http://foo.www.bรผcher.test", ""},
+ {"http://foo.www.xn--bcher-kva.test", ""},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Domain cookie on IDNA domain #1",
+ "http://www.bรผcher.test",
+ []string{"a=1; domain=xn--bcher-kva.test"},
+ "a=1",
+ []query{
+ {"http://www.bรผcher.test", "a=1"},
+ {"http://www.xn--bcher-kva.test", "a=1"},
+ {"http://bรผcher.test", "a=1"},
+ {"http://xn--bcher-kva.test", "a=1"},
+ {"http://bar.bรผcher.test", "a=1"},
+ {"http://bar.xn--bcher-kva.test", "a=1"},
+ {"http://foo.www.bรผcher.test", "a=1"},
+ {"http://foo.www.xn--bcher-kva.test", "a=1"},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Domain cookie on IDNA domain #2",
+ "http://www.xn--bcher-kva.test",
+ []string{"a=1; domain=xn--bcher-kva.test"},
+ "a=1",
+ []query{
+ {"http://www.bรผcher.test", "a=1"},
+ {"http://www.xn--bcher-kva.test", "a=1"},
+ {"http://bรผcher.test", "a=1"},
+ {"http://xn--bcher-kva.test", "a=1"},
+ {"http://bar.bรผcher.test", "a=1"},
+ {"http://bar.xn--bcher-kva.test", "a=1"},
+ {"http://foo.www.bรผcher.test", "a=1"},
+ {"http://foo.www.xn--bcher-kva.test", "a=1"},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Host cookie on TLD.",
+ "http://com",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://com", "a=1"},
+ {"http://any.com", ""},
+ {"http://any.test", ""},
+ },
+ },
+ {
+ "Domain cookie on TLD becomes a host cookie.",
+ "http://com",
+ []string{"a=1; domain=com"},
+ "a=1",
+ []query{
+ {"http://com", "a=1"},
+ {"http://any.com", ""},
+ {"http://any.test", ""},
+ },
+ },
+ {
+ "Host cookie on public suffix.",
+ "http://co.uk",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://co.uk", "a=1"},
+ {"http://uk", ""},
+ {"http://some.co.uk", ""},
+ {"http://foo.some.co.uk", ""},
+ {"http://any.uk", ""},
+ },
+ },
+ {
+ "Domain cookie on public suffix is ignored.",
+ "http://some.co.uk",
+ []string{"a=1; domain=co.uk"},
+ "",
+ []query{
+ {"http://co.uk", ""},
+ {"http://uk", ""},
+ {"http://some.co.uk", ""},
+ {"http://foo.some.co.uk", ""},
+ {"http://any.uk", ""},
+ },
+ },
+}
+
+func TestDomainHandling(t *testing.T) {
+ for _, test := range domainHandlingTests {
+ jar := newTestJar()
+ test.run(t, jar)
+ }
+}
diff --git a/src/net/http/cookiejar/punycode.go b/src/net/http/cookiejar/punycode.go
new file mode 100644
index 000000000..ea7ceb5ef
--- /dev/null
+++ b/src/net/http/cookiejar/punycode.go
@@ -0,0 +1,159 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cookiejar
+
+// This file implements the Punycode algorithm from RFC 3492.
+
+import (
+ "fmt"
+ "strings"
+ "unicode/utf8"
+)
+
+// These parameter values are specified in section 5.
+//
+// All computation is done with int32s, so that overflow behavior is identical
+// regardless of whether int is 32-bit or 64-bit.
+const (
+ base int32 = 36
+ damp int32 = 700
+ initialBias int32 = 72
+ initialN int32 = 128
+ skew int32 = 38
+ tmax int32 = 26
+ tmin int32 = 1
+)
+
+// encode encodes a string as specified in section 6.3 and prepends prefix to
+// the result.
+//
+// The "while h < length(input)" line in the specification becomes "for
+// remaining != 0" in the Go code, because len(s) in Go is in bytes, not runes.
+func encode(prefix, s string) (string, error) {
+ output := make([]byte, len(prefix), len(prefix)+1+2*len(s))
+ copy(output, prefix)
+ delta, n, bias := int32(0), initialN, initialBias
+ b, remaining := int32(0), int32(0)
+ for _, r := range s {
+ if r < 0x80 {
+ b++
+ output = append(output, byte(r))
+ } else {
+ remaining++
+ }
+ }
+ h := b
+ if b > 0 {
+ output = append(output, '-')
+ }
+ for remaining != 0 {
+ m := int32(0x7fffffff)
+ for _, r := range s {
+ if m > r && r >= n {
+ m = r
+ }
+ }
+ delta += (m - n) * (h + 1)
+ if delta < 0 {
+ return "", fmt.Errorf("cookiejar: invalid label %q", s)
+ }
+ n = m
+ for _, r := range s {
+ if r < n {
+ delta++
+ if delta < 0 {
+ return "", fmt.Errorf("cookiejar: invalid label %q", s)
+ }
+ continue
+ }
+ if r > n {
+ continue
+ }
+ q := delta
+ for k := base; ; k += base {
+ t := k - bias
+ if t < tmin {
+ t = tmin
+ } else if t > tmax {
+ t = tmax
+ }
+ if q < t {
+ break
+ }
+ output = append(output, encodeDigit(t+(q-t)%(base-t)))
+ q = (q - t) / (base - t)
+ }
+ output = append(output, encodeDigit(q))
+ bias = adapt(delta, h+1, h == b)
+ delta = 0
+ h++
+ remaining--
+ }
+ delta++
+ n++
+ }
+ return string(output), nil
+}
+
+func encodeDigit(digit int32) byte {
+ switch {
+ case 0 <= digit && digit < 26:
+ return byte(digit + 'a')
+ case 26 <= digit && digit < 36:
+ return byte(digit + ('0' - 26))
+ }
+ panic("cookiejar: internal error in punycode encoding")
+}
+
+// adapt is the bias adaptation function specified in section 6.1.
+func adapt(delta, numPoints int32, firstTime bool) int32 {
+ if firstTime {
+ delta /= damp
+ } else {
+ delta /= 2
+ }
+ delta += delta / numPoints
+ k := int32(0)
+ for delta > ((base-tmin)*tmax)/2 {
+ delta /= base - tmin
+ k += base
+ }
+ return k + (base-tmin+1)*delta/(delta+skew)
+}
+
+// Strictly speaking, the remaining code below deals with IDNA (RFC 5890 and
+// friends) and not Punycode (RFC 3492) per se.
+
+// acePrefix is the ASCII Compatible Encoding prefix.
+const acePrefix = "xn--"
+
+// toASCII converts a domain or domain label to its ASCII form. For example,
+// toASCII("bรผcher.example.com") is "xn--bcher-kva.example.com", and
+// toASCII("golang") is "golang".
+func toASCII(s string) (string, error) {
+ if ascii(s) {
+ return s, nil
+ }
+ labels := strings.Split(s, ".")
+ for i, label := range labels {
+ if !ascii(label) {
+ a, err := encode(acePrefix, label)
+ if err != nil {
+ return "", err
+ }
+ labels[i] = a
+ }
+ }
+ return strings.Join(labels, "."), nil
+}
+
+func ascii(s string) bool {
+ for i := 0; i < len(s); i++ {
+ if s[i] >= utf8.RuneSelf {
+ return false
+ }
+ }
+ return true
+}
diff --git a/src/net/http/cookiejar/punycode_test.go b/src/net/http/cookiejar/punycode_test.go
new file mode 100644
index 000000000..0301de14e
--- /dev/null
+++ b/src/net/http/cookiejar/punycode_test.go
@@ -0,0 +1,161 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cookiejar
+
+import (
+ "testing"
+)
+
+var punycodeTestCases = [...]struct {
+ s, encoded string
+}{
+ {"", ""},
+ {"-", "--"},
+ {"-a", "-a-"},
+ {"-a-", "-a--"},
+ {"a", "a-"},
+ {"a-", "a--"},
+ {"a-b", "a-b-"},
+ {"books", "books-"},
+ {"bรผcher", "bcher-kva"},
+ {"Helloไธ–็•Œ", "Hello-ck1hg65u"},
+ {"รผ", "tda"},
+ {"รผรฝ", "tdac"},
+
+ // The test cases below come from RFC 3492 section 7.1 with Errata 3026.
+ {
+ // (A) Arabic (Egyptian).
+ "\u0644\u064A\u0647\u0645\u0627\u0628\u062A\u0643\u0644" +
+ "\u0645\u0648\u0634\u0639\u0631\u0628\u064A\u061F",
+ "egbpdaj6bu4bxfgehfvwxn",
+ },
+ {
+ // (B) Chinese (simplified).
+ "\u4ED6\u4EEC\u4E3A\u4EC0\u4E48\u4E0D\u8BF4\u4E2D\u6587",
+ "ihqwcrb4cv8a8dqg056pqjye",
+ },
+ {
+ // (C) Chinese (traditional).
+ "\u4ED6\u5011\u7232\u4EC0\u9EBD\u4E0D\u8AAA\u4E2D\u6587",
+ "ihqwctvzc91f659drss3x8bo0yb",
+ },
+ {
+ // (D) Czech.
+ "\u0050\u0072\u006F\u010D\u0070\u0072\u006F\u0073\u0074" +
+ "\u011B\u006E\u0065\u006D\u006C\u0075\u0076\u00ED\u010D" +
+ "\u0065\u0073\u006B\u0079",
+ "Proprostnemluvesky-uyb24dma41a",
+ },
+ {
+ // (E) Hebrew.
+ "\u05DC\u05DE\u05D4\u05D4\u05DD\u05E4\u05E9\u05D5\u05D8" +
+ "\u05DC\u05D0\u05DE\u05D3\u05D1\u05E8\u05D9\u05DD\u05E2" +
+ "\u05D1\u05E8\u05D9\u05EA",
+ "4dbcagdahymbxekheh6e0a7fei0b",
+ },
+ {
+ // (F) Hindi (Devanagari).
+ "\u092F\u0939\u0932\u094B\u0917\u0939\u093F\u0928\u094D" +
+ "\u0926\u0940\u0915\u094D\u092F\u094B\u0902\u0928\u0939" +
+ "\u0940\u0902\u092C\u094B\u0932\u0938\u0915\u0924\u0947" +
+ "\u0939\u0948\u0902",
+ "i1baa7eci9glrd9b2ae1bj0hfcgg6iyaf8o0a1dig0cd",
+ },
+ {
+ // (G) Japanese (kanji and hiragana).
+ "\u306A\u305C\u307F\u3093\u306A\u65E5\u672C\u8A9E\u3092" +
+ "\u8A71\u3057\u3066\u304F\u308C\u306A\u3044\u306E\u304B",
+ "n8jok5ay5dzabd5bym9f0cm5685rrjetr6pdxa",
+ },
+ {
+ // (H) Korean (Hangul syllables).
+ "\uC138\uACC4\uC758\uBAA8\uB4E0\uC0AC\uB78C\uB4E4\uC774" +
+ "\uD55C\uAD6D\uC5B4\uB97C\uC774\uD574\uD55C\uB2E4\uBA74" +
+ "\uC5BC\uB9C8\uB098\uC88B\uC744\uAE4C",
+ "989aomsvi5e83db1d2a355cv1e0vak1dwrv93d5xbh15a0dt30a5j" +
+ "psd879ccm6fea98c",
+ },
+ {
+ // (I) Russian (Cyrillic).
+ "\u043F\u043E\u0447\u0435\u043C\u0443\u0436\u0435\u043E" +
+ "\u043D\u0438\u043D\u0435\u0433\u043E\u0432\u043E\u0440" +
+ "\u044F\u0442\u043F\u043E\u0440\u0443\u0441\u0441\u043A" +
+ "\u0438",
+ "b1abfaaepdrnnbgefbadotcwatmq2g4l",
+ },
+ {
+ // (J) Spanish.
+ "\u0050\u006F\u0072\u0071\u0075\u00E9\u006E\u006F\u0070" +
+ "\u0075\u0065\u0064\u0065\u006E\u0073\u0069\u006D\u0070" +
+ "\u006C\u0065\u006D\u0065\u006E\u0074\u0065\u0068\u0061" +
+ "\u0062\u006C\u0061\u0072\u0065\u006E\u0045\u0073\u0070" +
+ "\u0061\u00F1\u006F\u006C",
+ "PorqunopuedensimplementehablarenEspaol-fmd56a",
+ },
+ {
+ // (K) Vietnamese.
+ "\u0054\u1EA1\u0069\u0073\u0061\u006F\u0068\u1ECD\u006B" +
+ "\u0068\u00F4\u006E\u0067\u0074\u0068\u1EC3\u0063\u0068" +
+ "\u1EC9\u006E\u00F3\u0069\u0074\u0069\u1EBF\u006E\u0067" +
+ "\u0056\u0069\u1EC7\u0074",
+ "TisaohkhngthchnitingVit-kjcr8268qyxafd2f1b9g",
+ },
+ {
+ // (L) 3<nen>B<gumi><kinpachi><sensei>.
+ "\u0033\u5E74\u0042\u7D44\u91D1\u516B\u5148\u751F",
+ "3B-ww4c5e180e575a65lsy2b",
+ },
+ {
+ // (M) <amuro><namie>-with-SUPER-MONKEYS.
+ "\u5B89\u5BA4\u5948\u7F8E\u6075\u002D\u0077\u0069\u0074" +
+ "\u0068\u002D\u0053\u0055\u0050\u0045\u0052\u002D\u004D" +
+ "\u004F\u004E\u004B\u0045\u0059\u0053",
+ "-with-SUPER-MONKEYS-pc58ag80a8qai00g7n9n",
+ },
+ {
+ // (N) Hello-Another-Way-<sorezore><no><basho>.
+ "\u0048\u0065\u006C\u006C\u006F\u002D\u0041\u006E\u006F" +
+ "\u0074\u0068\u0065\u0072\u002D\u0057\u0061\u0079\u002D" +
+ "\u305D\u308C\u305E\u308C\u306E\u5834\u6240",
+ "Hello-Another-Way--fc4qua05auwb3674vfr0b",
+ },
+ {
+ // (O) <hitotsu><yane><no><shita>2.
+ "\u3072\u3068\u3064\u5C4B\u6839\u306E\u4E0B\u0032",
+ "2-u9tlzr9756bt3uc0v",
+ },
+ {
+ // (P) Maji<de>Koi<suru>5<byou><mae>
+ "\u004D\u0061\u006A\u0069\u3067\u004B\u006F\u0069\u3059" +
+ "\u308B\u0035\u79D2\u524D",
+ "MajiKoi5-783gue6qz075azm5e",
+ },
+ {
+ // (Q) <pafii>de<runba>
+ "\u30D1\u30D5\u30A3\u30FC\u0064\u0065\u30EB\u30F3\u30D0",
+ "de-jg4avhby1noc0d",
+ },
+ {
+ // (R) <sono><supiido><de>
+ "\u305D\u306E\u30B9\u30D4\u30FC\u30C9\u3067",
+ "d9juau41awczczp",
+ },
+ {
+ // (S) -> $1.00 <-
+ "\u002D\u003E\u0020\u0024\u0031\u002E\u0030\u0030\u0020" +
+ "\u003C\u002D",
+ "-> $1.00 <--",
+ },
+}
+
+func TestPunycode(t *testing.T) {
+ for _, tc := range punycodeTestCases {
+ if got, err := encode("", tc.s); err != nil {
+ t.Errorf(`encode("", %q): %v`, tc.s, err)
+ } else if got != tc.encoded {
+ t.Errorf(`encode("", %q): got %q, want %q`, tc.s, got, tc.encoded)
+ }
+ }
+}
diff --git a/src/net/http/doc.go b/src/net/http/doc.go
new file mode 100644
index 000000000..b1216e8da
--- /dev/null
+++ b/src/net/http/doc.go
@@ -0,0 +1,80 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+Package http provides HTTP client and server implementations.
+
+Get, Head, Post, and PostForm make HTTP (or HTTPS) requests:
+
+ resp, err := http.Get("http://example.com/")
+ ...
+ resp, err := http.Post("http://example.com/upload", "image/jpeg", &buf)
+ ...
+ resp, err := http.PostForm("http://example.com/form",
+ url.Values{"key": {"Value"}, "id": {"123"}})
+
+The client must close the response body when finished with it:
+
+ resp, err := http.Get("http://example.com/")
+ if err != nil {
+ // handle error
+ }
+ defer resp.Body.Close()
+ body, err := ioutil.ReadAll(resp.Body)
+ // ...
+
+For control over HTTP client headers, redirect policy, and other
+settings, create a Client:
+
+ client := &http.Client{
+ CheckRedirect: redirectPolicyFunc,
+ }
+
+ resp, err := client.Get("http://example.com")
+ // ...
+
+ req, err := http.NewRequest("GET", "http://example.com", nil)
+ // ...
+ req.Header.Add("If-None-Match", `W/"wyzzy"`)
+ resp, err := client.Do(req)
+ // ...
+
+For control over proxies, TLS configuration, keep-alives,
+compression, and other settings, create a Transport:
+
+ tr := &http.Transport{
+ TLSClientConfig: &tls.Config{RootCAs: pool},
+ DisableCompression: true,
+ }
+ client := &http.Client{Transport: tr}
+ resp, err := client.Get("https://example.com")
+
+Clients and Transports are safe for concurrent use by multiple
+goroutines and for efficiency should only be created once and re-used.
+
+ListenAndServe starts an HTTP server with a given address and handler.
+The handler is usually nil, which means to use DefaultServeMux.
+Handle and HandleFunc add handlers to DefaultServeMux:
+
+ http.Handle("/foo", fooHandler)
+
+ http.HandleFunc("/bar", func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path))
+ })
+
+ log.Fatal(http.ListenAndServe(":8080", nil))
+
+More control over the server's behavior is available by creating a
+custom Server:
+
+ s := &http.Server{
+ Addr: ":8080",
+ Handler: myHandler,
+ ReadTimeout: 10 * time.Second,
+ WriteTimeout: 10 * time.Second,
+ MaxHeaderBytes: 1 << 20,
+ }
+ log.Fatal(s.ListenAndServe())
+*/
+package http
diff --git a/src/net/http/example_test.go b/src/net/http/example_test.go
new file mode 100644
index 000000000..88b97d9e3
--- /dev/null
+++ b/src/net/http/example_test.go
@@ -0,0 +1,88 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "fmt"
+ "io/ioutil"
+ "log"
+ "net/http"
+)
+
+func ExampleHijacker() {
+ http.HandleFunc("/hijack", func(w http.ResponseWriter, r *http.Request) {
+ hj, ok := w.(http.Hijacker)
+ if !ok {
+ http.Error(w, "webserver doesn't support hijacking", http.StatusInternalServerError)
+ return
+ }
+ conn, bufrw, err := hj.Hijack()
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ // Don't forget to close the connection:
+ defer conn.Close()
+ bufrw.WriteString("Now we're speaking raw TCP. Say hi: ")
+ bufrw.Flush()
+ s, err := bufrw.ReadString('\n')
+ if err != nil {
+ log.Printf("error reading string: %v", err)
+ return
+ }
+ fmt.Fprintf(bufrw, "You said: %q\nBye.\n", s)
+ bufrw.Flush()
+ })
+}
+
+func ExampleGet() {
+ res, err := http.Get("http://www.google.com/robots.txt")
+ if err != nil {
+ log.Fatal(err)
+ }
+ robots, err := ioutil.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Printf("%s", robots)
+}
+
+func ExampleFileServer() {
+ // Simple static webserver:
+ log.Fatal(http.ListenAndServe(":8080", http.FileServer(http.Dir("/usr/share/doc"))))
+}
+
+func ExampleFileServer_stripPrefix() {
+ // To serve a directory on disk (/tmp) under an alternate URL
+ // path (/tmpfiles/), use StripPrefix to modify the request
+ // URL's path before the FileServer sees it:
+ http.Handle("/tmpfiles/", http.StripPrefix("/tmpfiles/", http.FileServer(http.Dir("/tmp"))))
+}
+
+func ExampleStripPrefix() {
+ // To serve a directory on disk (/tmp) under an alternate URL
+ // path (/tmpfiles/), use StripPrefix to modify the request
+ // URL's path before the FileServer sees it:
+ http.Handle("/tmpfiles/", http.StripPrefix("/tmpfiles/", http.FileServer(http.Dir("/tmp"))))
+}
+
+type apiHandler struct{}
+
+func (apiHandler) ServeHTTP(http.ResponseWriter, *http.Request) {}
+
+func ExampleServeMux_Handle() {
+ mux := http.NewServeMux()
+ mux.Handle("/api/", apiHandler{})
+ mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
+ // The "/" pattern matches everything, so we need to check
+ // that we're at the root here.
+ if req.URL.Path != "/" {
+ http.NotFound(w, req)
+ return
+ }
+ fmt.Fprintf(w, "Welcome to the home page!")
+ })
+}
diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go
new file mode 100644
index 000000000..87b6c0773
--- /dev/null
+++ b/src/net/http/export_test.go
@@ -0,0 +1,108 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Bridge package to expose http internals to tests in the http_test
+// package.
+
+package http
+
+import (
+ "net"
+ "net/url"
+ "time"
+)
+
+func NewLoggingConn(baseName string, c net.Conn) net.Conn {
+ return newLoggingConn(baseName, c)
+}
+
+var ExportAppendTime = appendTime
+
+func (t *Transport) NumPendingRequestsForTesting() int {
+ t.reqMu.Lock()
+ defer t.reqMu.Unlock()
+ return len(t.reqCanceler)
+}
+
+func (t *Transport) IdleConnKeysForTesting() (keys []string) {
+ keys = make([]string, 0)
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
+ if t.idleConn == nil {
+ return
+ }
+ for key := range t.idleConn {
+ keys = append(keys, key.String())
+ }
+ return
+}
+
+func (t *Transport) IdleConnCountForTesting(cacheKey string) int {
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
+ if t.idleConn == nil {
+ return 0
+ }
+ for k, conns := range t.idleConn {
+ if k.String() == cacheKey {
+ return len(conns)
+ }
+ }
+ return 0
+}
+
+func (t *Transport) IdleConnChMapSizeForTesting() int {
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
+ return len(t.idleConnCh)
+}
+
+func (t *Transport) IsIdleForTesting() bool {
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
+ return t.wantIdle
+}
+
+func (t *Transport) RequestIdleConnChForTesting() {
+ t.getIdleConnCh(connectMethod{nil, "http", "example.com"})
+}
+
+func (t *Transport) PutIdleTestConn() bool {
+ c, _ := net.Pipe()
+ return t.putIdleConn(&persistConn{
+ t: t,
+ conn: c, // dummy
+ closech: make(chan struct{}), // so it can be closed
+ cacheKey: connectMethodKey{"", "http", "example.com"},
+ })
+}
+
+func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler {
+ f := func() <-chan time.Time {
+ return ch
+ }
+ return &timeoutHandler{handler, f, ""}
+}
+
+func ResetCachedEnvironment() {
+ httpProxyEnv.reset()
+ httpsProxyEnv.reset()
+ noProxyEnv.reset()
+}
+
+var DefaultUserAgent = defaultUserAgent
+
+func ExportRefererForURL(lastReq, newReq *url.URL) string {
+ return refererForURL(lastReq, newReq)
+}
+
+// SetPendingDialHooks sets the hooks that run before and after handling
+// pending dials.
+func SetPendingDialHooks(before, after func()) {
+ prePendingDial, postPendingDial = before, after
+}
+
+var ExportServerNewConn = (*Server).newConn
+
+var ExportCloseWriteAndWait = (*conn).closeWriteAndWait
diff --git a/src/net/http/fcgi/child.go b/src/net/http/fcgi/child.go
new file mode 100644
index 000000000..a3beaa33a
--- /dev/null
+++ b/src/net/http/fcgi/child.go
@@ -0,0 +1,305 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fcgi
+
+// This file implements FastCGI from the perspective of a child process.
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net"
+ "net/http"
+ "net/http/cgi"
+ "os"
+ "strings"
+ "sync"
+ "time"
+)
+
+// request holds the state for an in-progress request. As soon as it's complete,
+// it's converted to an http.Request.
+type request struct {
+ pw *io.PipeWriter
+ reqId uint16
+ params map[string]string
+ buf [1024]byte
+ rawParams []byte
+ keepConn bool
+}
+
+func newRequest(reqId uint16, flags uint8) *request {
+ r := &request{
+ reqId: reqId,
+ params: map[string]string{},
+ keepConn: flags&flagKeepConn != 0,
+ }
+ r.rawParams = r.buf[:0]
+ return r
+}
+
+// parseParams reads an encoded []byte into Params.
+func (r *request) parseParams() {
+ text := r.rawParams
+ r.rawParams = nil
+ for len(text) > 0 {
+ keyLen, n := readSize(text)
+ if n == 0 {
+ return
+ }
+ text = text[n:]
+ valLen, n := readSize(text)
+ if n == 0 {
+ return
+ }
+ text = text[n:]
+ key := readString(text, keyLen)
+ text = text[keyLen:]
+ val := readString(text, valLen)
+ text = text[valLen:]
+ r.params[key] = val
+ }
+}
+
+// response implements http.ResponseWriter.
+type response struct {
+ req *request
+ header http.Header
+ w *bufWriter
+ wroteHeader bool
+}
+
+func newResponse(c *child, req *request) *response {
+ return &response{
+ req: req,
+ header: http.Header{},
+ w: newWriter(c.conn, typeStdout, req.reqId),
+ }
+}
+
+func (r *response) Header() http.Header {
+ return r.header
+}
+
+func (r *response) Write(data []byte) (int, error) {
+ if !r.wroteHeader {
+ r.WriteHeader(http.StatusOK)
+ }
+ return r.w.Write(data)
+}
+
+func (r *response) WriteHeader(code int) {
+ if r.wroteHeader {
+ return
+ }
+ r.wroteHeader = true
+ if code == http.StatusNotModified {
+ // Must not have body.
+ r.header.Del("Content-Type")
+ r.header.Del("Content-Length")
+ r.header.Del("Transfer-Encoding")
+ } else if r.header.Get("Content-Type") == "" {
+ r.header.Set("Content-Type", "text/html; charset=utf-8")
+ }
+
+ if r.header.Get("Date") == "" {
+ r.header.Set("Date", time.Now().UTC().Format(http.TimeFormat))
+ }
+
+ fmt.Fprintf(r.w, "Status: %d %s\r\n", code, http.StatusText(code))
+ r.header.Write(r.w)
+ r.w.WriteString("\r\n")
+}
+
+func (r *response) Flush() {
+ if !r.wroteHeader {
+ r.WriteHeader(http.StatusOK)
+ }
+ r.w.Flush()
+}
+
+func (r *response) Close() error {
+ r.Flush()
+ return r.w.Close()
+}
+
+type child struct {
+ conn *conn
+ handler http.Handler
+
+ mu sync.Mutex // protects requests:
+ requests map[uint16]*request // keyed by request ID
+}
+
+func newChild(rwc io.ReadWriteCloser, handler http.Handler) *child {
+ return &child{
+ conn: newConn(rwc),
+ handler: handler,
+ requests: make(map[uint16]*request),
+ }
+}
+
+func (c *child) serve() {
+ defer c.conn.Close()
+ var rec record
+ for {
+ if err := rec.read(c.conn.rwc); err != nil {
+ return
+ }
+ if err := c.handleRecord(&rec); err != nil {
+ return
+ }
+ }
+}
+
+var errCloseConn = errors.New("fcgi: connection should be closed")
+
+var emptyBody = ioutil.NopCloser(strings.NewReader(""))
+
+func (c *child) handleRecord(rec *record) error {
+ c.mu.Lock()
+ req, ok := c.requests[rec.h.Id]
+ c.mu.Unlock()
+ if !ok && rec.h.Type != typeBeginRequest && rec.h.Type != typeGetValues {
+ // The spec says to ignore unknown request IDs.
+ return nil
+ }
+
+ switch rec.h.Type {
+ case typeBeginRequest:
+ if req != nil {
+ // The server is trying to begin a request with the same ID
+ // as an in-progress request. This is an error.
+ return errors.New("fcgi: received ID that is already in-flight")
+ }
+
+ var br beginRequest
+ if err := br.read(rec.content()); err != nil {
+ return err
+ }
+ if br.role != roleResponder {
+ c.conn.writeEndRequest(rec.h.Id, 0, statusUnknownRole)
+ return nil
+ }
+ req = newRequest(rec.h.Id, br.flags)
+ c.mu.Lock()
+ c.requests[rec.h.Id] = req
+ c.mu.Unlock()
+ return nil
+ case typeParams:
+ // NOTE(eds): Technically a key-value pair can straddle the boundary
+ // between two packets. We buffer until we've received all parameters.
+ if len(rec.content()) > 0 {
+ req.rawParams = append(req.rawParams, rec.content()...)
+ return nil
+ }
+ req.parseParams()
+ return nil
+ case typeStdin:
+ content := rec.content()
+ if req.pw == nil {
+ var body io.ReadCloser
+ if len(content) > 0 {
+ // body could be an io.LimitReader, but it shouldn't matter
+ // as long as both sides are behaving.
+ body, req.pw = io.Pipe()
+ } else {
+ body = emptyBody
+ }
+ go c.serveRequest(req, body)
+ }
+ if len(content) > 0 {
+ // TODO(eds): This blocks until the handler reads from the pipe.
+ // If the handler takes a long time, it might be a problem.
+ req.pw.Write(content)
+ } else if req.pw != nil {
+ req.pw.Close()
+ }
+ return nil
+ case typeGetValues:
+ values := map[string]string{"FCGI_MPXS_CONNS": "1"}
+ c.conn.writePairs(typeGetValuesResult, 0, values)
+ return nil
+ case typeData:
+ // If the filter role is implemented, read the data stream here.
+ return nil
+ case typeAbortRequest:
+ println("abort")
+ c.mu.Lock()
+ delete(c.requests, rec.h.Id)
+ c.mu.Unlock()
+ c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete)
+ if !req.keepConn {
+ // connection will close upon return
+ return errCloseConn
+ }
+ return nil
+ default:
+ b := make([]byte, 8)
+ b[0] = byte(rec.h.Type)
+ c.conn.writeRecord(typeUnknownType, 0, b)
+ return nil
+ }
+}
+
+func (c *child) serveRequest(req *request, body io.ReadCloser) {
+ r := newResponse(c, req)
+ httpReq, err := cgi.RequestFromMap(req.params)
+ if err != nil {
+ // there was an error reading the request
+ r.WriteHeader(http.StatusInternalServerError)
+ c.conn.writeRecord(typeStderr, req.reqId, []byte(err.Error()))
+ } else {
+ httpReq.Body = body
+ c.handler.ServeHTTP(r, httpReq)
+ }
+ r.Close()
+ c.mu.Lock()
+ delete(c.requests, req.reqId)
+ c.mu.Unlock()
+ c.conn.writeEndRequest(req.reqId, 0, statusRequestComplete)
+
+ // Consume the entire body, so the host isn't still writing to
+ // us when we close the socket below in the !keepConn case,
+ // otherwise we'd send a RST. (golang.org/issue/4183)
+ // TODO(bradfitz): also bound this copy in time. Or send
+ // some sort of abort request to the host, so the host
+ // can properly cut off the client sending all the data.
+ // For now just bound it a little and
+ io.CopyN(ioutil.Discard, body, 100<<20)
+ body.Close()
+
+ if !req.keepConn {
+ c.conn.Close()
+ }
+}
+
+// Serve accepts incoming FastCGI connections on the listener l, creating a new
+// goroutine for each. The goroutine reads requests and then calls handler
+// to reply to them.
+// If l is nil, Serve accepts connections from os.Stdin.
+// If handler is nil, http.DefaultServeMux is used.
+func Serve(l net.Listener, handler http.Handler) error {
+ if l == nil {
+ var err error
+ l, err = net.FileListener(os.Stdin)
+ if err != nil {
+ return err
+ }
+ defer l.Close()
+ }
+ if handler == nil {
+ handler = http.DefaultServeMux
+ }
+ for {
+ rw, err := l.Accept()
+ if err != nil {
+ return err
+ }
+ c := newChild(rw, handler)
+ go c.serve()
+ }
+}
diff --git a/src/net/http/fcgi/fcgi.go b/src/net/http/fcgi/fcgi.go
new file mode 100644
index 000000000..06bba0488
--- /dev/null
+++ b/src/net/http/fcgi/fcgi.go
@@ -0,0 +1,274 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package fcgi implements the FastCGI protocol.
+// Currently only the responder role is supported.
+// The protocol is defined at http://www.fastcgi.com/drupal/node/6?q=node/22
+package fcgi
+
+// This file defines the raw protocol and some utilities used by the child and
+// the host.
+
+import (
+ "bufio"
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "io"
+ "sync"
+)
+
+// recType is a record type, as defined by
+// http://www.fastcgi.com/devkit/doc/fcgi-spec.html#S8
+type recType uint8
+
+const (
+ typeBeginRequest recType = 1
+ typeAbortRequest recType = 2
+ typeEndRequest recType = 3
+ typeParams recType = 4
+ typeStdin recType = 5
+ typeStdout recType = 6
+ typeStderr recType = 7
+ typeData recType = 8
+ typeGetValues recType = 9
+ typeGetValuesResult recType = 10
+ typeUnknownType recType = 11
+)
+
+// keep the connection between web-server and responder open after request
+const flagKeepConn = 1
+
+const (
+ maxWrite = 65535 // maximum record body
+ maxPad = 255
+)
+
+const (
+ roleResponder = iota + 1 // only Responders are implemented.
+ roleAuthorizer
+ roleFilter
+)
+
+const (
+ statusRequestComplete = iota
+ statusCantMultiplex
+ statusOverloaded
+ statusUnknownRole
+)
+
+const headerLen = 8
+
+type header struct {
+ Version uint8
+ Type recType
+ Id uint16
+ ContentLength uint16
+ PaddingLength uint8
+ Reserved uint8
+}
+
+type beginRequest struct {
+ role uint16
+ flags uint8
+ reserved [5]uint8
+}
+
+func (br *beginRequest) read(content []byte) error {
+ if len(content) != 8 {
+ return errors.New("fcgi: invalid begin request record")
+ }
+ br.role = binary.BigEndian.Uint16(content)
+ br.flags = content[2]
+ return nil
+}
+
+// for padding so we don't have to allocate all the time
+// not synchronized because we don't care what the contents are
+var pad [maxPad]byte
+
+func (h *header) init(recType recType, reqId uint16, contentLength int) {
+ h.Version = 1
+ h.Type = recType
+ h.Id = reqId
+ h.ContentLength = uint16(contentLength)
+ h.PaddingLength = uint8(-contentLength & 7)
+}
+
+// conn sends records over rwc
+type conn struct {
+ mutex sync.Mutex
+ rwc io.ReadWriteCloser
+
+ // to avoid allocations
+ buf bytes.Buffer
+ h header
+}
+
+func newConn(rwc io.ReadWriteCloser) *conn {
+ return &conn{rwc: rwc}
+}
+
+func (c *conn) Close() error {
+ c.mutex.Lock()
+ defer c.mutex.Unlock()
+ return c.rwc.Close()
+}
+
+type record struct {
+ h header
+ buf [maxWrite + maxPad]byte
+}
+
+func (rec *record) read(r io.Reader) (err error) {
+ if err = binary.Read(r, binary.BigEndian, &rec.h); err != nil {
+ return err
+ }
+ if rec.h.Version != 1 {
+ return errors.New("fcgi: invalid header version")
+ }
+ n := int(rec.h.ContentLength) + int(rec.h.PaddingLength)
+ if _, err = io.ReadFull(r, rec.buf[:n]); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (r *record) content() []byte {
+ return r.buf[:r.h.ContentLength]
+}
+
+// writeRecord writes and sends a single record.
+func (c *conn) writeRecord(recType recType, reqId uint16, b []byte) error {
+ c.mutex.Lock()
+ defer c.mutex.Unlock()
+ c.buf.Reset()
+ c.h.init(recType, reqId, len(b))
+ if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil {
+ return err
+ }
+ if _, err := c.buf.Write(b); err != nil {
+ return err
+ }
+ if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil {
+ return err
+ }
+ _, err := c.rwc.Write(c.buf.Bytes())
+ return err
+}
+
+func (c *conn) writeBeginRequest(reqId uint16, role uint16, flags uint8) error {
+ b := [8]byte{byte(role >> 8), byte(role), flags}
+ return c.writeRecord(typeBeginRequest, reqId, b[:])
+}
+
+func (c *conn) writeEndRequest(reqId uint16, appStatus int, protocolStatus uint8) error {
+ b := make([]byte, 8)
+ binary.BigEndian.PutUint32(b, uint32(appStatus))
+ b[4] = protocolStatus
+ return c.writeRecord(typeEndRequest, reqId, b)
+}
+
+func (c *conn) writePairs(recType recType, reqId uint16, pairs map[string]string) error {
+ w := newWriter(c, recType, reqId)
+ b := make([]byte, 8)
+ for k, v := range pairs {
+ n := encodeSize(b, uint32(len(k)))
+ n += encodeSize(b[n:], uint32(len(v)))
+ if _, err := w.Write(b[:n]); err != nil {
+ return err
+ }
+ if _, err := w.WriteString(k); err != nil {
+ return err
+ }
+ if _, err := w.WriteString(v); err != nil {
+ return err
+ }
+ }
+ w.Close()
+ return nil
+}
+
+func readSize(s []byte) (uint32, int) {
+ if len(s) == 0 {
+ return 0, 0
+ }
+ size, n := uint32(s[0]), 1
+ if size&(1<<7) != 0 {
+ if len(s) < 4 {
+ return 0, 0
+ }
+ n = 4
+ size = binary.BigEndian.Uint32(s)
+ size &^= 1 << 31
+ }
+ return size, n
+}
+
+func readString(s []byte, size uint32) string {
+ if size > uint32(len(s)) {
+ return ""
+ }
+ return string(s[:size])
+}
+
+func encodeSize(b []byte, size uint32) int {
+ if size > 127 {
+ size |= 1 << 31
+ binary.BigEndian.PutUint32(b, size)
+ return 4
+ }
+ b[0] = byte(size)
+ return 1
+}
+
+// bufWriter encapsulates bufio.Writer but also closes the underlying stream when
+// Closed.
+type bufWriter struct {
+ closer io.Closer
+ *bufio.Writer
+}
+
+func (w *bufWriter) Close() error {
+ if err := w.Writer.Flush(); err != nil {
+ w.closer.Close()
+ return err
+ }
+ return w.closer.Close()
+}
+
+func newWriter(c *conn, recType recType, reqId uint16) *bufWriter {
+ s := &streamWriter{c: c, recType: recType, reqId: reqId}
+ w := bufio.NewWriterSize(s, maxWrite)
+ return &bufWriter{s, w}
+}
+
+// streamWriter abstracts out the separation of a stream into discrete records.
+// It only writes maxWrite bytes at a time.
+type streamWriter struct {
+ c *conn
+ recType recType
+ reqId uint16
+}
+
+func (w *streamWriter) Write(p []byte) (int, error) {
+ nn := 0
+ for len(p) > 0 {
+ n := len(p)
+ if n > maxWrite {
+ n = maxWrite
+ }
+ if err := w.c.writeRecord(w.recType, w.reqId, p[:n]); err != nil {
+ return nn, err
+ }
+ nn += n
+ p = p[n:]
+ }
+ return nn, nil
+}
+
+func (w *streamWriter) Close() error {
+ // send empty record to close the stream
+ return w.c.writeRecord(w.recType, w.reqId, nil)
+}
diff --git a/src/net/http/fcgi/fcgi_test.go b/src/net/http/fcgi/fcgi_test.go
new file mode 100644
index 000000000..6c7e1a9ce
--- /dev/null
+++ b/src/net/http/fcgi/fcgi_test.go
@@ -0,0 +1,150 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fcgi
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "testing"
+)
+
+var sizeTests = []struct {
+ size uint32
+ bytes []byte
+}{
+ {0, []byte{0x00}},
+ {127, []byte{0x7F}},
+ {128, []byte{0x80, 0x00, 0x00, 0x80}},
+ {1000, []byte{0x80, 0x00, 0x03, 0xE8}},
+ {33554431, []byte{0x81, 0xFF, 0xFF, 0xFF}},
+}
+
+func TestSize(t *testing.T) {
+ b := make([]byte, 4)
+ for i, test := range sizeTests {
+ n := encodeSize(b, test.size)
+ if !bytes.Equal(b[:n], test.bytes) {
+ t.Errorf("%d expected %x, encoded %x", i, test.bytes, b)
+ }
+ size, n := readSize(test.bytes)
+ if size != test.size {
+ t.Errorf("%d expected %d, read %d", i, test.size, size)
+ }
+ if len(test.bytes) != n {
+ t.Errorf("%d did not consume all the bytes", i)
+ }
+ }
+}
+
+var streamTests = []struct {
+ desc string
+ recType recType
+ reqId uint16
+ content []byte
+ raw []byte
+}{
+ {"single record", typeStdout, 1, nil,
+ []byte{1, byte(typeStdout), 0, 1, 0, 0, 0, 0},
+ },
+ // this data will have to be split into two records
+ {"two records", typeStdin, 300, make([]byte, 66000),
+ bytes.Join([][]byte{
+ // header for the first record
+ {1, byte(typeStdin), 0x01, 0x2C, 0xFF, 0xFF, 1, 0},
+ make([]byte, 65536),
+ // header for the second
+ {1, byte(typeStdin), 0x01, 0x2C, 0x01, 0xD1, 7, 0},
+ make([]byte, 472),
+ // header for the empty record
+ {1, byte(typeStdin), 0x01, 0x2C, 0, 0, 0, 0},
+ },
+ nil),
+ },
+}
+
+type nilCloser struct {
+ io.ReadWriter
+}
+
+func (c *nilCloser) Close() error { return nil }
+
+func TestStreams(t *testing.T) {
+ var rec record
+outer:
+ for _, test := range streamTests {
+ buf := bytes.NewBuffer(test.raw)
+ var content []byte
+ for buf.Len() > 0 {
+ if err := rec.read(buf); err != nil {
+ t.Errorf("%s: error reading record: %v", test.desc, err)
+ continue outer
+ }
+ content = append(content, rec.content()...)
+ }
+ if rec.h.Type != test.recType {
+ t.Errorf("%s: got type %d expected %d", test.desc, rec.h.Type, test.recType)
+ continue
+ }
+ if rec.h.Id != test.reqId {
+ t.Errorf("%s: got request ID %d expected %d", test.desc, rec.h.Id, test.reqId)
+ continue
+ }
+ if !bytes.Equal(content, test.content) {
+ t.Errorf("%s: read wrong content", test.desc)
+ continue
+ }
+ buf.Reset()
+ c := newConn(&nilCloser{buf})
+ w := newWriter(c, test.recType, test.reqId)
+ if _, err := w.Write(test.content); err != nil {
+ t.Errorf("%s: error writing record: %v", test.desc, err)
+ continue
+ }
+ if err := w.Close(); err != nil {
+ t.Errorf("%s: error closing stream: %v", test.desc, err)
+ continue
+ }
+ if !bytes.Equal(buf.Bytes(), test.raw) {
+ t.Errorf("%s: wrote wrong content", test.desc)
+ }
+ }
+}
+
+type writeOnlyConn struct {
+ buf []byte
+}
+
+func (c *writeOnlyConn) Write(p []byte) (int, error) {
+ c.buf = append(c.buf, p...)
+ return len(p), nil
+}
+
+func (c *writeOnlyConn) Read(p []byte) (int, error) {
+ return 0, errors.New("conn is write-only")
+}
+
+func (c *writeOnlyConn) Close() error {
+ return nil
+}
+
+func TestGetValues(t *testing.T) {
+ var rec record
+ rec.h.Type = typeGetValues
+
+ wc := new(writeOnlyConn)
+ c := newChild(wc, nil)
+ err := c.handleRecord(&rec)
+ if err != nil {
+ t.Fatalf("handleRecord: %v", err)
+ }
+
+ const want = "\x01\n\x00\x00\x00\x12\x06\x00" +
+ "\x0f\x01FCGI_MPXS_CONNS1" +
+ "\x00\x00\x00\x00\x00\x00\x01\n\x00\x00\x00\x00\x00\x00"
+ if got := string(wc.buf); got != want {
+ t.Errorf(" got: %q\nwant: %q\n", got, want)
+ }
+}
diff --git a/src/net/http/filetransport.go b/src/net/http/filetransport.go
new file mode 100644
index 000000000..821787e0c
--- /dev/null
+++ b/src/net/http/filetransport.go
@@ -0,0 +1,123 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "fmt"
+ "io"
+)
+
+// fileTransport implements RoundTripper for the 'file' protocol.
+type fileTransport struct {
+ fh fileHandler
+}
+
+// NewFileTransport returns a new RoundTripper, serving the provided
+// FileSystem. The returned RoundTripper ignores the URL host in its
+// incoming requests, as well as most other properties of the
+// request.
+//
+// The typical use case for NewFileTransport is to register the "file"
+// protocol with a Transport, as in:
+//
+// t := &http.Transport{}
+// t.RegisterProtocol("file", http.NewFileTransport(http.Dir("/")))
+// c := &http.Client{Transport: t}
+// res, err := c.Get("file:///etc/passwd")
+// ...
+func NewFileTransport(fs FileSystem) RoundTripper {
+ return fileTransport{fileHandler{fs}}
+}
+
+func (t fileTransport) RoundTrip(req *Request) (resp *Response, err error) {
+ // We start ServeHTTP in a goroutine, which may take a long
+ // time if the file is large. The newPopulateResponseWriter
+ // call returns a channel which either ServeHTTP or finish()
+ // sends our *Response on, once the *Response itself has been
+ // populated (even if the body itself is still being
+ // written to the res.Body, a pipe)
+ rw, resc := newPopulateResponseWriter()
+ go func() {
+ t.fh.ServeHTTP(rw, req)
+ rw.finish()
+ }()
+ return <-resc, nil
+}
+
+func newPopulateResponseWriter() (*populateResponse, <-chan *Response) {
+ pr, pw := io.Pipe()
+ rw := &populateResponse{
+ ch: make(chan *Response),
+ pw: pw,
+ res: &Response{
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ Header: make(Header),
+ Close: true,
+ Body: pr,
+ },
+ }
+ return rw, rw.ch
+}
+
+// populateResponse is a ResponseWriter that populates the *Response
+// in res, and writes its body to a pipe connected to the response
+// body. Once writes begin or finish() is called, the response is sent
+// on ch.
+type populateResponse struct {
+ res *Response
+ ch chan *Response
+ wroteHeader bool
+ hasContent bool
+ sentResponse bool
+ pw *io.PipeWriter
+}
+
+func (pr *populateResponse) finish() {
+ if !pr.wroteHeader {
+ pr.WriteHeader(500)
+ }
+ if !pr.sentResponse {
+ pr.sendResponse()
+ }
+ pr.pw.Close()
+}
+
+func (pr *populateResponse) sendResponse() {
+ if pr.sentResponse {
+ return
+ }
+ pr.sentResponse = true
+
+ if pr.hasContent {
+ pr.res.ContentLength = -1
+ }
+ pr.ch <- pr.res
+}
+
+func (pr *populateResponse) Header() Header {
+ return pr.res.Header
+}
+
+func (pr *populateResponse) WriteHeader(code int) {
+ if pr.wroteHeader {
+ return
+ }
+ pr.wroteHeader = true
+
+ pr.res.StatusCode = code
+ pr.res.Status = fmt.Sprintf("%d %s", code, StatusText(code))
+}
+
+func (pr *populateResponse) Write(p []byte) (n int, err error) {
+ if !pr.wroteHeader {
+ pr.WriteHeader(StatusOK)
+ }
+ pr.hasContent = true
+ if !pr.sentResponse {
+ pr.sendResponse()
+ }
+ return pr.pw.Write(p)
+}
diff --git a/src/net/http/filetransport_test.go b/src/net/http/filetransport_test.go
new file mode 100644
index 000000000..6f1a537e2
--- /dev/null
+++ b/src/net/http/filetransport_test.go
@@ -0,0 +1,65 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func checker(t *testing.T) func(string, error) {
+ return func(call string, err error) {
+ if err == nil {
+ return
+ }
+ t.Fatalf("%s: %v", call, err)
+ }
+}
+
+func TestFileTransport(t *testing.T) {
+ check := checker(t)
+
+ dname, err := ioutil.TempDir("", "")
+ check("TempDir", err)
+ fname := filepath.Join(dname, "foo.txt")
+ err = ioutil.WriteFile(fname, []byte("Bar"), 0644)
+ check("WriteFile", err)
+ defer os.Remove(dname)
+ defer os.Remove(fname)
+
+ tr := &Transport{}
+ tr.RegisterProtocol("file", NewFileTransport(Dir(dname)))
+ c := &Client{Transport: tr}
+
+ fooURLs := []string{"file:///foo.txt", "file://../foo.txt"}
+ for _, urlstr := range fooURLs {
+ res, err := c.Get(urlstr)
+ check("Get "+urlstr, err)
+ if res.StatusCode != 200 {
+ t.Errorf("for %s, StatusCode = %d, want 200", urlstr, res.StatusCode)
+ }
+ if res.ContentLength != -1 {
+ t.Errorf("for %s, ContentLength = %d, want -1", urlstr, res.ContentLength)
+ }
+ if res.Body == nil {
+ t.Fatalf("for %s, nil Body", urlstr)
+ }
+ slurp, err := ioutil.ReadAll(res.Body)
+ check("ReadAll "+urlstr, err)
+ if string(slurp) != "Bar" {
+ t.Errorf("for %s, got content %q, want %q", urlstr, string(slurp), "Bar")
+ }
+ }
+
+ const badURL = "file://../no-exist.txt"
+ res, err := c.Get(badURL)
+ check("Get "+badURL, err)
+ if res.StatusCode != 404 {
+ t.Errorf("for %s, StatusCode = %d, want 404", badURL, res.StatusCode)
+ }
+ res.Body.Close()
+}
diff --git a/src/net/http/fs.go b/src/net/http/fs.go
new file mode 100644
index 000000000..e322f710a
--- /dev/null
+++ b/src/net/http/fs.go
@@ -0,0 +1,556 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP file system request handler
+
+package http
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "mime"
+ "mime/multipart"
+ "net/textproto"
+ "net/url"
+ "os"
+ "path"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "time"
+)
+
+// A Dir implements FileSystem using the native file system restricted to a
+// specific directory tree.
+//
+// While the FileSystem.Open method takes '/'-separated paths, a Dir's string
+// value is a filename on the native file system, not a URL, so it is separated
+// by filepath.Separator, which isn't necessarily '/'.
+//
+// An empty Dir is treated as ".".
+type Dir string
+
+func (d Dir) Open(name string) (File, error) {
+ if filepath.Separator != '/' && strings.IndexRune(name, filepath.Separator) >= 0 ||
+ strings.Contains(name, "\x00") {
+ return nil, errors.New("http: invalid character in file path")
+ }
+ dir := string(d)
+ if dir == "" {
+ dir = "."
+ }
+ f, err := os.Open(filepath.Join(dir, filepath.FromSlash(path.Clean("/"+name))))
+ if err != nil {
+ return nil, err
+ }
+ return f, nil
+}
+
+// A FileSystem implements access to a collection of named files.
+// The elements in a file path are separated by slash ('/', U+002F)
+// characters, regardless of host operating system convention.
+type FileSystem interface {
+ Open(name string) (File, error)
+}
+
+// A File is returned by a FileSystem's Open method and can be
+// served by the FileServer implementation.
+//
+// The methods should behave the same as those on an *os.File.
+type File interface {
+ io.Closer
+ io.Reader
+ Readdir(count int) ([]os.FileInfo, error)
+ Seek(offset int64, whence int) (int64, error)
+ Stat() (os.FileInfo, error)
+}
+
+func dirList(w ResponseWriter, f File) {
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
+ fmt.Fprintf(w, "<pre>\n")
+ for {
+ dirs, err := f.Readdir(100)
+ if err != nil || len(dirs) == 0 {
+ break
+ }
+ for _, d := range dirs {
+ name := d.Name()
+ if d.IsDir() {
+ name += "/"
+ }
+ // name may contain '?' or '#', which must be escaped to remain
+ // part of the URL path, and not indicate the start of a query
+ // string or fragment.
+ url := url.URL{Path: name}
+ fmt.Fprintf(w, "<a href=\"%s\">%s</a>\n", url.String(), htmlReplacer.Replace(name))
+ }
+ }
+ fmt.Fprintf(w, "</pre>\n")
+}
+
+// ServeContent replies to the request using the content in the
+// provided ReadSeeker. The main benefit of ServeContent over io.Copy
+// is that it handles Range requests properly, sets the MIME type, and
+// handles If-Modified-Since requests.
+//
+// If the response's Content-Type header is not set, ServeContent
+// first tries to deduce the type from name's file extension and,
+// if that fails, falls back to reading the first block of the content
+// and passing it to DetectContentType.
+// The name is otherwise unused; in particular it can be empty and is
+// never sent in the response.
+//
+// If modtime is not the zero time, ServeContent includes it in a
+// Last-Modified header in the response. If the request includes an
+// If-Modified-Since header, ServeContent uses modtime to decide
+// whether the content needs to be sent at all.
+//
+// The content's Seek method must work: ServeContent uses
+// a seek to the end of the content to determine its size.
+//
+// If the caller has set w's ETag header, ServeContent uses it to
+// handle requests using If-Range and If-None-Match.
+//
+// Note that *os.File implements the io.ReadSeeker interface.
+func ServeContent(w ResponseWriter, req *Request, name string, modtime time.Time, content io.ReadSeeker) {
+ sizeFunc := func() (int64, error) {
+ size, err := content.Seek(0, os.SEEK_END)
+ if err != nil {
+ return 0, errSeeker
+ }
+ _, err = content.Seek(0, os.SEEK_SET)
+ if err != nil {
+ return 0, errSeeker
+ }
+ return size, nil
+ }
+ serveContent(w, req, name, modtime, sizeFunc, content)
+}
+
+// errSeeker is returned by ServeContent's sizeFunc when the content
+// doesn't seek properly. The underlying Seeker's error text isn't
+// included in the sizeFunc reply so it's not sent over HTTP to end
+// users.
+var errSeeker = errors.New("seeker can't seek")
+
+// if name is empty, filename is unknown. (used for mime type, before sniffing)
+// if modtime.IsZero(), modtime is unknown.
+// content must be seeked to the beginning of the file.
+// The sizeFunc is called at most once. Its error, if any, is sent in the HTTP response.
+func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, sizeFunc func() (int64, error), content io.ReadSeeker) {
+ if checkLastModified(w, r, modtime) {
+ return
+ }
+ rangeReq, done := checkETag(w, r, modtime)
+ if done {
+ return
+ }
+
+ code := StatusOK
+
+ // If Content-Type isn't set, use the file's extension to find it, but
+ // if the Content-Type is unset explicitly, do not sniff the type.
+ ctypes, haveType := w.Header()["Content-Type"]
+ var ctype string
+ if !haveType {
+ ctype = mime.TypeByExtension(filepath.Ext(name))
+ if ctype == "" {
+ // read a chunk to decide between utf-8 text and binary
+ var buf [sniffLen]byte
+ n, _ := io.ReadFull(content, buf[:])
+ ctype = DetectContentType(buf[:n])
+ _, err := content.Seek(0, os.SEEK_SET) // rewind to output whole file
+ if err != nil {
+ Error(w, "seeker can't seek", StatusInternalServerError)
+ return
+ }
+ }
+ w.Header().Set("Content-Type", ctype)
+ } else if len(ctypes) > 0 {
+ ctype = ctypes[0]
+ }
+
+ size, err := sizeFunc()
+ if err != nil {
+ Error(w, err.Error(), StatusInternalServerError)
+ return
+ }
+
+ // handle Content-Range header.
+ sendSize := size
+ var sendContent io.Reader = content
+ if size >= 0 {
+ ranges, err := parseRange(rangeReq, size)
+ if err != nil {
+ Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
+ return
+ }
+ if sumRangesSize(ranges) > size {
+ // The total number of bytes in all the ranges
+ // is larger than the size of the file by
+ // itself, so this is probably an attack, or a
+ // dumb client. Ignore the range request.
+ ranges = nil
+ }
+ switch {
+ case len(ranges) == 1:
+ // RFC 2616, Section 14.16:
+ // "When an HTTP message includes the content of a single
+ // range (for example, a response to a request for a
+ // single range, or to a request for a set of ranges
+ // that overlap without any holes), this content is
+ // transmitted with a Content-Range header, and a
+ // Content-Length header showing the number of bytes
+ // actually transferred.
+ // ...
+ // A response to a request for a single range MUST NOT
+ // be sent using the multipart/byteranges media type."
+ ra := ranges[0]
+ if _, err := content.Seek(ra.start, os.SEEK_SET); err != nil {
+ Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
+ return
+ }
+ sendSize = ra.length
+ code = StatusPartialContent
+ w.Header().Set("Content-Range", ra.contentRange(size))
+ case len(ranges) > 1:
+ sendSize = rangesMIMESize(ranges, ctype, size)
+ code = StatusPartialContent
+
+ pr, pw := io.Pipe()
+ mw := multipart.NewWriter(pw)
+ w.Header().Set("Content-Type", "multipart/byteranges; boundary="+mw.Boundary())
+ sendContent = pr
+ defer pr.Close() // cause writing goroutine to fail and exit if CopyN doesn't finish.
+ go func() {
+ for _, ra := range ranges {
+ part, err := mw.CreatePart(ra.mimeHeader(ctype, size))
+ if err != nil {
+ pw.CloseWithError(err)
+ return
+ }
+ if _, err := content.Seek(ra.start, os.SEEK_SET); err != nil {
+ pw.CloseWithError(err)
+ return
+ }
+ if _, err := io.CopyN(part, content, ra.length); err != nil {
+ pw.CloseWithError(err)
+ return
+ }
+ }
+ mw.Close()
+ pw.Close()
+ }()
+ }
+
+ w.Header().Set("Accept-Ranges", "bytes")
+ if w.Header().Get("Content-Encoding") == "" {
+ w.Header().Set("Content-Length", strconv.FormatInt(sendSize, 10))
+ }
+ }
+
+ w.WriteHeader(code)
+
+ if r.Method != "HEAD" {
+ io.CopyN(w, sendContent, sendSize)
+ }
+}
+
+// modtime is the modification time of the resource to be served, or IsZero().
+// return value is whether this request is now complete.
+func checkLastModified(w ResponseWriter, r *Request, modtime time.Time) bool {
+ if modtime.IsZero() {
+ return false
+ }
+
+ // The Date-Modified header truncates sub-second precision, so
+ // use mtime < t+1s instead of mtime <= t to check for unmodified.
+ if t, err := time.Parse(TimeFormat, r.Header.Get("If-Modified-Since")); err == nil && modtime.Before(t.Add(1*time.Second)) {
+ h := w.Header()
+ delete(h, "Content-Type")
+ delete(h, "Content-Length")
+ w.WriteHeader(StatusNotModified)
+ return true
+ }
+ w.Header().Set("Last-Modified", modtime.UTC().Format(TimeFormat))
+ return false
+}
+
+// checkETag implements If-None-Match and If-Range checks.
+//
+// The ETag or modtime must have been previously set in the
+// ResponseWriter's headers. The modtime is only compared at second
+// granularity and may be the zero value to mean unknown.
+//
+// The return value is the effective request "Range" header to use and
+// whether this request is now considered done.
+func checkETag(w ResponseWriter, r *Request, modtime time.Time) (rangeReq string, done bool) {
+ etag := w.Header().get("Etag")
+ rangeReq = r.Header.get("Range")
+
+ // Invalidate the range request if the entity doesn't match the one
+ // the client was expecting.
+ // "If-Range: version" means "ignore the Range: header unless version matches the
+ // current file."
+ // We only support ETag versions.
+ // The caller must have set the ETag on the response already.
+ if ir := r.Header.get("If-Range"); ir != "" && ir != etag {
+ // The If-Range value is typically the ETag value, but it may also be
+ // the modtime date. See golang.org/issue/8367.
+ timeMatches := false
+ if !modtime.IsZero() {
+ if t, err := ParseTime(ir); err == nil && t.Unix() == modtime.Unix() {
+ timeMatches = true
+ }
+ }
+ if !timeMatches {
+ rangeReq = ""
+ }
+ }
+
+ if inm := r.Header.get("If-None-Match"); inm != "" {
+ // Must know ETag.
+ if etag == "" {
+ return rangeReq, false
+ }
+
+ // TODO(bradfitz): non-GET/HEAD requests require more work:
+ // sending a different status code on matches, and
+ // also can't use weak cache validators (those with a "W/
+ // prefix). But most users of ServeContent will be using
+ // it on GET or HEAD, so only support those for now.
+ if r.Method != "GET" && r.Method != "HEAD" {
+ return rangeReq, false
+ }
+
+ // TODO(bradfitz): deal with comma-separated or multiple-valued
+ // list of If-None-match values. For now just handle the common
+ // case of a single item.
+ if inm == etag || inm == "*" {
+ h := w.Header()
+ delete(h, "Content-Type")
+ delete(h, "Content-Length")
+ w.WriteHeader(StatusNotModified)
+ return "", true
+ }
+ }
+ return rangeReq, false
+}
+
+// name is '/'-separated, not filepath.Separator.
+func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirect bool) {
+ const indexPage = "/index.html"
+
+ // redirect .../index.html to .../
+ // can't use Redirect() because that would make the path absolute,
+ // which would be a problem running under StripPrefix
+ if strings.HasSuffix(r.URL.Path, indexPage) {
+ localRedirect(w, r, "./")
+ return
+ }
+
+ f, err := fs.Open(name)
+ if err != nil {
+ // TODO expose actual error?
+ NotFound(w, r)
+ return
+ }
+ defer f.Close()
+
+ d, err1 := f.Stat()
+ if err1 != nil {
+ // TODO expose actual error?
+ NotFound(w, r)
+ return
+ }
+
+ if redirect {
+ // redirect to canonical path: / at end of directory url
+ // r.URL.Path always begins with /
+ url := r.URL.Path
+ if d.IsDir() {
+ if url[len(url)-1] != '/' {
+ localRedirect(w, r, path.Base(url)+"/")
+ return
+ }
+ } else {
+ if url[len(url)-1] == '/' {
+ localRedirect(w, r, "../"+path.Base(url))
+ return
+ }
+ }
+ }
+
+ // use contents of index.html for directory, if present
+ if d.IsDir() {
+ index := strings.TrimSuffix(name, "/") + indexPage
+ ff, err := fs.Open(index)
+ if err == nil {
+ defer ff.Close()
+ dd, err := ff.Stat()
+ if err == nil {
+ name = index
+ d = dd
+ f = ff
+ }
+ }
+ }
+
+ // Still a directory? (we didn't find an index.html file)
+ if d.IsDir() {
+ if checkLastModified(w, r, d.ModTime()) {
+ return
+ }
+ dirList(w, f)
+ return
+ }
+
+ // serveContent will check modification time
+ sizeFunc := func() (int64, error) { return d.Size(), nil }
+ serveContent(w, r, d.Name(), d.ModTime(), sizeFunc, f)
+}
+
+// localRedirect gives a Moved Permanently response.
+// It does not convert relative paths to absolute paths like Redirect does.
+func localRedirect(w ResponseWriter, r *Request, newPath string) {
+ if q := r.URL.RawQuery; q != "" {
+ newPath += "?" + q
+ }
+ w.Header().Set("Location", newPath)
+ w.WriteHeader(StatusMovedPermanently)
+}
+
+// ServeFile replies to the request with the contents of the named file or directory.
+func ServeFile(w ResponseWriter, r *Request, name string) {
+ dir, file := filepath.Split(name)
+ serveFile(w, r, Dir(dir), file, false)
+}
+
+type fileHandler struct {
+ root FileSystem
+}
+
+// FileServer returns a handler that serves HTTP requests
+// with the contents of the file system rooted at root.
+//
+// To use the operating system's file system implementation,
+// use http.Dir:
+//
+// http.Handle("/", http.FileServer(http.Dir("/tmp")))
+func FileServer(root FileSystem) Handler {
+ return &fileHandler{root}
+}
+
+func (f *fileHandler) ServeHTTP(w ResponseWriter, r *Request) {
+ upath := r.URL.Path
+ if !strings.HasPrefix(upath, "/") {
+ upath = "/" + upath
+ r.URL.Path = upath
+ }
+ serveFile(w, r, f.root, path.Clean(upath), true)
+}
+
+// httpRange specifies the byte range to be sent to the client.
+type httpRange struct {
+ start, length int64
+}
+
+func (r httpRange) contentRange(size int64) string {
+ return fmt.Sprintf("bytes %d-%d/%d", r.start, r.start+r.length-1, size)
+}
+
+func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHeader {
+ return textproto.MIMEHeader{
+ "Content-Range": {r.contentRange(size)},
+ "Content-Type": {contentType},
+ }
+}
+
+// parseRange parses a Range header string as per RFC 2616.
+func parseRange(s string, size int64) ([]httpRange, error) {
+ if s == "" {
+ return nil, nil // header not present
+ }
+ const b = "bytes="
+ if !strings.HasPrefix(s, b) {
+ return nil, errors.New("invalid range")
+ }
+ var ranges []httpRange
+ for _, ra := range strings.Split(s[len(b):], ",") {
+ ra = strings.TrimSpace(ra)
+ if ra == "" {
+ continue
+ }
+ i := strings.Index(ra, "-")
+ if i < 0 {
+ return nil, errors.New("invalid range")
+ }
+ start, end := strings.TrimSpace(ra[:i]), strings.TrimSpace(ra[i+1:])
+ var r httpRange
+ if start == "" {
+ // If no start is specified, end specifies the
+ // range start relative to the end of the file.
+ i, err := strconv.ParseInt(end, 10, 64)
+ if err != nil {
+ return nil, errors.New("invalid range")
+ }
+ if i > size {
+ i = size
+ }
+ r.start = size - i
+ r.length = size - r.start
+ } else {
+ i, err := strconv.ParseInt(start, 10, 64)
+ if err != nil || i > size || i < 0 {
+ return nil, errors.New("invalid range")
+ }
+ r.start = i
+ if end == "" {
+ // If no end is specified, range extends to end of the file.
+ r.length = size - r.start
+ } else {
+ i, err := strconv.ParseInt(end, 10, 64)
+ if err != nil || r.start > i {
+ return nil, errors.New("invalid range")
+ }
+ if i >= size {
+ i = size - 1
+ }
+ r.length = i - r.start + 1
+ }
+ }
+ ranges = append(ranges, r)
+ }
+ return ranges, nil
+}
+
+// countingWriter counts how many bytes have been written to it.
+type countingWriter int64
+
+func (w *countingWriter) Write(p []byte) (n int, err error) {
+ *w += countingWriter(len(p))
+ return len(p), nil
+}
+
+// rangesMIMESize returns the number of bytes it takes to encode the
+// provided ranges as a multipart response.
+func rangesMIMESize(ranges []httpRange, contentType string, contentSize int64) (encSize int64) {
+ var w countingWriter
+ mw := multipart.NewWriter(&w)
+ for _, ra := range ranges {
+ mw.CreatePart(ra.mimeHeader(contentType, contentSize))
+ encSize += ra.length
+ }
+ mw.Close()
+ encSize += int64(w)
+ return
+}
+
+func sumRangesSize(ranges []httpRange) (size int64) {
+ for _, ra := range ranges {
+ size += ra.length
+ }
+ return
+}
diff --git a/src/net/http/fs_test.go b/src/net/http/fs_test.go
new file mode 100644
index 000000000..8770d9b41
--- /dev/null
+++ b/src/net/http/fs_test.go
@@ -0,0 +1,917 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "mime"
+ "mime/multipart"
+ "net"
+ . "net/http"
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "os/exec"
+ "path"
+ "path/filepath"
+ "reflect"
+ "regexp"
+ "runtime"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+)
+
+const (
+ testFile = "testdata/file"
+ testFileLen = 11
+)
+
+type wantRange struct {
+ start, end int64 // range [start,end)
+}
+
+var itoa = strconv.Itoa
+
+var ServeFileRangeTests = []struct {
+ r string
+ code int
+ ranges []wantRange
+}{
+ {r: "", code: StatusOK},
+ {r: "bytes=0-4", code: StatusPartialContent, ranges: []wantRange{{0, 5}}},
+ {r: "bytes=2-", code: StatusPartialContent, ranges: []wantRange{{2, testFileLen}}},
+ {r: "bytes=-5", code: StatusPartialContent, ranges: []wantRange{{testFileLen - 5, testFileLen}}},
+ {r: "bytes=3-7", code: StatusPartialContent, ranges: []wantRange{{3, 8}}},
+ {r: "bytes=20-", code: StatusRequestedRangeNotSatisfiable},
+ {r: "bytes=0-0,-2", code: StatusPartialContent, ranges: []wantRange{{0, 1}, {testFileLen - 2, testFileLen}}},
+ {r: "bytes=0-1,5-8", code: StatusPartialContent, ranges: []wantRange{{0, 2}, {5, 9}}},
+ {r: "bytes=0-1,5-", code: StatusPartialContent, ranges: []wantRange{{0, 2}, {5, testFileLen}}},
+ {r: "bytes=5-1000", code: StatusPartialContent, ranges: []wantRange{{5, testFileLen}}},
+ {r: "bytes=0-,1-,2-,3-,4-", code: StatusOK}, // ignore wasteful range request
+ {r: "bytes=0-" + itoa(testFileLen-2), code: StatusPartialContent, ranges: []wantRange{{0, testFileLen - 1}}},
+ {r: "bytes=0-" + itoa(testFileLen-1), code: StatusPartialContent, ranges: []wantRange{{0, testFileLen}}},
+ {r: "bytes=0-" + itoa(testFileLen), code: StatusPartialContent, ranges: []wantRange{{0, testFileLen}}},
+}
+
+func TestServeFile(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ ServeFile(w, r, "testdata/file")
+ }))
+ defer ts.Close()
+
+ var err error
+
+ file, err := ioutil.ReadFile(testFile)
+ if err != nil {
+ t.Fatal("reading file:", err)
+ }
+
+ // set up the Request (re-used for all tests)
+ var req Request
+ req.Header = make(Header)
+ if req.URL, err = url.Parse(ts.URL); err != nil {
+ t.Fatal("ParseURL:", err)
+ }
+ req.Method = "GET"
+
+ // straight GET
+ _, body := getBody(t, "straight get", req)
+ if !bytes.Equal(body, file) {
+ t.Fatalf("body mismatch: got %q, want %q", body, file)
+ }
+
+ // Range tests
+Cases:
+ for _, rt := range ServeFileRangeTests {
+ if rt.r != "" {
+ req.Header.Set("Range", rt.r)
+ }
+ resp, body := getBody(t, fmt.Sprintf("range test %q", rt.r), req)
+ if resp.StatusCode != rt.code {
+ t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, resp.StatusCode, rt.code)
+ }
+ if rt.code == StatusRequestedRangeNotSatisfiable {
+ continue
+ }
+ wantContentRange := ""
+ if len(rt.ranges) == 1 {
+ rng := rt.ranges[0]
+ wantContentRange = fmt.Sprintf("bytes %d-%d/%d", rng.start, rng.end-1, testFileLen)
+ }
+ cr := resp.Header.Get("Content-Range")
+ if cr != wantContentRange {
+ t.Errorf("range=%q: Content-Range = %q, want %q", rt.r, cr, wantContentRange)
+ }
+ ct := resp.Header.Get("Content-Type")
+ if len(rt.ranges) == 1 {
+ rng := rt.ranges[0]
+ wantBody := file[rng.start:rng.end]
+ if !bytes.Equal(body, wantBody) {
+ t.Errorf("range=%q: body = %q, want %q", rt.r, body, wantBody)
+ }
+ if strings.HasPrefix(ct, "multipart/byteranges") {
+ t.Errorf("range=%q content-type = %q; unexpected multipart/byteranges", rt.r, ct)
+ }
+ }
+ if len(rt.ranges) > 1 {
+ typ, params, err := mime.ParseMediaType(ct)
+ if err != nil {
+ t.Errorf("range=%q content-type = %q; %v", rt.r, ct, err)
+ continue
+ }
+ if typ != "multipart/byteranges" {
+ t.Errorf("range=%q content-type = %q; want multipart/byteranges", rt.r, typ)
+ continue
+ }
+ if params["boundary"] == "" {
+ t.Errorf("range=%q content-type = %q; lacks boundary", rt.r, ct)
+ continue
+ }
+ if g, w := resp.ContentLength, int64(len(body)); g != w {
+ t.Errorf("range=%q Content-Length = %d; want %d", rt.r, g, w)
+ continue
+ }
+ mr := multipart.NewReader(bytes.NewReader(body), params["boundary"])
+ for ri, rng := range rt.ranges {
+ part, err := mr.NextPart()
+ if err != nil {
+ t.Errorf("range=%q, reading part index %d: %v", rt.r, ri, err)
+ continue Cases
+ }
+ wantContentRange = fmt.Sprintf("bytes %d-%d/%d", rng.start, rng.end-1, testFileLen)
+ if g, w := part.Header.Get("Content-Range"), wantContentRange; g != w {
+ t.Errorf("range=%q: part Content-Range = %q; want %q", rt.r, g, w)
+ }
+ body, err := ioutil.ReadAll(part)
+ if err != nil {
+ t.Errorf("range=%q, reading part index %d body: %v", rt.r, ri, err)
+ continue Cases
+ }
+ wantBody := file[rng.start:rng.end]
+ if !bytes.Equal(body, wantBody) {
+ t.Errorf("range=%q: body = %q, want %q", rt.r, body, wantBody)
+ }
+ }
+ _, err = mr.NextPart()
+ if err != io.EOF {
+ t.Errorf("range=%q; expected final error io.EOF; got %v", rt.r, err)
+ }
+ }
+ }
+}
+
+var fsRedirectTestData = []struct {
+ original, redirect string
+}{
+ {"/test/index.html", "/test/"},
+ {"/test/testdata", "/test/testdata/"},
+ {"/test/testdata/file/", "/test/testdata/file"},
+}
+
+func TestFSRedirect(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(StripPrefix("/test", FileServer(Dir("."))))
+ defer ts.Close()
+
+ for _, data := range fsRedirectTestData {
+ res, err := Get(ts.URL + data.original)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if g, e := res.Request.URL.Path, data.redirect; g != e {
+ t.Errorf("redirect from %s: got %s, want %s", data.original, g, e)
+ }
+ }
+}
+
+type testFileSystem struct {
+ open func(name string) (File, error)
+}
+
+func (fs *testFileSystem) Open(name string) (File, error) {
+ return fs.open(name)
+}
+
+func TestFileServerCleans(t *testing.T) {
+ defer afterTest(t)
+ ch := make(chan string, 1)
+ fs := FileServer(&testFileSystem{func(name string) (File, error) {
+ ch <- name
+ return nil, errors.New("file does not exist")
+ }})
+ tests := []struct {
+ reqPath, openArg string
+ }{
+ {"/foo.txt", "/foo.txt"},
+ {"//foo.txt", "/foo.txt"},
+ {"/../foo.txt", "/foo.txt"},
+ }
+ req, _ := NewRequest("GET", "http://example.com", nil)
+ for n, test := range tests {
+ rec := httptest.NewRecorder()
+ req.URL.Path = test.reqPath
+ fs.ServeHTTP(rec, req)
+ if got := <-ch; got != test.openArg {
+ t.Errorf("test %d: got %q, want %q", n, got, test.openArg)
+ }
+ }
+}
+
+func TestFileServerEscapesNames(t *testing.T) {
+ defer afterTest(t)
+ const dirListPrefix = "<pre>\n"
+ const dirListSuffix = "\n</pre>\n"
+ tests := []struct {
+ name, escaped string
+ }{
+ {`simple_name`, `<a href="simple_name">simple_name</a>`},
+ {`"'<>&`, `<a href="%22%27%3C%3E&">&#34;&#39;&lt;&gt;&amp;</a>`},
+ {`?foo=bar#baz`, `<a href="%3Ffoo=bar%23baz">?foo=bar#baz</a>`},
+ {`<combo>?foo`, `<a href="%3Ccombo%3E%3Ffoo">&lt;combo&gt;?foo</a>`},
+ }
+
+ // We put each test file in its own directory in the fakeFS so we can look at it in isolation.
+ fs := make(fakeFS)
+ for i, test := range tests {
+ testFile := &fakeFileInfo{basename: test.name}
+ fs[fmt.Sprintf("/%d", i)] = &fakeFileInfo{
+ dir: true,
+ modtime: time.Unix(1000000000, 0).UTC(),
+ ents: []*fakeFileInfo{testFile},
+ }
+ fs[fmt.Sprintf("/%d/%s", i, test.name)] = testFile
+ }
+
+ ts := httptest.NewServer(FileServer(&fs))
+ defer ts.Close()
+ for i, test := range tests {
+ url := fmt.Sprintf("%s/%d", ts.URL, i)
+ res, err := Get(url)
+ if err != nil {
+ t.Fatalf("test %q: Get: %v", test.name, err)
+ }
+ b, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("test %q: read Body: %v", test.name, err)
+ }
+ s := string(b)
+ if !strings.HasPrefix(s, dirListPrefix) || !strings.HasSuffix(s, dirListSuffix) {
+ t.Errorf("test %q: listing dir, full output is %q, want prefix %q and suffix %q", test.name, s, dirListPrefix, dirListSuffix)
+ }
+ if trimmed := strings.TrimSuffix(strings.TrimPrefix(s, dirListPrefix), dirListSuffix); trimmed != test.escaped {
+ t.Errorf("test %q: listing dir, filename escaped to %q, want %q", test.name, trimmed, test.escaped)
+ }
+ res.Body.Close()
+ }
+}
+
+func mustRemoveAll(dir string) {
+ err := os.RemoveAll(dir)
+ if err != nil {
+ panic(err)
+ }
+}
+
+func TestFileServerImplicitLeadingSlash(t *testing.T) {
+ defer afterTest(t)
+ tempDir, err := ioutil.TempDir("", "")
+ if err != nil {
+ t.Fatalf("TempDir: %v", err)
+ }
+ defer mustRemoveAll(tempDir)
+ if err := ioutil.WriteFile(filepath.Join(tempDir, "foo.txt"), []byte("Hello world"), 0644); err != nil {
+ t.Fatalf("WriteFile: %v", err)
+ }
+ ts := httptest.NewServer(StripPrefix("/bar/", FileServer(Dir(tempDir))))
+ defer ts.Close()
+ get := func(suffix string) string {
+ res, err := Get(ts.URL + suffix)
+ if err != nil {
+ t.Fatalf("Get %s: %v", suffix, err)
+ }
+ b, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("ReadAll %s: %v", suffix, err)
+ }
+ res.Body.Close()
+ return string(b)
+ }
+ if s := get("/bar/"); !strings.Contains(s, ">foo.txt<") {
+ t.Logf("expected a directory listing with foo.txt, got %q", s)
+ }
+ if s := get("/bar/foo.txt"); s != "Hello world" {
+ t.Logf("expected %q, got %q", "Hello world", s)
+ }
+}
+
+func TestDirJoin(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("skipping test on windows")
+ }
+ wfi, err := os.Stat("/etc/hosts")
+ if err != nil {
+ t.Skip("skipping test; no /etc/hosts file")
+ }
+ test := func(d Dir, name string) {
+ f, err := d.Open(name)
+ if err != nil {
+ t.Fatalf("open of %s: %v", name, err)
+ }
+ defer f.Close()
+ gfi, err := f.Stat()
+ if err != nil {
+ t.Fatalf("stat of %s: %v", name, err)
+ }
+ if !os.SameFile(gfi, wfi) {
+ t.Errorf("%s got different file", name)
+ }
+ }
+ test(Dir("/etc/"), "/hosts")
+ test(Dir("/etc/"), "hosts")
+ test(Dir("/etc/"), "../../../../hosts")
+ test(Dir("/etc"), "/hosts")
+ test(Dir("/etc"), "hosts")
+ test(Dir("/etc"), "../../../../hosts")
+
+ // Not really directories, but since we use this trick in
+ // ServeFile, test it:
+ test(Dir("/etc/hosts"), "")
+ test(Dir("/etc/hosts"), "/")
+ test(Dir("/etc/hosts"), "../")
+}
+
+func TestEmptyDirOpenCWD(t *testing.T) {
+ test := func(d Dir) {
+ name := "fs_test.go"
+ f, err := d.Open(name)
+ if err != nil {
+ t.Fatalf("open of %s: %v", name, err)
+ }
+ defer f.Close()
+ }
+ test(Dir(""))
+ test(Dir("."))
+ test(Dir("./"))
+}
+
+func TestServeFileContentType(t *testing.T) {
+ defer afterTest(t)
+ const ctype = "icecream/chocolate"
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ switch r.FormValue("override") {
+ case "1":
+ w.Header().Set("Content-Type", ctype)
+ case "2":
+ // Explicitly inhibit sniffing.
+ w.Header()["Content-Type"] = []string{}
+ }
+ ServeFile(w, r, "testdata/file")
+ }))
+ defer ts.Close()
+ get := func(override string, want []string) {
+ resp, err := Get(ts.URL + "?override=" + override)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if h := resp.Header["Content-Type"]; !reflect.DeepEqual(h, want) {
+ t.Errorf("Content-Type mismatch: got %v, want %v", h, want)
+ }
+ resp.Body.Close()
+ }
+ get("0", []string{"text/plain; charset=utf-8"})
+ get("1", []string{ctype})
+ get("2", nil)
+}
+
+func TestServeFileMimeType(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ ServeFile(w, r, "testdata/style.css")
+ }))
+ defer ts.Close()
+ resp, err := Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ resp.Body.Close()
+ want := "text/css; charset=utf-8"
+ if h := resp.Header.Get("Content-Type"); h != want {
+ t.Errorf("Content-Type mismatch: got %q, want %q", h, want)
+ }
+}
+
+func TestServeFileFromCWD(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ ServeFile(w, r, "fs_test.go")
+ }))
+ defer ts.Close()
+ r, err := Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ r.Body.Close()
+ if r.StatusCode != 200 {
+ t.Fatalf("expected 200 OK, got %s", r.Status)
+ }
+}
+
+func TestServeFileWithContentEncoding(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Encoding", "foo")
+ ServeFile(w, r, "testdata/file")
+ }))
+ defer ts.Close()
+ resp, err := Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ resp.Body.Close()
+ if g, e := resp.ContentLength, int64(-1); g != e {
+ t.Errorf("Content-Length mismatch: got %d, want %d", g, e)
+ }
+}
+
+func TestServeIndexHtml(t *testing.T) {
+ defer afterTest(t)
+ const want = "index.html says hello\n"
+ ts := httptest.NewServer(FileServer(Dir(".")))
+ defer ts.Close()
+
+ for _, path := range []string{"/testdata/", "/testdata/index.html"} {
+ res, err := Get(ts.URL + path)
+ if err != nil {
+ t.Fatal(err)
+ }
+ b, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal("reading Body:", err)
+ }
+ if s := string(b); s != want {
+ t.Errorf("for path %q got %q, want %q", path, s, want)
+ }
+ res.Body.Close()
+ }
+}
+
+func TestFileServerZeroByte(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(FileServer(Dir(".")))
+ defer ts.Close()
+
+ res, err := Get(ts.URL + "/..\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ b, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal("reading Body:", err)
+ }
+ if res.StatusCode == 200 {
+ t.Errorf("got status 200; want an error. Body is:\n%s", string(b))
+ }
+}
+
+type fakeFileInfo struct {
+ dir bool
+ basename string
+ modtime time.Time
+ ents []*fakeFileInfo
+ contents string
+}
+
+func (f *fakeFileInfo) Name() string { return f.basename }
+func (f *fakeFileInfo) Sys() interface{} { return nil }
+func (f *fakeFileInfo) ModTime() time.Time { return f.modtime }
+func (f *fakeFileInfo) IsDir() bool { return f.dir }
+func (f *fakeFileInfo) Size() int64 { return int64(len(f.contents)) }
+func (f *fakeFileInfo) Mode() os.FileMode {
+ if f.dir {
+ return 0755 | os.ModeDir
+ }
+ return 0644
+}
+
+type fakeFile struct {
+ io.ReadSeeker
+ fi *fakeFileInfo
+ path string // as opened
+ entpos int
+}
+
+func (f *fakeFile) Close() error { return nil }
+func (f *fakeFile) Stat() (os.FileInfo, error) { return f.fi, nil }
+func (f *fakeFile) Readdir(count int) ([]os.FileInfo, error) {
+ if !f.fi.dir {
+ return nil, os.ErrInvalid
+ }
+ var fis []os.FileInfo
+
+ limit := f.entpos + count
+ if count <= 0 || limit > len(f.fi.ents) {
+ limit = len(f.fi.ents)
+ }
+ for ; f.entpos < limit; f.entpos++ {
+ fis = append(fis, f.fi.ents[f.entpos])
+ }
+
+ if len(fis) == 0 && count > 0 {
+ return fis, io.EOF
+ } else {
+ return fis, nil
+ }
+}
+
+type fakeFS map[string]*fakeFileInfo
+
+func (fs fakeFS) Open(name string) (File, error) {
+ name = path.Clean(name)
+ f, ok := fs[name]
+ if !ok {
+ return nil, os.ErrNotExist
+ }
+ return &fakeFile{ReadSeeker: strings.NewReader(f.contents), fi: f, path: name}, nil
+}
+
+func TestDirectoryIfNotModified(t *testing.T) {
+ defer afterTest(t)
+ const indexContents = "I am a fake index.html file"
+ fileMod := time.Unix(1000000000, 0).UTC()
+ fileModStr := fileMod.Format(TimeFormat)
+ dirMod := time.Unix(123, 0).UTC()
+ indexFile := &fakeFileInfo{
+ basename: "index.html",
+ modtime: fileMod,
+ contents: indexContents,
+ }
+ fs := fakeFS{
+ "/": &fakeFileInfo{
+ dir: true,
+ modtime: dirMod,
+ ents: []*fakeFileInfo{indexFile},
+ },
+ "/index.html": indexFile,
+ }
+
+ ts := httptest.NewServer(FileServer(fs))
+ defer ts.Close()
+
+ res, err := Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ b, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(b) != indexContents {
+ t.Fatalf("Got body %q; want %q", b, indexContents)
+ }
+ res.Body.Close()
+
+ lastMod := res.Header.Get("Last-Modified")
+ if lastMod != fileModStr {
+ t.Fatalf("initial Last-Modified = %q; want %q", lastMod, fileModStr)
+ }
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ req.Header.Set("If-Modified-Since", lastMod)
+
+ res, err = DefaultClient.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 304 {
+ t.Fatalf("Code after If-Modified-Since request = %v; want 304", res.StatusCode)
+ }
+ res.Body.Close()
+
+ // Advance the index.html file's modtime, but not the directory's.
+ indexFile.modtime = indexFile.modtime.Add(1 * time.Hour)
+
+ res, err = DefaultClient.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 200 {
+ t.Fatalf("Code after second If-Modified-Since request = %v; want 200; res is %#v", res.StatusCode, res)
+ }
+ res.Body.Close()
+}
+
+func mustStat(t *testing.T, fileName string) os.FileInfo {
+ fi, err := os.Stat(fileName)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return fi
+}
+
+func TestServeContent(t *testing.T) {
+ defer afterTest(t)
+ type serveParam struct {
+ name string
+ modtime time.Time
+ content io.ReadSeeker
+ contentType string
+ etag string
+ }
+ servec := make(chan serveParam, 1)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ p := <-servec
+ if p.etag != "" {
+ w.Header().Set("ETag", p.etag)
+ }
+ if p.contentType != "" {
+ w.Header().Set("Content-Type", p.contentType)
+ }
+ ServeContent(w, r, p.name, p.modtime, p.content)
+ }))
+ defer ts.Close()
+
+ type testCase struct {
+ // One of file or content must be set:
+ file string
+ content io.ReadSeeker
+
+ modtime time.Time
+ serveETag string // optional
+ serveContentType string // optional
+ reqHeader map[string]string
+ wantLastMod string
+ wantContentType string
+ wantStatus int
+ }
+ htmlModTime := mustStat(t, "testdata/index.html").ModTime()
+ tests := map[string]testCase{
+ "no_last_modified": {
+ file: "testdata/style.css",
+ wantContentType: "text/css; charset=utf-8",
+ wantStatus: 200,
+ },
+ "with_last_modified": {
+ file: "testdata/index.html",
+ wantContentType: "text/html; charset=utf-8",
+ modtime: htmlModTime,
+ wantLastMod: htmlModTime.UTC().Format(TimeFormat),
+ wantStatus: 200,
+ },
+ "not_modified_modtime": {
+ file: "testdata/style.css",
+ modtime: htmlModTime,
+ reqHeader: map[string]string{
+ "If-Modified-Since": htmlModTime.UTC().Format(TimeFormat),
+ },
+ wantStatus: 304,
+ },
+ "not_modified_modtime_with_contenttype": {
+ file: "testdata/style.css",
+ serveContentType: "text/css", // explicit content type
+ modtime: htmlModTime,
+ reqHeader: map[string]string{
+ "If-Modified-Since": htmlModTime.UTC().Format(TimeFormat),
+ },
+ wantStatus: 304,
+ },
+ "not_modified_etag": {
+ file: "testdata/style.css",
+ serveETag: `"foo"`,
+ reqHeader: map[string]string{
+ "If-None-Match": `"foo"`,
+ },
+ wantStatus: 304,
+ },
+ "not_modified_etag_no_seek": {
+ content: panicOnSeek{nil}, // should never be called
+ serveETag: `"foo"`,
+ reqHeader: map[string]string{
+ "If-None-Match": `"foo"`,
+ },
+ wantStatus: 304,
+ },
+ "range_good": {
+ file: "testdata/style.css",
+ serveETag: `"A"`,
+ reqHeader: map[string]string{
+ "Range": "bytes=0-4",
+ },
+ wantStatus: StatusPartialContent,
+ wantContentType: "text/css; charset=utf-8",
+ },
+ // An If-Range resource for entity "A", but entity "B" is now current.
+ // The Range request should be ignored.
+ "range_no_match": {
+ file: "testdata/style.css",
+ serveETag: `"A"`,
+ reqHeader: map[string]string{
+ "Range": "bytes=0-4",
+ "If-Range": `"B"`,
+ },
+ wantStatus: 200,
+ wantContentType: "text/css; charset=utf-8",
+ },
+ "range_with_modtime": {
+ file: "testdata/style.css",
+ modtime: time.Date(2014, 6, 25, 17, 12, 18, 0 /* nanos */, time.UTC),
+ reqHeader: map[string]string{
+ "Range": "bytes=0-4",
+ "If-Range": "Wed, 25 Jun 2014 17:12:18 GMT",
+ },
+ wantStatus: StatusPartialContent,
+ wantContentType: "text/css; charset=utf-8",
+ wantLastMod: "Wed, 25 Jun 2014 17:12:18 GMT",
+ },
+ "range_with_modtime_nanos": {
+ file: "testdata/style.css",
+ modtime: time.Date(2014, 6, 25, 17, 12, 18, 123 /* nanos */, time.UTC),
+ reqHeader: map[string]string{
+ "Range": "bytes=0-4",
+ "If-Range": "Wed, 25 Jun 2014 17:12:18 GMT",
+ },
+ wantStatus: StatusPartialContent,
+ wantContentType: "text/css; charset=utf-8",
+ wantLastMod: "Wed, 25 Jun 2014 17:12:18 GMT",
+ },
+ }
+ for testName, tt := range tests {
+ var content io.ReadSeeker
+ if tt.file != "" {
+ f, err := os.Open(tt.file)
+ if err != nil {
+ t.Fatalf("test %q: %v", testName, err)
+ }
+ defer f.Close()
+ content = f
+ } else {
+ content = tt.content
+ }
+
+ servec <- serveParam{
+ name: filepath.Base(tt.file),
+ content: content,
+ modtime: tt.modtime,
+ etag: tt.serveETag,
+ contentType: tt.serveContentType,
+ }
+ req, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for k, v := range tt.reqHeader {
+ req.Header.Set(k, v)
+ }
+ res, err := DefaultClient.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ io.Copy(ioutil.Discard, res.Body)
+ res.Body.Close()
+ if res.StatusCode != tt.wantStatus {
+ t.Errorf("test %q: status = %d; want %d", testName, res.StatusCode, tt.wantStatus)
+ }
+ if g, e := res.Header.Get("Content-Type"), tt.wantContentType; g != e {
+ t.Errorf("test %q: content-type = %q, want %q", testName, g, e)
+ }
+ if g, e := res.Header.Get("Last-Modified"), tt.wantLastMod; g != e {
+ t.Errorf("test %q: last-modified = %q, want %q", testName, g, e)
+ }
+ }
+}
+
+// verifies that sendfile is being used on Linux
+func TestLinuxSendfile(t *testing.T) {
+ defer afterTest(t)
+ if runtime.GOOS != "linux" {
+ t.Skip("skipping; linux-only test")
+ }
+ if _, err := exec.LookPath("strace"); err != nil {
+ t.Skip("skipping; strace not found in path")
+ }
+
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ lnf, err := ln.(*net.TCPListener).File()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+
+ var buf bytes.Buffer
+ child := exec.Command("strace", "-f", "-q", "-e", "trace=sendfile,sendfile64", os.Args[0], "-test.run=TestLinuxSendfileChild")
+ child.ExtraFiles = append(child.ExtraFiles, lnf)
+ child.Env = append([]string{"GO_WANT_HELPER_PROCESS=1"}, os.Environ()...)
+ child.Stdout = &buf
+ child.Stderr = &buf
+ if err := child.Start(); err != nil {
+ t.Skipf("skipping; failed to start straced child: %v", err)
+ }
+
+ res, err := Get(fmt.Sprintf("http://%s/", ln.Addr()))
+ if err != nil {
+ t.Fatalf("http client error: %v", err)
+ }
+ _, err = io.Copy(ioutil.Discard, res.Body)
+ if err != nil {
+ t.Fatalf("client body read error: %v", err)
+ }
+ res.Body.Close()
+
+ // Force child to exit cleanly.
+ Get(fmt.Sprintf("http://%s/quit", ln.Addr()))
+ child.Wait()
+
+ rx := regexp.MustCompile(`sendfile(64)?\(\d+,\s*\d+,\s*NULL,\s*\d+\)\s*=\s*\d+\s*\n`)
+ rxResume := regexp.MustCompile(`<\.\.\. sendfile(64)? resumed> \)\s*=\s*\d+\s*\n`)
+ out := buf.String()
+ if !rx.MatchString(out) && !rxResume.MatchString(out) {
+ t.Errorf("no sendfile system call found in:\n%s", out)
+ }
+}
+
+func getBody(t *testing.T, testName string, req Request) (*Response, []byte) {
+ r, err := DefaultClient.Do(&req)
+ if err != nil {
+ t.Fatalf("%s: for URL %q, send error: %v", testName, req.URL.String(), err)
+ }
+ b, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ t.Fatalf("%s: for URL %q, reading body: %v", testName, req.URL.String(), err)
+ }
+ return r, b
+}
+
+// TestLinuxSendfileChild isn't a real test. It's used as a helper process
+// for TestLinuxSendfile.
+func TestLinuxSendfileChild(*testing.T) {
+ if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" {
+ return
+ }
+ defer os.Exit(0)
+ fd3 := os.NewFile(3, "ephemeral-port-listener")
+ ln, err := net.FileListener(fd3)
+ if err != nil {
+ panic(err)
+ }
+ mux := NewServeMux()
+ mux.Handle("/", FileServer(Dir("testdata")))
+ mux.HandleFunc("/quit", func(ResponseWriter, *Request) {
+ os.Exit(0)
+ })
+ s := &Server{Handler: mux}
+ err = s.Serve(ln)
+ if err != nil {
+ panic(err)
+ }
+}
+
+func TestFileServerCleanPath(t *testing.T) {
+ tests := []struct {
+ path string
+ wantCode int
+ wantOpen []string
+ }{
+ {"/", 200, []string{"/", "/index.html"}},
+ {"/dir", 301, []string{"/dir"}},
+ {"/dir/", 200, []string{"/dir", "/dir/index.html"}},
+ }
+ for _, tt := range tests {
+ var log []string
+ rr := httptest.NewRecorder()
+ req, _ := NewRequest("GET", "http://foo.localhost"+tt.path, nil)
+ FileServer(fileServerCleanPathDir{&log}).ServeHTTP(rr, req)
+ if !reflect.DeepEqual(log, tt.wantOpen) {
+ t.Logf("For %s: Opens = %q; want %q", tt.path, log, tt.wantOpen)
+ }
+ if rr.Code != tt.wantCode {
+ t.Logf("For %s: Response code = %d; want %d", tt.path, rr.Code, tt.wantCode)
+ }
+ }
+}
+
+type fileServerCleanPathDir struct {
+ log *[]string
+}
+
+func (d fileServerCleanPathDir) Open(path string) (File, error) {
+ *(d.log) = append(*(d.log), path)
+ if path == "/" || path == "/dir" || path == "/dir/" {
+ // Just return back something that's a directory.
+ return Dir(".").Open(".")
+ }
+ return nil, os.ErrNotExist
+}
+
+type panicOnSeek struct{ io.ReadSeeker }
diff --git a/src/net/http/header.go b/src/net/http/header.go
new file mode 100644
index 000000000..153b94370
--- /dev/null
+++ b/src/net/http/header.go
@@ -0,0 +1,211 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "io"
+ "net/textproto"
+ "sort"
+ "strings"
+ "sync"
+ "time"
+)
+
+var raceEnabled = false // set by race.go
+
+// A Header represents the key-value pairs in an HTTP header.
+type Header map[string][]string
+
+// Add adds the key, value pair to the header.
+// It appends to any existing values associated with key.
+func (h Header) Add(key, value string) {
+ textproto.MIMEHeader(h).Add(key, value)
+}
+
+// Set sets the header entries associated with key to
+// the single element value. It replaces any existing
+// values associated with key.
+func (h Header) Set(key, value string) {
+ textproto.MIMEHeader(h).Set(key, value)
+}
+
+// Get gets the first value associated with the given key.
+// If there are no values associated with the key, Get returns "".
+// To access multiple values of a key, access the map directly
+// with CanonicalHeaderKey.
+func (h Header) Get(key string) string {
+ return textproto.MIMEHeader(h).Get(key)
+}
+
+// get is like Get, but key must already be in CanonicalHeaderKey form.
+func (h Header) get(key string) string {
+ if v := h[key]; len(v) > 0 {
+ return v[0]
+ }
+ return ""
+}
+
+// Del deletes the values associated with key.
+func (h Header) Del(key string) {
+ textproto.MIMEHeader(h).Del(key)
+}
+
+// Write writes a header in wire format.
+func (h Header) Write(w io.Writer) error {
+ return h.WriteSubset(w, nil)
+}
+
+func (h Header) clone() Header {
+ h2 := make(Header, len(h))
+ for k, vv := range h {
+ vv2 := make([]string, len(vv))
+ copy(vv2, vv)
+ h2[k] = vv2
+ }
+ return h2
+}
+
+var timeFormats = []string{
+ TimeFormat,
+ time.RFC850,
+ time.ANSIC,
+}
+
+// ParseTime parses a time header (such as the Date: header),
+// trying each of the three formats allowed by HTTP/1.1:
+// TimeFormat, time.RFC850, and time.ANSIC.
+func ParseTime(text string) (t time.Time, err error) {
+ for _, layout := range timeFormats {
+ t, err = time.Parse(layout, text)
+ if err == nil {
+ return
+ }
+ }
+ return
+}
+
+var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ")
+
+type writeStringer interface {
+ WriteString(string) (int, error)
+}
+
+// stringWriter implements WriteString on a Writer.
+type stringWriter struct {
+ w io.Writer
+}
+
+func (w stringWriter) WriteString(s string) (n int, err error) {
+ return w.w.Write([]byte(s))
+}
+
+type keyValues struct {
+ key string
+ values []string
+}
+
+// A headerSorter implements sort.Interface by sorting a []keyValues
+// by key. It's used as a pointer, so it can fit in a sort.Interface
+// interface value without allocation.
+type headerSorter struct {
+ kvs []keyValues
+}
+
+func (s *headerSorter) Len() int { return len(s.kvs) }
+func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] }
+func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key }
+
+var headerSorterPool = sync.Pool{
+ New: func() interface{} { return new(headerSorter) },
+}
+
+// sortedKeyValues returns h's keys sorted in the returned kvs
+// slice. The headerSorter used to sort is also returned, for possible
+// return to headerSorterCache.
+func (h Header) sortedKeyValues(exclude map[string]bool) (kvs []keyValues, hs *headerSorter) {
+ hs = headerSorterPool.Get().(*headerSorter)
+ if cap(hs.kvs) < len(h) {
+ hs.kvs = make([]keyValues, 0, len(h))
+ }
+ kvs = hs.kvs[:0]
+ for k, vv := range h {
+ if !exclude[k] {
+ kvs = append(kvs, keyValues{k, vv})
+ }
+ }
+ hs.kvs = kvs
+ sort.Sort(hs)
+ return kvs, hs
+}
+
+// WriteSubset writes a header in wire format.
+// If exclude is not nil, keys where exclude[key] == true are not written.
+func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error {
+ ws, ok := w.(writeStringer)
+ if !ok {
+ ws = stringWriter{w}
+ }
+ kvs, sorter := h.sortedKeyValues(exclude)
+ for _, kv := range kvs {
+ for _, v := range kv.values {
+ v = headerNewlineToSpace.Replace(v)
+ v = textproto.TrimString(v)
+ for _, s := range []string{kv.key, ": ", v, "\r\n"} {
+ if _, err := ws.WriteString(s); err != nil {
+ return err
+ }
+ }
+ }
+ }
+ headerSorterPool.Put(sorter)
+ return nil
+}
+
+// CanonicalHeaderKey returns the canonical format of the
+// header key s. The canonicalization converts the first
+// letter and any letter following a hyphen to upper case;
+// the rest are converted to lowercase. For example, the
+// canonical key for "accept-encoding" is "Accept-Encoding".
+func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) }
+
+// hasToken reports whether token appears with v, ASCII
+// case-insensitive, with space or comma boundaries.
+// token must be all lowercase.
+// v may contain mixed cased.
+func hasToken(v, token string) bool {
+ if len(token) > len(v) || token == "" {
+ return false
+ }
+ if v == token {
+ return true
+ }
+ for sp := 0; sp <= len(v)-len(token); sp++ {
+ // Check that first character is good.
+ // The token is ASCII, so checking only a single byte
+ // is sufficient. We skip this potential starting
+ // position if both the first byte and its potential
+ // ASCII uppercase equivalent (b|0x20) don't match.
+ // False positives ('^' => '~') are caught by EqualFold.
+ if b := v[sp]; b != token[0] && b|0x20 != token[0] {
+ continue
+ }
+ // Check that start pos is on a valid token boundary.
+ if sp > 0 && !isTokenBoundary(v[sp-1]) {
+ continue
+ }
+ // Check that end pos is on a valid token boundary.
+ if endPos := sp + len(token); endPos != len(v) && !isTokenBoundary(v[endPos]) {
+ continue
+ }
+ if strings.EqualFold(v[sp:sp+len(token)], token) {
+ return true
+ }
+ }
+ return false
+}
+
+func isTokenBoundary(b byte) bool {
+ return b == ' ' || b == ',' || b == '\t'
+}
diff --git a/src/net/http/header_test.go b/src/net/http/header_test.go
new file mode 100644
index 000000000..9dcd591fa
--- /dev/null
+++ b/src/net/http/header_test.go
@@ -0,0 +1,212 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bytes"
+ "runtime"
+ "testing"
+ "time"
+)
+
+var headerWriteTests = []struct {
+ h Header
+ exclude map[string]bool
+ expected string
+}{
+ {Header{}, nil, ""},
+ {
+ Header{
+ "Content-Type": {"text/html; charset=UTF-8"},
+ "Content-Length": {"0"},
+ },
+ nil,
+ "Content-Length: 0\r\nContent-Type: text/html; charset=UTF-8\r\n",
+ },
+ {
+ Header{
+ "Content-Length": {"0", "1", "2"},
+ },
+ nil,
+ "Content-Length: 0\r\nContent-Length: 1\r\nContent-Length: 2\r\n",
+ },
+ {
+ Header{
+ "Expires": {"-1"},
+ "Content-Length": {"0"},
+ "Content-Encoding": {"gzip"},
+ },
+ map[string]bool{"Content-Length": true},
+ "Content-Encoding: gzip\r\nExpires: -1\r\n",
+ },
+ {
+ Header{
+ "Expires": {"-1"},
+ "Content-Length": {"0", "1", "2"},
+ "Content-Encoding": {"gzip"},
+ },
+ map[string]bool{"Content-Length": true},
+ "Content-Encoding: gzip\r\nExpires: -1\r\n",
+ },
+ {
+ Header{
+ "Expires": {"-1"},
+ "Content-Length": {"0"},
+ "Content-Encoding": {"gzip"},
+ },
+ map[string]bool{"Content-Length": true, "Expires": true, "Content-Encoding": true},
+ "",
+ },
+ {
+ Header{
+ "Nil": nil,
+ "Empty": {},
+ "Blank": {""},
+ "Double-Blank": {"", ""},
+ },
+ nil,
+ "Blank: \r\nDouble-Blank: \r\nDouble-Blank: \r\n",
+ },
+ // Tests header sorting when over the insertion sort threshold side:
+ {
+ Header{
+ "k1": {"1a", "1b"},
+ "k2": {"2a", "2b"},
+ "k3": {"3a", "3b"},
+ "k4": {"4a", "4b"},
+ "k5": {"5a", "5b"},
+ "k6": {"6a", "6b"},
+ "k7": {"7a", "7b"},
+ "k8": {"8a", "8b"},
+ "k9": {"9a", "9b"},
+ },
+ map[string]bool{"k5": true},
+ "k1: 1a\r\nk1: 1b\r\nk2: 2a\r\nk2: 2b\r\nk3: 3a\r\nk3: 3b\r\n" +
+ "k4: 4a\r\nk4: 4b\r\nk6: 6a\r\nk6: 6b\r\n" +
+ "k7: 7a\r\nk7: 7b\r\nk8: 8a\r\nk8: 8b\r\nk9: 9a\r\nk9: 9b\r\n",
+ },
+}
+
+func TestHeaderWrite(t *testing.T) {
+ var buf bytes.Buffer
+ for i, test := range headerWriteTests {
+ test.h.WriteSubset(&buf, test.exclude)
+ if buf.String() != test.expected {
+ t.Errorf("#%d:\n got: %q\nwant: %q", i, buf.String(), test.expected)
+ }
+ buf.Reset()
+ }
+}
+
+var parseTimeTests = []struct {
+ h Header
+ err bool
+}{
+ {Header{"Date": {""}}, true},
+ {Header{"Date": {"invalid"}}, true},
+ {Header{"Date": {"1994-11-06T08:49:37Z00:00"}}, true},
+ {Header{"Date": {"Sun, 06 Nov 1994 08:49:37 GMT"}}, false},
+ {Header{"Date": {"Sunday, 06-Nov-94 08:49:37 GMT"}}, false},
+ {Header{"Date": {"Sun Nov 6 08:49:37 1994"}}, false},
+}
+
+func TestParseTime(t *testing.T) {
+ expect := time.Date(1994, 11, 6, 8, 49, 37, 0, time.UTC)
+ for i, test := range parseTimeTests {
+ d, err := ParseTime(test.h.Get("Date"))
+ if err != nil {
+ if !test.err {
+ t.Errorf("#%d:\n got err: %v", i, err)
+ }
+ continue
+ }
+ if test.err {
+ t.Errorf("#%d:\n should err", i)
+ continue
+ }
+ if !expect.Equal(d) {
+ t.Errorf("#%d:\n got: %v\nwant: %v", i, d, expect)
+ }
+ }
+}
+
+type hasTokenTest struct {
+ header string
+ token string
+ want bool
+}
+
+var hasTokenTests = []hasTokenTest{
+ {"", "", false},
+ {"", "foo", false},
+ {"foo", "foo", true},
+ {"foo ", "foo", true},
+ {" foo", "foo", true},
+ {" foo ", "foo", true},
+ {"foo,bar", "foo", true},
+ {"bar,foo", "foo", true},
+ {"bar, foo", "foo", true},
+ {"bar,foo, baz", "foo", true},
+ {"bar, foo,baz", "foo", true},
+ {"bar,foo, baz", "foo", true},
+ {"bar, foo, baz", "foo", true},
+ {"FOO", "foo", true},
+ {"FOO ", "foo", true},
+ {" FOO", "foo", true},
+ {" FOO ", "foo", true},
+ {"FOO,BAR", "foo", true},
+ {"BAR,FOO", "foo", true},
+ {"BAR, FOO", "foo", true},
+ {"BAR,FOO, baz", "foo", true},
+ {"BAR, FOO,BAZ", "foo", true},
+ {"BAR,FOO, BAZ", "foo", true},
+ {"BAR, FOO, BAZ", "foo", true},
+ {"foobar", "foo", false},
+ {"barfoo ", "foo", false},
+}
+
+func TestHasToken(t *testing.T) {
+ for _, tt := range hasTokenTests {
+ if hasToken(tt.header, tt.token) != tt.want {
+ t.Errorf("hasToken(%q, %q) = %v; want %v", tt.header, tt.token, !tt.want, tt.want)
+ }
+ }
+}
+
+var testHeader = Header{
+ "Content-Length": {"123"},
+ "Content-Type": {"text/plain"},
+ "Date": {"some date at some time Z"},
+ "Server": {DefaultUserAgent},
+}
+
+var buf bytes.Buffer
+
+func BenchmarkHeaderWriteSubset(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ buf.Reset()
+ testHeader.WriteSubset(&buf, nil)
+ }
+}
+
+func TestHeaderWriteSubsetAllocs(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping alloc test in short mode")
+ }
+ if raceEnabled {
+ t.Skip("skipping test under race detector")
+ }
+ if runtime.GOMAXPROCS(0) > 1 {
+ t.Skip("skipping; GOMAXPROCS>1")
+ }
+ n := testing.AllocsPerRun(100, func() {
+ buf.Reset()
+ testHeader.WriteSubset(&buf, nil)
+ })
+ if n > 0 {
+ t.Errorf("allocs = %g; want 0", n)
+ }
+}
diff --git a/src/net/http/httptest/example_test.go b/src/net/http/httptest/example_test.go
new file mode 100644
index 000000000..42a0ec953
--- /dev/null
+++ b/src/net/http/httptest/example_test.go
@@ -0,0 +1,50 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httptest_test
+
+import (
+ "fmt"
+ "io/ioutil"
+ "log"
+ "net/http"
+ "net/http/httptest"
+)
+
+func ExampleResponseRecorder() {
+ handler := func(w http.ResponseWriter, r *http.Request) {
+ http.Error(w, "something failed", http.StatusInternalServerError)
+ }
+
+ req, err := http.NewRequest("GET", "http://example.com/foo", nil)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ w := httptest.NewRecorder()
+ handler(w, req)
+
+ fmt.Printf("%d - %s", w.Code, w.Body.String())
+ // Output: 500 - something failed
+}
+
+func ExampleServer() {
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintln(w, "Hello, client")
+ }))
+ defer ts.Close()
+
+ res, err := http.Get(ts.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+ greeting, err := ioutil.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Printf("%s", greeting)
+ // Output: Hello, client
+}
diff --git a/src/net/http/httptest/recorder.go b/src/net/http/httptest/recorder.go
new file mode 100644
index 000000000..5451f5423
--- /dev/null
+++ b/src/net/http/httptest/recorder.go
@@ -0,0 +1,72 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package httptest provides utilities for HTTP testing.
+package httptest
+
+import (
+ "bytes"
+ "net/http"
+)
+
+// ResponseRecorder is an implementation of http.ResponseWriter that
+// records its mutations for later inspection in tests.
+type ResponseRecorder struct {
+ Code int // the HTTP response code from WriteHeader
+ HeaderMap http.Header // the HTTP response headers
+ Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
+ Flushed bool
+
+ wroteHeader bool
+}
+
+// NewRecorder returns an initialized ResponseRecorder.
+func NewRecorder() *ResponseRecorder {
+ return &ResponseRecorder{
+ HeaderMap: make(http.Header),
+ Body: new(bytes.Buffer),
+ Code: 200,
+ }
+}
+
+// DefaultRemoteAddr is the default remote address to return in RemoteAddr if
+// an explicit DefaultRemoteAddr isn't set on ResponseRecorder.
+const DefaultRemoteAddr = "1.2.3.4"
+
+// Header returns the response headers.
+func (rw *ResponseRecorder) Header() http.Header {
+ m := rw.HeaderMap
+ if m == nil {
+ m = make(http.Header)
+ rw.HeaderMap = m
+ }
+ return m
+}
+
+// Write always succeeds and writes to rw.Body, if not nil.
+func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
+ if !rw.wroteHeader {
+ rw.WriteHeader(200)
+ }
+ if rw.Body != nil {
+ rw.Body.Write(buf)
+ }
+ return len(buf), nil
+}
+
+// WriteHeader sets rw.Code.
+func (rw *ResponseRecorder) WriteHeader(code int) {
+ if !rw.wroteHeader {
+ rw.Code = code
+ }
+ rw.wroteHeader = true
+}
+
+// Flush sets rw.Flushed to true.
+func (rw *ResponseRecorder) Flush() {
+ if !rw.wroteHeader {
+ rw.WriteHeader(200)
+ }
+ rw.Flushed = true
+}
diff --git a/src/net/http/httptest/recorder_test.go b/src/net/http/httptest/recorder_test.go
new file mode 100644
index 000000000..2b563260c
--- /dev/null
+++ b/src/net/http/httptest/recorder_test.go
@@ -0,0 +1,90 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httptest
+
+import (
+ "fmt"
+ "net/http"
+ "testing"
+)
+
+func TestRecorder(t *testing.T) {
+ type checkFunc func(*ResponseRecorder) error
+ check := func(fns ...checkFunc) []checkFunc { return fns }
+
+ hasStatus := func(wantCode int) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if rec.Code != wantCode {
+ return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode)
+ }
+ return nil
+ }
+ }
+ hasContents := func(want string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if rec.Body.String() != want {
+ return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want)
+ }
+ return nil
+ }
+ }
+ hasFlush := func(want bool) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if rec.Flushed != want {
+ return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want)
+ }
+ return nil
+ }
+ }
+
+ tests := []struct {
+ name string
+ h func(w http.ResponseWriter, r *http.Request)
+ checks []checkFunc
+ }{
+ {
+ "200 default",
+ func(w http.ResponseWriter, r *http.Request) {},
+ check(hasStatus(200), hasContents("")),
+ },
+ {
+ "first code only",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(201)
+ w.WriteHeader(202)
+ w.Write([]byte("hi"))
+ },
+ check(hasStatus(201), hasContents("hi")),
+ },
+ {
+ "write sends 200",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("hi first"))
+ w.WriteHeader(201)
+ w.WriteHeader(202)
+ },
+ check(hasStatus(200), hasContents("hi first"), hasFlush(false)),
+ },
+ {
+ "flush",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.(http.Flusher).Flush() // also sends a 200
+ w.WriteHeader(201)
+ },
+ check(hasStatus(200), hasFlush(true)),
+ },
+ }
+ r, _ := http.NewRequest("GET", "http://foo.com/", nil)
+ for _, tt := range tests {
+ h := http.HandlerFunc(tt.h)
+ rec := NewRecorder()
+ h.ServeHTTP(rec, r)
+ for _, check := range tt.checks {
+ if err := check(rec); err != nil {
+ t.Errorf("%s: %v", tt.name, err)
+ }
+ }
+ }
+}
diff --git a/src/net/http/httptest/server.go b/src/net/http/httptest/server.go
new file mode 100644
index 000000000..789e7bf41
--- /dev/null
+++ b/src/net/http/httptest/server.go
@@ -0,0 +1,228 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Implementation of Server
+
+package httptest
+
+import (
+ "crypto/tls"
+ "flag"
+ "fmt"
+ "net"
+ "net/http"
+ "os"
+ "sync"
+)
+
+// A Server is an HTTP server listening on a system-chosen port on the
+// local loopback interface, for use in end-to-end HTTP tests.
+type Server struct {
+ URL string // base URL of form http://ipaddr:port with no trailing slash
+ Listener net.Listener
+
+ // TLS is the optional TLS configuration, populated with a new config
+ // after TLS is started. If set on an unstarted server before StartTLS
+ // is called, existing fields are copied into the new config.
+ TLS *tls.Config
+
+ // Config may be changed after calling NewUnstartedServer and
+ // before Start or StartTLS.
+ Config *http.Server
+
+ // wg counts the number of outstanding HTTP requests on this server.
+ // Close blocks until all requests are finished.
+ wg sync.WaitGroup
+}
+
+// historyListener keeps track of all connections that it's ever
+// accepted.
+type historyListener struct {
+ net.Listener
+ sync.Mutex // protects history
+ history []net.Conn
+}
+
+func (hs *historyListener) Accept() (c net.Conn, err error) {
+ c, err = hs.Listener.Accept()
+ if err == nil {
+ hs.Lock()
+ hs.history = append(hs.history, c)
+ hs.Unlock()
+ }
+ return
+}
+
+func newLocalListener() net.Listener {
+ if *serve != "" {
+ l, err := net.Listen("tcp", *serve)
+ if err != nil {
+ panic(fmt.Sprintf("httptest: failed to listen on %v: %v", *serve, err))
+ }
+ return l
+ }
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
+ panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
+ }
+ }
+ return l
+}
+
+// When debugging a particular http server-based test,
+// this flag lets you run
+// go test -run=BrokenTest -httptest.serve=127.0.0.1:8000
+// to start the broken server so you can interact with it manually.
+var serve = flag.String("httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks")
+
+// NewServer starts and returns a new Server.
+// The caller should call Close when finished, to shut it down.
+func NewServer(handler http.Handler) *Server {
+ ts := NewUnstartedServer(handler)
+ ts.Start()
+ return ts
+}
+
+// NewUnstartedServer returns a new Server but doesn't start it.
+//
+// After changing its configuration, the caller should call Start or
+// StartTLS.
+//
+// The caller should call Close when finished, to shut it down.
+func NewUnstartedServer(handler http.Handler) *Server {
+ return &Server{
+ Listener: newLocalListener(),
+ Config: &http.Server{Handler: handler},
+ }
+}
+
+// Start starts a server from NewUnstartedServer.
+func (s *Server) Start() {
+ if s.URL != "" {
+ panic("Server already started")
+ }
+ s.Listener = &historyListener{Listener: s.Listener}
+ s.URL = "http://" + s.Listener.Addr().String()
+ s.wrapHandler()
+ go s.Config.Serve(s.Listener)
+ if *serve != "" {
+ fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
+ select {}
+ }
+}
+
+// StartTLS starts TLS on a server from NewUnstartedServer.
+func (s *Server) StartTLS() {
+ if s.URL != "" {
+ panic("Server already started")
+ }
+ cert, err := tls.X509KeyPair(localhostCert, localhostKey)
+ if err != nil {
+ panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
+ }
+
+ existingConfig := s.TLS
+ s.TLS = new(tls.Config)
+ if existingConfig != nil {
+ *s.TLS = *existingConfig
+ }
+ if s.TLS.NextProtos == nil {
+ s.TLS.NextProtos = []string{"http/1.1"}
+ }
+ if len(s.TLS.Certificates) == 0 {
+ s.TLS.Certificates = []tls.Certificate{cert}
+ }
+ tlsListener := tls.NewListener(s.Listener, s.TLS)
+
+ s.Listener = &historyListener{Listener: tlsListener}
+ s.URL = "https://" + s.Listener.Addr().String()
+ s.wrapHandler()
+ go s.Config.Serve(s.Listener)
+}
+
+func (s *Server) wrapHandler() {
+ h := s.Config.Handler
+ if h == nil {
+ h = http.DefaultServeMux
+ }
+ s.Config.Handler = &waitGroupHandler{
+ s: s,
+ h: h,
+ }
+}
+
+// NewTLSServer starts and returns a new Server using TLS.
+// The caller should call Close when finished, to shut it down.
+func NewTLSServer(handler http.Handler) *Server {
+ ts := NewUnstartedServer(handler)
+ ts.StartTLS()
+ return ts
+}
+
+// Close shuts down the server and blocks until all outstanding
+// requests on this server have completed.
+func (s *Server) Close() {
+ s.Listener.Close()
+ s.wg.Wait()
+ s.CloseClientConnections()
+ if t, ok := http.DefaultTransport.(*http.Transport); ok {
+ t.CloseIdleConnections()
+ }
+}
+
+// CloseClientConnections closes any currently open HTTP connections
+// to the test Server.
+func (s *Server) CloseClientConnections() {
+ hl, ok := s.Listener.(*historyListener)
+ if !ok {
+ return
+ }
+ hl.Lock()
+ for _, conn := range hl.history {
+ conn.Close()
+ }
+ hl.Unlock()
+}
+
+// waitGroupHandler wraps a handler, incrementing and decrementing a
+// sync.WaitGroup on each request, to enable Server.Close to block
+// until outstanding requests are finished.
+type waitGroupHandler struct {
+ s *Server
+ h http.Handler // non-nil
+}
+
+func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ h.s.wg.Add(1)
+ defer h.s.wg.Done() // a defer, in case ServeHTTP below panics
+ h.h.ServeHTTP(w, r)
+}
+
+// localhostCert is a PEM-encoded TLS cert with SAN IPs
+// "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end
+// of ASN.1 time).
+// generated from src/crypto/tls:
+// go run generate_cert.go --rsa-bits 512 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
+var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
+MIIBdzCCASOgAwIBAgIBADALBgkqhkiG9w0BAQUwEjEQMA4GA1UEChMHQWNtZSBD
+bzAeFw03MDAxMDEwMDAwMDBaFw00OTEyMzEyMzU5NTlaMBIxEDAOBgNVBAoTB0Fj
+bWUgQ28wWjALBgkqhkiG9w0BAQEDSwAwSAJBAN55NcYKZeInyTuhcCwFMhDHCmwa
+IUSdtXdcbItRB/yfXGBhiex00IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEA
+AaNoMGYwDgYDVR0PAQH/BAQDAgCkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1Ud
+EwEB/wQFMAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAA
+AAAAAAAAAAAAAAEwCwYJKoZIhvcNAQEFA0EAAoQn/ytgqpiLcZu9XKbCJsJcvkgk
+Se6AbGXgSlq+ZCEVo0qIwSgeBqmsJxUu7NCSOwVJLYNEBO2DtIxoYVk+MA==
+-----END CERTIFICATE-----`)
+
+// localhostKey is the private key for localhostCert.
+var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
+MIIBPAIBAAJBAN55NcYKZeInyTuhcCwFMhDHCmwaIUSdtXdcbItRB/yfXGBhiex0
+0IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEAAQJBAQdUx66rfh8sYsgfdcvV
+NoafYpnEcB5s4m/vSVe6SU7dCK6eYec9f9wpT353ljhDUHq3EbmE4foNzJngh35d
+AekCIQDhRQG5Li0Wj8TM4obOnnXUXf1jRv0UkzE9AHWLG5q3AwIhAPzSjpYUDjVW
+MCUXgckTpKCuGwbJk7424Nb8bLzf3kllAiA5mUBgjfr/WtFSJdWcPQ4Zt9KTMNKD
+EUO0ukpTwEIl6wIhAMbGqZK3zAAFdq8DD2jPx+UJXnh0rnOkZBzDtJ6/iN69AiEA
+1Aq8MJgTaYsDQWyU/hDq5YkDJc9e9DSCvUIzqxQWMQE=
+-----END RSA PRIVATE KEY-----`)
diff --git a/src/net/http/httptest/server_test.go b/src/net/http/httptest/server_test.go
new file mode 100644
index 000000000..500a9f0b8
--- /dev/null
+++ b/src/net/http/httptest/server_test.go
@@ -0,0 +1,29 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httptest
+
+import (
+ "io/ioutil"
+ "net/http"
+ "testing"
+)
+
+func TestServer(t *testing.T) {
+ ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("hello"))
+ }))
+ defer ts.Close()
+ res, err := http.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ got, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(got) != "hello" {
+ t.Errorf("got %q, want hello", string(got))
+ }
+}
diff --git a/src/net/http/httputil/dump.go b/src/net/http/httputil/dump.go
new file mode 100644
index 000000000..ac8f103f9
--- /dev/null
+++ b/src/net/http/httputil/dump.go
@@ -0,0 +1,284 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httputil
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+)
+
+// One of the copies, say from b to r2, could be avoided by using a more
+// elaborate trick where the other copy is made during Request/Response.Write.
+// This would complicate things too much, given that these functions are for
+// debugging only.
+func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) {
+ var buf bytes.Buffer
+ if _, err = buf.ReadFrom(b); err != nil {
+ return nil, nil, err
+ }
+ if err = b.Close(); err != nil {
+ return nil, nil, err
+ }
+ return ioutil.NopCloser(&buf), ioutil.NopCloser(bytes.NewReader(buf.Bytes())), nil
+}
+
+// dumpConn is a net.Conn which writes to Writer and reads from Reader
+type dumpConn struct {
+ io.Writer
+ io.Reader
+}
+
+func (c *dumpConn) Close() error { return nil }
+func (c *dumpConn) LocalAddr() net.Addr { return nil }
+func (c *dumpConn) RemoteAddr() net.Addr { return nil }
+func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
+func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
+func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
+
+type neverEnding byte
+
+func (b neverEnding) Read(p []byte) (n int, err error) {
+ for i := range p {
+ p[i] = byte(b)
+ }
+ return len(p), nil
+}
+
+// DumpRequestOut is like DumpRequest but includes
+// headers that the standard http.Transport adds,
+// such as User-Agent.
+func DumpRequestOut(req *http.Request, body bool) ([]byte, error) {
+ save := req.Body
+ dummyBody := false
+ if !body || req.Body == nil {
+ req.Body = nil
+ if req.ContentLength != 0 {
+ req.Body = ioutil.NopCloser(io.LimitReader(neverEnding('x'), req.ContentLength))
+ dummyBody = true
+ }
+ } else {
+ var err error
+ save, req.Body, err = drainBody(req.Body)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ // Since we're using the actual Transport code to write the request,
+ // switch to http so the Transport doesn't try to do an SSL
+ // negotiation with our dumpConn and its bytes.Buffer & pipe.
+ // The wire format for https and http are the same, anyway.
+ reqSend := req
+ if req.URL.Scheme == "https" {
+ reqSend = new(http.Request)
+ *reqSend = *req
+ reqSend.URL = new(url.URL)
+ *reqSend.URL = *req.URL
+ reqSend.URL.Scheme = "http"
+ }
+
+ // Use the actual Transport code to record what we would send
+ // on the wire, but not using TCP. Use a Transport with a
+ // custom dialer that returns a fake net.Conn that waits
+ // for the full input (and recording it), and then responds
+ // with a dummy response.
+ var buf bytes.Buffer // records the output
+ pr, pw := io.Pipe()
+ defer pr.Close()
+ defer pw.Close()
+ dr := &delegateReader{c: make(chan io.Reader)}
+ // Wait for the request before replying with a dummy response:
+ go func() {
+ req, err := http.ReadRequest(bufio.NewReader(pr))
+ if err == nil {
+ // Ensure all the body is read; otherwise
+ // we'll get a partial dump.
+ io.Copy(ioutil.Discard, req.Body)
+ req.Body.Close()
+ }
+ dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\n\r\n")
+ }()
+
+ t := &http.Transport{
+ DisableKeepAlives: true,
+ Dial: func(net, addr string) (net.Conn, error) {
+ return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil
+ },
+ }
+
+ _, err := t.RoundTrip(reqSend)
+
+ req.Body = save
+ if err != nil {
+ return nil, err
+ }
+ dump := buf.Bytes()
+
+ // If we used a dummy body above, remove it now.
+ // TODO: if the req.ContentLength is large, we allocate memory
+ // unnecessarily just to slice it off here. But this is just
+ // a debug function, so this is acceptable for now. We could
+ // discard the body earlier if this matters.
+ if dummyBody {
+ if i := bytes.Index(dump, []byte("\r\n\r\n")); i >= 0 {
+ dump = dump[:i+4]
+ }
+ }
+ return dump, nil
+}
+
+// delegateReader is a reader that delegates to another reader,
+// once it arrives on a channel.
+type delegateReader struct {
+ c chan io.Reader
+ r io.Reader // nil until received from c
+}
+
+func (r *delegateReader) Read(p []byte) (int, error) {
+ if r.r == nil {
+ r.r = <-r.c
+ }
+ return r.r.Read(p)
+}
+
+// Return value if nonempty, def otherwise.
+func valueOrDefault(value, def string) string {
+ if value != "" {
+ return value
+ }
+ return def
+}
+
+var reqWriteExcludeHeaderDump = map[string]bool{
+ "Host": true, // not in Header map anyway
+ "Content-Length": true,
+ "Transfer-Encoding": true,
+ "Trailer": true,
+}
+
+// dumpAsReceived writes req to w in the form as it was received, or
+// at least as accurately as possible from the information retained in
+// the request.
+func dumpAsReceived(req *http.Request, w io.Writer) error {
+ return nil
+}
+
+// DumpRequest returns the as-received wire representation of req,
+// optionally including the request body, for debugging.
+// DumpRequest is semantically a no-op, but in order to
+// dump the body, it reads the body data into memory and
+// changes req.Body to refer to the in-memory copy.
+// The documentation for http.Request.Write details which fields
+// of req are used.
+func DumpRequest(req *http.Request, body bool) (dump []byte, err error) {
+ save := req.Body
+ if !body || req.Body == nil {
+ req.Body = nil
+ } else {
+ save, req.Body, err = drainBody(req.Body)
+ if err != nil {
+ return
+ }
+ }
+
+ var b bytes.Buffer
+
+ fmt.Fprintf(&b, "%s %s HTTP/%d.%d\r\n", valueOrDefault(req.Method, "GET"),
+ req.URL.RequestURI(), req.ProtoMajor, req.ProtoMinor)
+
+ host := req.Host
+ if host == "" && req.URL != nil {
+ host = req.URL.Host
+ }
+ if host != "" {
+ fmt.Fprintf(&b, "Host: %s\r\n", host)
+ }
+
+ chunked := len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked"
+ if len(req.TransferEncoding) > 0 {
+ fmt.Fprintf(&b, "Transfer-Encoding: %s\r\n", strings.Join(req.TransferEncoding, ","))
+ }
+ if req.Close {
+ fmt.Fprintf(&b, "Connection: close\r\n")
+ }
+
+ err = req.Header.WriteSubset(&b, reqWriteExcludeHeaderDump)
+ if err != nil {
+ return
+ }
+
+ io.WriteString(&b, "\r\n")
+
+ if req.Body != nil {
+ var dest io.Writer = &b
+ if chunked {
+ dest = NewChunkedWriter(dest)
+ }
+ _, err = io.Copy(dest, req.Body)
+ if chunked {
+ dest.(io.Closer).Close()
+ io.WriteString(&b, "\r\n")
+ }
+ }
+
+ req.Body = save
+ if err != nil {
+ return
+ }
+ dump = b.Bytes()
+ return
+}
+
+// errNoBody is a sentinel error value used by failureToReadBody so we can detect
+// that the lack of body was intentional.
+var errNoBody = errors.New("sentinel error value")
+
+// failureToReadBody is a io.ReadCloser that just returns errNoBody on
+// Read. It's swapped in when we don't actually want to consume the
+// body, but need a non-nil one, and want to distinguish the error
+// from reading the dummy body.
+type failureToReadBody struct{}
+
+func (failureToReadBody) Read([]byte) (int, error) { return 0, errNoBody }
+func (failureToReadBody) Close() error { return nil }
+
+var emptyBody = ioutil.NopCloser(strings.NewReader(""))
+
+// DumpResponse is like DumpRequest but dumps a response.
+func DumpResponse(resp *http.Response, body bool) (dump []byte, err error) {
+ var b bytes.Buffer
+ save := resp.Body
+ savecl := resp.ContentLength
+
+ if !body {
+ resp.Body = failureToReadBody{}
+ } else if resp.Body == nil {
+ resp.Body = emptyBody
+ } else {
+ save, resp.Body, err = drainBody(resp.Body)
+ if err != nil {
+ return
+ }
+ }
+ err = resp.Write(&b)
+ if err == errNoBody {
+ err = nil
+ }
+ resp.Body = save
+ resp.ContentLength = savecl
+ if err != nil {
+ return nil, err
+ }
+ return b.Bytes(), nil
+}
diff --git a/src/net/http/httputil/dump_test.go b/src/net/http/httputil/dump_test.go
new file mode 100644
index 000000000..024ee5a86
--- /dev/null
+++ b/src/net/http/httputil/dump_test.go
@@ -0,0 +1,291 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httputil
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "net/url"
+ "runtime"
+ "strings"
+ "testing"
+)
+
+type dumpTest struct {
+ Req http.Request
+ Body interface{} // optional []byte or func() io.ReadCloser to populate Req.Body
+
+ WantDump string
+ WantDumpOut string
+ NoBody bool // if true, set DumpRequest{,Out} body to false
+}
+
+var dumpTests = []dumpTest{
+
+ // HTTP/1.1 => chunked coding; body; empty trailer
+ {
+ Req: http.Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/search",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ TransferEncoding: []string{"chunked"},
+ },
+
+ Body: []byte("abcdef"),
+
+ WantDump: "GET /search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ chunk("abcdef") + chunk(""),
+ },
+
+ // Verify that DumpRequest preserves the HTTP version number, doesn't add a Host,
+ // and doesn't add a User-Agent.
+ {
+ Req: http.Request{
+ Method: "GET",
+ URL: mustParseURL("/foo"),
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Header: http.Header{
+ "X-Foo": []string{"X-Bar"},
+ },
+ },
+
+ WantDump: "GET /foo HTTP/1.0\r\n" +
+ "X-Foo: X-Bar\r\n\r\n",
+ },
+
+ {
+ Req: *mustNewRequest("GET", "http://example.com/foo", nil),
+
+ WantDumpOut: "GET /foo HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go 1.1 package http\r\n" +
+ "Accept-Encoding: gzip\r\n\r\n",
+ },
+
+ // Test that an https URL doesn't try to do an SSL negotiation
+ // with a bytes.Buffer and hang with all goroutines not
+ // runnable.
+ {
+ Req: *mustNewRequest("GET", "https://example.com/foo", nil),
+
+ WantDumpOut: "GET /foo HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go 1.1 package http\r\n" +
+ "Accept-Encoding: gzip\r\n\r\n",
+ },
+
+ // Request with Body, but Dump requested without it.
+ {
+ Req: http.Request{
+ Method: "POST",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "post.tld",
+ Path: "/",
+ },
+ ContentLength: 6,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ },
+
+ Body: []byte("abcdef"),
+
+ WantDumpOut: "POST / HTTP/1.1\r\n" +
+ "Host: post.tld\r\n" +
+ "User-Agent: Go 1.1 package http\r\n" +
+ "Content-Length: 6\r\n" +
+ "Accept-Encoding: gzip\r\n\r\n",
+
+ NoBody: true,
+ },
+
+ // Request with Body > 8196 (default buffer size)
+ {
+ Req: http.Request{
+ Method: "POST",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "post.tld",
+ Path: "/",
+ },
+ ContentLength: 8193,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ },
+
+ Body: bytes.Repeat([]byte("a"), 8193),
+
+ WantDumpOut: "POST / HTTP/1.1\r\n" +
+ "Host: post.tld\r\n" +
+ "User-Agent: Go 1.1 package http\r\n" +
+ "Content-Length: 8193\r\n" +
+ "Accept-Encoding: gzip\r\n\r\n" +
+ strings.Repeat("a", 8193),
+ },
+}
+
+func TestDumpRequest(t *testing.T) {
+ numg0 := runtime.NumGoroutine()
+ for i, tt := range dumpTests {
+ setBody := func() {
+ if tt.Body == nil {
+ return
+ }
+ switch b := tt.Body.(type) {
+ case []byte:
+ tt.Req.Body = ioutil.NopCloser(bytes.NewReader(b))
+ case func() io.ReadCloser:
+ tt.Req.Body = b()
+ default:
+ t.Fatalf("Test %d: unsupported Body of %T", i, tt.Body)
+ }
+ }
+ setBody()
+ if tt.Req.Header == nil {
+ tt.Req.Header = make(http.Header)
+ }
+
+ if tt.WantDump != "" {
+ setBody()
+ dump, err := DumpRequest(&tt.Req, !tt.NoBody)
+ if err != nil {
+ t.Errorf("DumpRequest #%d: %s", i, err)
+ continue
+ }
+ if string(dump) != tt.WantDump {
+ t.Errorf("DumpRequest %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantDump, string(dump))
+ continue
+ }
+ }
+
+ if tt.WantDumpOut != "" {
+ setBody()
+ dump, err := DumpRequestOut(&tt.Req, !tt.NoBody)
+ if err != nil {
+ t.Errorf("DumpRequestOut #%d: %s", i, err)
+ continue
+ }
+ if string(dump) != tt.WantDumpOut {
+ t.Errorf("DumpRequestOut %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantDumpOut, string(dump))
+ continue
+ }
+ }
+ }
+ if dg := runtime.NumGoroutine() - numg0; dg > 4 {
+ buf := make([]byte, 4096)
+ buf = buf[:runtime.Stack(buf, true)]
+ t.Errorf("Unexpectedly large number of new goroutines: %d new: %s", dg, buf)
+ }
+}
+
+func chunk(s string) string {
+ return fmt.Sprintf("%x\r\n%s\r\n", len(s), s)
+}
+
+func mustParseURL(s string) *url.URL {
+ u, err := url.Parse(s)
+ if err != nil {
+ panic(fmt.Sprintf("Error parsing URL %q: %v", s, err))
+ }
+ return u
+}
+
+func mustNewRequest(method, url string, body io.Reader) *http.Request {
+ req, err := http.NewRequest(method, url, body)
+ if err != nil {
+ panic(fmt.Sprintf("NewRequest(%q, %q, %p) err = %v", method, url, body, err))
+ }
+ return req
+}
+
+var dumpResTests = []struct {
+ res *http.Response
+ body bool
+ want string
+}{
+ {
+ res: &http.Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 50,
+ Header: http.Header{
+ "Foo": []string{"Bar"},
+ },
+ Body: ioutil.NopCloser(strings.NewReader("foo")), // shouldn't be used
+ },
+ body: false, // to verify we see 50, not empty or 3.
+ want: `HTTP/1.1 200 OK
+Content-Length: 50
+Foo: Bar`,
+ },
+
+ {
+ res: &http.Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 3,
+ Body: ioutil.NopCloser(strings.NewReader("foo")),
+ },
+ body: true,
+ want: `HTTP/1.1 200 OK
+Content-Length: 3
+
+foo`,
+ },
+
+ {
+ res: &http.Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: -1,
+ Body: ioutil.NopCloser(strings.NewReader("foo")),
+ TransferEncoding: []string{"chunked"},
+ },
+ body: true,
+ want: `HTTP/1.1 200 OK
+Transfer-Encoding: chunked
+
+3
+foo
+0`,
+ },
+}
+
+func TestDumpResponse(t *testing.T) {
+ for i, tt := range dumpResTests {
+ gotb, err := DumpResponse(tt.res, tt.body)
+ if err != nil {
+ t.Errorf("%d. DumpResponse = %v", i, err)
+ continue
+ }
+ got := string(gotb)
+ got = strings.TrimSpace(got)
+ got = strings.Replace(got, "\r", "", -1)
+
+ if got != tt.want {
+ t.Errorf("%d.\nDumpResponse got:\n%s\n\nWant:\n%s\n", i, got, tt.want)
+ }
+ }
+}
diff --git a/src/net/http/httputil/httputil.go b/src/net/http/httputil/httputil.go
new file mode 100644
index 000000000..2e523e9e2
--- /dev/null
+++ b/src/net/http/httputil/httputil.go
@@ -0,0 +1,39 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package httputil provides HTTP utility functions, complementing the
+// more common ones in the net/http package.
+package httputil
+
+import (
+ "io"
+ "net/http/internal"
+)
+
+// NewChunkedReader returns a new chunkedReader that translates the data read from r
+// out of HTTP "chunked" format before returning it.
+// The chunkedReader returns io.EOF when the final 0-length chunk is read.
+//
+// NewChunkedReader is not needed by normal applications. The http package
+// automatically decodes chunking when reading response bodies.
+func NewChunkedReader(r io.Reader) io.Reader {
+ return internal.NewChunkedReader(r)
+}
+
+// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP
+// "chunked" format before writing them to w. Closing the returned chunkedWriter
+// sends the final 0-length chunk that marks the end of the stream.
+//
+// NewChunkedWriter is not needed by normal applications. The http
+// package adds chunking automatically if handlers don't set a
+// Content-Length header. Using NewChunkedWriter inside a handler
+// would result in double chunking or chunking with a Content-Length
+// length, both of which are wrong.
+func NewChunkedWriter(w io.Writer) io.WriteCloser {
+ return internal.NewChunkedWriter(w)
+}
+
+// ErrLineTooLong is returned when reading malformed chunked data
+// with lines that are too long.
+var ErrLineTooLong = internal.ErrLineTooLong
diff --git a/src/net/http/httputil/persist.go b/src/net/http/httputil/persist.go
new file mode 100644
index 000000000..987bcc96b
--- /dev/null
+++ b/src/net/http/httputil/persist.go
@@ -0,0 +1,429 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httputil
+
+import (
+ "bufio"
+ "errors"
+ "io"
+ "net"
+ "net/http"
+ "net/textproto"
+ "sync"
+)
+
+var (
+ ErrPersistEOF = &http.ProtocolError{ErrorString: "persistent connection closed"}
+ ErrClosed = &http.ProtocolError{ErrorString: "connection closed by user"}
+ ErrPipeline = &http.ProtocolError{ErrorString: "pipeline error"}
+)
+
+// This is an API usage error - the local side is closed.
+// ErrPersistEOF (above) reports that the remote side is closed.
+var errClosed = errors.New("i/o operation on closed connection")
+
+// A ServerConn reads requests and sends responses over an underlying
+// connection, until the HTTP keepalive logic commands an end. ServerConn
+// also allows hijacking the underlying connection by calling Hijack
+// to regain control over the connection. ServerConn supports pipe-lining,
+// i.e. requests can be read out of sync (but in the same order) while the
+// respective responses are sent.
+//
+// ServerConn is low-level and old. Applications should instead use Server
+// in the net/http package.
+type ServerConn struct {
+ lk sync.Mutex // read-write protects the following fields
+ c net.Conn
+ r *bufio.Reader
+ re, we error // read/write errors
+ lastbody io.ReadCloser
+ nread, nwritten int
+ pipereq map[*http.Request]uint
+
+ pipe textproto.Pipeline
+}
+
+// NewServerConn returns a new ServerConn reading and writing c. If r is not
+// nil, it is the buffer to use when reading c.
+//
+// ServerConn is low-level and old. Applications should instead use Server
+// in the net/http package.
+func NewServerConn(c net.Conn, r *bufio.Reader) *ServerConn {
+ if r == nil {
+ r = bufio.NewReader(c)
+ }
+ return &ServerConn{c: c, r: r, pipereq: make(map[*http.Request]uint)}
+}
+
+// Hijack detaches the ServerConn and returns the underlying connection as well
+// as the read-side bufio which may have some left over data. Hijack may be
+// called before Read has signaled the end of the keep-alive logic. The user
+// should not call Hijack while Read or Write is in progress.
+func (sc *ServerConn) Hijack() (c net.Conn, r *bufio.Reader) {
+ sc.lk.Lock()
+ defer sc.lk.Unlock()
+ c = sc.c
+ r = sc.r
+ sc.c = nil
+ sc.r = nil
+ return
+}
+
+// Close calls Hijack and then also closes the underlying connection
+func (sc *ServerConn) Close() error {
+ c, _ := sc.Hijack()
+ if c != nil {
+ return c.Close()
+ }
+ return nil
+}
+
+// Read returns the next request on the wire. An ErrPersistEOF is returned if
+// it is gracefully determined that there are no more requests (e.g. after the
+// first request on an HTTP/1.0 connection, or after a Connection:close on a
+// HTTP/1.1 connection).
+func (sc *ServerConn) Read() (req *http.Request, err error) {
+
+ // Ensure ordered execution of Reads and Writes
+ id := sc.pipe.Next()
+ sc.pipe.StartRequest(id)
+ defer func() {
+ sc.pipe.EndRequest(id)
+ if req == nil {
+ sc.pipe.StartResponse(id)
+ sc.pipe.EndResponse(id)
+ } else {
+ // Remember the pipeline id of this request
+ sc.lk.Lock()
+ sc.pipereq[req] = id
+ sc.lk.Unlock()
+ }
+ }()
+
+ sc.lk.Lock()
+ if sc.we != nil { // no point receiving if write-side broken or closed
+ defer sc.lk.Unlock()
+ return nil, sc.we
+ }
+ if sc.re != nil {
+ defer sc.lk.Unlock()
+ return nil, sc.re
+ }
+ if sc.r == nil { // connection closed by user in the meantime
+ defer sc.lk.Unlock()
+ return nil, errClosed
+ }
+ r := sc.r
+ lastbody := sc.lastbody
+ sc.lastbody = nil
+ sc.lk.Unlock()
+
+ // Make sure body is fully consumed, even if user does not call body.Close
+ if lastbody != nil {
+ // body.Close is assumed to be idempotent and multiple calls to
+ // it should return the error that its first invocation
+ // returned.
+ err = lastbody.Close()
+ if err != nil {
+ sc.lk.Lock()
+ defer sc.lk.Unlock()
+ sc.re = err
+ return nil, err
+ }
+ }
+
+ req, err = http.ReadRequest(r)
+ sc.lk.Lock()
+ defer sc.lk.Unlock()
+ if err != nil {
+ if err == io.ErrUnexpectedEOF {
+ // A close from the opposing client is treated as a
+ // graceful close, even if there was some unparse-able
+ // data before the close.
+ sc.re = ErrPersistEOF
+ return nil, sc.re
+ } else {
+ sc.re = err
+ return req, err
+ }
+ }
+ sc.lastbody = req.Body
+ sc.nread++
+ if req.Close {
+ sc.re = ErrPersistEOF
+ return req, sc.re
+ }
+ return req, err
+}
+
+// Pending returns the number of unanswered requests
+// that have been received on the connection.
+func (sc *ServerConn) Pending() int {
+ sc.lk.Lock()
+ defer sc.lk.Unlock()
+ return sc.nread - sc.nwritten
+}
+
+// Write writes resp in response to req. To close the connection gracefully, set the
+// Response.Close field to true. Write should be considered operational until
+// it returns an error, regardless of any errors returned on the Read side.
+func (sc *ServerConn) Write(req *http.Request, resp *http.Response) error {
+
+ // Retrieve the pipeline ID of this request/response pair
+ sc.lk.Lock()
+ id, ok := sc.pipereq[req]
+ delete(sc.pipereq, req)
+ if !ok {
+ sc.lk.Unlock()
+ return ErrPipeline
+ }
+ sc.lk.Unlock()
+
+ // Ensure pipeline order
+ sc.pipe.StartResponse(id)
+ defer sc.pipe.EndResponse(id)
+
+ sc.lk.Lock()
+ if sc.we != nil {
+ defer sc.lk.Unlock()
+ return sc.we
+ }
+ if sc.c == nil { // connection closed by user in the meantime
+ defer sc.lk.Unlock()
+ return ErrClosed
+ }
+ c := sc.c
+ if sc.nread <= sc.nwritten {
+ defer sc.lk.Unlock()
+ return errors.New("persist server pipe count")
+ }
+ if resp.Close {
+ // After signaling a keep-alive close, any pipelined unread
+ // requests will be lost. It is up to the user to drain them
+ // before signaling.
+ sc.re = ErrPersistEOF
+ }
+ sc.lk.Unlock()
+
+ err := resp.Write(c)
+ sc.lk.Lock()
+ defer sc.lk.Unlock()
+ if err != nil {
+ sc.we = err
+ return err
+ }
+ sc.nwritten++
+
+ return nil
+}
+
+// A ClientConn sends request and receives headers over an underlying
+// connection, while respecting the HTTP keepalive logic. ClientConn
+// supports hijacking the connection calling Hijack to
+// regain control of the underlying net.Conn and deal with it as desired.
+//
+// ClientConn is low-level and old. Applications should instead use
+// Client or Transport in the net/http package.
+type ClientConn struct {
+ lk sync.Mutex // read-write protects the following fields
+ c net.Conn
+ r *bufio.Reader
+ re, we error // read/write errors
+ lastbody io.ReadCloser
+ nread, nwritten int
+ pipereq map[*http.Request]uint
+
+ pipe textproto.Pipeline
+ writeReq func(*http.Request, io.Writer) error
+}
+
+// NewClientConn returns a new ClientConn reading and writing c. If r is not
+// nil, it is the buffer to use when reading c.
+//
+// ClientConn is low-level and old. Applications should use Client or
+// Transport in the net/http package.
+func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn {
+ if r == nil {
+ r = bufio.NewReader(c)
+ }
+ return &ClientConn{
+ c: c,
+ r: r,
+ pipereq: make(map[*http.Request]uint),
+ writeReq: (*http.Request).Write,
+ }
+}
+
+// NewProxyClientConn works like NewClientConn but writes Requests
+// using Request's WriteProxy method.
+//
+// New code should not use NewProxyClientConn. See Client or
+// Transport in the net/http package instead.
+func NewProxyClientConn(c net.Conn, r *bufio.Reader) *ClientConn {
+ cc := NewClientConn(c, r)
+ cc.writeReq = (*http.Request).WriteProxy
+ return cc
+}
+
+// Hijack detaches the ClientConn and returns the underlying connection as well
+// as the read-side bufio which may have some left over data. Hijack may be
+// called before the user or Read have signaled the end of the keep-alive
+// logic. The user should not call Hijack while Read or Write is in progress.
+func (cc *ClientConn) Hijack() (c net.Conn, r *bufio.Reader) {
+ cc.lk.Lock()
+ defer cc.lk.Unlock()
+ c = cc.c
+ r = cc.r
+ cc.c = nil
+ cc.r = nil
+ return
+}
+
+// Close calls Hijack and then also closes the underlying connection
+func (cc *ClientConn) Close() error {
+ c, _ := cc.Hijack()
+ if c != nil {
+ return c.Close()
+ }
+ return nil
+}
+
+// Write writes a request. An ErrPersistEOF error is returned if the connection
+// has been closed in an HTTP keepalive sense. If req.Close equals true, the
+// keepalive connection is logically closed after this request and the opposing
+// server is informed. An ErrUnexpectedEOF indicates the remote closed the
+// underlying TCP connection, which is usually considered as graceful close.
+func (cc *ClientConn) Write(req *http.Request) (err error) {
+
+ // Ensure ordered execution of Writes
+ id := cc.pipe.Next()
+ cc.pipe.StartRequest(id)
+ defer func() {
+ cc.pipe.EndRequest(id)
+ if err != nil {
+ cc.pipe.StartResponse(id)
+ cc.pipe.EndResponse(id)
+ } else {
+ // Remember the pipeline id of this request
+ cc.lk.Lock()
+ cc.pipereq[req] = id
+ cc.lk.Unlock()
+ }
+ }()
+
+ cc.lk.Lock()
+ if cc.re != nil { // no point sending if read-side closed or broken
+ defer cc.lk.Unlock()
+ return cc.re
+ }
+ if cc.we != nil {
+ defer cc.lk.Unlock()
+ return cc.we
+ }
+ if cc.c == nil { // connection closed by user in the meantime
+ defer cc.lk.Unlock()
+ return errClosed
+ }
+ c := cc.c
+ if req.Close {
+ // We write the EOF to the write-side error, because there
+ // still might be some pipelined reads
+ cc.we = ErrPersistEOF
+ }
+ cc.lk.Unlock()
+
+ err = cc.writeReq(req, c)
+ cc.lk.Lock()
+ defer cc.lk.Unlock()
+ if err != nil {
+ cc.we = err
+ return err
+ }
+ cc.nwritten++
+
+ return nil
+}
+
+// Pending returns the number of unanswered requests
+// that have been sent on the connection.
+func (cc *ClientConn) Pending() int {
+ cc.lk.Lock()
+ defer cc.lk.Unlock()
+ return cc.nwritten - cc.nread
+}
+
+// Read reads the next response from the wire. A valid response might be
+// returned together with an ErrPersistEOF, which means that the remote
+// requested that this be the last request serviced. Read can be called
+// concurrently with Write, but not with another Read.
+func (cc *ClientConn) Read(req *http.Request) (resp *http.Response, err error) {
+ // Retrieve the pipeline ID of this request/response pair
+ cc.lk.Lock()
+ id, ok := cc.pipereq[req]
+ delete(cc.pipereq, req)
+ if !ok {
+ cc.lk.Unlock()
+ return nil, ErrPipeline
+ }
+ cc.lk.Unlock()
+
+ // Ensure pipeline order
+ cc.pipe.StartResponse(id)
+ defer cc.pipe.EndResponse(id)
+
+ cc.lk.Lock()
+ if cc.re != nil {
+ defer cc.lk.Unlock()
+ return nil, cc.re
+ }
+ if cc.r == nil { // connection closed by user in the meantime
+ defer cc.lk.Unlock()
+ return nil, errClosed
+ }
+ r := cc.r
+ lastbody := cc.lastbody
+ cc.lastbody = nil
+ cc.lk.Unlock()
+
+ // Make sure body is fully consumed, even if user does not call body.Close
+ if lastbody != nil {
+ // body.Close is assumed to be idempotent and multiple calls to
+ // it should return the error that its first invocation
+ // returned.
+ err = lastbody.Close()
+ if err != nil {
+ cc.lk.Lock()
+ defer cc.lk.Unlock()
+ cc.re = err
+ return nil, err
+ }
+ }
+
+ resp, err = http.ReadResponse(r, req)
+ cc.lk.Lock()
+ defer cc.lk.Unlock()
+ if err != nil {
+ cc.re = err
+ return resp, err
+ }
+ cc.lastbody = resp.Body
+
+ cc.nread++
+
+ if resp.Close {
+ cc.re = ErrPersistEOF // don't send any more requests
+ return resp, cc.re
+ }
+ return resp, err
+}
+
+// Do is convenience method that writes a request and reads a response.
+func (cc *ClientConn) Do(req *http.Request) (resp *http.Response, err error) {
+ err = cc.Write(req)
+ if err != nil {
+ return
+ }
+ return cc.Read(req)
+}
diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go
new file mode 100644
index 000000000..ab4637018
--- /dev/null
+++ b/src/net/http/httputil/reverseproxy.go
@@ -0,0 +1,225 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP reverse proxy handler
+
+package httputil
+
+import (
+ "io"
+ "log"
+ "net"
+ "net/http"
+ "net/url"
+ "strings"
+ "sync"
+ "time"
+)
+
+// onExitFlushLoop is a callback set by tests to detect the state of the
+// flushLoop() goroutine.
+var onExitFlushLoop func()
+
+// ReverseProxy is an HTTP Handler that takes an incoming request and
+// sends it to another server, proxying the response back to the
+// client.
+type ReverseProxy struct {
+ // Director must be a function which modifies
+ // the request into a new request to be sent
+ // using Transport. Its response is then copied
+ // back to the original client unmodified.
+ Director func(*http.Request)
+
+ // The transport used to perform proxy requests.
+ // If nil, http.DefaultTransport is used.
+ Transport http.RoundTripper
+
+ // FlushInterval specifies the flush interval
+ // to flush to the client while copying the
+ // response body.
+ // If zero, no periodic flushing is done.
+ FlushInterval time.Duration
+
+ // ErrorLog specifies an optional logger for errors
+ // that occur when attempting to proxy the request.
+ // If nil, logging goes to os.Stderr via the log package's
+ // standard logger.
+ ErrorLog *log.Logger
+}
+
+func singleJoiningSlash(a, b string) string {
+ aslash := strings.HasSuffix(a, "/")
+ bslash := strings.HasPrefix(b, "/")
+ switch {
+ case aslash && bslash:
+ return a + b[1:]
+ case !aslash && !bslash:
+ return a + "/" + b
+ }
+ return a + b
+}
+
+// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
+// URLs to the scheme, host, and base path provided in target. If the
+// target's path is "/base" and the incoming request was for "/dir",
+// the target request will be for /base/dir.
+func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
+ targetQuery := target.RawQuery
+ director := func(req *http.Request) {
+ req.URL.Scheme = target.Scheme
+ req.URL.Host = target.Host
+ req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
+ if targetQuery == "" || req.URL.RawQuery == "" {
+ req.URL.RawQuery = targetQuery + req.URL.RawQuery
+ } else {
+ req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
+ }
+ }
+ return &ReverseProxy{Director: director}
+}
+
+func copyHeader(dst, src http.Header) {
+ for k, vv := range src {
+ for _, v := range vv {
+ dst.Add(k, v)
+ }
+ }
+}
+
+// Hop-by-hop headers. These are removed when sent to the backend.
+// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
+var hopHeaders = []string{
+ "Connection",
+ "Keep-Alive",
+ "Proxy-Authenticate",
+ "Proxy-Authorization",
+ "Te", // canonicalized version of "TE"
+ "Trailers",
+ "Transfer-Encoding",
+ "Upgrade",
+}
+
+func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
+ transport := p.Transport
+ if transport == nil {
+ transport = http.DefaultTransport
+ }
+
+ outreq := new(http.Request)
+ *outreq = *req // includes shallow copies of maps, but okay
+
+ p.Director(outreq)
+ outreq.Proto = "HTTP/1.1"
+ outreq.ProtoMajor = 1
+ outreq.ProtoMinor = 1
+ outreq.Close = false
+
+ // Remove hop-by-hop headers to the backend. Especially
+ // important is "Connection" because we want a persistent
+ // connection, regardless of what the client sent to us. This
+ // is modifying the same underlying map from req (shallow
+ // copied above) so we only copy it if necessary.
+ copiedHeaders := false
+ for _, h := range hopHeaders {
+ if outreq.Header.Get(h) != "" {
+ if !copiedHeaders {
+ outreq.Header = make(http.Header)
+ copyHeader(outreq.Header, req.Header)
+ copiedHeaders = true
+ }
+ outreq.Header.Del(h)
+ }
+ }
+
+ if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
+ // If we aren't the first proxy retain prior
+ // X-Forwarded-For information as a comma+space
+ // separated list and fold multiple headers into one.
+ if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
+ clientIP = strings.Join(prior, ", ") + ", " + clientIP
+ }
+ outreq.Header.Set("X-Forwarded-For", clientIP)
+ }
+
+ res, err := transport.RoundTrip(outreq)
+ if err != nil {
+ p.logf("http: proxy error: %v", err)
+ rw.WriteHeader(http.StatusInternalServerError)
+ return
+ }
+ defer res.Body.Close()
+
+ for _, h := range hopHeaders {
+ res.Header.Del(h)
+ }
+
+ copyHeader(rw.Header(), res.Header)
+
+ rw.WriteHeader(res.StatusCode)
+ p.copyResponse(rw, res.Body)
+}
+
+func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
+ if p.FlushInterval != 0 {
+ if wf, ok := dst.(writeFlusher); ok {
+ mlw := &maxLatencyWriter{
+ dst: wf,
+ latency: p.FlushInterval,
+ done: make(chan bool),
+ }
+ go mlw.flushLoop()
+ defer mlw.stop()
+ dst = mlw
+ }
+ }
+
+ io.Copy(dst, src)
+}
+
+func (p *ReverseProxy) logf(format string, args ...interface{}) {
+ if p.ErrorLog != nil {
+ p.ErrorLog.Printf(format, args...)
+ } else {
+ log.Printf(format, args...)
+ }
+}
+
+type writeFlusher interface {
+ io.Writer
+ http.Flusher
+}
+
+type maxLatencyWriter struct {
+ dst writeFlusher
+ latency time.Duration
+
+ lk sync.Mutex // protects Write + Flush
+ done chan bool
+}
+
+func (m *maxLatencyWriter) Write(p []byte) (int, error) {
+ m.lk.Lock()
+ defer m.lk.Unlock()
+ return m.dst.Write(p)
+}
+
+func (m *maxLatencyWriter) flushLoop() {
+ t := time.NewTicker(m.latency)
+ defer t.Stop()
+ for {
+ select {
+ case <-m.done:
+ if onExitFlushLoop != nil {
+ onExitFlushLoop()
+ }
+ return
+ case <-t.C:
+ m.lk.Lock()
+ m.dst.Flush()
+ m.lk.Unlock()
+ }
+ }
+}
+
+func (m *maxLatencyWriter) stop() { m.done <- true }
diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go
new file mode 100644
index 000000000..e9539b44b
--- /dev/null
+++ b/src/net/http/httputil/reverseproxy_test.go
@@ -0,0 +1,213 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Reverse proxy tests.
+
+package httputil
+
+import (
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+ "time"
+)
+
+const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
+
+func init() {
+ hopHeaders = append(hopHeaders, fakeHopHeader)
+}
+
+func TestReverseProxy(t *testing.T) {
+ const backendResponse = "I am the backend"
+ const backendStatus = 404
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if len(r.TransferEncoding) > 0 {
+ t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
+ }
+ if r.Header.Get("X-Forwarded-For") == "" {
+ t.Errorf("didn't get X-Forwarded-For header")
+ }
+ if c := r.Header.Get("Connection"); c != "" {
+ t.Errorf("handler got Connection header value %q", c)
+ }
+ if c := r.Header.Get("Upgrade"); c != "" {
+ t.Errorf("handler got Upgrade header value %q", c)
+ }
+ if g, e := r.Host, "some-name"; g != e {
+ t.Errorf("backend got Host header %q, want %q", g, e)
+ }
+ w.Header().Set("X-Foo", "bar")
+ w.Header().Set("Upgrade", "foo")
+ w.Header().Set(fakeHopHeader, "foo")
+ w.Header().Add("X-Multi-Value", "foo")
+ w.Header().Add("X-Multi-Value", "bar")
+ http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
+ w.WriteHeader(backendStatus)
+ w.Write([]byte(backendResponse))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Host = "some-name"
+ getReq.Header.Set("Connection", "close")
+ getReq.Header.Set("Upgrade", "foo")
+ getReq.Close = true
+ res, err := http.DefaultClient.Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ if g, e := res.StatusCode, backendStatus; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
+ t.Errorf("got X-Foo %q; expected %q", g, e)
+ }
+ if c := res.Header.Get(fakeHopHeader); c != "" {
+ t.Errorf("got %s header value %q", fakeHopHeader, c)
+ }
+ if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
+ t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
+ }
+ if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
+ t.Fatalf("got %d SetCookies, want %d", g, e)
+ }
+ if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
+ t.Errorf("unexpected cookie %q", cookie.Name)
+ }
+ bodyBytes, _ := ioutil.ReadAll(res.Body)
+ if g, e := string(bodyBytes), backendResponse; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+}
+
+func TestXForwardedFor(t *testing.T) {
+ const prevForwardedFor = "client ip"
+ const backendResponse = "I am the backend"
+ const backendStatus = 404
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Header.Get("X-Forwarded-For") == "" {
+ t.Errorf("didn't get X-Forwarded-For header")
+ }
+ if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
+ t.Errorf("X-Forwarded-For didn't contain prior data")
+ }
+ w.WriteHeader(backendStatus)
+ w.Write([]byte(backendResponse))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Host = "some-name"
+ getReq.Header.Set("Connection", "close")
+ getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
+ getReq.Close = true
+ res, err := http.DefaultClient.Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ if g, e := res.StatusCode, backendStatus; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ bodyBytes, _ := ioutil.ReadAll(res.Body)
+ if g, e := string(bodyBytes), backendResponse; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+}
+
+var proxyQueryTests = []struct {
+ baseSuffix string // suffix to add to backend URL
+ reqSuffix string // suffix to add to frontend's request URL
+ want string // what backend should see for final request URL (without ?)
+}{
+ {"", "", ""},
+ {"?sta=tic", "?us=er", "sta=tic&us=er"},
+ {"", "?us=er", "us=er"},
+ {"?sta=tic", "", "sta=tic"},
+}
+
+func TestReverseProxyQuery(t *testing.T) {
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("X-Got-Query", r.URL.RawQuery)
+ w.Write([]byte("hi"))
+ }))
+ defer backend.Close()
+
+ for i, tt := range proxyQueryTests {
+ backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
+ if err != nil {
+ t.Fatal(err)
+ }
+ frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
+ req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
+ req.Close = true
+ res, err := http.DefaultClient.Do(req)
+ if err != nil {
+ t.Fatalf("%d. Get: %v", i, err)
+ }
+ if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
+ t.Errorf("%d. got query %q; expected %q", i, g, e)
+ }
+ res.Body.Close()
+ frontend.Close()
+ }
+}
+
+func TestReverseProxyFlushInterval(t *testing.T) {
+ const expected = "hi"
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(expected))
+ }))
+ defer backend.Close()
+
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.FlushInterval = time.Microsecond
+
+ done := make(chan bool)
+ onExitFlushLoop = func() { done <- true }
+ defer func() { onExitFlushLoop = nil }()
+
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ req, _ := http.NewRequest("GET", frontend.URL, nil)
+ req.Close = true
+ res, err := http.DefaultClient.Do(req)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+ if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected {
+ t.Errorf("got body %q; expected %q", bodyBytes, expected)
+ }
+
+ select {
+ case <-done:
+ // OK
+ case <-time.After(5 * time.Second):
+ t.Error("maxLatencyWriter flushLoop() never exited")
+ }
+}
diff --git a/src/net/http/internal/chunked.go b/src/net/http/internal/chunked.go
new file mode 100644
index 000000000..9294deb3e
--- /dev/null
+++ b/src/net/http/internal/chunked.go
@@ -0,0 +1,202 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// The wire protocol for HTTP's "chunked" Transfer-Encoding.
+
+// Package internal contains HTTP internals shared by net/http and
+// net/http/httputil.
+package internal
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+)
+
+const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
+
+var ErrLineTooLong = errors.New("header line too long")
+
+// NewChunkedReader returns a new chunkedReader that translates the data read from r
+// out of HTTP "chunked" format before returning it.
+// The chunkedReader returns io.EOF when the final 0-length chunk is read.
+//
+// NewChunkedReader is not needed by normal applications. The http package
+// automatically decodes chunking when reading response bodies.
+func NewChunkedReader(r io.Reader) io.Reader {
+ br, ok := r.(*bufio.Reader)
+ if !ok {
+ br = bufio.NewReader(r)
+ }
+ return &chunkedReader{r: br}
+}
+
+type chunkedReader struct {
+ r *bufio.Reader
+ n uint64 // unread bytes in chunk
+ err error
+ buf [2]byte
+}
+
+func (cr *chunkedReader) beginChunk() {
+ // chunk-size CRLF
+ var line []byte
+ line, cr.err = readLine(cr.r)
+ if cr.err != nil {
+ return
+ }
+ cr.n, cr.err = parseHexUint(line)
+ if cr.err != nil {
+ return
+ }
+ if cr.n == 0 {
+ cr.err = io.EOF
+ }
+}
+
+func (cr *chunkedReader) chunkHeaderAvailable() bool {
+ n := cr.r.Buffered()
+ if n > 0 {
+ peek, _ := cr.r.Peek(n)
+ return bytes.IndexByte(peek, '\n') >= 0
+ }
+ return false
+}
+
+func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
+ for cr.err == nil {
+ if cr.n == 0 {
+ if n > 0 && !cr.chunkHeaderAvailable() {
+ // We've read enough. Don't potentially block
+ // reading a new chunk header.
+ break
+ }
+ cr.beginChunk()
+ continue
+ }
+ if len(b) == 0 {
+ break
+ }
+ rbuf := b
+ if uint64(len(rbuf)) > cr.n {
+ rbuf = rbuf[:cr.n]
+ }
+ var n0 int
+ n0, cr.err = cr.r.Read(rbuf)
+ n += n0
+ b = b[n0:]
+ cr.n -= uint64(n0)
+ // If we're at the end of a chunk, read the next two
+ // bytes to verify they are "\r\n".
+ if cr.n == 0 && cr.err == nil {
+ if _, cr.err = io.ReadFull(cr.r, cr.buf[:2]); cr.err == nil {
+ if cr.buf[0] != '\r' || cr.buf[1] != '\n' {
+ cr.err = errors.New("malformed chunked encoding")
+ }
+ }
+ }
+ }
+ return n, cr.err
+}
+
+// Read a line of bytes (up to \n) from b.
+// Give up if the line exceeds maxLineLength.
+// The returned bytes are a pointer into storage in
+// the bufio, so they are only valid until the next bufio read.
+func readLine(b *bufio.Reader) (p []byte, err error) {
+ if p, err = b.ReadSlice('\n'); err != nil {
+ // We always know when EOF is coming.
+ // If the caller asked for a line, there should be a line.
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ } else if err == bufio.ErrBufferFull {
+ err = ErrLineTooLong
+ }
+ return nil, err
+ }
+ if len(p) >= maxLineLength {
+ return nil, ErrLineTooLong
+ }
+ return trimTrailingWhitespace(p), nil
+}
+
+func trimTrailingWhitespace(b []byte) []byte {
+ for len(b) > 0 && isASCIISpace(b[len(b)-1]) {
+ b = b[:len(b)-1]
+ }
+ return b
+}
+
+func isASCIISpace(b byte) bool {
+ return b == ' ' || b == '\t' || b == '\n' || b == '\r'
+}
+
+// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP
+// "chunked" format before writing them to w. Closing the returned chunkedWriter
+// sends the final 0-length chunk that marks the end of the stream.
+//
+// NewChunkedWriter is not needed by normal applications. The http
+// package adds chunking automatically if handlers don't set a
+// Content-Length header. Using newChunkedWriter inside a handler
+// would result in double chunking or chunking with a Content-Length
+// length, both of which are wrong.
+func NewChunkedWriter(w io.Writer) io.WriteCloser {
+ return &chunkedWriter{w}
+}
+
+// Writing to chunkedWriter translates to writing in HTTP chunked Transfer
+// Encoding wire format to the underlying Wire chunkedWriter.
+type chunkedWriter struct {
+ Wire io.Writer
+}
+
+// Write the contents of data as one chunk to Wire.
+// NOTE: Note that the corresponding chunk-writing procedure in Conn.Write has
+// a bug since it does not check for success of io.WriteString
+func (cw *chunkedWriter) Write(data []byte) (n int, err error) {
+
+ // Don't send 0-length data. It looks like EOF for chunked encoding.
+ if len(data) == 0 {
+ return 0, nil
+ }
+
+ if _, err = fmt.Fprintf(cw.Wire, "%x\r\n", len(data)); err != nil {
+ return 0, err
+ }
+ if n, err = cw.Wire.Write(data); err != nil {
+ return
+ }
+ if n != len(data) {
+ err = io.ErrShortWrite
+ return
+ }
+ _, err = io.WriteString(cw.Wire, "\r\n")
+
+ return
+}
+
+func (cw *chunkedWriter) Close() error {
+ _, err := io.WriteString(cw.Wire, "0\r\n")
+ return err
+}
+
+func parseHexUint(v []byte) (n uint64, err error) {
+ for _, b := range v {
+ n <<= 4
+ switch {
+ case '0' <= b && b <= '9':
+ b = b - '0'
+ case 'a' <= b && b <= 'f':
+ b = b - 'a' + 10
+ case 'A' <= b && b <= 'F':
+ b = b - 'A' + 10
+ default:
+ return 0, errors.New("invalid byte in chunk length")
+ }
+ n |= uint64(b)
+ }
+ return
+}
diff --git a/src/net/http/internal/chunked_test.go b/src/net/http/internal/chunked_test.go
new file mode 100644
index 000000000..ebc626ea9
--- /dev/null
+++ b/src/net/http/internal/chunked_test.go
@@ -0,0 +1,156 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package internal
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "strings"
+ "testing"
+)
+
+func TestChunk(t *testing.T) {
+ var b bytes.Buffer
+
+ w := NewChunkedWriter(&b)
+ const chunk1 = "hello, "
+ const chunk2 = "world! 0123456789abcdef"
+ w.Write([]byte(chunk1))
+ w.Write([]byte(chunk2))
+ w.Close()
+
+ if g, e := b.String(), "7\r\nhello, \r\n17\r\nworld! 0123456789abcdef\r\n0\r\n"; g != e {
+ t.Fatalf("chunk writer wrote %q; want %q", g, e)
+ }
+
+ r := NewChunkedReader(&b)
+ data, err := ioutil.ReadAll(r)
+ if err != nil {
+ t.Logf(`data: "%s"`, data)
+ t.Fatalf("ReadAll from reader: %v", err)
+ }
+ if g, e := string(data), chunk1+chunk2; g != e {
+ t.Errorf("chunk reader read %q; want %q", g, e)
+ }
+}
+
+func TestChunkReadMultiple(t *testing.T) {
+ // Bunch of small chunks, all read together.
+ {
+ var b bytes.Buffer
+ w := NewChunkedWriter(&b)
+ w.Write([]byte("foo"))
+ w.Write([]byte("bar"))
+ w.Close()
+
+ r := NewChunkedReader(&b)
+ buf := make([]byte, 10)
+ n, err := r.Read(buf)
+ if n != 6 || err != io.EOF {
+ t.Errorf("Read = %d, %v; want 6, EOF", n, err)
+ }
+ buf = buf[:n]
+ if string(buf) != "foobar" {
+ t.Errorf("Read = %q; want %q", buf, "foobar")
+ }
+ }
+
+ // One big chunk followed by a little chunk, but the small bufio.Reader size
+ // should prevent the second chunk header from being read.
+ {
+ var b bytes.Buffer
+ w := NewChunkedWriter(&b)
+ // fillBufChunk is 11 bytes + 3 bytes header + 2 bytes footer = 16 bytes,
+ // the same as the bufio ReaderSize below (the minimum), so even
+ // though we're going to try to Read with a buffer larger enough to also
+ // receive "foo", the second chunk header won't be read yet.
+ const fillBufChunk = "0123456789a"
+ const shortChunk = "foo"
+ w.Write([]byte(fillBufChunk))
+ w.Write([]byte(shortChunk))
+ w.Close()
+
+ r := NewChunkedReader(bufio.NewReaderSize(&b, 16))
+ buf := make([]byte, len(fillBufChunk)+len(shortChunk))
+ n, err := r.Read(buf)
+ if n != len(fillBufChunk) || err != nil {
+ t.Errorf("Read = %d, %v; want %d, nil", n, err, len(fillBufChunk))
+ }
+ buf = buf[:n]
+ if string(buf) != fillBufChunk {
+ t.Errorf("Read = %q; want %q", buf, fillBufChunk)
+ }
+
+ n, err = r.Read(buf)
+ if n != len(shortChunk) || err != io.EOF {
+ t.Errorf("Read = %d, %v; want %d, EOF", n, err, len(shortChunk))
+ }
+ }
+
+ // And test that we see an EOF chunk, even though our buffer is already full:
+ {
+ r := NewChunkedReader(bufio.NewReader(strings.NewReader("3\r\nfoo\r\n0\r\n")))
+ buf := make([]byte, 3)
+ n, err := r.Read(buf)
+ if n != 3 || err != io.EOF {
+ t.Errorf("Read = %d, %v; want 3, EOF", n, err)
+ }
+ if string(buf) != "foo" {
+ t.Errorf("buf = %q; want foo", buf)
+ }
+ }
+}
+
+func TestChunkReaderAllocs(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ var buf bytes.Buffer
+ w := NewChunkedWriter(&buf)
+ a, b, c := []byte("aaaaaa"), []byte("bbbbbbbbbbbb"), []byte("cccccccccccccccccccccccc")
+ w.Write(a)
+ w.Write(b)
+ w.Write(c)
+ w.Close()
+
+ readBuf := make([]byte, len(a)+len(b)+len(c)+1)
+ byter := bytes.NewReader(buf.Bytes())
+ bufr := bufio.NewReader(byter)
+ mallocs := testing.AllocsPerRun(100, func() {
+ byter.Seek(0, 0)
+ bufr.Reset(byter)
+ r := NewChunkedReader(bufr)
+ n, err := io.ReadFull(r, readBuf)
+ if n != len(readBuf)-1 {
+ t.Fatalf("read %d bytes; want %d", n, len(readBuf)-1)
+ }
+ if err != io.ErrUnexpectedEOF {
+ t.Fatalf("read error = %v; want ErrUnexpectedEOF", err)
+ }
+ })
+ if mallocs > 1.5 {
+ t.Errorf("mallocs = %v; want 1", mallocs)
+ }
+}
+
+func TestParseHexUint(t *testing.T) {
+ for i := uint64(0); i <= 1234; i++ {
+ line := []byte(fmt.Sprintf("%x", i))
+ got, err := parseHexUint(line)
+ if err != nil {
+ t.Fatalf("on %d: %v", i, err)
+ }
+ if got != i {
+ t.Errorf("for input %q = %d; want %d", line, got, i)
+ }
+ }
+ _, err := parseHexUint([]byte("bogus"))
+ if err == nil {
+ t.Error("expected error on bogus input")
+ }
+}
diff --git a/src/net/http/jar.go b/src/net/http/jar.go
new file mode 100644
index 000000000..5c3de0dad
--- /dev/null
+++ b/src/net/http/jar.go
@@ -0,0 +1,27 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "net/url"
+)
+
+// A CookieJar manages storage and use of cookies in HTTP requests.
+//
+// Implementations of CookieJar must be safe for concurrent use by multiple
+// goroutines.
+//
+// The net/http/cookiejar package provides a CookieJar implementation.
+type CookieJar interface {
+ // SetCookies handles the receipt of the cookies in a reply for the
+ // given URL. It may or may not choose to save the cookies, depending
+ // on the jar's policy and implementation.
+ SetCookies(u *url.URL, cookies []*Cookie)
+
+ // Cookies returns the cookies to send in a request for the given URL.
+ // It is up to the implementation to honor the standard cookie use
+ // restrictions such as in RFC 6265.
+ Cookies(u *url.URL) []*Cookie
+}
diff --git a/src/net/http/lex.go b/src/net/http/lex.go
new file mode 100644
index 000000000..cb33318f4
--- /dev/null
+++ b/src/net/http/lex.go
@@ -0,0 +1,96 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+// This file deals with lexical matters of HTTP
+
+var isTokenTable = [127]bool{
+ '!': true,
+ '#': true,
+ '$': true,
+ '%': true,
+ '&': true,
+ '\'': true,
+ '*': true,
+ '+': true,
+ '-': true,
+ '.': true,
+ '0': true,
+ '1': true,
+ '2': true,
+ '3': true,
+ '4': true,
+ '5': true,
+ '6': true,
+ '7': true,
+ '8': true,
+ '9': true,
+ 'A': true,
+ 'B': true,
+ 'C': true,
+ 'D': true,
+ 'E': true,
+ 'F': true,
+ 'G': true,
+ 'H': true,
+ 'I': true,
+ 'J': true,
+ 'K': true,
+ 'L': true,
+ 'M': true,
+ 'N': true,
+ 'O': true,
+ 'P': true,
+ 'Q': true,
+ 'R': true,
+ 'S': true,
+ 'T': true,
+ 'U': true,
+ 'W': true,
+ 'V': true,
+ 'X': true,
+ 'Y': true,
+ 'Z': true,
+ '^': true,
+ '_': true,
+ '`': true,
+ 'a': true,
+ 'b': true,
+ 'c': true,
+ 'd': true,
+ 'e': true,
+ 'f': true,
+ 'g': true,
+ 'h': true,
+ 'i': true,
+ 'j': true,
+ 'k': true,
+ 'l': true,
+ 'm': true,
+ 'n': true,
+ 'o': true,
+ 'p': true,
+ 'q': true,
+ 'r': true,
+ 's': true,
+ 't': true,
+ 'u': true,
+ 'v': true,
+ 'w': true,
+ 'x': true,
+ 'y': true,
+ 'z': true,
+ '|': true,
+ '~': true,
+}
+
+func isToken(r rune) bool {
+ i := int(r)
+ return i < len(isTokenTable) && isTokenTable[i]
+}
+
+func isNotToken(r rune) bool {
+ return !isToken(r)
+}
diff --git a/src/net/http/lex_test.go b/src/net/http/lex_test.go
new file mode 100644
index 000000000..6d9d294f7
--- /dev/null
+++ b/src/net/http/lex_test.go
@@ -0,0 +1,31 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "testing"
+)
+
+func isChar(c rune) bool { return c <= 127 }
+
+func isCtl(c rune) bool { return c <= 31 || c == 127 }
+
+func isSeparator(c rune) bool {
+ switch c {
+ case '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}', ' ', '\t':
+ return true
+ }
+ return false
+}
+
+func TestIsToken(t *testing.T) {
+ for i := 0; i <= 130; i++ {
+ r := rune(i)
+ expected := isChar(r) && !isCtl(r) && !isSeparator(r)
+ if isToken(r) != expected {
+ t.Errorf("isToken(0x%x) = %v", r, !expected)
+ }
+ }
+}
diff --git a/src/net/http/main_test.go b/src/net/http/main_test.go
new file mode 100644
index 000000000..b8c71fd19
--- /dev/null
+++ b/src/net/http/main_test.go
@@ -0,0 +1,109 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "fmt"
+ "net/http"
+ "os"
+ "runtime"
+ "sort"
+ "strings"
+ "testing"
+ "time"
+)
+
+func TestMain(m *testing.M) {
+ v := m.Run()
+ if v == 0 && goroutineLeaked() {
+ os.Exit(1)
+ }
+ os.Exit(v)
+}
+
+func interestingGoroutines() (gs []string) {
+ buf := make([]byte, 2<<20)
+ buf = buf[:runtime.Stack(buf, true)]
+ for _, g := range strings.Split(string(buf), "\n\n") {
+ sl := strings.SplitN(g, "\n", 2)
+ if len(sl) != 2 {
+ continue
+ }
+ stack := strings.TrimSpace(sl[1])
+ if stack == "" ||
+ strings.Contains(stack, "created by net.startServer") ||
+ strings.Contains(stack, "created by testing.RunTests") ||
+ strings.Contains(stack, "closeWriteAndWait") ||
+ strings.Contains(stack, "testing.Main(") ||
+ // These only show up with GOTRACEBACK=2; Issue 5005 (comment 28)
+ strings.Contains(stack, "runtime.goexit") ||
+ strings.Contains(stack, "created by runtime.gc") ||
+ strings.Contains(stack, "net/http_test.interestingGoroutines") ||
+ strings.Contains(stack, "runtime.MHeap_Scavenger") {
+ continue
+ }
+ gs = append(gs, stack)
+ }
+ sort.Strings(gs)
+ return
+}
+
+// Verify the other tests didn't leave any goroutines running.
+func goroutineLeaked() bool {
+ if testing.Short() {
+ // not counting goroutines for leakage in -short mode
+ return false
+ }
+ gs := interestingGoroutines()
+
+ n := 0
+ stackCount := make(map[string]int)
+ for _, g := range gs {
+ stackCount[g]++
+ n++
+ }
+
+ if n == 0 {
+ return false
+ }
+ fmt.Fprintf(os.Stderr, "Too many goroutines running after net/http test(s).\n")
+ for stack, count := range stackCount {
+ fmt.Fprintf(os.Stderr, "%d instances of:\n%s\n", count, stack)
+ }
+ return true
+}
+
+func afterTest(t *testing.T) {
+ http.DefaultTransport.(*http.Transport).CloseIdleConnections()
+ if testing.Short() {
+ return
+ }
+ var bad string
+ badSubstring := map[string]string{
+ ").readLoop(": "a Transport",
+ ").writeLoop(": "a Transport",
+ "created by net/http/httptest.(*Server).Start": "an httptest.Server",
+ "timeoutHandler": "a TimeoutHandler",
+ "net.(*netFD).connect(": "a timing out dial",
+ ").noteClientGone(": "a closenotifier sender",
+ }
+ var stacks string
+ for i := 0; i < 4; i++ {
+ bad = ""
+ stacks = strings.Join(interestingGoroutines(), "\n\n")
+ for substr, what := range badSubstring {
+ if strings.Contains(stacks, substr) {
+ bad = what
+ }
+ }
+ if bad == "" {
+ return
+ }
+ // Bad stuff found, but goroutines might just still be
+ // shutting down, so give it some time.
+ time.Sleep(250 * time.Millisecond)
+ }
+ t.Errorf("Test appears to have leaked %s:\n%s", bad, stacks)
+}
diff --git a/src/net/http/npn_test.go b/src/net/http/npn_test.go
new file mode 100644
index 000000000..98b8930d0
--- /dev/null
+++ b/src/net/http/npn_test.go
@@ -0,0 +1,118 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "bufio"
+ "crypto/tls"
+ "fmt"
+ "io"
+ "io/ioutil"
+ . "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+)
+
+func TestNextProtoUpgrade(t *testing.T) {
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "path=%s,proto=", r.URL.Path)
+ if r.TLS != nil {
+ w.Write([]byte(r.TLS.NegotiatedProtocol))
+ }
+ if r.RemoteAddr == "" {
+ t.Error("request with no RemoteAddr")
+ }
+ if r.Body == nil {
+ t.Errorf("request with nil Body")
+ }
+ }))
+ ts.TLS = &tls.Config{
+ NextProtos: []string{"unhandled-proto", "tls-0.9"},
+ }
+ ts.Config.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){
+ "tls-0.9": handleTLSProtocol09,
+ }
+ ts.StartTLS()
+ defer ts.Close()
+
+ tr := newTLSTransport(t, ts)
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ // Normal request, without NPN.
+ {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if want := "path=/,proto="; string(body) != want {
+ t.Errorf("plain request = %q; want %q", body, want)
+ }
+ }
+
+ // Request to an advertised but unhandled NPN protocol.
+ // Server will hang up.
+ {
+ tr.CloseIdleConnections()
+ tr.TLSClientConfig.NextProtos = []string{"unhandled-proto"}
+ _, err := c.Get(ts.URL)
+ if err == nil {
+ t.Errorf("expected error on unhandled-proto request")
+ }
+ }
+
+ // Request using the "tls-0.9" protocol, which we register here.
+ // It is HTTP/0.9 over TLS.
+ {
+ tlsConfig := newTLSTransport(t, ts).TLSClientConfig
+ tlsConfig.NextProtos = []string{"tls-0.9"}
+ conn, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig)
+ if err != nil {
+ t.Fatal(err)
+ }
+ conn.Write([]byte("GET /foo\n"))
+ body, err := ioutil.ReadAll(conn)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if want := "path=/foo,proto=tls-0.9"; string(body) != want {
+ t.Errorf("plain request = %q; want %q", body, want)
+ }
+ }
+}
+
+// handleTLSProtocol09 implements the HTTP/0.9 protocol over TLS, for the
+// TestNextProtoUpgrade test.
+func handleTLSProtocol09(srv *Server, conn *tls.Conn, h Handler) {
+ br := bufio.NewReader(conn)
+ line, err := br.ReadString('\n')
+ if err != nil {
+ return
+ }
+ line = strings.TrimSpace(line)
+ path := strings.TrimPrefix(line, "GET ")
+ if path == line {
+ return
+ }
+ req, _ := NewRequest("GET", path, nil)
+ req.Proto = "HTTP/0.9"
+ req.ProtoMajor = 0
+ req.ProtoMinor = 9
+ rw := &http09Writer{conn, make(Header)}
+ h.ServeHTTP(rw, req)
+}
+
+type http09Writer struct {
+ io.Writer
+ h Header
+}
+
+func (w http09Writer) Header() Header { return w.h }
+func (w http09Writer) WriteHeader(int) {} // no headers
diff --git a/src/net/http/pprof/pprof.go b/src/net/http/pprof/pprof.go
new file mode 100644
index 000000000..a23f1bc4b
--- /dev/null
+++ b/src/net/http/pprof/pprof.go
@@ -0,0 +1,209 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package pprof serves via its HTTP server runtime profiling data
+// in the format expected by the pprof visualization tool.
+// For more information about pprof, see
+// http://code.google.com/p/google-perftools/.
+//
+// The package is typically only imported for the side effect of
+// registering its HTTP handlers.
+// The handled paths all begin with /debug/pprof/.
+//
+// To use pprof, link this package into your program:
+// import _ "net/http/pprof"
+//
+// If your application is not already running an http server, you
+// need to start one. Add "net/http" and "log" to your imports and
+// the following code to your main function:
+//
+// go func() {
+// log.Println(http.ListenAndServe("localhost:6060", nil))
+// }()
+//
+// Then use the pprof tool to look at the heap profile:
+//
+// go tool pprof http://localhost:6060/debug/pprof/heap
+//
+// Or to look at a 30-second CPU profile:
+//
+// go tool pprof http://localhost:6060/debug/pprof/profile
+//
+// Or to look at the goroutine blocking profile:
+//
+// go tool pprof http://localhost:6060/debug/pprof/block
+//
+// To view all available profiles, open http://localhost:6060/debug/pprof/
+// in your browser.
+//
+// For a study of the facility in action, visit
+//
+// http://blog.golang.org/2011/06/profiling-go-programs.html
+//
+package pprof
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "html/template"
+ "io"
+ "log"
+ "net/http"
+ "os"
+ "runtime"
+ "runtime/pprof"
+ "strconv"
+ "strings"
+ "time"
+)
+
+func init() {
+ http.Handle("/debug/pprof/", http.HandlerFunc(Index))
+ http.Handle("/debug/pprof/cmdline", http.HandlerFunc(Cmdline))
+ http.Handle("/debug/pprof/profile", http.HandlerFunc(Profile))
+ http.Handle("/debug/pprof/symbol", http.HandlerFunc(Symbol))
+}
+
+// Cmdline responds with the running program's
+// command line, with arguments separated by NUL bytes.
+// The package initialization registers it as /debug/pprof/cmdline.
+func Cmdline(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+ fmt.Fprintf(w, strings.Join(os.Args, "\x00"))
+}
+
+// Profile responds with the pprof-formatted cpu profile.
+// The package initialization registers it as /debug/pprof/profile.
+func Profile(w http.ResponseWriter, r *http.Request) {
+ sec, _ := strconv.ParseInt(r.FormValue("seconds"), 10, 64)
+ if sec == 0 {
+ sec = 30
+ }
+
+ // Set Content Type assuming StartCPUProfile will work,
+ // because if it does it starts writing.
+ w.Header().Set("Content-Type", "application/octet-stream")
+ if err := pprof.StartCPUProfile(w); err != nil {
+ // StartCPUProfile failed, so no writes yet.
+ // Can change header back to text content
+ // and send error code.
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+ w.WriteHeader(http.StatusInternalServerError)
+ fmt.Fprintf(w, "Could not enable CPU profiling: %s\n", err)
+ return
+ }
+ time.Sleep(time.Duration(sec) * time.Second)
+ pprof.StopCPUProfile()
+}
+
+// Symbol looks up the program counters listed in the request,
+// responding with a table mapping program counters to function names.
+// The package initialization registers it as /debug/pprof/symbol.
+func Symbol(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+
+ // We have to read the whole POST body before
+ // writing any output. Buffer the output here.
+ var buf bytes.Buffer
+
+ // We don't know how many symbols we have, but we
+ // do have symbol information. Pprof only cares whether
+ // this number is 0 (no symbols available) or > 0.
+ fmt.Fprintf(&buf, "num_symbols: 1\n")
+
+ var b *bufio.Reader
+ if r.Method == "POST" {
+ b = bufio.NewReader(r.Body)
+ } else {
+ b = bufio.NewReader(strings.NewReader(r.URL.RawQuery))
+ }
+
+ for {
+ word, err := b.ReadSlice('+')
+ if err == nil {
+ word = word[0 : len(word)-1] // trim +
+ }
+ pc, _ := strconv.ParseUint(string(word), 0, 64)
+ if pc != 0 {
+ f := runtime.FuncForPC(uintptr(pc))
+ if f != nil {
+ fmt.Fprintf(&buf, "%#x %s\n", pc, f.Name())
+ }
+ }
+
+ // Wait until here to check for err; the last
+ // symbol will have an err because it doesn't end in +.
+ if err != nil {
+ if err != io.EOF {
+ fmt.Fprintf(&buf, "reading request: %v\n", err)
+ }
+ break
+ }
+ }
+
+ w.Write(buf.Bytes())
+}
+
+// Handler returns an HTTP handler that serves the named profile.
+func Handler(name string) http.Handler {
+ return handler(name)
+}
+
+type handler string
+
+func (name handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+ debug, _ := strconv.Atoi(r.FormValue("debug"))
+ p := pprof.Lookup(string(name))
+ if p == nil {
+ w.WriteHeader(404)
+ fmt.Fprintf(w, "Unknown profile: %s\n", name)
+ return
+ }
+ gc, _ := strconv.Atoi(r.FormValue("gc"))
+ if name == "heap" && gc > 0 {
+ runtime.GC()
+ }
+ p.WriteTo(w, debug)
+ return
+}
+
+// Index responds with the pprof-formatted profile named by the request.
+// For example, "/debug/pprof/heap" serves the "heap" profile.
+// Index responds to a request for "/debug/pprof/" with an HTML page
+// listing the available profiles.
+func Index(w http.ResponseWriter, r *http.Request) {
+ if strings.HasPrefix(r.URL.Path, "/debug/pprof/") {
+ name := strings.TrimPrefix(r.URL.Path, "/debug/pprof/")
+ if name != "" {
+ handler(name).ServeHTTP(w, r)
+ return
+ }
+ }
+
+ profiles := pprof.Profiles()
+ if err := indexTmpl.Execute(w, profiles); err != nil {
+ log.Print(err)
+ }
+}
+
+var indexTmpl = template.Must(template.New("index").Parse(`<html>
+<head>
+<title>/debug/pprof/</title>
+</head>
+/debug/pprof/<br>
+<br>
+<body>
+profiles:<br>
+<table>
+{{range .}}
+<tr><td align=right>{{.Count}}<td><a href="/debug/pprof/{{.Name}}?debug=1">{{.Name}}</a>
+{{end}}
+</table>
+<br>
+<a href="/debug/pprof/goroutine?debug=2">full goroutine stack dump</a><br>
+</body>
+</html>
+`))
diff --git a/src/net/http/proxy_test.go b/src/net/http/proxy_test.go
new file mode 100644
index 000000000..b6aed3792
--- /dev/null
+++ b/src/net/http/proxy_test.go
@@ -0,0 +1,81 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "net/url"
+ "os"
+ "testing"
+)
+
+// TODO(mattn):
+// test ProxyAuth
+
+var UseProxyTests = []struct {
+ host string
+ match bool
+}{
+ // Never proxy localhost:
+ {"localhost:80", false},
+ {"127.0.0.1", false},
+ {"127.0.0.2", false},
+ {"[::1]", false},
+ {"[::2]", true}, // not a loopback address
+
+ {"barbaz.net", false}, // match as .barbaz.net
+ {"foobar.com", false}, // have a port but match
+ {"foofoobar.com", true}, // not match as a part of foobar.com
+ {"baz.com", true}, // not match as a part of barbaz.com
+ {"localhost.net", true}, // not match as suffix of address
+ {"local.localhost", true}, // not match as prefix as address
+ {"barbarbaz.net", true}, // not match because NO_PROXY have a '.'
+ {"www.foobar.com", false}, // match because NO_PROXY includes "foobar.com"
+}
+
+func TestUseProxy(t *testing.T) {
+ ResetProxyEnv()
+ os.Setenv("NO_PROXY", "foobar.com, .barbaz.net")
+ for _, test := range UseProxyTests {
+ if useProxy(test.host+":80") != test.match {
+ t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match)
+ }
+ }
+}
+
+var cacheKeysTests = []struct {
+ proxy string
+ scheme string
+ addr string
+ key string
+}{
+ {"", "http", "foo.com", "|http|foo.com"},
+ {"", "https", "foo.com", "|https|foo.com"},
+ {"http://foo.com", "http", "foo.com", "http://foo.com|http|"},
+ {"http://foo.com", "https", "foo.com", "http://foo.com|https|foo.com"},
+}
+
+func TestCacheKeys(t *testing.T) {
+ for _, tt := range cacheKeysTests {
+ var proxy *url.URL
+ if tt.proxy != "" {
+ u, err := url.Parse(tt.proxy)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxy = u
+ }
+ cm := connectMethod{proxy, tt.scheme, tt.addr}
+ if got := cm.key().String(); got != tt.key {
+ t.Fatalf("{%q, %q, %q} cache key = %q; want %q", tt.proxy, tt.scheme, tt.addr, got, tt.key)
+ }
+ }
+}
+
+func ResetProxyEnv() {
+ for _, v := range []string{"HTTP_PROXY", "http_proxy", "NO_PROXY", "no_proxy"} {
+ os.Setenv(v, "")
+ }
+ ResetCachedEnvironment()
+}
diff --git a/src/net/http/race.go b/src/net/http/race.go
new file mode 100644
index 000000000..766503967
--- /dev/null
+++ b/src/net/http/race.go
@@ -0,0 +1,11 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build race
+
+package http
+
+func init() {
+ raceEnabled = true
+}
diff --git a/src/net/http/range_test.go b/src/net/http/range_test.go
new file mode 100644
index 000000000..ef911af7b
--- /dev/null
+++ b/src/net/http/range_test.go
@@ -0,0 +1,79 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "testing"
+)
+
+var ParseRangeTests = []struct {
+ s string
+ length int64
+ r []httpRange
+}{
+ {"", 0, nil},
+ {"", 1000, nil},
+ {"foo", 0, nil},
+ {"bytes=", 0, nil},
+ {"bytes=7", 10, nil},
+ {"bytes= 7 ", 10, nil},
+ {"bytes=1-", 0, nil},
+ {"bytes=5-4", 10, nil},
+ {"bytes=0-2,5-4", 10, nil},
+ {"bytes=2-5,4-3", 10, nil},
+ {"bytes=--5,4--3", 10, nil},
+ {"bytes=A-", 10, nil},
+ {"bytes=A- ", 10, nil},
+ {"bytes=A-Z", 10, nil},
+ {"bytes= -Z", 10, nil},
+ {"bytes=5-Z", 10, nil},
+ {"bytes=Ran-dom, garbage", 10, nil},
+ {"bytes=0x01-0x02", 10, nil},
+ {"bytes= ", 10, nil},
+ {"bytes= , , , ", 10, nil},
+
+ {"bytes=0-9", 10, []httpRange{{0, 10}}},
+ {"bytes=0-", 10, []httpRange{{0, 10}}},
+ {"bytes=5-", 10, []httpRange{{5, 5}}},
+ {"bytes=0-20", 10, []httpRange{{0, 10}}},
+ {"bytes=15-,0-5", 10, nil},
+ {"bytes=1-2,5-", 10, []httpRange{{1, 2}, {5, 5}}},
+ {"bytes=-2 , 7-", 11, []httpRange{{9, 2}, {7, 4}}},
+ {"bytes=0-0 ,2-2, 7-", 11, []httpRange{{0, 1}, {2, 1}, {7, 4}}},
+ {"bytes=-5", 10, []httpRange{{5, 5}}},
+ {"bytes=-15", 10, []httpRange{{0, 10}}},
+ {"bytes=0-499", 10000, []httpRange{{0, 500}}},
+ {"bytes=500-999", 10000, []httpRange{{500, 500}}},
+ {"bytes=-500", 10000, []httpRange{{9500, 500}}},
+ {"bytes=9500-", 10000, []httpRange{{9500, 500}}},
+ {"bytes=0-0,-1", 10000, []httpRange{{0, 1}, {9999, 1}}},
+ {"bytes=500-600,601-999", 10000, []httpRange{{500, 101}, {601, 399}}},
+ {"bytes=500-700,601-999", 10000, []httpRange{{500, 201}, {601, 399}}},
+
+ // Match Apache laxity:
+ {"bytes= 1 -2 , 4- 5, 7 - 8 , ,,", 11, []httpRange{{1, 2}, {4, 2}, {7, 2}}},
+}
+
+func TestParseRange(t *testing.T) {
+ for _, test := range ParseRangeTests {
+ r := test.r
+ ranges, err := parseRange(test.s, test.length)
+ if err != nil && r != nil {
+ t.Errorf("parseRange(%q) returned error %q", test.s, err)
+ }
+ if len(ranges) != len(r) {
+ t.Errorf("len(parseRange(%q)) = %d, want %d", test.s, len(ranges), len(r))
+ continue
+ }
+ for i := range r {
+ if ranges[i].start != r[i].start {
+ t.Errorf("parseRange(%q)[%d].start = %d, want %d", test.s, i, ranges[i].start, r[i].start)
+ }
+ if ranges[i].length != r[i].length {
+ t.Errorf("parseRange(%q)[%d].length = %d, want %d", test.s, i, ranges[i].length, r[i].length)
+ }
+ }
+ }
+}
diff --git a/src/net/http/readrequest_test.go b/src/net/http/readrequest_test.go
new file mode 100644
index 000000000..e930d99af
--- /dev/null
+++ b/src/net/http/readrequest_test.go
@@ -0,0 +1,358 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "io"
+ "net/url"
+ "reflect"
+ "strings"
+ "testing"
+)
+
+type reqTest struct {
+ Raw string
+ Req *Request
+ Body string
+ Trailer Header
+ Error string
+}
+
+var noError = ""
+var noBody = ""
+var noTrailer Header = nil
+
+var reqTests = []reqTest{
+ // Baseline test; All Request fields included for template use
+ {
+ "GET http://www.techcrunch.com/ HTTP/1.1\r\n" +
+ "Host: www.techcrunch.com\r\n" +
+ "User-Agent: Fake\r\n" +
+ "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" +
+ "Accept-Language: en-us,en;q=0.5\r\n" +
+ "Accept-Encoding: gzip,deflate\r\n" +
+ "Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.7\r\n" +
+ "Keep-Alive: 300\r\n" +
+ "Content-Length: 7\r\n" +
+ "Proxy-Connection: keep-alive\r\n\r\n" +
+ "abcdef\n???",
+
+ &Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.techcrunch.com",
+ Path: "/",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{
+ "Accept": {"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"},
+ "Accept-Language": {"en-us,en;q=0.5"},
+ "Accept-Encoding": {"gzip,deflate"},
+ "Accept-Charset": {"ISO-8859-1,utf-8;q=0.7,*;q=0.7"},
+ "Keep-Alive": {"300"},
+ "Proxy-Connection": {"keep-alive"},
+ "Content-Length": {"7"},
+ "User-Agent": {"Fake"},
+ },
+ Close: false,
+ ContentLength: 7,
+ Host: "www.techcrunch.com",
+ RequestURI: "http://www.techcrunch.com/",
+ },
+
+ "abcdef\n",
+
+ noTrailer,
+ noError,
+ },
+
+ // GET request with no body (the normal case)
+ {
+ "GET / HTTP/1.1\r\n" +
+ "Host: foo.com\r\n\r\n",
+
+ &Request{
+ Method: "GET",
+ URL: &url.URL{
+ Path: "/",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Close: false,
+ ContentLength: 0,
+ Host: "foo.com",
+ RequestURI: "/",
+ },
+
+ noBody,
+ noTrailer,
+ noError,
+ },
+
+ // Tests that we don't parse a path that looks like a
+ // scheme-relative URI as a scheme-relative URI.
+ {
+ "GET //user@host/is/actually/a/path/ HTTP/1.1\r\n" +
+ "Host: test\r\n\r\n",
+
+ &Request{
+ Method: "GET",
+ URL: &url.URL{
+ Path: "//user@host/is/actually/a/path/",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Close: false,
+ ContentLength: 0,
+ Host: "test",
+ RequestURI: "//user@host/is/actually/a/path/",
+ },
+
+ noBody,
+ noTrailer,
+ noError,
+ },
+
+ // Tests a bogus abs_path on the Request-Line (RFC 2616 section 5.1.2)
+ {
+ "GET ../../../../etc/passwd HTTP/1.1\r\n" +
+ "Host: test\r\n\r\n",
+ nil,
+ noBody,
+ noTrailer,
+ "parse ../../../../etc/passwd: invalid URI for request",
+ },
+
+ // Tests missing URL:
+ {
+ "GET HTTP/1.1\r\n" +
+ "Host: test\r\n\r\n",
+ nil,
+ noBody,
+ noTrailer,
+ "parse : empty url",
+ },
+
+ // Tests chunked body with trailer:
+ {
+ "POST / HTTP/1.1\r\n" +
+ "Host: foo.com\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ "3\r\nfoo\r\n" +
+ "3\r\nbar\r\n" +
+ "0\r\n" +
+ "Trailer-Key: Trailer-Value\r\n" +
+ "\r\n",
+ &Request{
+ Method: "POST",
+ URL: &url.URL{
+ Path: "/",
+ },
+ TransferEncoding: []string{"chunked"},
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ ContentLength: -1,
+ Host: "foo.com",
+ RequestURI: "/",
+ },
+
+ "foobar",
+ Header{
+ "Trailer-Key": {"Trailer-Value"},
+ },
+ noError,
+ },
+
+ // CONNECT request with domain name:
+ {
+ "CONNECT www.google.com:443 HTTP/1.1\r\n\r\n",
+
+ &Request{
+ Method: "CONNECT",
+ URL: &url.URL{
+ Host: "www.google.com:443",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Close: false,
+ ContentLength: 0,
+ Host: "www.google.com:443",
+ RequestURI: "www.google.com:443",
+ },
+
+ noBody,
+ noTrailer,
+ noError,
+ },
+
+ // CONNECT request with IP address:
+ {
+ "CONNECT 127.0.0.1:6060 HTTP/1.1\r\n\r\n",
+
+ &Request{
+ Method: "CONNECT",
+ URL: &url.URL{
+ Host: "127.0.0.1:6060",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Close: false,
+ ContentLength: 0,
+ Host: "127.0.0.1:6060",
+ RequestURI: "127.0.0.1:6060",
+ },
+
+ noBody,
+ noTrailer,
+ noError,
+ },
+
+ // CONNECT request for RPC:
+ {
+ "CONNECT /_goRPC_ HTTP/1.1\r\n\r\n",
+
+ &Request{
+ Method: "CONNECT",
+ URL: &url.URL{
+ Path: "/_goRPC_",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Close: false,
+ ContentLength: 0,
+ Host: "",
+ RequestURI: "/_goRPC_",
+ },
+
+ noBody,
+ noTrailer,
+ noError,
+ },
+
+ // SSDP Notify request. golang.org/issue/3692
+ {
+ "NOTIFY * HTTP/1.1\r\nServer: foo\r\n\r\n",
+ &Request{
+ Method: "NOTIFY",
+ URL: &url.URL{
+ Path: "*",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{
+ "Server": []string{"foo"},
+ },
+ Close: false,
+ ContentLength: 0,
+ RequestURI: "*",
+ },
+
+ noBody,
+ noTrailer,
+ noError,
+ },
+
+ // OPTIONS request. Similar to golang.org/issue/3692
+ {
+ "OPTIONS * HTTP/1.1\r\nServer: foo\r\n\r\n",
+ &Request{
+ Method: "OPTIONS",
+ URL: &url.URL{
+ Path: "*",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{
+ "Server": []string{"foo"},
+ },
+ Close: false,
+ ContentLength: 0,
+ RequestURI: "*",
+ },
+
+ noBody,
+ noTrailer,
+ noError,
+ },
+
+ // Connection: close. golang.org/issue/8261
+ {
+ "GET / HTTP/1.1\r\nHost: issue8261.com\r\nConnection: close\r\n\r\n",
+ &Request{
+ Method: "GET",
+ URL: &url.URL{
+ Path: "/",
+ },
+ Header: Header{
+ // This wasn't removed from Go 1.0 to
+ // Go 1.3, so locking it in that we
+ // keep this:
+ "Connection": []string{"close"},
+ },
+ Host: "issue8261.com",
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Close: true,
+ RequestURI: "/",
+ },
+
+ noBody,
+ noTrailer,
+ noError,
+ },
+}
+
+func TestReadRequest(t *testing.T) {
+ for i := range reqTests {
+ tt := &reqTests[i]
+ req, err := ReadRequest(bufio.NewReader(strings.NewReader(tt.Raw)))
+ if err != nil {
+ if err.Error() != tt.Error {
+ t.Errorf("#%d: error %q, want error %q", i, err.Error(), tt.Error)
+ }
+ continue
+ }
+ rbody := req.Body
+ req.Body = nil
+ testName := fmt.Sprintf("Test %d (%q)", i, tt.Raw)
+ diff(t, testName, req, tt.Req)
+ var bout bytes.Buffer
+ if rbody != nil {
+ _, err := io.Copy(&bout, rbody)
+ if err != nil {
+ t.Fatalf("%s: copying body: %v", testName, err)
+ }
+ rbody.Close()
+ }
+ body := bout.String()
+ if body != tt.Body {
+ t.Errorf("%s: Body = %q want %q", testName, body, tt.Body)
+ }
+ if !reflect.DeepEqual(tt.Trailer, req.Trailer) {
+ t.Errorf("%s: Trailers differ.\n got: %v\nwant: %v", testName, req.Trailer, tt.Trailer)
+ }
+ }
+}
diff --git a/src/net/http/request.go b/src/net/http/request.go
new file mode 100644
index 000000000..487eebcb8
--- /dev/null
+++ b/src/net/http/request.go
@@ -0,0 +1,921 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP Request reading and parsing.
+
+package http
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/tls"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "mime"
+ "mime/multipart"
+ "net/textproto"
+ "net/url"
+ "strconv"
+ "strings"
+ "sync"
+)
+
+const (
+ maxValueLength = 4096
+ maxHeaderLines = 1024
+ chunkSize = 4 << 10 // 4 KB chunks
+ defaultMaxMemory = 32 << 20 // 32 MB
+)
+
+// ErrMissingFile is returned by FormFile when the provided file field name
+// is either not present in the request or not a file field.
+var ErrMissingFile = errors.New("http: no such file")
+
+// HTTP request parsing errors.
+type ProtocolError struct {
+ ErrorString string
+}
+
+func (err *ProtocolError) Error() string { return err.ErrorString }
+
+var (
+ ErrHeaderTooLong = &ProtocolError{"header too long"}
+ ErrShortBody = &ProtocolError{"entity body too short"}
+ ErrNotSupported = &ProtocolError{"feature not supported"}
+ ErrUnexpectedTrailer = &ProtocolError{"trailer header without chunked transfer encoding"}
+ ErrMissingContentLength = &ProtocolError{"missing ContentLength in HEAD response"}
+ ErrNotMultipart = &ProtocolError{"request Content-Type isn't multipart/form-data"}
+ ErrMissingBoundary = &ProtocolError{"no multipart boundary param in Content-Type"}
+)
+
+type badStringError struct {
+ what string
+ str string
+}
+
+func (e *badStringError) Error() string { return fmt.Sprintf("%s %q", e.what, e.str) }
+
+// Headers that Request.Write handles itself and should be skipped.
+var reqWriteExcludeHeader = map[string]bool{
+ "Host": true, // not in Header map anyway
+ "User-Agent": true,
+ "Content-Length": true,
+ "Transfer-Encoding": true,
+ "Trailer": true,
+}
+
+// A Request represents an HTTP request received by a server
+// or to be sent by a client.
+//
+// The field semantics differ slightly between client and server
+// usage. In addition to the notes on the fields below, see the
+// documentation for Request.Write and RoundTripper.
+type Request struct {
+ // Method specifies the HTTP method (GET, POST, PUT, etc.).
+ // For client requests an empty string means GET.
+ Method string
+
+ // URL specifies either the URI being requested (for server
+ // requests) or the URL to access (for client requests).
+ //
+ // For server requests the URL is parsed from the URI
+ // supplied on the Request-Line as stored in RequestURI. For
+ // most requests, fields other than Path and RawQuery will be
+ // empty. (See RFC 2616, Section 5.1.2)
+ //
+ // For client requests, the URL's Host specifies the server to
+ // connect to, while the Request's Host field optionally
+ // specifies the Host header value to send in the HTTP
+ // request.
+ URL *url.URL
+
+ // The protocol version for incoming requests.
+ // Client requests always use HTTP/1.1.
+ Proto string // "HTTP/1.0"
+ ProtoMajor int // 1
+ ProtoMinor int // 0
+
+ // A header maps request lines to their values.
+ // If the header says
+ //
+ // accept-encoding: gzip, deflate
+ // Accept-Language: en-us
+ // Connection: keep-alive
+ //
+ // then
+ //
+ // Header = map[string][]string{
+ // "Accept-Encoding": {"gzip, deflate"},
+ // "Accept-Language": {"en-us"},
+ // "Connection": {"keep-alive"},
+ // }
+ //
+ // HTTP defines that header names are case-insensitive.
+ // The request parser implements this by canonicalizing the
+ // name, making the first character and any characters
+ // following a hyphen uppercase and the rest lowercase.
+ //
+ // For client requests certain headers are automatically
+ // added and may override values in Header.
+ //
+ // See the documentation for the Request.Write method.
+ Header Header
+
+ // Body is the request's body.
+ //
+ // For client requests a nil body means the request has no
+ // body, such as a GET request. The HTTP Client's Transport
+ // is responsible for calling the Close method.
+ //
+ // For server requests the Request Body is always non-nil
+ // but will return EOF immediately when no body is present.
+ // The Server will close the request body. The ServeHTTP
+ // Handler does not need to.
+ Body io.ReadCloser
+
+ // ContentLength records the length of the associated content.
+ // The value -1 indicates that the length is unknown.
+ // Values >= 0 indicate that the given number of bytes may
+ // be read from Body.
+ // For client requests, a value of 0 means unknown if Body is not nil.
+ ContentLength int64
+
+ // TransferEncoding lists the transfer encodings from outermost to
+ // innermost. An empty list denotes the "identity" encoding.
+ // TransferEncoding can usually be ignored; chunked encoding is
+ // automatically added and removed as necessary when sending and
+ // receiving requests.
+ TransferEncoding []string
+
+ // Close indicates whether to close the connection after
+ // replying to this request (for servers) or after sending
+ // the request (for clients).
+ Close bool
+
+ // For server requests Host specifies the host on which the
+ // URL is sought. Per RFC 2616, this is either the value of
+ // the "Host" header or the host name given in the URL itself.
+ // It may be of the form "host:port".
+ //
+ // For client requests Host optionally overrides the Host
+ // header to send. If empty, the Request.Write method uses
+ // the value of URL.Host.
+ Host string
+
+ // Form contains the parsed form data, including both the URL
+ // field's query parameters and the POST or PUT form data.
+ // This field is only available after ParseForm is called.
+ // The HTTP client ignores Form and uses Body instead.
+ Form url.Values
+
+ // PostForm contains the parsed form data from POST or PUT
+ // body parameters.
+ // This field is only available after ParseForm is called.
+ // The HTTP client ignores PostForm and uses Body instead.
+ PostForm url.Values
+
+ // MultipartForm is the parsed multipart form, including file uploads.
+ // This field is only available after ParseMultipartForm is called.
+ // The HTTP client ignores MultipartForm and uses Body instead.
+ MultipartForm *multipart.Form
+
+ // Trailer specifies additional headers that are sent after the request
+ // body.
+ //
+ // For server requests the Trailer map initially contains only the
+ // trailer keys, with nil values. (The client declares which trailers it
+ // will later send.) While the handler is reading from Body, it must
+ // not reference Trailer. After reading from Body returns EOF, Trailer
+ // can be read again and will contain non-nil values, if they were sent
+ // by the client.
+ //
+ // For client requests Trailer must be initialized to a map containing
+ // the trailer keys to later send. The values may be nil or their final
+ // values. The ContentLength must be 0 or -1, to send a chunked request.
+ // After the HTTP request is sent the map values can be updated while
+ // the request body is read. Once the body returns EOF, the caller must
+ // not mutate Trailer.
+ //
+ // Few HTTP clients, servers, or proxies support HTTP trailers.
+ Trailer Header
+
+ // RemoteAddr allows HTTP servers and other software to record
+ // the network address that sent the request, usually for
+ // logging. This field is not filled in by ReadRequest and
+ // has no defined format. The HTTP server in this package
+ // sets RemoteAddr to an "IP:port" address before invoking a
+ // handler.
+ // This field is ignored by the HTTP client.
+ RemoteAddr string
+
+ // RequestURI is the unmodified Request-URI of the
+ // Request-Line (RFC 2616, Section 5.1) as sent by the client
+ // to a server. Usually the URL field should be used instead.
+ // It is an error to set this field in an HTTP client request.
+ RequestURI string
+
+ // TLS allows HTTP servers and other software to record
+ // information about the TLS connection on which the request
+ // was received. This field is not filled in by ReadRequest.
+ // The HTTP server in this package sets the field for
+ // TLS-enabled connections before invoking a handler;
+ // otherwise it leaves the field nil.
+ // This field is ignored by the HTTP client.
+ TLS *tls.ConnectionState
+}
+
+// ProtoAtLeast reports whether the HTTP protocol used
+// in the request is at least major.minor.
+func (r *Request) ProtoAtLeast(major, minor int) bool {
+ return r.ProtoMajor > major ||
+ r.ProtoMajor == major && r.ProtoMinor >= minor
+}
+
+// UserAgent returns the client's User-Agent, if sent in the request.
+func (r *Request) UserAgent() string {
+ return r.Header.Get("User-Agent")
+}
+
+// Cookies parses and returns the HTTP cookies sent with the request.
+func (r *Request) Cookies() []*Cookie {
+ return readCookies(r.Header, "")
+}
+
+var ErrNoCookie = errors.New("http: named cookie not present")
+
+// Cookie returns the named cookie provided in the request or
+// ErrNoCookie if not found.
+func (r *Request) Cookie(name string) (*Cookie, error) {
+ for _, c := range readCookies(r.Header, name) {
+ return c, nil
+ }
+ return nil, ErrNoCookie
+}
+
+// AddCookie adds a cookie to the request. Per RFC 6265 section 5.4,
+// AddCookie does not attach more than one Cookie header field. That
+// means all cookies, if any, are written into the same line,
+// separated by semicolon.
+func (r *Request) AddCookie(c *Cookie) {
+ s := fmt.Sprintf("%s=%s", sanitizeCookieName(c.Name), sanitizeCookieValue(c.Value))
+ if c := r.Header.Get("Cookie"); c != "" {
+ r.Header.Set("Cookie", c+"; "+s)
+ } else {
+ r.Header.Set("Cookie", s)
+ }
+}
+
+// Referer returns the referring URL, if sent in the request.
+//
+// Referer is misspelled as in the request itself, a mistake from the
+// earliest days of HTTP. This value can also be fetched from the
+// Header map as Header["Referer"]; the benefit of making it available
+// as a method is that the compiler can diagnose programs that use the
+// alternate (correct English) spelling req.Referrer() but cannot
+// diagnose programs that use Header["Referrer"].
+func (r *Request) Referer() string {
+ return r.Header.Get("Referer")
+}
+
+// multipartByReader is a sentinel value.
+// Its presence in Request.MultipartForm indicates that parsing of the request
+// body has been handed off to a MultipartReader instead of ParseMultipartFrom.
+var multipartByReader = &multipart.Form{
+ Value: make(map[string][]string),
+ File: make(map[string][]*multipart.FileHeader),
+}
+
+// MultipartReader returns a MIME multipart reader if this is a
+// multipart/form-data POST request, else returns nil and an error.
+// Use this function instead of ParseMultipartForm to
+// process the request body as a stream.
+func (r *Request) MultipartReader() (*multipart.Reader, error) {
+ if r.MultipartForm == multipartByReader {
+ return nil, errors.New("http: MultipartReader called twice")
+ }
+ if r.MultipartForm != nil {
+ return nil, errors.New("http: multipart handled by ParseMultipartForm")
+ }
+ r.MultipartForm = multipartByReader
+ return r.multipartReader()
+}
+
+func (r *Request) multipartReader() (*multipart.Reader, error) {
+ v := r.Header.Get("Content-Type")
+ if v == "" {
+ return nil, ErrNotMultipart
+ }
+ d, params, err := mime.ParseMediaType(v)
+ if err != nil || d != "multipart/form-data" {
+ return nil, ErrNotMultipart
+ }
+ boundary, ok := params["boundary"]
+ if !ok {
+ return nil, ErrMissingBoundary
+ }
+ return multipart.NewReader(r.Body, boundary), nil
+}
+
+// Return value if nonempty, def otherwise.
+func valueOrDefault(value, def string) string {
+ if value != "" {
+ return value
+ }
+ return def
+}
+
+// NOTE: This is not intended to reflect the actual Go version being used.
+// It was changed from "Go http package" to "Go 1.1 package http" at the
+// time of the Go 1.1 release because the former User-Agent had ended up
+// on a blacklist for some intrusion detection systems.
+// See https://codereview.appspot.com/7532043.
+const defaultUserAgent = "Go 1.1 package http"
+
+// Write writes an HTTP/1.1 request -- header and body -- in wire format.
+// This method consults the following fields of the request:
+// Host
+// URL
+// Method (defaults to "GET")
+// Header
+// ContentLength
+// TransferEncoding
+// Body
+//
+// If Body is present, Content-Length is <= 0 and TransferEncoding
+// hasn't been set to "identity", Write adds "Transfer-Encoding:
+// chunked" to the header. Body is closed after it is sent.
+func (r *Request) Write(w io.Writer) error {
+ return r.write(w, false, nil)
+}
+
+// WriteProxy is like Write but writes the request in the form
+// expected by an HTTP proxy. In particular, WriteProxy writes the
+// initial Request-URI line of the request with an absolute URI, per
+// section 5.1.2 of RFC 2616, including the scheme and host.
+// In either case, WriteProxy also writes a Host header, using
+// either r.Host or r.URL.Host.
+func (r *Request) WriteProxy(w io.Writer) error {
+ return r.write(w, true, nil)
+}
+
+// extraHeaders may be nil
+func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) error {
+ host := req.Host
+ if host == "" {
+ if req.URL == nil {
+ return errors.New("http: Request.Write on Request with no Host or URL set")
+ }
+ host = req.URL.Host
+ }
+
+ ruri := req.URL.RequestURI()
+ if usingProxy && req.URL.Scheme != "" && req.URL.Opaque == "" {
+ ruri = req.URL.Scheme + "://" + host + ruri
+ } else if req.Method == "CONNECT" && req.URL.Path == "" {
+ // CONNECT requests normally give just the host and port, not a full URL.
+ ruri = host
+ }
+ // TODO(bradfitz): escape at least newlines in ruri?
+
+ // Wrap the writer in a bufio Writer if it's not already buffered.
+ // Don't always call NewWriter, as that forces a bytes.Buffer
+ // and other small bufio Writers to have a minimum 4k buffer
+ // size.
+ var bw *bufio.Writer
+ if _, ok := w.(io.ByteWriter); !ok {
+ bw = bufio.NewWriter(w)
+ w = bw
+ }
+
+ _, err := fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri)
+ if err != nil {
+ return err
+ }
+
+ // Header lines
+ _, err = fmt.Fprintf(w, "Host: %s\r\n", host)
+ if err != nil {
+ return err
+ }
+
+ // Use the defaultUserAgent unless the Header contains one, which
+ // may be blank to not send the header.
+ userAgent := defaultUserAgent
+ if req.Header != nil {
+ if ua := req.Header["User-Agent"]; len(ua) > 0 {
+ userAgent = ua[0]
+ }
+ }
+ if userAgent != "" {
+ _, err = fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent)
+ if err != nil {
+ return err
+ }
+ }
+
+ // Process Body,ContentLength,Close,Trailer
+ tw, err := newTransferWriter(req)
+ if err != nil {
+ return err
+ }
+ err = tw.WriteHeader(w)
+ if err != nil {
+ return err
+ }
+
+ err = req.Header.WriteSubset(w, reqWriteExcludeHeader)
+ if err != nil {
+ return err
+ }
+
+ if extraHeaders != nil {
+ err = extraHeaders.Write(w)
+ if err != nil {
+ return err
+ }
+ }
+
+ _, err = io.WriteString(w, "\r\n")
+ if err != nil {
+ return err
+ }
+
+ // Write body and trailer
+ err = tw.WriteBody(w)
+ if err != nil {
+ return err
+ }
+
+ if bw != nil {
+ return bw.Flush()
+ }
+ return nil
+}
+
+// ParseHTTPVersion parses a HTTP version string.
+// "HTTP/1.0" returns (1, 0, true).
+func ParseHTTPVersion(vers string) (major, minor int, ok bool) {
+ const Big = 1000000 // arbitrary upper bound
+ switch vers {
+ case "HTTP/1.1":
+ return 1, 1, true
+ case "HTTP/1.0":
+ return 1, 0, true
+ }
+ if !strings.HasPrefix(vers, "HTTP/") {
+ return 0, 0, false
+ }
+ dot := strings.Index(vers, ".")
+ if dot < 0 {
+ return 0, 0, false
+ }
+ major, err := strconv.Atoi(vers[5:dot])
+ if err != nil || major < 0 || major > Big {
+ return 0, 0, false
+ }
+ minor, err = strconv.Atoi(vers[dot+1:])
+ if err != nil || minor < 0 || minor > Big {
+ return 0, 0, false
+ }
+ return major, minor, true
+}
+
+// NewRequest returns a new Request given a method, URL, and optional body.
+//
+// If the provided body is also an io.Closer, the returned
+// Request.Body is set to body and will be closed by the Client
+// methods Do, Post, and PostForm, and Transport.RoundTrip.
+func NewRequest(method, urlStr string, body io.Reader) (*Request, error) {
+ u, err := url.Parse(urlStr)
+ if err != nil {
+ return nil, err
+ }
+ rc, ok := body.(io.ReadCloser)
+ if !ok && body != nil {
+ rc = ioutil.NopCloser(body)
+ }
+ req := &Request{
+ Method: method,
+ URL: u,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: make(Header),
+ Body: rc,
+ Host: u.Host,
+ }
+ if body != nil {
+ switch v := body.(type) {
+ case *bytes.Buffer:
+ req.ContentLength = int64(v.Len())
+ case *bytes.Reader:
+ req.ContentLength = int64(v.Len())
+ case *strings.Reader:
+ req.ContentLength = int64(v.Len())
+ }
+ }
+
+ return req, nil
+}
+
+// BasicAuth returns the username and password provided in the request's
+// Authorization header, if the request uses HTTP Basic Authentication.
+// See RFC 2617, Section 2.
+func (r *Request) BasicAuth() (username, password string, ok bool) {
+ auth := r.Header.Get("Authorization")
+ if auth == "" {
+ return
+ }
+ return parseBasicAuth(auth)
+}
+
+// parseBasicAuth parses an HTTP Basic Authentication string.
+// "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ("Aladdin", "open sesame", true).
+func parseBasicAuth(auth string) (username, password string, ok bool) {
+ if !strings.HasPrefix(auth, "Basic ") {
+ return
+ }
+ c, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(auth, "Basic "))
+ if err != nil {
+ return
+ }
+ cs := string(c)
+ s := strings.IndexByte(cs, ':')
+ if s < 0 {
+ return
+ }
+ return cs[:s], cs[s+1:], true
+}
+
+// SetBasicAuth sets the request's Authorization header to use HTTP
+// Basic Authentication with the provided username and password.
+//
+// With HTTP Basic Authentication the provided username and password
+// are not encrypted.
+func (r *Request) SetBasicAuth(username, password string) {
+ r.Header.Set("Authorization", "Basic "+basicAuth(username, password))
+}
+
+// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts.
+func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
+ s1 := strings.Index(line, " ")
+ s2 := strings.Index(line[s1+1:], " ")
+ if s1 < 0 || s2 < 0 {
+ return
+ }
+ s2 += s1 + 1
+ return line[:s1], line[s1+1 : s2], line[s2+1:], true
+}
+
+var textprotoReaderPool sync.Pool
+
+func newTextprotoReader(br *bufio.Reader) *textproto.Reader {
+ if v := textprotoReaderPool.Get(); v != nil {
+ tr := v.(*textproto.Reader)
+ tr.R = br
+ return tr
+ }
+ return textproto.NewReader(br)
+}
+
+func putTextprotoReader(r *textproto.Reader) {
+ r.R = nil
+ textprotoReaderPool.Put(r)
+}
+
+// ReadRequest reads and parses a request from b.
+func ReadRequest(b *bufio.Reader) (req *Request, err error) {
+
+ tp := newTextprotoReader(b)
+ req = new(Request)
+
+ // First line: GET /index.html HTTP/1.0
+ var s string
+ if s, err = tp.ReadLine(); err != nil {
+ return nil, err
+ }
+ defer func() {
+ putTextprotoReader(tp)
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ }()
+
+ var ok bool
+ req.Method, req.RequestURI, req.Proto, ok = parseRequestLine(s)
+ if !ok {
+ return nil, &badStringError{"malformed HTTP request", s}
+ }
+ rawurl := req.RequestURI
+ if req.ProtoMajor, req.ProtoMinor, ok = ParseHTTPVersion(req.Proto); !ok {
+ return nil, &badStringError{"malformed HTTP version", req.Proto}
+ }
+
+ // CONNECT requests are used two different ways, and neither uses a full URL:
+ // The standard use is to tunnel HTTPS through an HTTP proxy.
+ // It looks like "CONNECT www.google.com:443 HTTP/1.1", and the parameter is
+ // just the authority section of a URL. This information should go in req.URL.Host.
+ //
+ // The net/rpc package also uses CONNECT, but there the parameter is a path
+ // that starts with a slash. It can be parsed with the regular URL parser,
+ // and the path will end up in req.URL.Path, where it needs to be in order for
+ // RPC to work.
+ justAuthority := req.Method == "CONNECT" && !strings.HasPrefix(rawurl, "/")
+ if justAuthority {
+ rawurl = "http://" + rawurl
+ }
+
+ if req.URL, err = url.ParseRequestURI(rawurl); err != nil {
+ return nil, err
+ }
+
+ if justAuthority {
+ // Strip the bogus "http://" back off.
+ req.URL.Scheme = ""
+ }
+
+ // Subsequent lines: Key: value.
+ mimeHeader, err := tp.ReadMIMEHeader()
+ if err != nil {
+ return nil, err
+ }
+ req.Header = Header(mimeHeader)
+
+ // RFC2616: Must treat
+ // GET /index.html HTTP/1.1
+ // Host: www.google.com
+ // and
+ // GET http://www.google.com/index.html HTTP/1.1
+ // Host: doesntmatter
+ // the same. In the second case, any Host line is ignored.
+ req.Host = req.URL.Host
+ if req.Host == "" {
+ req.Host = req.Header.get("Host")
+ }
+ delete(req.Header, "Host")
+
+ fixPragmaCacheControl(req.Header)
+
+ err = readTransfer(req, b)
+ if err != nil {
+ return nil, err
+ }
+
+ req.Close = shouldClose(req.ProtoMajor, req.ProtoMinor, req.Header, false)
+ return req, nil
+}
+
+// MaxBytesReader is similar to io.LimitReader but is intended for
+// limiting the size of incoming request bodies. In contrast to
+// io.LimitReader, MaxBytesReader's result is a ReadCloser, returns a
+// non-EOF error for a Read beyond the limit, and Closes the
+// underlying reader when its Close method is called.
+//
+// MaxBytesReader prevents clients from accidentally or maliciously
+// sending a large request and wasting server resources.
+func MaxBytesReader(w ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser {
+ return &maxBytesReader{w: w, r: r, n: n}
+}
+
+type maxBytesReader struct {
+ w ResponseWriter
+ r io.ReadCloser // underlying reader
+ n int64 // max bytes remaining
+ stopped bool
+}
+
+func (l *maxBytesReader) Read(p []byte) (n int, err error) {
+ if l.n <= 0 {
+ if !l.stopped {
+ l.stopped = true
+ if res, ok := l.w.(*response); ok {
+ res.requestTooLarge()
+ }
+ }
+ return 0, errors.New("http: request body too large")
+ }
+ if int64(len(p)) > l.n {
+ p = p[:l.n]
+ }
+ n, err = l.r.Read(p)
+ l.n -= int64(n)
+ return
+}
+
+func (l *maxBytesReader) Close() error {
+ return l.r.Close()
+}
+
+func copyValues(dst, src url.Values) {
+ for k, vs := range src {
+ for _, value := range vs {
+ dst.Add(k, value)
+ }
+ }
+}
+
+func parsePostForm(r *Request) (vs url.Values, err error) {
+ if r.Body == nil {
+ err = errors.New("missing form body")
+ return
+ }
+ ct := r.Header.Get("Content-Type")
+ // RFC 2616, section 7.2.1 - empty type
+ // SHOULD be treated as application/octet-stream
+ if ct == "" {
+ ct = "application/octet-stream"
+ }
+ ct, _, err = mime.ParseMediaType(ct)
+ switch {
+ case ct == "application/x-www-form-urlencoded":
+ var reader io.Reader = r.Body
+ maxFormSize := int64(1<<63 - 1)
+ if _, ok := r.Body.(*maxBytesReader); !ok {
+ maxFormSize = int64(10 << 20) // 10 MB is a lot of text.
+ reader = io.LimitReader(r.Body, maxFormSize+1)
+ }
+ b, e := ioutil.ReadAll(reader)
+ if e != nil {
+ if err == nil {
+ err = e
+ }
+ break
+ }
+ if int64(len(b)) > maxFormSize {
+ err = errors.New("http: POST too large")
+ return
+ }
+ vs, e = url.ParseQuery(string(b))
+ if err == nil {
+ err = e
+ }
+ case ct == "multipart/form-data":
+ // handled by ParseMultipartForm (which is calling us, or should be)
+ // TODO(bradfitz): there are too many possible
+ // orders to call too many functions here.
+ // Clean this up and write more tests.
+ // request_test.go contains the start of this,
+ // in TestParseMultipartFormOrder and others.
+ }
+ return
+}
+
+// ParseForm parses the raw query from the URL and updates r.Form.
+//
+// For POST or PUT requests, it also parses the request body as a form and
+// put the results into both r.PostForm and r.Form.
+// POST and PUT body parameters take precedence over URL query string values
+// in r.Form.
+//
+// If the request Body's size has not already been limited by MaxBytesReader,
+// the size is capped at 10MB.
+//
+// ParseMultipartForm calls ParseForm automatically.
+// It is idempotent.
+func (r *Request) ParseForm() error {
+ var err error
+ if r.PostForm == nil {
+ if r.Method == "POST" || r.Method == "PUT" || r.Method == "PATCH" {
+ r.PostForm, err = parsePostForm(r)
+ }
+ if r.PostForm == nil {
+ r.PostForm = make(url.Values)
+ }
+ }
+ if r.Form == nil {
+ if len(r.PostForm) > 0 {
+ r.Form = make(url.Values)
+ copyValues(r.Form, r.PostForm)
+ }
+ var newValues url.Values
+ if r.URL != nil {
+ var e error
+ newValues, e = url.ParseQuery(r.URL.RawQuery)
+ if err == nil {
+ err = e
+ }
+ }
+ if newValues == nil {
+ newValues = make(url.Values)
+ }
+ if r.Form == nil {
+ r.Form = newValues
+ } else {
+ copyValues(r.Form, newValues)
+ }
+ }
+ return err
+}
+
+// ParseMultipartForm parses a request body as multipart/form-data.
+// The whole request body is parsed and up to a total of maxMemory bytes of
+// its file parts are stored in memory, with the remainder stored on
+// disk in temporary files.
+// ParseMultipartForm calls ParseForm if necessary.
+// After one call to ParseMultipartForm, subsequent calls have no effect.
+func (r *Request) ParseMultipartForm(maxMemory int64) error {
+ if r.MultipartForm == multipartByReader {
+ return errors.New("http: multipart handled by MultipartReader")
+ }
+ if r.Form == nil {
+ err := r.ParseForm()
+ if err != nil {
+ return err
+ }
+ }
+ if r.MultipartForm != nil {
+ return nil
+ }
+
+ mr, err := r.multipartReader()
+ if err != nil {
+ return err
+ }
+
+ f, err := mr.ReadForm(maxMemory)
+ if err != nil {
+ return err
+ }
+ for k, v := range f.Value {
+ r.Form[k] = append(r.Form[k], v...)
+ }
+ r.MultipartForm = f
+
+ return nil
+}
+
+// FormValue returns the first value for the named component of the query.
+// POST and PUT body parameters take precedence over URL query string values.
+// FormValue calls ParseMultipartForm and ParseForm if necessary and ignores
+// any errors returned by these functions.
+// To access multiple values of the same key, call ParseForm and
+// then inspect Request.Form directly.
+func (r *Request) FormValue(key string) string {
+ if r.Form == nil {
+ r.ParseMultipartForm(defaultMaxMemory)
+ }
+ if vs := r.Form[key]; len(vs) > 0 {
+ return vs[0]
+ }
+ return ""
+}
+
+// PostFormValue returns the first value for the named component of the POST
+// or PUT request body. URL query parameters are ignored.
+// PostFormValue calls ParseMultipartForm and ParseForm if necessary and ignores
+// any errors returned by these functions.
+func (r *Request) PostFormValue(key string) string {
+ if r.PostForm == nil {
+ r.ParseMultipartForm(defaultMaxMemory)
+ }
+ if vs := r.PostForm[key]; len(vs) > 0 {
+ return vs[0]
+ }
+ return ""
+}
+
+// FormFile returns the first file for the provided form key.
+// FormFile calls ParseMultipartForm and ParseForm if necessary.
+func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, error) {
+ if r.MultipartForm == multipartByReader {
+ return nil, nil, errors.New("http: multipart handled by MultipartReader")
+ }
+ if r.MultipartForm == nil {
+ err := r.ParseMultipartForm(defaultMaxMemory)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+ if r.MultipartForm != nil && r.MultipartForm.File != nil {
+ if fhs := r.MultipartForm.File[key]; len(fhs) > 0 {
+ f, err := fhs[0].Open()
+ return f, fhs[0], err
+ }
+ }
+ return nil, nil, ErrMissingFile
+}
+
+func (r *Request) expectsContinue() bool {
+ return hasToken(r.Header.get("Expect"), "100-continue")
+}
+
+func (r *Request) wantsHttp10KeepAlive() bool {
+ if r.ProtoMajor != 1 || r.ProtoMinor != 0 {
+ return false
+ }
+ return hasToken(r.Header.get("Connection"), "keep-alive")
+}
+
+func (r *Request) wantsClose() bool {
+ return hasToken(r.Header.get("Connection"), "close")
+}
+
+func (r *Request) closeBody() {
+ if r.Body != nil {
+ r.Body.Close()
+ }
+}
diff --git a/src/net/http/request_test.go b/src/net/http/request_test.go
new file mode 100644
index 000000000..759ea4e8b
--- /dev/null
+++ b/src/net/http/request_test.go
@@ -0,0 +1,680 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "bufio"
+ "bytes"
+ "encoding/base64"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "mime/multipart"
+ . "net/http"
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "reflect"
+ "regexp"
+ "strings"
+ "testing"
+)
+
+func TestQuery(t *testing.T) {
+ req := &Request{Method: "GET"}
+ req.URL, _ = url.Parse("http://www.google.com/search?q=foo&q=bar")
+ if q := req.FormValue("q"); q != "foo" {
+ t.Errorf(`req.FormValue("q") = %q, want "foo"`, q)
+ }
+}
+
+func TestPostQuery(t *testing.T) {
+ req, _ := NewRequest("POST", "http://www.google.com/search?q=foo&q=bar&both=x&prio=1&empty=not",
+ strings.NewReader("z=post&both=y&prio=2&empty="))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
+
+ if q := req.FormValue("q"); q != "foo" {
+ t.Errorf(`req.FormValue("q") = %q, want "foo"`, q)
+ }
+ if z := req.FormValue("z"); z != "post" {
+ t.Errorf(`req.FormValue("z") = %q, want "post"`, z)
+ }
+ if bq, found := req.PostForm["q"]; found {
+ t.Errorf(`req.PostForm["q"] = %q, want no entry in map`, bq)
+ }
+ if bz := req.PostFormValue("z"); bz != "post" {
+ t.Errorf(`req.PostFormValue("z") = %q, want "post"`, bz)
+ }
+ if qs := req.Form["q"]; !reflect.DeepEqual(qs, []string{"foo", "bar"}) {
+ t.Errorf(`req.Form["q"] = %q, want ["foo", "bar"]`, qs)
+ }
+ if both := req.Form["both"]; !reflect.DeepEqual(both, []string{"y", "x"}) {
+ t.Errorf(`req.Form["both"] = %q, want ["y", "x"]`, both)
+ }
+ if prio := req.FormValue("prio"); prio != "2" {
+ t.Errorf(`req.FormValue("prio") = %q, want "2" (from body)`, prio)
+ }
+ if empty := req.FormValue("empty"); empty != "" {
+ t.Errorf(`req.FormValue("empty") = %q, want "" (from body)`, empty)
+ }
+}
+
+func TestPatchQuery(t *testing.T) {
+ req, _ := NewRequest("PATCH", "http://www.google.com/search?q=foo&q=bar&both=x&prio=1&empty=not",
+ strings.NewReader("z=post&both=y&prio=2&empty="))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
+
+ if q := req.FormValue("q"); q != "foo" {
+ t.Errorf(`req.FormValue("q") = %q, want "foo"`, q)
+ }
+ if z := req.FormValue("z"); z != "post" {
+ t.Errorf(`req.FormValue("z") = %q, want "post"`, z)
+ }
+ if bq, found := req.PostForm["q"]; found {
+ t.Errorf(`req.PostForm["q"] = %q, want no entry in map`, bq)
+ }
+ if bz := req.PostFormValue("z"); bz != "post" {
+ t.Errorf(`req.PostFormValue("z") = %q, want "post"`, bz)
+ }
+ if qs := req.Form["q"]; !reflect.DeepEqual(qs, []string{"foo", "bar"}) {
+ t.Errorf(`req.Form["q"] = %q, want ["foo", "bar"]`, qs)
+ }
+ if both := req.Form["both"]; !reflect.DeepEqual(both, []string{"y", "x"}) {
+ t.Errorf(`req.Form["both"] = %q, want ["y", "x"]`, both)
+ }
+ if prio := req.FormValue("prio"); prio != "2" {
+ t.Errorf(`req.FormValue("prio") = %q, want "2" (from body)`, prio)
+ }
+ if empty := req.FormValue("empty"); empty != "" {
+ t.Errorf(`req.FormValue("empty") = %q, want "" (from body)`, empty)
+ }
+}
+
+type stringMap map[string][]string
+type parseContentTypeTest struct {
+ shouldError bool
+ contentType stringMap
+}
+
+var parseContentTypeTests = []parseContentTypeTest{
+ {false, stringMap{"Content-Type": {"text/plain"}}},
+ // Empty content type is legal - shoult be treated as
+ // application/octet-stream (RFC 2616, section 7.2.1)
+ {false, stringMap{}},
+ {true, stringMap{"Content-Type": {"text/plain; boundary="}}},
+ {false, stringMap{"Content-Type": {"application/unknown"}}},
+}
+
+func TestParseFormUnknownContentType(t *testing.T) {
+ for i, test := range parseContentTypeTests {
+ req := &Request{
+ Method: "POST",
+ Header: Header(test.contentType),
+ Body: ioutil.NopCloser(strings.NewReader("body")),
+ }
+ err := req.ParseForm()
+ switch {
+ case err == nil && test.shouldError:
+ t.Errorf("test %d should have returned error", i)
+ case err != nil && !test.shouldError:
+ t.Errorf("test %d should not have returned error, got %v", i, err)
+ }
+ }
+}
+
+func TestParseFormInitializeOnError(t *testing.T) {
+ nilBody, _ := NewRequest("POST", "http://www.google.com/search?q=foo", nil)
+ tests := []*Request{
+ nilBody,
+ {Method: "GET", URL: nil},
+ }
+ for i, req := range tests {
+ err := req.ParseForm()
+ if req.Form == nil {
+ t.Errorf("%d. Form not initialized, error %v", i, err)
+ }
+ if req.PostForm == nil {
+ t.Errorf("%d. PostForm not initialized, error %v", i, err)
+ }
+ }
+}
+
+func TestMultipartReader(t *testing.T) {
+ req := &Request{
+ Method: "POST",
+ Header: Header{"Content-Type": {`multipart/form-data; boundary="foo123"`}},
+ Body: ioutil.NopCloser(new(bytes.Buffer)),
+ }
+ multipart, err := req.MultipartReader()
+ if multipart == nil {
+ t.Errorf("expected multipart; error: %v", err)
+ }
+
+ req.Header = Header{"Content-Type": {"text/plain"}}
+ multipart, err = req.MultipartReader()
+ if multipart != nil {
+ t.Error("unexpected multipart for text/plain")
+ }
+}
+
+func TestParseMultipartForm(t *testing.T) {
+ req := &Request{
+ Method: "POST",
+ Header: Header{"Content-Type": {`multipart/form-data; boundary="foo123"`}},
+ Body: ioutil.NopCloser(new(bytes.Buffer)),
+ }
+ err := req.ParseMultipartForm(25)
+ if err == nil {
+ t.Error("expected multipart EOF, got nil")
+ }
+
+ req.Header = Header{"Content-Type": {"text/plain"}}
+ err = req.ParseMultipartForm(25)
+ if err != ErrNotMultipart {
+ t.Error("expected ErrNotMultipart for text/plain")
+ }
+}
+
+func TestRedirect(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ switch r.URL.Path {
+ case "/":
+ w.Header().Set("Location", "/foo/")
+ w.WriteHeader(StatusSeeOther)
+ case "/foo/":
+ fmt.Fprintf(w, "foo")
+ default:
+ w.WriteHeader(StatusBadRequest)
+ }
+ }))
+ defer ts.Close()
+
+ var end = regexp.MustCompile("/foo/$")
+ r, err := Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ r.Body.Close()
+ url := r.Request.URL.String()
+ if r.StatusCode != 200 || !end.MatchString(url) {
+ t.Fatalf("Get got status %d at %q, want 200 matching /foo/$", r.StatusCode, url)
+ }
+}
+
+func TestSetBasicAuth(t *testing.T) {
+ r, _ := NewRequest("GET", "http://example.com/", nil)
+ r.SetBasicAuth("Aladdin", "open sesame")
+ if g, e := r.Header.Get("Authorization"), "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="; g != e {
+ t.Errorf("got header %q, want %q", g, e)
+ }
+}
+
+func TestMultipartRequest(t *testing.T) {
+ // Test that we can read the values and files of a
+ // multipart request with FormValue and FormFile,
+ // and that ParseMultipartForm can be called multiple times.
+ req := newTestMultipartRequest(t)
+ if err := req.ParseMultipartForm(25); err != nil {
+ t.Fatal("ParseMultipartForm first call:", err)
+ }
+ defer req.MultipartForm.RemoveAll()
+ validateTestMultipartContents(t, req, false)
+ if err := req.ParseMultipartForm(25); err != nil {
+ t.Fatal("ParseMultipartForm second call:", err)
+ }
+ validateTestMultipartContents(t, req, false)
+}
+
+func TestMultipartRequestAuto(t *testing.T) {
+ // Test that FormValue and FormFile automatically invoke
+ // ParseMultipartForm and return the right values.
+ req := newTestMultipartRequest(t)
+ defer func() {
+ if req.MultipartForm != nil {
+ req.MultipartForm.RemoveAll()
+ }
+ }()
+ validateTestMultipartContents(t, req, true)
+}
+
+func TestMissingFileMultipartRequest(t *testing.T) {
+ // Test that FormFile returns an error if
+ // the named file is missing.
+ req := newTestMultipartRequest(t)
+ testMissingFile(t, req)
+}
+
+// Test that FormValue invokes ParseMultipartForm.
+func TestFormValueCallsParseMultipartForm(t *testing.T) {
+ req, _ := NewRequest("POST", "http://www.google.com/", strings.NewReader("z=post"))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
+ if req.Form != nil {
+ t.Fatal("Unexpected request Form, want nil")
+ }
+ req.FormValue("z")
+ if req.Form == nil {
+ t.Fatal("ParseMultipartForm not called by FormValue")
+ }
+}
+
+// Test that FormFile invokes ParseMultipartForm.
+func TestFormFileCallsParseMultipartForm(t *testing.T) {
+ req := newTestMultipartRequest(t)
+ if req.Form != nil {
+ t.Fatal("Unexpected request Form, want nil")
+ }
+ req.FormFile("")
+ if req.Form == nil {
+ t.Fatal("ParseMultipartForm not called by FormFile")
+ }
+}
+
+// Test that ParseMultipartForm errors if called
+// after MultipartReader on the same request.
+func TestParseMultipartFormOrder(t *testing.T) {
+ req := newTestMultipartRequest(t)
+ if _, err := req.MultipartReader(); err != nil {
+ t.Fatalf("MultipartReader: %v", err)
+ }
+ if err := req.ParseMultipartForm(1024); err == nil {
+ t.Fatal("expected an error from ParseMultipartForm after call to MultipartReader")
+ }
+}
+
+// Test that MultipartReader errors if called
+// after ParseMultipartForm on the same request.
+func TestMultipartReaderOrder(t *testing.T) {
+ req := newTestMultipartRequest(t)
+ if err := req.ParseMultipartForm(25); err != nil {
+ t.Fatalf("ParseMultipartForm: %v", err)
+ }
+ defer req.MultipartForm.RemoveAll()
+ if _, err := req.MultipartReader(); err == nil {
+ t.Fatal("expected an error from MultipartReader after call to ParseMultipartForm")
+ }
+}
+
+// Test that FormFile errors if called after
+// MultipartReader on the same request.
+func TestFormFileOrder(t *testing.T) {
+ req := newTestMultipartRequest(t)
+ if _, err := req.MultipartReader(); err != nil {
+ t.Fatalf("MultipartReader: %v", err)
+ }
+ if _, _, err := req.FormFile(""); err == nil {
+ t.Fatal("expected an error from FormFile after call to MultipartReader")
+ }
+}
+
+var readRequestErrorTests = []struct {
+ in string
+ err error
+}{
+ {"GET / HTTP/1.1\r\nheader:foo\r\n\r\n", nil},
+ {"GET / HTTP/1.1\r\nheader:foo\r\n", io.ErrUnexpectedEOF},
+ {"", io.EOF},
+}
+
+func TestReadRequestErrors(t *testing.T) {
+ for i, tt := range readRequestErrorTests {
+ _, err := ReadRequest(bufio.NewReader(strings.NewReader(tt.in)))
+ if err != tt.err {
+ t.Errorf("%d. got error = %v; want %v", i, err, tt.err)
+ }
+ }
+}
+
+func TestNewRequestHost(t *testing.T) {
+ req, err := NewRequest("GET", "http://localhost:1234/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if req.Host != "localhost:1234" {
+ t.Errorf("Host = %q; want localhost:1234", req.Host)
+ }
+}
+
+func TestNewRequestContentLength(t *testing.T) {
+ readByte := func(r io.Reader) io.Reader {
+ var b [1]byte
+ r.Read(b[:])
+ return r
+ }
+ tests := []struct {
+ r io.Reader
+ want int64
+ }{
+ {bytes.NewReader([]byte("123")), 3},
+ {bytes.NewBuffer([]byte("1234")), 4},
+ {strings.NewReader("12345"), 5},
+ // Not detected:
+ {struct{ io.Reader }{strings.NewReader("xyz")}, 0},
+ {io.NewSectionReader(strings.NewReader("x"), 0, 6), 0},
+ {readByte(io.NewSectionReader(strings.NewReader("xy"), 0, 6)), 0},
+ }
+ for _, tt := range tests {
+ req, err := NewRequest("POST", "http://localhost/", tt.r)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if req.ContentLength != tt.want {
+ t.Errorf("ContentLength(%T) = %d; want %d", tt.r, req.ContentLength, tt.want)
+ }
+ }
+}
+
+var parseHTTPVersionTests = []struct {
+ vers string
+ major, minor int
+ ok bool
+}{
+ {"HTTP/0.9", 0, 9, true},
+ {"HTTP/1.0", 1, 0, true},
+ {"HTTP/1.1", 1, 1, true},
+ {"HTTP/3.14", 3, 14, true},
+
+ {"HTTP", 0, 0, false},
+ {"HTTP/one.one", 0, 0, false},
+ {"HTTP/1.1/", 0, 0, false},
+ {"HTTP/-1,0", 0, 0, false},
+ {"HTTP/0,-1", 0, 0, false},
+ {"HTTP/", 0, 0, false},
+ {"HTTP/1,1", 0, 0, false},
+}
+
+func TestParseHTTPVersion(t *testing.T) {
+ for _, tt := range parseHTTPVersionTests {
+ major, minor, ok := ParseHTTPVersion(tt.vers)
+ if ok != tt.ok || major != tt.major || minor != tt.minor {
+ type version struct {
+ major, minor int
+ ok bool
+ }
+ t.Errorf("failed to parse %q, expected: %#v, got %#v", tt.vers, version{tt.major, tt.minor, tt.ok}, version{major, minor, ok})
+ }
+ }
+}
+
+type getBasicAuthTest struct {
+ username, password string
+ ok bool
+}
+
+type parseBasicAuthTest getBasicAuthTest
+
+type basicAuthCredentialsTest struct {
+ username, password string
+}
+
+var getBasicAuthTests = []struct {
+ username, password string
+ ok bool
+}{
+ {"Aladdin", "open sesame", true},
+ {"Aladdin", "open:sesame", true},
+ {"", "", true},
+}
+
+func TestGetBasicAuth(t *testing.T) {
+ for _, tt := range getBasicAuthTests {
+ r, _ := NewRequest("GET", "http://example.com/", nil)
+ r.SetBasicAuth(tt.username, tt.password)
+ username, password, ok := r.BasicAuth()
+ if ok != tt.ok || username != tt.username || password != tt.password {
+ t.Errorf("BasicAuth() = %#v, want %#v", getBasicAuthTest{username, password, ok},
+ getBasicAuthTest{tt.username, tt.password, tt.ok})
+ }
+ }
+ // Unauthenticated request.
+ r, _ := NewRequest("GET", "http://example.com/", nil)
+ username, password, ok := r.BasicAuth()
+ if ok {
+ t.Errorf("expected false from BasicAuth when the request is unauthenticated")
+ }
+ want := basicAuthCredentialsTest{"", ""}
+ if username != want.username || password != want.password {
+ t.Errorf("expected credentials: %#v when the request is unauthenticated, got %#v",
+ want, basicAuthCredentialsTest{username, password})
+ }
+}
+
+var parseBasicAuthTests = []struct {
+ header, username, password string
+ ok bool
+}{
+ {"Basic " + base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "Aladdin", "open sesame", true},
+ {"Basic " + base64.StdEncoding.EncodeToString([]byte("Aladdin:open:sesame")), "Aladdin", "open:sesame", true},
+ {"Basic " + base64.StdEncoding.EncodeToString([]byte(":")), "", "", true},
+ {"Basic" + base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "", "", false},
+ {base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "", "", false},
+ {"Basic ", "", "", false},
+ {"Basic Aladdin:open sesame", "", "", false},
+ {`Digest username="Aladdin"`, "", "", false},
+}
+
+func TestParseBasicAuth(t *testing.T) {
+ for _, tt := range parseBasicAuthTests {
+ r, _ := NewRequest("GET", "http://example.com/", nil)
+ r.Header.Set("Authorization", tt.header)
+ username, password, ok := r.BasicAuth()
+ if ok != tt.ok || username != tt.username || password != tt.password {
+ t.Errorf("BasicAuth() = %#v, want %#v", getBasicAuthTest{username, password, ok},
+ getBasicAuthTest{tt.username, tt.password, tt.ok})
+ }
+ }
+}
+
+type logWrites struct {
+ t *testing.T
+ dst *[]string
+}
+
+func (l logWrites) WriteByte(c byte) error {
+ l.t.Fatalf("unexpected WriteByte call")
+ return nil
+}
+
+func (l logWrites) Write(p []byte) (n int, err error) {
+ *l.dst = append(*l.dst, string(p))
+ return len(p), nil
+}
+
+func TestRequestWriteBufferedWriter(t *testing.T) {
+ got := []string{}
+ req, _ := NewRequest("GET", "http://foo.com/", nil)
+ req.Write(logWrites{t, &got})
+ want := []string{
+ "GET / HTTP/1.1\r\n",
+ "Host: foo.com\r\n",
+ "User-Agent: " + DefaultUserAgent + "\r\n",
+ "\r\n",
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("Writes = %q\n Want = %q", got, want)
+ }
+}
+
+func testMissingFile(t *testing.T, req *Request) {
+ f, fh, err := req.FormFile("missing")
+ if f != nil {
+ t.Errorf("FormFile file = %v, want nil", f)
+ }
+ if fh != nil {
+ t.Errorf("FormFile file header = %q, want nil", fh)
+ }
+ if err != ErrMissingFile {
+ t.Errorf("FormFile err = %q, want ErrMissingFile", err)
+ }
+}
+
+func newTestMultipartRequest(t *testing.T) *Request {
+ b := strings.NewReader(strings.Replace(message, "\n", "\r\n", -1))
+ req, err := NewRequest("POST", "/", b)
+ if err != nil {
+ t.Fatal("NewRequest:", err)
+ }
+ ctype := fmt.Sprintf(`multipart/form-data; boundary="%s"`, boundary)
+ req.Header.Set("Content-type", ctype)
+ return req
+}
+
+func validateTestMultipartContents(t *testing.T, req *Request, allMem bool) {
+ if g, e := req.FormValue("texta"), textaValue; g != e {
+ t.Errorf("texta value = %q, want %q", g, e)
+ }
+ if g, e := req.FormValue("textb"), textbValue; g != e {
+ t.Errorf("textb value = %q, want %q", g, e)
+ }
+ if g := req.FormValue("missing"); g != "" {
+ t.Errorf("missing value = %q, want empty string", g)
+ }
+
+ assertMem := func(n string, fd multipart.File) {
+ if _, ok := fd.(*os.File); ok {
+ t.Error(n, " is *os.File, should not be")
+ }
+ }
+ fda := testMultipartFile(t, req, "filea", "filea.txt", fileaContents)
+ defer fda.Close()
+ assertMem("filea", fda)
+ fdb := testMultipartFile(t, req, "fileb", "fileb.txt", filebContents)
+ defer fdb.Close()
+ if allMem {
+ assertMem("fileb", fdb)
+ } else {
+ if _, ok := fdb.(*os.File); !ok {
+ t.Errorf("fileb has unexpected underlying type %T", fdb)
+ }
+ }
+
+ testMissingFile(t, req)
+}
+
+func testMultipartFile(t *testing.T, req *Request, key, expectFilename, expectContent string) multipart.File {
+ f, fh, err := req.FormFile(key)
+ if err != nil {
+ t.Fatalf("FormFile(%q): %q", key, err)
+ }
+ if fh.Filename != expectFilename {
+ t.Errorf("filename = %q, want %q", fh.Filename, expectFilename)
+ }
+ var b bytes.Buffer
+ _, err = io.Copy(&b, f)
+ if err != nil {
+ t.Fatal("copying contents:", err)
+ }
+ if g := b.String(); g != expectContent {
+ t.Errorf("contents = %q, want %q", g, expectContent)
+ }
+ return f
+}
+
+const (
+ fileaContents = "This is a test file."
+ filebContents = "Another test file."
+ textaValue = "foo"
+ textbValue = "bar"
+ boundary = `MyBoundary`
+)
+
+const message = `
+--MyBoundary
+Content-Disposition: form-data; name="filea"; filename="filea.txt"
+Content-Type: text/plain
+
+` + fileaContents + `
+--MyBoundary
+Content-Disposition: form-data; name="fileb"; filename="fileb.txt"
+Content-Type: text/plain
+
+` + filebContents + `
+--MyBoundary
+Content-Disposition: form-data; name="texta"
+
+` + textaValue + `
+--MyBoundary
+Content-Disposition: form-data; name="textb"
+
+` + textbValue + `
+--MyBoundary--
+`
+
+func benchmarkReadRequest(b *testing.B, request string) {
+ request = request + "\n" // final \n
+ request = strings.Replace(request, "\n", "\r\n", -1) // expand \n to \r\n
+ b.SetBytes(int64(len(request)))
+ r := bufio.NewReader(&infiniteReader{buf: []byte(request)})
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _, err := ReadRequest(r)
+ if err != nil {
+ b.Fatalf("failed to read request: %v", err)
+ }
+ }
+}
+
+// infiniteReader satisfies Read requests as if the contents of buf
+// loop indefinitely.
+type infiniteReader struct {
+ buf []byte
+ offset int
+}
+
+func (r *infiniteReader) Read(b []byte) (int, error) {
+ n := copy(b, r.buf[r.offset:])
+ r.offset = (r.offset + n) % len(r.buf)
+ return n, nil
+}
+
+func BenchmarkReadRequestChrome(b *testing.B) {
+ // https://github.com/felixge/node-http-perf/blob/master/fixtures/get.http
+ benchmarkReadRequest(b, `GET / HTTP/1.1
+Host: localhost:8080
+Connection: keep-alive
+Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
+User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
+Accept-Encoding: gzip,deflate,sdch
+Accept-Language: en-US,en;q=0.8
+Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
+Cookie: __utma=1.1978842379.1323102373.1323102373.1323102373.1; EPi:NumberOfVisits=1,2012-02-28T13:42:18; CrmSession=5b707226b9563e1bc69084d07a107c98; plushContainerWidth=100%25; plushNoTopMenu=0; hudson_auto_refresh=false
+`)
+}
+
+func BenchmarkReadRequestCurl(b *testing.B) {
+ // curl http://localhost:8080/
+ benchmarkReadRequest(b, `GET / HTTP/1.1
+User-Agent: curl/7.27.0
+Host: localhost:8080
+Accept: */*
+`)
+}
+
+func BenchmarkReadRequestApachebench(b *testing.B) {
+ // ab -n 1 -c 1 http://localhost:8080/
+ benchmarkReadRequest(b, `GET / HTTP/1.0
+Host: localhost:8080
+User-Agent: ApacheBench/2.3
+Accept: */*
+`)
+}
+
+func BenchmarkReadRequestSiege(b *testing.B) {
+ // siege -r 1 -c 1 http://localhost:8080/
+ benchmarkReadRequest(b, `GET / HTTP/1.1
+Host: localhost:8080
+Accept: */*
+Accept-Encoding: gzip
+User-Agent: JoeDog/1.00 [en] (X11; I; Siege 2.70)
+Connection: keep-alive
+`)
+}
+
+func BenchmarkReadRequestWrk(b *testing.B) {
+ // wrk -t 1 -r 1 -c 1 http://localhost:8080/
+ benchmarkReadRequest(b, `GET / HTTP/1.1
+Host: localhost:8080
+`)
+}
diff --git a/src/net/http/requestwrite_test.go b/src/net/http/requestwrite_test.go
new file mode 100644
index 000000000..7a6bd5878
--- /dev/null
+++ b/src/net/http/requestwrite_test.go
@@ -0,0 +1,623 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/url"
+ "strings"
+ "testing"
+)
+
+type reqWriteTest struct {
+ Req Request
+ Body interface{} // optional []byte or func() io.ReadCloser to populate Req.Body
+
+ // Any of these three may be empty to skip that test.
+ WantWrite string // Request.Write
+ WantProxy string // Request.WriteProxy
+
+ WantError error // wanted error from Request.Write
+}
+
+var reqWriteTests = []reqWriteTest{
+ // HTTP/1.1 => chunked coding; no body; no trailer
+ {
+ Req: Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.techcrunch.com",
+ Path: "/",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{
+ "Accept": {"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"},
+ "Accept-Charset": {"ISO-8859-1,utf-8;q=0.7,*;q=0.7"},
+ "Accept-Encoding": {"gzip,deflate"},
+ "Accept-Language": {"en-us,en;q=0.5"},
+ "Keep-Alive": {"300"},
+ "Proxy-Connection": {"keep-alive"},
+ "User-Agent": {"Fake"},
+ },
+ Body: nil,
+ Close: false,
+ Host: "www.techcrunch.com",
+ Form: map[string][]string{},
+ },
+
+ WantWrite: "GET / HTTP/1.1\r\n" +
+ "Host: www.techcrunch.com\r\n" +
+ "User-Agent: Fake\r\n" +
+ "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" +
+ "Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.7\r\n" +
+ "Accept-Encoding: gzip,deflate\r\n" +
+ "Accept-Language: en-us,en;q=0.5\r\n" +
+ "Keep-Alive: 300\r\n" +
+ "Proxy-Connection: keep-alive\r\n\r\n",
+
+ WantProxy: "GET http://www.techcrunch.com/ HTTP/1.1\r\n" +
+ "Host: www.techcrunch.com\r\n" +
+ "User-Agent: Fake\r\n" +
+ "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" +
+ "Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.7\r\n" +
+ "Accept-Encoding: gzip,deflate\r\n" +
+ "Accept-Language: en-us,en;q=0.5\r\n" +
+ "Keep-Alive: 300\r\n" +
+ "Proxy-Connection: keep-alive\r\n\r\n",
+ },
+ // HTTP/1.1 => chunked coding; body; empty trailer
+ {
+ Req: Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/search",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ TransferEncoding: []string{"chunked"},
+ },
+
+ Body: []byte("abcdef"),
+
+ WantWrite: "GET /search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go 1.1 package http\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ chunk("abcdef") + chunk(""),
+
+ WantProxy: "GET http://www.google.com/search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go 1.1 package http\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ chunk("abcdef") + chunk(""),
+ },
+ // HTTP/1.1 POST => chunked coding; body; empty trailer
+ {
+ Req: Request{
+ Method: "POST",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/search",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Close: true,
+ TransferEncoding: []string{"chunked"},
+ },
+
+ Body: []byte("abcdef"),
+
+ WantWrite: "POST /search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go 1.1 package http\r\n" +
+ "Connection: close\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ chunk("abcdef") + chunk(""),
+
+ WantProxy: "POST http://www.google.com/search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go 1.1 package http\r\n" +
+ "Connection: close\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ chunk("abcdef") + chunk(""),
+ },
+
+ // HTTP/1.1 POST with Content-Length, no chunking
+ {
+ Req: Request{
+ Method: "POST",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/search",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Close: true,
+ ContentLength: 6,
+ },
+
+ Body: []byte("abcdef"),
+
+ WantWrite: "POST /search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go 1.1 package http\r\n" +
+ "Connection: close\r\n" +
+ "Content-Length: 6\r\n" +
+ "\r\n" +
+ "abcdef",
+
+ WantProxy: "POST http://www.google.com/search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go 1.1 package http\r\n" +
+ "Connection: close\r\n" +
+ "Content-Length: 6\r\n" +
+ "\r\n" +
+ "abcdef",
+ },
+
+ // HTTP/1.1 POST with Content-Length in headers
+ {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("http://example.com/"),
+ Host: "example.com",
+ Header: Header{
+ "Content-Length": []string{"10"}, // ignored
+ },
+ ContentLength: 6,
+ },
+
+ Body: []byte("abcdef"),
+
+ WantWrite: "POST / HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go 1.1 package http\r\n" +
+ "Content-Length: 6\r\n" +
+ "\r\n" +
+ "abcdef",
+
+ WantProxy: "POST http://example.com/ HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go 1.1 package http\r\n" +
+ "Content-Length: 6\r\n" +
+ "\r\n" +
+ "abcdef",
+ },
+
+ // default to HTTP/1.1
+ {
+ Req: Request{
+ Method: "GET",
+ URL: mustParseURL("/search"),
+ Host: "www.google.com",
+ },
+
+ WantWrite: "GET /search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go 1.1 package http\r\n" +
+ "\r\n",
+ },
+
+ // Request with a 0 ContentLength and a 0 byte body.
+ {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("/"),
+ Host: "example.com",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 0, // as if unset by user
+ },
+
+ Body: func() io.ReadCloser { return ioutil.NopCloser(io.LimitReader(strings.NewReader("xx"), 0)) },
+
+ // RFC 2616 Section 14.13 says Content-Length should be specified
+ // unless body is prohibited by the request method.
+ // Also, nginx expects it for POST and PUT.
+ WantWrite: "POST / HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go 1.1 package http\r\n" +
+ "Content-Length: 0\r\n" +
+ "\r\n",
+
+ WantProxy: "POST / HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go 1.1 package http\r\n" +
+ "Content-Length: 0\r\n" +
+ "\r\n",
+ },
+
+ // Request with a 0 ContentLength and a 1 byte body.
+ {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("/"),
+ Host: "example.com",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 0, // as if unset by user
+ },
+
+ Body: func() io.ReadCloser { return ioutil.NopCloser(io.LimitReader(strings.NewReader("xx"), 1)) },
+
+ WantWrite: "POST / HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go 1.1 package http\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ chunk("x") + chunk(""),
+
+ WantProxy: "POST / HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go 1.1 package http\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ chunk("x") + chunk(""),
+ },
+
+ // Request with a ContentLength of 10 but a 5 byte body.
+ {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("/"),
+ Host: "example.com",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 10, // but we're going to send only 5 bytes
+ },
+ Body: []byte("12345"),
+ WantError: errors.New("http: ContentLength=10 with Body length 5"),
+ },
+
+ // Request with a ContentLength of 4 but an 8 byte body.
+ {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("/"),
+ Host: "example.com",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 4, // but we're going to try to send 8 bytes
+ },
+ Body: []byte("12345678"),
+ WantError: errors.New("http: ContentLength=4 with Body length 8"),
+ },
+
+ // Request with a 5 ContentLength and nil body.
+ {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("/"),
+ Host: "example.com",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 5, // but we'll omit the body
+ },
+ WantError: errors.New("http: Request.ContentLength=5 with nil Body"),
+ },
+
+ // Request with a 0 ContentLength and a body with 1 byte content and an error.
+ {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("/"),
+ Host: "example.com",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 0, // as if unset by user
+ },
+
+ Body: func() io.ReadCloser {
+ err := errors.New("Custom reader error")
+ errReader := &errorReader{err}
+ return ioutil.NopCloser(io.MultiReader(strings.NewReader("x"), errReader))
+ },
+
+ WantError: errors.New("Custom reader error"),
+ },
+
+ // Request with a 0 ContentLength and a body without content and an error.
+ {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("/"),
+ Host: "example.com",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 0, // as if unset by user
+ },
+
+ Body: func() io.ReadCloser {
+ err := errors.New("Custom reader error")
+ errReader := &errorReader{err}
+ return ioutil.NopCloser(errReader)
+ },
+
+ WantError: errors.New("Custom reader error"),
+ },
+
+ // Verify that DumpRequest preserves the HTTP version number, doesn't add a Host,
+ // and doesn't add a User-Agent.
+ {
+ Req: Request{
+ Method: "GET",
+ URL: mustParseURL("/foo"),
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Header: Header{
+ "X-Foo": []string{"X-Bar"},
+ },
+ },
+
+ WantWrite: "GET /foo HTTP/1.1\r\n" +
+ "Host: \r\n" +
+ "User-Agent: Go 1.1 package http\r\n" +
+ "X-Foo: X-Bar\r\n\r\n",
+ },
+
+ // If no Request.Host and no Request.URL.Host, we send
+ // an empty Host header, and don't use
+ // Request.Header["Host"]. This is just testing that
+ // we don't change Go 1.0 behavior.
+ {
+ Req: Request{
+ Method: "GET",
+ Host: "",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "",
+ Path: "/search",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{
+ "Host": []string{"bad.example.com"},
+ },
+ },
+
+ WantWrite: "GET /search HTTP/1.1\r\n" +
+ "Host: \r\n" +
+ "User-Agent: Go 1.1 package http\r\n\r\n",
+ },
+
+ // Opaque test #1 from golang.org/issue/4860
+ {
+ Req: Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Opaque: "/%2F/%2F/",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ },
+
+ WantWrite: "GET /%2F/%2F/ HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go 1.1 package http\r\n\r\n",
+ },
+
+ // Opaque test #2 from golang.org/issue/4860
+ {
+ Req: Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "x.google.com",
+ Opaque: "//y.google.com/%2F/%2F/",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ },
+
+ WantWrite: "GET http://y.google.com/%2F/%2F/ HTTP/1.1\r\n" +
+ "Host: x.google.com\r\n" +
+ "User-Agent: Go 1.1 package http\r\n\r\n",
+ },
+
+ // Testing custom case in header keys. Issue 5022.
+ {
+ Req: Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{
+ "ALL-CAPS": {"x"},
+ },
+ },
+
+ WantWrite: "GET / HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go 1.1 package http\r\n" +
+ "ALL-CAPS: x\r\n" +
+ "\r\n",
+ },
+}
+
+func TestRequestWrite(t *testing.T) {
+ for i := range reqWriteTests {
+ tt := &reqWriteTests[i]
+
+ setBody := func() {
+ if tt.Body == nil {
+ return
+ }
+ switch b := tt.Body.(type) {
+ case []byte:
+ tt.Req.Body = ioutil.NopCloser(bytes.NewReader(b))
+ case func() io.ReadCloser:
+ tt.Req.Body = b()
+ }
+ }
+ setBody()
+ if tt.Req.Header == nil {
+ tt.Req.Header = make(Header)
+ }
+
+ var braw bytes.Buffer
+ err := tt.Req.Write(&braw)
+ if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.WantError); g != e {
+ t.Errorf("writing #%d, err = %q, want %q", i, g, e)
+ continue
+ }
+ if err != nil {
+ continue
+ }
+
+ if tt.WantWrite != "" {
+ sraw := braw.String()
+ if sraw != tt.WantWrite {
+ t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantWrite, sraw)
+ continue
+ }
+ }
+
+ if tt.WantProxy != "" {
+ setBody()
+ var praw bytes.Buffer
+ err = tt.Req.WriteProxy(&praw)
+ if err != nil {
+ t.Errorf("WriteProxy #%d: %s", i, err)
+ continue
+ }
+ sraw := praw.String()
+ if sraw != tt.WantProxy {
+ t.Errorf("Test Proxy %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantProxy, sraw)
+ continue
+ }
+ }
+ }
+}
+
+type closeChecker struct {
+ io.Reader
+ closed bool
+}
+
+func (rc *closeChecker) Close() error {
+ rc.closed = true
+ return nil
+}
+
+// TestRequestWriteClosesBody tests that Request.Write does close its request.Body.
+// It also indirectly tests NewRequest and that it doesn't wrap an existing Closer
+// inside a NopCloser, and that it serializes it correctly.
+func TestRequestWriteClosesBody(t *testing.T) {
+ rc := &closeChecker{Reader: strings.NewReader("my body")}
+ req, _ := NewRequest("POST", "http://foo.com/", rc)
+ if req.ContentLength != 0 {
+ t.Errorf("got req.ContentLength %d, want 0", req.ContentLength)
+ }
+ buf := new(bytes.Buffer)
+ req.Write(buf)
+ if !rc.closed {
+ t.Error("body not closed after write")
+ }
+ expected := "POST / HTTP/1.1\r\n" +
+ "Host: foo.com\r\n" +
+ "User-Agent: Go 1.1 package http\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ // TODO: currently we don't buffer before chunking, so we get a
+ // single "m" chunk before the other chunks, as this was the 1-byte
+ // read from our MultiReader where we stiched the Body back together
+ // after sniffing whether the Body was 0 bytes or not.
+ chunk("m") +
+ chunk("y body") +
+ chunk("")
+ if buf.String() != expected {
+ t.Errorf("write:\n got: %s\nwant: %s", buf.String(), expected)
+ }
+}
+
+func chunk(s string) string {
+ return fmt.Sprintf("%x\r\n%s\r\n", len(s), s)
+}
+
+func mustParseURL(s string) *url.URL {
+ u, err := url.Parse(s)
+ if err != nil {
+ panic(fmt.Sprintf("Error parsing URL %q: %v", s, err))
+ }
+ return u
+}
+
+type writerFunc func([]byte) (int, error)
+
+func (f writerFunc) Write(p []byte) (int, error) { return f(p) }
+
+// TestRequestWriteError tests the Write err != nil checks in (*Request).write.
+func TestRequestWriteError(t *testing.T) {
+ failAfter, writeCount := 0, 0
+ errFail := errors.New("fake write failure")
+
+ // w is the buffered io.Writer to write the request to. It
+ // fails exactly once on its Nth Write call, as controlled by
+ // failAfter. It also tracks the number of calls in
+ // writeCount.
+ w := struct {
+ io.ByteWriter // to avoid being wrapped by a bufio.Writer
+ io.Writer
+ }{
+ nil,
+ writerFunc(func(p []byte) (n int, err error) {
+ writeCount++
+ if failAfter == 0 {
+ err = errFail
+ }
+ failAfter--
+ return len(p), err
+ }),
+ }
+
+ req, _ := NewRequest("GET", "http://example.com/", nil)
+ const writeCalls = 4 // number of Write calls in current implementation
+ sawGood := false
+ for n := 0; n <= writeCalls+2; n++ {
+ failAfter = n
+ writeCount = 0
+ err := req.Write(w)
+ var wantErr error
+ if n < writeCalls {
+ wantErr = errFail
+ }
+ if err != wantErr {
+ t.Errorf("for fail-after %d Writes, err = %v; want %v", n, err, wantErr)
+ continue
+ }
+ if err == nil {
+ sawGood = true
+ if writeCount != writeCalls {
+ t.Fatalf("writeCalls constant is outdated in test")
+ }
+ }
+ if writeCount > writeCalls || writeCount > n+1 {
+ t.Errorf("for fail-after %d, saw unexpectedly high (%d) write calls", n, writeCount)
+ }
+ }
+ if !sawGood {
+ t.Fatalf("writeCalls constant is outdated in test")
+ }
+}
diff --git a/src/net/http/response.go b/src/net/http/response.go
new file mode 100644
index 000000000..5d2c39080
--- /dev/null
+++ b/src/net/http/response.go
@@ -0,0 +1,291 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP Response reading and parsing.
+
+package http
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/tls"
+ "errors"
+ "io"
+ "net/textproto"
+ "net/url"
+ "strconv"
+ "strings"
+)
+
+var respExcludeHeader = map[string]bool{
+ "Content-Length": true,
+ "Transfer-Encoding": true,
+ "Trailer": true,
+}
+
+// Response represents the response from an HTTP request.
+//
+type Response struct {
+ Status string // e.g. "200 OK"
+ StatusCode int // e.g. 200
+ Proto string // e.g. "HTTP/1.0"
+ ProtoMajor int // e.g. 1
+ ProtoMinor int // e.g. 0
+
+ // Header maps header keys to values. If the response had multiple
+ // headers with the same key, they may be concatenated, with comma
+ // delimiters. (Section 4.2 of RFC 2616 requires that multiple headers
+ // be semantically equivalent to a comma-delimited sequence.) Values
+ // duplicated by other fields in this struct (e.g., ContentLength) are
+ // omitted from Header.
+ //
+ // Keys in the map are canonicalized (see CanonicalHeaderKey).
+ Header Header
+
+ // Body represents the response body.
+ //
+ // The http Client and Transport guarantee that Body is always
+ // non-nil, even on responses without a body or responses with
+ // a zero-length body. It is the caller's responsibility to
+ // close Body.
+ //
+ // The Body is automatically dechunked if the server replied
+ // with a "chunked" Transfer-Encoding.
+ Body io.ReadCloser
+
+ // ContentLength records the length of the associated content. The
+ // value -1 indicates that the length is unknown. Unless Request.Method
+ // is "HEAD", values >= 0 indicate that the given number of bytes may
+ // be read from Body.
+ ContentLength int64
+
+ // Contains transfer encodings from outer-most to inner-most. Value is
+ // nil, means that "identity" encoding is used.
+ TransferEncoding []string
+
+ // Close records whether the header directed that the connection be
+ // closed after reading Body. The value is advice for clients: neither
+ // ReadResponse nor Response.Write ever closes a connection.
+ Close bool
+
+ // Trailer maps trailer keys to values, in the same
+ // format as the header.
+ Trailer Header
+
+ // The Request that was sent to obtain this Response.
+ // Request's Body is nil (having already been consumed).
+ // This is only populated for Client requests.
+ Request *Request
+
+ // TLS contains information about the TLS connection on which the
+ // response was received. It is nil for unencrypted responses.
+ // The pointer is shared between responses and should not be
+ // modified.
+ TLS *tls.ConnectionState
+}
+
+// Cookies parses and returns the cookies set in the Set-Cookie headers.
+func (r *Response) Cookies() []*Cookie {
+ return readSetCookies(r.Header)
+}
+
+var ErrNoLocation = errors.New("http: no Location header in response")
+
+// Location returns the URL of the response's "Location" header,
+// if present. Relative redirects are resolved relative to
+// the Response's Request. ErrNoLocation is returned if no
+// Location header is present.
+func (r *Response) Location() (*url.URL, error) {
+ lv := r.Header.Get("Location")
+ if lv == "" {
+ return nil, ErrNoLocation
+ }
+ if r.Request != nil && r.Request.URL != nil {
+ return r.Request.URL.Parse(lv)
+ }
+ return url.Parse(lv)
+}
+
+// ReadResponse reads and returns an HTTP response from r.
+// The req parameter optionally specifies the Request that corresponds
+// to this Response. If nil, a GET request is assumed.
+// Clients must call resp.Body.Close when finished reading resp.Body.
+// After that call, clients can inspect resp.Trailer to find key/value
+// pairs included in the response trailer.
+func ReadResponse(r *bufio.Reader, req *Request) (*Response, error) {
+ tp := textproto.NewReader(r)
+ resp := &Response{
+ Request: req,
+ }
+
+ // Parse the first line of the response.
+ line, err := tp.ReadLine()
+ if err != nil {
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ return nil, err
+ }
+ f := strings.SplitN(line, " ", 3)
+ if len(f) < 2 {
+ return nil, &badStringError{"malformed HTTP response", line}
+ }
+ reasonPhrase := ""
+ if len(f) > 2 {
+ reasonPhrase = f[2]
+ }
+ resp.Status = f[1] + " " + reasonPhrase
+ resp.StatusCode, err = strconv.Atoi(f[1])
+ if err != nil {
+ return nil, &badStringError{"malformed HTTP status code", f[1]}
+ }
+
+ resp.Proto = f[0]
+ var ok bool
+ if resp.ProtoMajor, resp.ProtoMinor, ok = ParseHTTPVersion(resp.Proto); !ok {
+ return nil, &badStringError{"malformed HTTP version", resp.Proto}
+ }
+
+ // Parse the response headers.
+ mimeHeader, err := tp.ReadMIMEHeader()
+ if err != nil {
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ return nil, err
+ }
+ resp.Header = Header(mimeHeader)
+
+ fixPragmaCacheControl(resp.Header)
+
+ err = readTransfer(resp, r)
+ if err != nil {
+ return nil, err
+ }
+
+ return resp, nil
+}
+
+// RFC2616: Should treat
+// Pragma: no-cache
+// like
+// Cache-Control: no-cache
+func fixPragmaCacheControl(header Header) {
+ if hp, ok := header["Pragma"]; ok && len(hp) > 0 && hp[0] == "no-cache" {
+ if _, presentcc := header["Cache-Control"]; !presentcc {
+ header["Cache-Control"] = []string{"no-cache"}
+ }
+ }
+}
+
+// ProtoAtLeast reports whether the HTTP protocol used
+// in the response is at least major.minor.
+func (r *Response) ProtoAtLeast(major, minor int) bool {
+ return r.ProtoMajor > major ||
+ r.ProtoMajor == major && r.ProtoMinor >= minor
+}
+
+// Writes the response (header, body and trailer) in wire format. This method
+// consults the following fields of the response:
+//
+// StatusCode
+// ProtoMajor
+// ProtoMinor
+// Request.Method
+// TransferEncoding
+// Trailer
+// Body
+// ContentLength
+// Header, values for non-canonical keys will have unpredictable behavior
+//
+// Body is closed after it is sent.
+func (r *Response) Write(w io.Writer) error {
+ // Status line
+ text := r.Status
+ if text == "" {
+ var ok bool
+ text, ok = statusText[r.StatusCode]
+ if !ok {
+ text = "status code " + strconv.Itoa(r.StatusCode)
+ }
+ }
+ protoMajor, protoMinor := strconv.Itoa(r.ProtoMajor), strconv.Itoa(r.ProtoMinor)
+ statusCode := strconv.Itoa(r.StatusCode) + " "
+ text = strings.TrimPrefix(text, statusCode)
+ if _, err := io.WriteString(w, "HTTP/"+protoMajor+"."+protoMinor+" "+statusCode+text+"\r\n"); err != nil {
+ return err
+ }
+
+ // Clone it, so we can modify r1 as needed.
+ r1 := new(Response)
+ *r1 = *r
+ if r1.ContentLength == 0 && r1.Body != nil {
+ // Is it actually 0 length? Or just unknown?
+ var buf [1]byte
+ n, err := r1.Body.Read(buf[:])
+ if err != nil && err != io.EOF {
+ return err
+ }
+ if n == 0 {
+ // Reset it to a known zero reader, in case underlying one
+ // is unhappy being read repeatedly.
+ r1.Body = eofReader
+ } else {
+ r1.ContentLength = -1
+ r1.Body = struct {
+ io.Reader
+ io.Closer
+ }{
+ io.MultiReader(bytes.NewReader(buf[:1]), r.Body),
+ r.Body,
+ }
+ }
+ }
+ // If we're sending a non-chunked HTTP/1.1 response without a
+ // content-length, the only way to do that is the old HTTP/1.0
+ // way, by noting the EOF with a connection close, so we need
+ // to set Close.
+ if r1.ContentLength == -1 && !r1.Close && r1.ProtoAtLeast(1, 1) && !chunked(r1.TransferEncoding) {
+ r1.Close = true
+ }
+
+ // Process Body,ContentLength,Close,Trailer
+ tw, err := newTransferWriter(r1)
+ if err != nil {
+ return err
+ }
+ err = tw.WriteHeader(w)
+ if err != nil {
+ return err
+ }
+
+ // Rest of header
+ err = r.Header.WriteSubset(w, respExcludeHeader)
+ if err != nil {
+ return err
+ }
+
+ // contentLengthAlreadySent may have been already sent for
+ // POST/PUT requests, even if zero length. See Issue 8180.
+ contentLengthAlreadySent := tw.shouldSendContentLength()
+ if r1.ContentLength == 0 && !chunked(r1.TransferEncoding) && !contentLengthAlreadySent {
+ if _, err := io.WriteString(w, "Content-Length: 0\r\n"); err != nil {
+ return err
+ }
+ }
+
+ // End-of-header
+ if _, err := io.WriteString(w, "\r\n"); err != nil {
+ return err
+ }
+
+ // Write body and trailer
+ err = tw.WriteBody(w)
+ if err != nil {
+ return err
+ }
+
+ // Success
+ return nil
+}
diff --git a/src/net/http/response_test.go b/src/net/http/response_test.go
new file mode 100644
index 000000000..06e940d9a
--- /dev/null
+++ b/src/net/http/response_test.go
@@ -0,0 +1,674 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bufio"
+ "bytes"
+ "compress/gzip"
+ "crypto/rand"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http/internal"
+ "net/url"
+ "reflect"
+ "regexp"
+ "strings"
+ "testing"
+)
+
+type respTest struct {
+ Raw string
+ Resp Response
+ Body string
+}
+
+func dummyReq(method string) *Request {
+ return &Request{Method: method}
+}
+
+func dummyReq11(method string) *Request {
+ return &Request{Method: method, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1}
+}
+
+var respTests = []respTest{
+ // Unchunked response without Content-Length.
+ {
+ "HTTP/1.0 200 OK\r\n" +
+ "Connection: close\r\n" +
+ "\r\n" +
+ "Body here\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Connection": {"close"}, // TODO(rsc): Delete?
+ },
+ Close: true,
+ ContentLength: -1,
+ },
+
+ "Body here\n",
+ },
+
+ // Unchunked HTTP/1.1 response without Content-Length or
+ // Connection headers.
+ {
+ "HTTP/1.1 200 OK\r\n" +
+ "\r\n" +
+ "Body here\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Request: dummyReq("GET"),
+ Close: true,
+ ContentLength: -1,
+ },
+
+ "Body here\n",
+ },
+
+ // Unchunked HTTP/1.1 204 response without Content-Length.
+ {
+ "HTTP/1.1 204 No Content\r\n" +
+ "\r\n" +
+ "Body should not be read!\n",
+
+ Response{
+ Status: "204 No Content",
+ StatusCode: 204,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Request: dummyReq("GET"),
+ Close: false,
+ ContentLength: 0,
+ },
+
+ "",
+ },
+
+ // Unchunked response with Content-Length.
+ {
+ "HTTP/1.0 200 OK\r\n" +
+ "Content-Length: 10\r\n" +
+ "Connection: close\r\n" +
+ "\r\n" +
+ "Body here\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Connection": {"close"},
+ "Content-Length": {"10"},
+ },
+ Close: true,
+ ContentLength: 10,
+ },
+
+ "Body here\n",
+ },
+
+ // Chunked response without Content-Length.
+ {
+ "HTTP/1.1 200 OK\r\n" +
+ "Transfer-Encoding: chunked\r\n" +
+ "\r\n" +
+ "0a\r\n" +
+ "Body here\n\r\n" +
+ "09\r\n" +
+ "continued\r\n" +
+ "0\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Close: false,
+ ContentLength: -1,
+ TransferEncoding: []string{"chunked"},
+ },
+
+ "Body here\ncontinued",
+ },
+
+ // Chunked response with Content-Length.
+ {
+ "HTTP/1.1 200 OK\r\n" +
+ "Transfer-Encoding: chunked\r\n" +
+ "Content-Length: 10\r\n" +
+ "\r\n" +
+ "0a\r\n" +
+ "Body here\n\r\n" +
+ "0\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Close: false,
+ ContentLength: -1,
+ TransferEncoding: []string{"chunked"},
+ },
+
+ "Body here\n",
+ },
+
+ // Chunked response in response to a HEAD request
+ {
+ "HTTP/1.1 200 OK\r\n" +
+ "Transfer-Encoding: chunked\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("HEAD"),
+ Header: Header{},
+ TransferEncoding: []string{"chunked"},
+ Close: false,
+ ContentLength: -1,
+ },
+
+ "",
+ },
+
+ // Content-Length in response to a HEAD request
+ {
+ "HTTP/1.0 200 OK\r\n" +
+ "Content-Length: 256\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("HEAD"),
+ Header: Header{"Content-Length": {"256"}},
+ TransferEncoding: nil,
+ Close: true,
+ ContentLength: 256,
+ },
+
+ "",
+ },
+
+ // Content-Length in response to a HEAD request with HTTP/1.1
+ {
+ "HTTP/1.1 200 OK\r\n" +
+ "Content-Length: 256\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("HEAD"),
+ Header: Header{"Content-Length": {"256"}},
+ TransferEncoding: nil,
+ Close: false,
+ ContentLength: 256,
+ },
+
+ "",
+ },
+
+ // No Content-Length or Chunked in response to a HEAD request
+ {
+ "HTTP/1.0 200 OK\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("HEAD"),
+ Header: Header{},
+ TransferEncoding: nil,
+ Close: true,
+ ContentLength: -1,
+ },
+
+ "",
+ },
+
+ // explicit Content-Length of 0.
+ {
+ "HTTP/1.1 200 OK\r\n" +
+ "Content-Length: 0\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Content-Length": {"0"},
+ },
+ Close: false,
+ ContentLength: 0,
+ },
+
+ "",
+ },
+
+ // Status line without a Reason-Phrase, but trailing space.
+ // (permitted by RFC 2616)
+ {
+ "HTTP/1.0 303 \r\n\r\n",
+ Response{
+ Status: "303 ",
+ StatusCode: 303,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Close: true,
+ ContentLength: -1,
+ },
+
+ "",
+ },
+
+ // Status line without a Reason-Phrase, and no trailing space.
+ // (not permitted by RFC 2616, but we'll accept it anyway)
+ {
+ "HTTP/1.0 303\r\n\r\n",
+ Response{
+ Status: "303 ",
+ StatusCode: 303,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Close: true,
+ ContentLength: -1,
+ },
+
+ "",
+ },
+
+ // golang.org/issue/4767: don't special-case multipart/byteranges responses
+ {
+ `HTTP/1.1 206 Partial Content
+Connection: close
+Content-Type: multipart/byteranges; boundary=18a75608c8f47cef
+
+some body`,
+ Response{
+ Status: "206 Partial Content",
+ StatusCode: 206,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Content-Type": []string{"multipart/byteranges; boundary=18a75608c8f47cef"},
+ },
+ Close: true,
+ ContentLength: -1,
+ },
+
+ "some body",
+ },
+
+ // Unchunked response without Content-Length, Request is nil
+ {
+ "HTTP/1.0 200 OK\r\n" +
+ "Connection: close\r\n" +
+ "\r\n" +
+ "Body here\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Header: Header{
+ "Connection": {"close"}, // TODO(rsc): Delete?
+ },
+ Close: true,
+ ContentLength: -1,
+ },
+
+ "Body here\n",
+ },
+
+ // 206 Partial Content. golang.org/issue/8923
+ {
+ "HTTP/1.1 206 Partial Content\r\n" +
+ "Content-Type: text/plain; charset=utf-8\r\n" +
+ "Accept-Ranges: bytes\r\n" +
+ "Content-Range: bytes 0-5/1862\r\n" +
+ "Content-Length: 6\r\n\r\n" +
+ "foobar",
+
+ Response{
+ Status: "206 Partial Content",
+ StatusCode: 206,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Accept-Ranges": []string{"bytes"},
+ "Content-Length": []string{"6"},
+ "Content-Type": []string{"text/plain; charset=utf-8"},
+ "Content-Range": []string{"bytes 0-5/1862"},
+ },
+ ContentLength: 6,
+ },
+
+ "foobar",
+ },
+}
+
+func TestReadResponse(t *testing.T) {
+ for i, tt := range respTests {
+ resp, err := ReadResponse(bufio.NewReader(strings.NewReader(tt.Raw)), tt.Resp.Request)
+ if err != nil {
+ t.Errorf("#%d: %v", i, err)
+ continue
+ }
+ rbody := resp.Body
+ resp.Body = nil
+ diff(t, fmt.Sprintf("#%d Response", i), resp, &tt.Resp)
+ var bout bytes.Buffer
+ if rbody != nil {
+ _, err = io.Copy(&bout, rbody)
+ if err != nil {
+ t.Errorf("#%d: %v", i, err)
+ continue
+ }
+ rbody.Close()
+ }
+ body := bout.String()
+ if body != tt.Body {
+ t.Errorf("#%d: Body = %q want %q", i, body, tt.Body)
+ }
+ }
+}
+
+func TestWriteResponse(t *testing.T) {
+ for i, tt := range respTests {
+ resp, err := ReadResponse(bufio.NewReader(strings.NewReader(tt.Raw)), tt.Resp.Request)
+ if err != nil {
+ t.Errorf("#%d: %v", i, err)
+ continue
+ }
+ err = resp.Write(ioutil.Discard)
+ if err != nil {
+ t.Errorf("#%d: %v", i, err)
+ continue
+ }
+ }
+}
+
+var readResponseCloseInMiddleTests = []struct {
+ chunked, compressed bool
+}{
+ {false, false},
+ {true, false},
+ {true, true},
+}
+
+// TestReadResponseCloseInMiddle tests that closing a body after
+// reading only part of its contents advances the read to the end of
+// the request, right up until the next request.
+func TestReadResponseCloseInMiddle(t *testing.T) {
+ for _, test := range readResponseCloseInMiddleTests {
+ fatalf := func(format string, args ...interface{}) {
+ args = append([]interface{}{test.chunked, test.compressed}, args...)
+ t.Fatalf("on test chunked=%v, compressed=%v: "+format, args...)
+ }
+ checkErr := func(err error, msg string) {
+ if err == nil {
+ return
+ }
+ fatalf(msg+": %v", err)
+ }
+ var buf bytes.Buffer
+ buf.WriteString("HTTP/1.1 200 OK\r\n")
+ if test.chunked {
+ buf.WriteString("Transfer-Encoding: chunked\r\n")
+ } else {
+ buf.WriteString("Content-Length: 1000000\r\n")
+ }
+ var wr io.Writer = &buf
+ if test.chunked {
+ wr = internal.NewChunkedWriter(wr)
+ }
+ if test.compressed {
+ buf.WriteString("Content-Encoding: gzip\r\n")
+ wr = gzip.NewWriter(wr)
+ }
+ buf.WriteString("\r\n")
+
+ chunk := bytes.Repeat([]byte{'x'}, 1000)
+ for i := 0; i < 1000; i++ {
+ if test.compressed {
+ // Otherwise this compresses too well.
+ _, err := io.ReadFull(rand.Reader, chunk)
+ checkErr(err, "rand.Reader ReadFull")
+ }
+ wr.Write(chunk)
+ }
+ if test.compressed {
+ err := wr.(*gzip.Writer).Close()
+ checkErr(err, "compressor close")
+ }
+ if test.chunked {
+ buf.WriteString("0\r\n\r\n")
+ }
+ buf.WriteString("Next Request Here")
+
+ bufr := bufio.NewReader(&buf)
+ resp, err := ReadResponse(bufr, dummyReq("GET"))
+ checkErr(err, "ReadResponse")
+ expectedLength := int64(-1)
+ if !test.chunked {
+ expectedLength = 1000000
+ }
+ if resp.ContentLength != expectedLength {
+ fatalf("expected response length %d, got %d", expectedLength, resp.ContentLength)
+ }
+ if resp.Body == nil {
+ fatalf("nil body")
+ }
+ if test.compressed {
+ gzReader, err := gzip.NewReader(resp.Body)
+ checkErr(err, "gzip.NewReader")
+ resp.Body = &readerAndCloser{gzReader, resp.Body}
+ }
+
+ rbuf := make([]byte, 2500)
+ n, err := io.ReadFull(resp.Body, rbuf)
+ checkErr(err, "2500 byte ReadFull")
+ if n != 2500 {
+ fatalf("ReadFull only read %d bytes", n)
+ }
+ if test.compressed == false && !bytes.Equal(bytes.Repeat([]byte{'x'}, 2500), rbuf) {
+ fatalf("ReadFull didn't read 2500 'x'; got %q", string(rbuf))
+ }
+ resp.Body.Close()
+
+ rest, err := ioutil.ReadAll(bufr)
+ checkErr(err, "ReadAll on remainder")
+ if e, g := "Next Request Here", string(rest); e != g {
+ g = regexp.MustCompile(`(xx+)`).ReplaceAllStringFunc(g, func(match string) string {
+ return fmt.Sprintf("x(repeated x%d)", len(match))
+ })
+ fatalf("remainder = %q, expected %q", g, e)
+ }
+ }
+}
+
+func diff(t *testing.T, prefix string, have, want interface{}) {
+ hv := reflect.ValueOf(have).Elem()
+ wv := reflect.ValueOf(want).Elem()
+ if hv.Type() != wv.Type() {
+ t.Errorf("%s: type mismatch %v want %v", prefix, hv.Type(), wv.Type())
+ }
+ for i := 0; i < hv.NumField(); i++ {
+ hf := hv.Field(i).Interface()
+ wf := wv.Field(i).Interface()
+ if !reflect.DeepEqual(hf, wf) {
+ t.Errorf("%s: %s = %v want %v", prefix, hv.Type().Field(i).Name, hf, wf)
+ }
+ }
+}
+
+type responseLocationTest struct {
+ location string // Response's Location header or ""
+ requrl string // Response.Request.URL or ""
+ want string
+ wantErr error
+}
+
+var responseLocationTests = []responseLocationTest{
+ {"/foo", "http://bar.com/baz", "http://bar.com/foo", nil},
+ {"http://foo.com/", "http://bar.com/baz", "http://foo.com/", nil},
+ {"", "http://bar.com/baz", "", ErrNoLocation},
+}
+
+func TestLocationResponse(t *testing.T) {
+ for i, tt := range responseLocationTests {
+ res := new(Response)
+ res.Header = make(Header)
+ res.Header.Set("Location", tt.location)
+ if tt.requrl != "" {
+ res.Request = &Request{}
+ var err error
+ res.Request.URL, err = url.Parse(tt.requrl)
+ if err != nil {
+ t.Fatalf("bad test URL %q: %v", tt.requrl, err)
+ }
+ }
+
+ got, err := res.Location()
+ if tt.wantErr != nil {
+ if err == nil {
+ t.Errorf("%d. err=nil; want %q", i, tt.wantErr)
+ continue
+ }
+ if g, e := err.Error(), tt.wantErr.Error(); g != e {
+ t.Errorf("%d. err=%q; want %q", i, g, e)
+ continue
+ }
+ continue
+ }
+ if err != nil {
+ t.Errorf("%d. err=%q", i, err)
+ continue
+ }
+ if g, e := got.String(), tt.want; g != e {
+ t.Errorf("%d. Location=%q; want %q", i, g, e)
+ }
+ }
+}
+
+func TestResponseStatusStutter(t *testing.T) {
+ r := &Response{
+ Status: "123 some status",
+ StatusCode: 123,
+ ProtoMajor: 1,
+ ProtoMinor: 3,
+ }
+ var buf bytes.Buffer
+ r.Write(&buf)
+ if strings.Contains(buf.String(), "123 123") {
+ t.Errorf("stutter in status: %s", buf.String())
+ }
+}
+
+func TestResponseContentLengthShortBody(t *testing.T) {
+ const shortBody = "Short body, not 123 bytes."
+ br := bufio.NewReader(strings.NewReader("HTTP/1.1 200 OK\r\n" +
+ "Content-Length: 123\r\n" +
+ "\r\n" +
+ shortBody))
+ res, err := ReadResponse(br, &Request{Method: "GET"})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.ContentLength != 123 {
+ t.Fatalf("Content-Length = %d; want 123", res.ContentLength)
+ }
+ var buf bytes.Buffer
+ n, err := io.Copy(&buf, res.Body)
+ if n != int64(len(shortBody)) {
+ t.Errorf("Copied %d bytes; want %d, len(%q)", n, len(shortBody), shortBody)
+ }
+ if buf.String() != shortBody {
+ t.Errorf("Read body %q; want %q", buf.String(), shortBody)
+ }
+ if err != io.ErrUnexpectedEOF {
+ t.Errorf("io.Copy error = %#v; want io.ErrUnexpectedEOF", err)
+ }
+}
+
+func TestReadResponseUnexpectedEOF(t *testing.T) {
+ br := bufio.NewReader(strings.NewReader("HTTP/1.1 301 Moved Permanently\r\n" +
+ "Location: http://example.com"))
+ _, err := ReadResponse(br, nil)
+ if err != io.ErrUnexpectedEOF {
+ t.Errorf("ReadResponse = %v; want io.ErrUnexpectedEOF", err)
+ }
+}
+
+func TestNeedsSniff(t *testing.T) {
+ // needsSniff returns true with an empty response.
+ r := &response{}
+ if got, want := r.needsSniff(), true; got != want {
+ t.Errorf("needsSniff = %t; want %t", got, want)
+ }
+ // needsSniff returns false when Content-Type = nil.
+ r.handlerHeader = Header{"Content-Type": nil}
+ if got, want := r.needsSniff(), false; got != want {
+ t.Errorf("needsSniff empty Content-Type = %t; want %t", got, want)
+ }
+}
diff --git a/src/net/http/responsewrite_test.go b/src/net/http/responsewrite_test.go
new file mode 100644
index 000000000..585b13b85
--- /dev/null
+++ b/src/net/http/responsewrite_test.go
@@ -0,0 +1,226 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bytes"
+ "io/ioutil"
+ "strings"
+ "testing"
+)
+
+type respWriteTest struct {
+ Resp Response
+ Raw string
+}
+
+func TestResponseWrite(t *testing.T) {
+ respWriteTests := []respWriteTest{
+ // HTTP/1.0, identity coding; no trailer
+ {
+ Response{
+ StatusCode: 503,
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Body: ioutil.NopCloser(strings.NewReader("abcdef")),
+ ContentLength: 6,
+ },
+
+ "HTTP/1.0 503 Service Unavailable\r\n" +
+ "Content-Length: 6\r\n\r\n" +
+ "abcdef",
+ },
+ // Unchunked response without Content-Length.
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Body: ioutil.NopCloser(strings.NewReader("abcdef")),
+ ContentLength: -1,
+ },
+ "HTTP/1.0 200 OK\r\n" +
+ "\r\n" +
+ "abcdef",
+ },
+ // HTTP/1.1 response with unknown length and Connection: close
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Body: ioutil.NopCloser(strings.NewReader("abcdef")),
+ ContentLength: -1,
+ Close: true,
+ },
+ "HTTP/1.1 200 OK\r\n" +
+ "Connection: close\r\n" +
+ "\r\n" +
+ "abcdef",
+ },
+ // HTTP/1.1 response with unknown length and not setting connection: close
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq11("GET"),
+ Header: Header{},
+ Body: ioutil.NopCloser(strings.NewReader("abcdef")),
+ ContentLength: -1,
+ Close: false,
+ },
+ "HTTP/1.1 200 OK\r\n" +
+ "Connection: close\r\n" +
+ "\r\n" +
+ "abcdef",
+ },
+ // HTTP/1.1 response with unknown length and not setting connection: close, but
+ // setting chunked.
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq11("GET"),
+ Header: Header{},
+ Body: ioutil.NopCloser(strings.NewReader("abcdef")),
+ ContentLength: -1,
+ TransferEncoding: []string{"chunked"},
+ Close: false,
+ },
+ "HTTP/1.1 200 OK\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ "6\r\nabcdef\r\n0\r\n\r\n",
+ },
+ // HTTP/1.1 response 0 content-length, and nil body
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq11("GET"),
+ Header: Header{},
+ Body: nil,
+ ContentLength: 0,
+ Close: false,
+ },
+ "HTTP/1.1 200 OK\r\n" +
+ "Content-Length: 0\r\n" +
+ "\r\n",
+ },
+ // HTTP/1.1 response 0 content-length, and non-nil empty body
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq11("GET"),
+ Header: Header{},
+ Body: ioutil.NopCloser(strings.NewReader("")),
+ ContentLength: 0,
+ Close: false,
+ },
+ "HTTP/1.1 200 OK\r\n" +
+ "Content-Length: 0\r\n" +
+ "\r\n",
+ },
+ // HTTP/1.1 response 0 content-length, and non-nil non-empty body
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq11("GET"),
+ Header: Header{},
+ Body: ioutil.NopCloser(strings.NewReader("foo")),
+ ContentLength: 0,
+ Close: false,
+ },
+ "HTTP/1.1 200 OK\r\n" +
+ "Connection: close\r\n" +
+ "\r\nfoo",
+ },
+ // HTTP/1.1, chunked coding; empty trailer; close
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Body: ioutil.NopCloser(strings.NewReader("abcdef")),
+ ContentLength: 6,
+ TransferEncoding: []string{"chunked"},
+ Close: true,
+ },
+
+ "HTTP/1.1 200 OK\r\n" +
+ "Connection: close\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ "6\r\nabcdef\r\n0\r\n\r\n",
+ },
+
+ // Header value with a newline character (Issue 914).
+ // Also tests removal of leading and trailing whitespace.
+ {
+ Response{
+ StatusCode: 204,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Foo": []string{" Bar\nBaz "},
+ },
+ Body: nil,
+ ContentLength: 0,
+ TransferEncoding: []string{"chunked"},
+ Close: true,
+ },
+
+ "HTTP/1.1 204 No Content\r\n" +
+ "Connection: close\r\n" +
+ "Foo: Bar Baz\r\n" +
+ "\r\n",
+ },
+
+ // Want a single Content-Length header. Fixing issue 8180 where
+ // there were two.
+ {
+ Response{
+ StatusCode: StatusOK,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: &Request{Method: "POST"},
+ Header: Header{},
+ ContentLength: 0,
+ TransferEncoding: nil,
+ Body: nil,
+ },
+ "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n",
+ },
+ }
+
+ for i := range respWriteTests {
+ tt := &respWriteTests[i]
+ var braw bytes.Buffer
+ err := tt.Resp.Write(&braw)
+ if err != nil {
+ t.Errorf("error writing #%d: %s", i, err)
+ continue
+ }
+ sraw := braw.String()
+ if sraw != tt.Raw {
+ t.Errorf("Test %d, expecting:\n%q\nGot:\n%q\n", i, tt.Raw, sraw)
+ continue
+ }
+ }
+}
diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go
new file mode 100644
index 000000000..5e0a0053c
--- /dev/null
+++ b/src/net/http/serve_test.go
@@ -0,0 +1,3094 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// End-to-end serving tests
+
+package http_test
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "math/rand"
+ "net"
+ . "net/http"
+ "net/http/httptest"
+ "net/http/httputil"
+ "net/url"
+ "os"
+ "os/exec"
+ "reflect"
+ "runtime"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "syscall"
+ "testing"
+ "time"
+)
+
+type dummyAddr string
+type oneConnListener struct {
+ conn net.Conn
+}
+
+func (l *oneConnListener) Accept() (c net.Conn, err error) {
+ c = l.conn
+ if c == nil {
+ err = io.EOF
+ return
+ }
+ err = nil
+ l.conn = nil
+ return
+}
+
+func (l *oneConnListener) Close() error {
+ return nil
+}
+
+func (l *oneConnListener) Addr() net.Addr {
+ return dummyAddr("test-address")
+}
+
+func (a dummyAddr) Network() string {
+ return string(a)
+}
+
+func (a dummyAddr) String() string {
+ return string(a)
+}
+
+type noopConn struct{}
+
+func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") }
+func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") }
+func (noopConn) SetDeadline(t time.Time) error { return nil }
+func (noopConn) SetReadDeadline(t time.Time) error { return nil }
+func (noopConn) SetWriteDeadline(t time.Time) error { return nil }
+
+type rwTestConn struct {
+ io.Reader
+ io.Writer
+ noopConn
+
+ closeFunc func() error // called if non-nil
+ closec chan bool // else, if non-nil, send value to it on close
+}
+
+func (c *rwTestConn) Close() error {
+ if c.closeFunc != nil {
+ return c.closeFunc()
+ }
+ select {
+ case c.closec <- true:
+ default:
+ }
+ return nil
+}
+
+type testConn struct {
+ readBuf bytes.Buffer
+ writeBuf bytes.Buffer
+ closec chan bool // if non-nil, send value to it on close
+ noopConn
+}
+
+func (c *testConn) Read(b []byte) (int, error) {
+ return c.readBuf.Read(b)
+}
+
+func (c *testConn) Write(b []byte) (int, error) {
+ return c.writeBuf.Write(b)
+}
+
+func (c *testConn) Close() error {
+ select {
+ case c.closec <- true:
+ default:
+ }
+ return nil
+}
+
+// reqBytes treats req as a request (with \n delimiters) and returns it with \r\n delimiters,
+// ending in \r\n\r\n
+func reqBytes(req string) []byte {
+ return []byte(strings.Replace(strings.TrimSpace(req), "\n", "\r\n", -1) + "\r\n\r\n")
+}
+
+type handlerTest struct {
+ handler Handler
+}
+
+func newHandlerTest(h Handler) handlerTest {
+ return handlerTest{h}
+}
+
+func (ht handlerTest) rawResponse(req string) string {
+ reqb := reqBytes(req)
+ var output bytes.Buffer
+ conn := &rwTestConn{
+ Reader: bytes.NewReader(reqb),
+ Writer: &output,
+ closec: make(chan bool, 1),
+ }
+ ln := &oneConnListener{conn: conn}
+ go Serve(ln, ht.handler)
+ <-conn.closec
+ return output.String()
+}
+
+func TestConsumingBodyOnNextConn(t *testing.T) {
+ conn := new(testConn)
+ for i := 0; i < 2; i++ {
+ conn.readBuf.Write([]byte(
+ "POST / HTTP/1.1\r\n" +
+ "Host: test\r\n" +
+ "Content-Length: 11\r\n" +
+ "\r\n" +
+ "foo=1&bar=1"))
+ }
+
+ reqNum := 0
+ ch := make(chan *Request)
+ servech := make(chan error)
+ listener := &oneConnListener{conn}
+ handler := func(res ResponseWriter, req *Request) {
+ reqNum++
+ ch <- req
+ }
+
+ go func() {
+ servech <- Serve(listener, HandlerFunc(handler))
+ }()
+
+ var req *Request
+ req = <-ch
+ if req == nil {
+ t.Fatal("Got nil first request.")
+ }
+ if req.Method != "POST" {
+ t.Errorf("For request #1's method, got %q; expected %q",
+ req.Method, "POST")
+ }
+
+ req = <-ch
+ if req == nil {
+ t.Fatal("Got nil first request.")
+ }
+ if req.Method != "POST" {
+ t.Errorf("For request #2's method, got %q; expected %q",
+ req.Method, "POST")
+ }
+
+ if serveerr := <-servech; serveerr != io.EOF {
+ t.Errorf("Serve returned %q; expected EOF", serveerr)
+ }
+}
+
+type stringHandler string
+
+func (s stringHandler) ServeHTTP(w ResponseWriter, r *Request) {
+ w.Header().Set("Result", string(s))
+}
+
+var handlers = []struct {
+ pattern string
+ msg string
+}{
+ {"/", "Default"},
+ {"/someDir/", "someDir"},
+ {"someHost.com/someDir/", "someHost.com/someDir"},
+}
+
+var vtests = []struct {
+ url string
+ expected string
+}{
+ {"http://localhost/someDir/apage", "someDir"},
+ {"http://localhost/otherDir/apage", "Default"},
+ {"http://someHost.com/someDir/apage", "someHost.com/someDir"},
+ {"http://otherHost.com/someDir/apage", "someDir"},
+ {"http://otherHost.com/aDir/apage", "Default"},
+ // redirections for trees
+ {"http://localhost/someDir", "/someDir/"},
+ {"http://someHost.com/someDir", "/someDir/"},
+}
+
+func TestHostHandlers(t *testing.T) {
+ defer afterTest(t)
+ mux := NewServeMux()
+ for _, h := range handlers {
+ mux.Handle(h.pattern, stringHandler(h.msg))
+ }
+ ts := httptest.NewServer(mux)
+ defer ts.Close()
+
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+ cc := httputil.NewClientConn(conn, nil)
+ for _, vt := range vtests {
+ var r *Response
+ var req Request
+ if req.URL, err = url.Parse(vt.url); err != nil {
+ t.Errorf("cannot parse url: %v", err)
+ continue
+ }
+ if err := cc.Write(&req); err != nil {
+ t.Errorf("writing request: %v", err)
+ continue
+ }
+ r, err := cc.Read(&req)
+ if err != nil {
+ t.Errorf("reading response: %v", err)
+ continue
+ }
+ switch r.StatusCode {
+ case StatusOK:
+ s := r.Header.Get("Result")
+ if s != vt.expected {
+ t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
+ }
+ case StatusMovedPermanently:
+ s := r.Header.Get("Location")
+ if s != vt.expected {
+ t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
+ }
+ default:
+ t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode)
+ }
+ }
+}
+
+var serveMuxRegister = []struct {
+ pattern string
+ h Handler
+}{
+ {"/dir/", serve(200)},
+ {"/search", serve(201)},
+ {"codesearch.google.com/search", serve(202)},
+ {"codesearch.google.com/", serve(203)},
+ {"example.com/", HandlerFunc(checkQueryStringHandler)},
+}
+
+// serve returns a handler that sends a response with the given code.
+func serve(code int) HandlerFunc {
+ return func(w ResponseWriter, r *Request) {
+ w.WriteHeader(code)
+ }
+}
+
+// checkQueryStringHandler checks if r.URL.RawQuery has the same value
+// as the URL excluding the scheme and the query string and sends 200
+// response code if it is, 500 otherwise.
+func checkQueryStringHandler(w ResponseWriter, r *Request) {
+ u := *r.URL
+ u.Scheme = "http"
+ u.Host = r.Host
+ u.RawQuery = ""
+ if "http://"+r.URL.RawQuery == u.String() {
+ w.WriteHeader(200)
+ } else {
+ w.WriteHeader(500)
+ }
+}
+
+var serveMuxTests = []struct {
+ method string
+ host string
+ path string
+ code int
+ pattern string
+}{
+ {"GET", "google.com", "/", 404, ""},
+ {"GET", "google.com", "/dir", 301, "/dir/"},
+ {"GET", "google.com", "/dir/", 200, "/dir/"},
+ {"GET", "google.com", "/dir/file", 200, "/dir/"},
+ {"GET", "google.com", "/search", 201, "/search"},
+ {"GET", "google.com", "/search/", 404, ""},
+ {"GET", "google.com", "/search/foo", 404, ""},
+ {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"},
+ {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"},
+ {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"},
+ {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"},
+ {"GET", "images.google.com", "/search", 201, "/search"},
+ {"GET", "images.google.com", "/search/", 404, ""},
+ {"GET", "images.google.com", "/search/foo", 404, ""},
+ {"GET", "google.com", "/../search", 301, "/search"},
+ {"GET", "google.com", "/dir/..", 301, ""},
+ {"GET", "google.com", "/dir/..", 301, ""},
+ {"GET", "google.com", "/dir/./file", 301, "/dir/"},
+
+ // The /foo -> /foo/ redirect applies to CONNECT requests
+ // but the path canonicalization does not.
+ {"CONNECT", "google.com", "/dir", 301, "/dir/"},
+ {"CONNECT", "google.com", "/../search", 404, ""},
+ {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
+ {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
+ {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"},
+}
+
+func TestServeMuxHandler(t *testing.T) {
+ mux := NewServeMux()
+ for _, e := range serveMuxRegister {
+ mux.Handle(e.pattern, e.h)
+ }
+
+ for _, tt := range serveMuxTests {
+ r := &Request{
+ Method: tt.method,
+ Host: tt.host,
+ URL: &url.URL{
+ Path: tt.path,
+ },
+ }
+ h, pattern := mux.Handler(r)
+ rr := httptest.NewRecorder()
+ h.ServeHTTP(rr, r)
+ if pattern != tt.pattern || rr.Code != tt.code {
+ t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern)
+ }
+ }
+}
+
+var serveMuxTests2 = []struct {
+ method string
+ host string
+ url string
+ code int
+ redirOk bool
+}{
+ {"GET", "google.com", "/", 404, false},
+ {"GET", "example.com", "/test/?example.com/test/", 200, false},
+ {"GET", "example.com", "test/?example.com/test/", 200, true},
+}
+
+// TestServeMuxHandlerRedirects tests that automatic redirects generated by
+// mux.Handler() shouldn't clear the request's query string.
+func TestServeMuxHandlerRedirects(t *testing.T) {
+ mux := NewServeMux()
+ for _, e := range serveMuxRegister {
+ mux.Handle(e.pattern, e.h)
+ }
+
+ for _, tt := range serveMuxTests2 {
+ tries := 1
+ turl := tt.url
+ for tries > 0 {
+ u, e := url.Parse(turl)
+ if e != nil {
+ t.Fatal(e)
+ }
+ r := &Request{
+ Method: tt.method,
+ Host: tt.host,
+ URL: u,
+ }
+ h, _ := mux.Handler(r)
+ rr := httptest.NewRecorder()
+ h.ServeHTTP(rr, r)
+ if rr.Code != 301 {
+ if rr.Code != tt.code {
+ t.Errorf("%s %s %s = %d, want %d", tt.method, tt.host, tt.url, rr.Code, tt.code)
+ }
+ break
+ }
+ if !tt.redirOk {
+ t.Errorf("%s %s %s, unexpected redirect", tt.method, tt.host, tt.url)
+ break
+ }
+ turl = rr.HeaderMap.Get("Location")
+ tries--
+ }
+ if tries < 0 {
+ t.Errorf("%s %s %s, too many redirects", tt.method, tt.host, tt.url)
+ }
+ }
+}
+
+// Tests for http://code.google.com/p/go/issues/detail?id=900
+func TestMuxRedirectLeadingSlashes(t *testing.T) {
+ paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"}
+ for _, path := range paths {
+ req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n")))
+ if err != nil {
+ t.Errorf("%s", err)
+ }
+ mux := NewServeMux()
+ resp := httptest.NewRecorder()
+
+ mux.ServeHTTP(resp, req)
+
+ if loc, expected := resp.Header().Get("Location"), "/foo.txt"; loc != expected {
+ t.Errorf("Expected Location header set to %q; got %q", expected, loc)
+ return
+ }
+
+ if code, expected := resp.Code, StatusMovedPermanently; code != expected {
+ t.Errorf("Expected response code of StatusMovedPermanently; got %d", code)
+ return
+ }
+ }
+}
+
+func TestServerTimeouts(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("skipping test; see http://golang.org/issue/7237")
+ }
+ defer afterTest(t)
+ reqNum := 0
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {
+ reqNum++
+ fmt.Fprintf(res, "req=%d", reqNum)
+ }))
+ ts.Config.ReadTimeout = 250 * time.Millisecond
+ ts.Config.WriteTimeout = 250 * time.Millisecond
+ ts.Start()
+ defer ts.Close()
+
+ // Hit the HTTP server successfully.
+ tr := &Transport{DisableKeepAlives: true} // they interfere with this test
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+ r, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatalf("http Get #1: %v", err)
+ }
+ got, _ := ioutil.ReadAll(r.Body)
+ expected := "req=1"
+ if string(got) != expected {
+ t.Errorf("Unexpected response for request #1; got %q; expected %q",
+ string(got), expected)
+ }
+
+ // Slow client that should timeout.
+ t1 := time.Now()
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ buf := make([]byte, 1)
+ n, err := conn.Read(buf)
+ latency := time.Since(t1)
+ if n != 0 || err != io.EOF {
+ t.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF)
+ }
+ if latency < 200*time.Millisecond /* fudge from 250 ms above */ {
+ t.Errorf("got EOF after %s, want >= %s", latency, 200*time.Millisecond)
+ }
+
+ // Hit the HTTP server successfully again, verifying that the
+ // previous slow connection didn't run our handler. (that we
+ // get "req=2", not "req=3")
+ r, err = Get(ts.URL)
+ if err != nil {
+ t.Fatalf("http Get #2: %v", err)
+ }
+ got, _ = ioutil.ReadAll(r.Body)
+ expected = "req=2"
+ if string(got) != expected {
+ t.Errorf("Get #2 got %q, want %q", string(got), expected)
+ }
+
+ if !testing.Short() {
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer conn.Close()
+ go io.Copy(ioutil.Discard, conn)
+ for i := 0; i < 5; i++ {
+ _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
+ if err != nil {
+ t.Fatalf("on write %d: %v", i, err)
+ }
+ time.Sleep(ts.Config.ReadTimeout / 2)
+ }
+ }
+}
+
+// golang.org/issue/4741 -- setting only a write timeout that triggers
+// shouldn't cause a handler to block forever on reads (next HTTP
+// request) that will never happen.
+func TestOnlyWriteTimeout(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("skipping test; see http://golang.org/issue/7237")
+ }
+ defer afterTest(t)
+ var conn net.Conn
+ var afterTimeoutErrc = make(chan error, 1)
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, req *Request) {
+ buf := make([]byte, 512<<10)
+ _, err := w.Write(buf)
+ if err != nil {
+ t.Errorf("handler Write error: %v", err)
+ return
+ }
+ conn.SetWriteDeadline(time.Now().Add(-30 * time.Second))
+ _, err = w.Write(buf)
+ afterTimeoutErrc <- err
+ }))
+ ts.Listener = trackLastConnListener{ts.Listener, &conn}
+ ts.Start()
+ defer ts.Close()
+
+ tr := &Transport{DisableKeepAlives: false}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ errc := make(chan error)
+ go func() {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ errc <- err
+ return
+ }
+ _, err = io.Copy(ioutil.Discard, res.Body)
+ errc <- err
+ }()
+ select {
+ case err := <-errc:
+ if err == nil {
+ t.Errorf("expected an error from Get request")
+ }
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout waiting for Get error")
+ }
+ if err := <-afterTimeoutErrc; err == nil {
+ t.Error("expected write error after timeout")
+ }
+}
+
+// trackLastConnListener tracks the last net.Conn that was accepted.
+type trackLastConnListener struct {
+ net.Listener
+ last *net.Conn // destination
+}
+
+func (l trackLastConnListener) Accept() (c net.Conn, err error) {
+ c, err = l.Listener.Accept()
+ *l.last = c
+ return
+}
+
+// TestIdentityResponse verifies that a handler can unset
+func TestIdentityResponse(t *testing.T) {
+ defer afterTest(t)
+ handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
+ rw.Header().Set("Content-Length", "3")
+ rw.Header().Set("Transfer-Encoding", req.FormValue("te"))
+ switch {
+ case req.FormValue("overwrite") == "1":
+ _, err := rw.Write([]byte("foo TOO LONG"))
+ if err != ErrContentLength {
+ t.Errorf("expected ErrContentLength; got %v", err)
+ }
+ case req.FormValue("underwrite") == "1":
+ rw.Header().Set("Content-Length", "500")
+ rw.Write([]byte("too short"))
+ default:
+ rw.Write([]byte("foo"))
+ }
+ })
+
+ ts := httptest.NewServer(handler)
+ defer ts.Close()
+
+ // Note: this relies on the assumption (which is true) that
+ // Get sends HTTP/1.1 or greater requests. Otherwise the
+ // server wouldn't have the choice to send back chunked
+ // responses.
+ for _, te := range []string{"", "identity"} {
+ url := ts.URL + "/?te=" + te
+ res, err := Get(url)
+ if err != nil {
+ t.Fatalf("error with Get of %s: %v", url, err)
+ }
+ if cl, expected := res.ContentLength, int64(3); cl != expected {
+ t.Errorf("for %s expected res.ContentLength of %d; got %d", url, expected, cl)
+ }
+ if cl, expected := res.Header.Get("Content-Length"), "3"; cl != expected {
+ t.Errorf("for %s expected Content-Length header of %q; got %q", url, expected, cl)
+ }
+ if tl, expected := len(res.TransferEncoding), 0; tl != expected {
+ t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)",
+ url, expected, tl, res.TransferEncoding)
+ }
+ res.Body.Close()
+ }
+
+ // Verify that ErrContentLength is returned
+ url := ts.URL + "/?overwrite=1"
+ res, err := Get(url)
+ if err != nil {
+ t.Fatalf("error with Get of %s: %v", url, err)
+ }
+ res.Body.Close()
+
+ // Verify that the connection is closed when the declared Content-Length
+ // is larger than what the handler wrote.
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("error dialing: %v", err)
+ }
+ _, err = conn.Write([]byte("GET /?underwrite=1 HTTP/1.1\r\nHost: foo\r\n\r\n"))
+ if err != nil {
+ t.Fatalf("error writing: %v", err)
+ }
+
+ // The ReadAll will hang for a failing test, so use a Timer to
+ // fail explicitly.
+ goTimeout(t, 2*time.Second, func() {
+ got, _ := ioutil.ReadAll(conn)
+ expectedSuffix := "\r\n\r\ntoo short"
+ if !strings.HasSuffix(string(got), expectedSuffix) {
+ t.Errorf("Expected output to end with %q; got response body %q",
+ expectedSuffix, string(got))
+ }
+ })
+}
+
+func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
+ defer afterTest(t)
+ s := httptest.NewServer(h)
+ defer s.Close()
+
+ conn, err := net.Dial("tcp", s.Listener.Addr().String())
+ if err != nil {
+ t.Fatal("dial error:", err)
+ }
+ defer conn.Close()
+
+ _, err = fmt.Fprint(conn, req)
+ if err != nil {
+ t.Fatal("print error:", err)
+ }
+
+ r := bufio.NewReader(conn)
+ res, err := ReadResponse(r, &Request{Method: "GET"})
+ if err != nil {
+ t.Fatal("ReadResponse error:", err)
+ }
+
+ didReadAll := make(chan bool, 1)
+ go func() {
+ select {
+ case <-time.After(5 * time.Second):
+ t.Error("body not closed after 5s")
+ return
+ case <-didReadAll:
+ }
+ }()
+
+ _, err = ioutil.ReadAll(r)
+ if err != nil {
+ t.Fatal("read error:", err)
+ }
+ didReadAll <- true
+
+ if !res.Close {
+ t.Errorf("Response.Close = false; want true")
+ }
+}
+
+// TestServeHTTP10Close verifies that HTTP/1.0 requests won't be kept alive.
+func TestServeHTTP10Close(t *testing.T) {
+ testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
+ ServeFile(w, r, "testdata/file")
+ }))
+}
+
+// TestClientCanClose verifies that clients can also force a connection to close.
+func TestClientCanClose(t *testing.T) {
+ testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
+ // Nothing.
+ }))
+}
+
+// TestHandlersCanSetConnectionClose verifies that handlers can force a connection to close,
+// even for HTTP/1.1 requests.
+func TestHandlersCanSetConnectionClose11(t *testing.T) {
+ testTCPConnectionCloses(t, "GET / HTTP/1.1\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Connection", "close")
+ }))
+}
+
+func TestHandlersCanSetConnectionClose10(t *testing.T) {
+ testTCPConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Connection", "close")
+ }))
+}
+
+func TestSetsRemoteAddr(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "%s", r.RemoteAddr)
+ }))
+ defer ts.Close()
+
+ res, err := Get(ts.URL)
+ if err != nil {
+ t.Fatalf("Get error: %v", err)
+ }
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("ReadAll error: %v", err)
+ }
+ ip := string(body)
+ if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") {
+ t.Fatalf("Expected local addr; got %q", ip)
+ }
+}
+
+func TestChunkedResponseHeaders(t *testing.T) {
+ defer afterTest(t)
+ log.SetOutput(ioutil.Discard) // is noisy otherwise
+ defer log.SetOutput(os.Stderr)
+
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted
+ w.(Flusher).Flush()
+ fmt.Fprintf(w, "I am a chunked response.")
+ }))
+ defer ts.Close()
+
+ res, err := Get(ts.URL)
+ if err != nil {
+ t.Fatalf("Get error: %v", err)
+ }
+ defer res.Body.Close()
+ if g, e := res.ContentLength, int64(-1); g != e {
+ t.Errorf("expected ContentLength of %d; got %d", e, g)
+ }
+ if g, e := res.TransferEncoding, []string{"chunked"}; !reflect.DeepEqual(g, e) {
+ t.Errorf("expected TransferEncoding of %v; got %v", e, g)
+ }
+ if _, haveCL := res.Header["Content-Length"]; haveCL {
+ t.Errorf("Unexpected Content-Length")
+ }
+}
+
+func TestIdentityResponseHeaders(t *testing.T) {
+ defer afterTest(t)
+ log.SetOutput(ioutil.Discard) // is noisy otherwise
+ defer log.SetOutput(os.Stderr)
+
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Transfer-Encoding", "identity")
+ w.(Flusher).Flush()
+ fmt.Fprintf(w, "I am an identity response.")
+ }))
+ defer ts.Close()
+
+ res, err := Get(ts.URL)
+ if err != nil {
+ t.Fatalf("Get error: %v", err)
+ }
+ defer res.Body.Close()
+
+ if g, e := res.TransferEncoding, []string(nil); !reflect.DeepEqual(g, e) {
+ t.Errorf("expected TransferEncoding of %v; got %v", e, g)
+ }
+ if _, haveCL := res.Header["Content-Length"]; haveCL {
+ t.Errorf("Unexpected Content-Length")
+ }
+ if !res.Close {
+ t.Errorf("expected Connection: close; got %v", res.Close)
+ }
+}
+
+// Test304Responses verifies that 304s don't declare that they're
+// chunking in their response headers and aren't allowed to produce
+// output.
+func Test304Responses(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.WriteHeader(StatusNotModified)
+ _, err := w.Write([]byte("illegal body"))
+ if err != ErrBodyNotAllowed {
+ t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
+ }
+ }))
+ defer ts.Close()
+ res, err := Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ if len(res.TransferEncoding) > 0 {
+ t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
+ }
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Error(err)
+ }
+ if len(body) > 0 {
+ t.Errorf("got unexpected body %q", string(body))
+ }
+}
+
+// TestHeadResponses verifies that all MIME type sniffing and Content-Length
+// counting of GET requests also happens on HEAD requests.
+func TestHeadResponses(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ _, err := w.Write([]byte("<html>"))
+ if err != nil {
+ t.Errorf("ResponseWriter.Write: %v", err)
+ }
+
+ // Also exercise the ReaderFrom path
+ _, err = io.Copy(w, strings.NewReader("789a"))
+ if err != nil {
+ t.Errorf("Copy(ResponseWriter, ...): %v", err)
+ }
+ }))
+ defer ts.Close()
+ res, err := Head(ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ if len(res.TransferEncoding) > 0 {
+ t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
+ }
+ if ct := res.Header.Get("Content-Type"); ct != "text/html; charset=utf-8" {
+ t.Errorf("Content-Type: %q; want text/html; charset=utf-8", ct)
+ }
+ if v := res.ContentLength; v != 10 {
+ t.Errorf("Content-Length: %d; want 10", v)
+ }
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Error(err)
+ }
+ if len(body) > 0 {
+ t.Errorf("got unexpected body %q", string(body))
+ }
+}
+
+func TestTLSHandshakeTimeout(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("skipping test; see http://golang.org/issue/7237")
+ }
+ defer afterTest(t)
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
+ errc := make(chanWriter, 10) // but only expecting 1
+ ts.Config.ReadTimeout = 250 * time.Millisecond
+ ts.Config.ErrorLog = log.New(errc, "", 0)
+ ts.StartTLS()
+ defer ts.Close()
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer conn.Close()
+ goTimeout(t, 10*time.Second, func() {
+ var buf [1]byte
+ n, err := conn.Read(buf[:])
+ if err == nil || n != 0 {
+ t.Errorf("Read = %d, %v; want an error and no bytes", n, err)
+ }
+ })
+ select {
+ case v := <-errc:
+ if !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") {
+ t.Errorf("expected a TLS handshake timeout error; got %q", v)
+ }
+ case <-time.After(5 * time.Second):
+ t.Errorf("timeout waiting for logged error")
+ }
+}
+
+func TestTLSServer(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.TLS != nil {
+ w.Header().Set("X-TLS-Set", "true")
+ if r.TLS.HandshakeComplete {
+ w.Header().Set("X-TLS-HandshakeComplete", "true")
+ }
+ }
+ }))
+ ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
+ defer ts.Close()
+
+ // Connect an idle TCP connection to this server before we run
+ // our real tests. This idle connection used to block forever
+ // in the TLS handshake, preventing future connections from
+ // being accepted. It may prevent future accidental blocking
+ // in newConn.
+ idleConn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer idleConn.Close()
+ goTimeout(t, 10*time.Second, func() {
+ if !strings.HasPrefix(ts.URL, "https://") {
+ t.Errorf("expected test TLS server to start with https://, got %q", ts.URL)
+ return
+ }
+ noVerifyTransport := &Transport{
+ TLSClientConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ }
+ client := &Client{Transport: noVerifyTransport}
+ res, err := client.Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ if res == nil {
+ t.Errorf("got nil Response")
+ return
+ }
+ defer res.Body.Close()
+ if res.Header.Get("X-TLS-Set") != "true" {
+ t.Errorf("expected X-TLS-Set response header")
+ return
+ }
+ if res.Header.Get("X-TLS-HandshakeComplete") != "true" {
+ t.Errorf("expected X-TLS-HandshakeComplete header")
+ }
+ })
+}
+
+type serverExpectTest struct {
+ contentLength int // of request body
+ chunked bool
+ expectation string // e.g. "100-continue"
+ readBody bool // whether handler should read the body (if false, sends StatusUnauthorized)
+ expectedResponse string // expected substring in first line of http response
+}
+
+func expectTest(contentLength int, expectation string, readBody bool, expectedResponse string) serverExpectTest {
+ return serverExpectTest{
+ contentLength: contentLength,
+ expectation: expectation,
+ readBody: readBody,
+ expectedResponse: expectedResponse,
+ }
+}
+
+var serverExpectTests = []serverExpectTest{
+ // Normal 100-continues, case-insensitive.
+ expectTest(100, "100-continue", true, "100 Continue"),
+ expectTest(100, "100-cOntInUE", true, "100 Continue"),
+
+ // No 100-continue.
+ expectTest(100, "", true, "200 OK"),
+
+ // 100-continue but requesting client to deny us,
+ // so it never reads the body.
+ expectTest(100, "100-continue", false, "401 Unauthorized"),
+ // Likewise without 100-continue:
+ expectTest(100, "", false, "401 Unauthorized"),
+
+ // Non-standard expectations are failures
+ expectTest(0, "a-pony", false, "417 Expectation Failed"),
+
+ // Expect-100 requested but no body (is apparently okay: Issue 7625)
+ expectTest(0, "100-continue", true, "200 OK"),
+ // Expect-100 requested but handler doesn't read the body
+ expectTest(0, "100-continue", false, "401 Unauthorized"),
+ // Expect-100 continue with no body, but a chunked body.
+ {
+ expectation: "100-continue",
+ readBody: true,
+ chunked: true,
+ expectedResponse: "100 Continue",
+ },
+}
+
+// Tests that the server responds to the "Expect" request header
+// correctly.
+func TestServerExpect(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ // Note using r.FormValue("readbody") because for POST
+ // requests that would read from r.Body, which we only
+ // conditionally want to do.
+ if strings.Contains(r.URL.RawQuery, "readbody=true") {
+ ioutil.ReadAll(r.Body)
+ w.Write([]byte("Hi"))
+ } else {
+ w.WriteHeader(StatusUnauthorized)
+ }
+ }))
+ defer ts.Close()
+
+ runTest := func(test serverExpectTest) {
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer conn.Close()
+
+ // Only send the body immediately if we're acting like an HTTP client
+ // that doesn't send 100-continue expectations.
+ writeBody := test.contentLength != 0 && strings.ToLower(test.expectation) != "100-continue"
+
+ go func() {
+ contentLen := fmt.Sprintf("Content-Length: %d", test.contentLength)
+ if test.chunked {
+ contentLen = "Transfer-Encoding: chunked"
+ }
+ _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+
+ "Connection: close\r\n"+
+ "%s\r\n"+
+ "Expect: %s\r\nHost: foo\r\n\r\n",
+ test.readBody, contentLen, test.expectation)
+ if err != nil {
+ t.Errorf("On test %#v, error writing request headers: %v", test, err)
+ return
+ }
+ if writeBody {
+ var targ io.WriteCloser = struct {
+ io.Writer
+ io.Closer
+ }{
+ conn,
+ ioutil.NopCloser(nil),
+ }
+ if test.chunked {
+ targ = httputil.NewChunkedWriter(conn)
+ }
+ body := strings.Repeat("A", test.contentLength)
+ _, err = fmt.Fprint(targ, body)
+ if err == nil {
+ err = targ.Close()
+ }
+ if err != nil {
+ if !test.readBody {
+ // Server likely already hung up on us.
+ // See larger comment below.
+ t.Logf("On test %#v, acceptable error writing request body: %v", test, err)
+ return
+ }
+ t.Errorf("On test %#v, error writing request body: %v", test, err)
+ }
+ }
+ }()
+ bufr := bufio.NewReader(conn)
+ line, err := bufr.ReadString('\n')
+ if err != nil {
+ if writeBody && !test.readBody {
+ // This is an acceptable failure due to a possible TCP race:
+ // We were still writing data and the server hung up on us. A TCP
+ // implementation may send a RST if our request body data was known
+ // to be lost, which may trigger our reads to fail.
+ // See RFC 1122 page 88.
+ t.Logf("On test %#v, acceptable error from ReadString: %v", test, err)
+ return
+ }
+ t.Fatalf("On test %#v, ReadString: %v", test, err)
+ }
+ if !strings.Contains(line, test.expectedResponse) {
+ t.Errorf("On test %#v, got first line = %q; want %q", test, line, test.expectedResponse)
+ }
+ }
+
+ for _, test := range serverExpectTests {
+ runTest(test)
+ }
+}
+
+// Under a ~256KB (maxPostHandlerReadBytes) threshold, the server
+// should consume client request bodies that a handler didn't read.
+func TestServerUnreadRequestBodyLittle(t *testing.T) {
+ conn := new(testConn)
+ body := strings.Repeat("x", 100<<10)
+ conn.readBuf.Write([]byte(fmt.Sprintf(
+ "POST / HTTP/1.1\r\n"+
+ "Host: test\r\n"+
+ "Content-Length: %d\r\n"+
+ "\r\n", len(body))))
+ conn.readBuf.Write([]byte(body))
+
+ done := make(chan bool)
+
+ ls := &oneConnListener{conn}
+ go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ defer close(done)
+ if conn.readBuf.Len() < len(body)/2 {
+ t.Errorf("on request, read buffer length is %d; expected about 100 KB", conn.readBuf.Len())
+ }
+ rw.WriteHeader(200)
+ rw.(Flusher).Flush()
+ if g, e := conn.readBuf.Len(), 0; g != e {
+ t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e)
+ }
+ if c := rw.Header().Get("Connection"); c != "" {
+ t.Errorf(`Connection header = %q; want ""`, c)
+ }
+ }))
+ <-done
+}
+
+// Over a ~256KB (maxPostHandlerReadBytes) threshold, the server
+// should ignore client request bodies that a handler didn't read
+// and close the connection.
+func TestServerUnreadRequestBodyLarge(t *testing.T) {
+ conn := new(testConn)
+ body := strings.Repeat("x", 1<<20)
+ conn.readBuf.Write([]byte(fmt.Sprintf(
+ "POST / HTTP/1.1\r\n"+
+ "Host: test\r\n"+
+ "Content-Length: %d\r\n"+
+ "\r\n", len(body))))
+ conn.readBuf.Write([]byte(body))
+ conn.closec = make(chan bool, 1)
+
+ ls := &oneConnListener{conn}
+ go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ if conn.readBuf.Len() < len(body)/2 {
+ t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
+ }
+ rw.WriteHeader(200)
+ rw.(Flusher).Flush()
+ if conn.readBuf.Len() < len(body)/2 {
+ t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
+ }
+ }))
+ <-conn.closec
+
+ if res := conn.writeBuf.String(); !strings.Contains(res, "Connection: close") {
+ t.Errorf("Expected a Connection: close header; got response: %s", res)
+ }
+}
+
+func TestTimeoutHandler(t *testing.T) {
+ defer afterTest(t)
+ sendHi := make(chan bool, 1)
+ writeErrors := make(chan error, 1)
+ sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
+ <-sendHi
+ _, werr := w.Write([]byte("hi"))
+ writeErrors <- werr
+ })
+ timeout := make(chan time.Time, 1) // write to this to force timeouts
+ ts := httptest.NewServer(NewTestTimeoutHandler(sayHi, timeout))
+ defer ts.Close()
+
+ // Succeed without timing out:
+ sendHi <- true
+ res, err := Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ if g, e := res.StatusCode, StatusOK; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ body, _ := ioutil.ReadAll(res.Body)
+ if g, e := string(body), "hi"; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+ if g := <-writeErrors; g != nil {
+ t.Errorf("got unexpected Write error on first request: %v", g)
+ }
+
+ // Times out:
+ timeout <- time.Time{}
+ res, err = Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ body, _ = ioutil.ReadAll(res.Body)
+ if !strings.Contains(string(body), "<title>Timeout</title>") {
+ t.Errorf("expected timeout body; got %q", string(body))
+ }
+
+ // Now make the previously-timed out handler speak again,
+ // which verifies the panic is handled:
+ sendHi <- true
+ if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
+ t.Errorf("expected Write error of %v; got %v", e, g)
+ }
+}
+
+// See issues 8209 and 8414.
+func TestTimeoutHandlerRace(t *testing.T) {
+ defer afterTest(t)
+
+ delayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
+ ms, _ := strconv.Atoi(r.URL.Path[1:])
+ if ms == 0 {
+ ms = 1
+ }
+ for i := 0; i < ms; i++ {
+ w.Write([]byte("hi"))
+ time.Sleep(time.Millisecond)
+ }
+ })
+
+ ts := httptest.NewServer(TimeoutHandler(delayHi, 20*time.Millisecond, ""))
+ defer ts.Close()
+
+ var wg sync.WaitGroup
+ gate := make(chan bool, 10)
+ n := 50
+ if testing.Short() {
+ n = 10
+ gate = make(chan bool, 3)
+ }
+ for i := 0; i < n; i++ {
+ gate <- true
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ defer func() { <-gate }()
+ res, err := Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50)))
+ if err == nil {
+ io.Copy(ioutil.Discard, res.Body)
+ res.Body.Close()
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+// See issues 8209 and 8414.
+func TestTimeoutHandlerRaceHeader(t *testing.T) {
+ defer afterTest(t)
+
+ delay204 := HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.WriteHeader(204)
+ })
+
+ ts := httptest.NewServer(TimeoutHandler(delay204, time.Nanosecond, ""))
+ defer ts.Close()
+
+ var wg sync.WaitGroup
+ gate := make(chan bool, 50)
+ n := 500
+ if testing.Short() {
+ n = 10
+ }
+ for i := 0; i < n; i++ {
+ gate <- true
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ defer func() { <-gate }()
+ res, err := Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer res.Body.Close()
+ io.Copy(ioutil.Discard, res.Body)
+ }()
+ }
+ wg.Wait()
+}
+
+// Verifies we don't path.Clean() on the wrong parts in redirects.
+func TestRedirectMunging(t *testing.T) {
+ req, _ := NewRequest("GET", "http://example.com/", nil)
+
+ resp := httptest.NewRecorder()
+ Redirect(resp, req, "/foo?next=http://bar.com/", 302)
+ if g, e := resp.Header().Get("Location"), "/foo?next=http://bar.com/"; g != e {
+ t.Errorf("Location header was %q; want %q", g, e)
+ }
+
+ resp = httptest.NewRecorder()
+ Redirect(resp, req, "http://localhost:8080/_ah/login?continue=http://localhost:8080/", 302)
+ if g, e := resp.Header().Get("Location"), "http://localhost:8080/_ah/login?continue=http://localhost:8080/"; g != e {
+ t.Errorf("Location header was %q; want %q", g, e)
+ }
+}
+
+func TestRedirectBadPath(t *testing.T) {
+ // This used to crash. It's not valid input (bad path), but it
+ // shouldn't crash.
+ rr := httptest.NewRecorder()
+ req := &Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Path: "not-empty-but-no-leading-slash", // bogus
+ },
+ }
+ Redirect(rr, req, "", 304)
+ if rr.Code != 304 {
+ t.Errorf("Code = %d; want 304", rr.Code)
+ }
+}
+
+// TestZeroLengthPostAndResponse exercises an optimization done by the Transport:
+// when there is no body (either because the method doesn't permit a body, or an
+// explicit Content-Length of zero is present), then the transport can re-use the
+// connection immediately. But when it re-uses the connection, it typically closes
+// the previous request's body, which is not optimal for zero-lengthed bodies,
+// as the client would then see http.ErrBodyReadAfterClose and not 0, io.EOF.
+func TestZeroLengthPostAndResponse(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
+ all, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ t.Fatalf("handler ReadAll: %v", err)
+ }
+ if len(all) != 0 {
+ t.Errorf("handler got %d bytes; expected 0", len(all))
+ }
+ rw.Header().Set("Content-Length", "0")
+ }))
+ defer ts.Close()
+
+ req, err := NewRequest("POST", ts.URL, strings.NewReader(""))
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.ContentLength = 0
+
+ var resp [5]*Response
+ for i := range resp {
+ resp[i], err = DefaultClient.Do(req)
+ if err != nil {
+ t.Fatalf("client post #%d: %v", i, err)
+ }
+ }
+
+ for i := range resp {
+ all, err := ioutil.ReadAll(resp[i].Body)
+ if err != nil {
+ t.Fatalf("req #%d: client ReadAll: %v", i, err)
+ }
+ if len(all) != 0 {
+ t.Errorf("req #%d: client got %d bytes; expected 0", i, len(all))
+ }
+ }
+}
+
+func TestHandlerPanicNil(t *testing.T) {
+ testHandlerPanic(t, false, nil)
+}
+
+func TestHandlerPanic(t *testing.T) {
+ testHandlerPanic(t, false, "intentional death for testing")
+}
+
+func TestHandlerPanicWithHijack(t *testing.T) {
+ testHandlerPanic(t, true, "intentional death for testing")
+}
+
+func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) {
+ defer afterTest(t)
+ // Unlike the other tests that set the log output to ioutil.Discard
+ // to quiet the output, this test uses a pipe. The pipe serves three
+ // purposes:
+ //
+ // 1) The log.Print from the http server (generated by the caught
+ // panic) will go to the pipe instead of stderr, making the
+ // output quiet.
+ //
+ // 2) We read from the pipe to verify that the handler
+ // actually caught the panic and logged something.
+ //
+ // 3) The blocking Read call prevents this TestHandlerPanic
+ // function from exiting before the HTTP server handler
+ // finishes crashing. If this text function exited too
+ // early (and its defer log.SetOutput(os.Stderr) ran),
+ // then the crash output could spill into the next test.
+ pr, pw := io.Pipe()
+ log.SetOutput(pw)
+ defer log.SetOutput(os.Stderr)
+ defer pw.Close()
+
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if withHijack {
+ rwc, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Logf("unexpected error: %v", err)
+ }
+ defer rwc.Close()
+ }
+ panic(panicValue)
+ }))
+ defer ts.Close()
+
+ // Do a blocking read on the log output pipe so its logging
+ // doesn't bleed into the next test. But wait only 5 seconds
+ // for it.
+ done := make(chan bool, 1)
+ go func() {
+ buf := make([]byte, 4<<10)
+ _, err := pr.Read(buf)
+ pr.Close()
+ if err != nil && err != io.EOF {
+ t.Error(err)
+ }
+ done <- true
+ }()
+
+ _, err := Get(ts.URL)
+ if err == nil {
+ t.Logf("expected an error")
+ }
+
+ if panicValue == nil {
+ return
+ }
+
+ select {
+ case <-done:
+ return
+ case <-time.After(5 * time.Second):
+ t.Fatal("expected server handler to log an error")
+ }
+}
+
+func TestNoDate(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header()["Date"] = nil
+ }))
+ defer ts.Close()
+ res, err := Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, present := res.Header["Date"]
+ if present {
+ t.Fatalf("Expected no Date header; got %v", res.Header["Date"])
+ }
+}
+
+func TestStripPrefix(t *testing.T) {
+ defer afterTest(t)
+ h := HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("X-Path", r.URL.Path)
+ })
+ ts := httptest.NewServer(StripPrefix("/foo", h))
+ defer ts.Close()
+
+ res, err := Get(ts.URL + "/foo/bar")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if g, e := res.Header.Get("X-Path"), "/bar"; g != e {
+ t.Errorf("test 1: got %s, want %s", g, e)
+ }
+ res.Body.Close()
+
+ res, err = Get(ts.URL + "/bar")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if g, e := res.StatusCode, 404; g != e {
+ t.Errorf("test 2: got status %v, want %v", g, e)
+ }
+ res.Body.Close()
+}
+
+func TestRequestLimit(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ t.Fatalf("didn't expect to get request in Handler")
+ }))
+ defer ts.Close()
+ req, _ := NewRequest("GET", ts.URL, nil)
+ var bytesPerHeader = len("header12345: val12345\r\n")
+ for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ {
+ req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i))
+ }
+ res, err := DefaultClient.Do(req)
+ if err != nil {
+ // Some HTTP clients may fail on this undefined behavior (server replying and
+ // closing the connection while the request is still being written), but
+ // we do support it (at least currently), so we expect a response below.
+ t.Fatalf("Do: %v", err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != 413 {
+ t.Fatalf("expected 413 response status; got: %d %s", res.StatusCode, res.Status)
+ }
+}
+
+type neverEnding byte
+
+func (b neverEnding) Read(p []byte) (n int, err error) {
+ for i := range p {
+ p[i] = byte(b)
+ }
+ return len(p), nil
+}
+
+type countReader struct {
+ r io.Reader
+ n *int64
+}
+
+func (cr countReader) Read(p []byte) (n int, err error) {
+ n, err = cr.r.Read(p)
+ atomic.AddInt64(cr.n, int64(n))
+ return
+}
+
+func TestRequestBodyLimit(t *testing.T) {
+ defer afterTest(t)
+ const limit = 1 << 20
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ r.Body = MaxBytesReader(w, r.Body, limit)
+ n, err := io.Copy(ioutil.Discard, r.Body)
+ if err == nil {
+ t.Errorf("expected error from io.Copy")
+ }
+ if n != limit {
+ t.Errorf("io.Copy = %d, want %d", n, limit)
+ }
+ }))
+ defer ts.Close()
+
+ nWritten := new(int64)
+ req, _ := NewRequest("POST", ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200))
+
+ // Send the POST, but don't care it succeeds or not. The
+ // remote side is going to reply and then close the TCP
+ // connection, and HTTP doesn't really define if that's
+ // allowed or not. Some HTTP clients will get the response
+ // and some (like ours, currently) will complain that the
+ // request write failed, without reading the response.
+ //
+ // But that's okay, since what we're really testing is that
+ // the remote side hung up on us before we wrote too much.
+ _, _ = DefaultClient.Do(req)
+
+ if atomic.LoadInt64(nWritten) > limit*100 {
+ t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d",
+ limit, nWritten)
+ }
+}
+
+// TestClientWriteShutdown tests that if the client shuts down the write
+// side of their TCP connection, the server doesn't send a 400 Bad Request.
+func TestClientWriteShutdown(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("skipping test; see http://golang.org/issue/7237")
+ }
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
+ defer ts.Close()
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ err = conn.(*net.TCPConn).CloseWrite()
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ donec := make(chan bool)
+ go func() {
+ defer close(donec)
+ bs, err := ioutil.ReadAll(conn)
+ if err != nil {
+ t.Fatalf("ReadAll: %v", err)
+ }
+ got := string(bs)
+ if got != "" {
+ t.Errorf("read %q from server; want nothing", got)
+ }
+ }()
+ select {
+ case <-donec:
+ case <-time.After(10 * time.Second):
+ t.Fatalf("timeout")
+ }
+}
+
+// Tests that chunked server responses that write 1 byte at a time are
+// buffered before chunk headers are added, not after chunk headers.
+func TestServerBufferedChunking(t *testing.T) {
+ conn := new(testConn)
+ conn.readBuf.Write([]byte("GET / HTTP/1.1\r\n\r\n"))
+ conn.closec = make(chan bool, 1)
+ ls := &oneConnListener{conn}
+ go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ rw.(Flusher).Flush() // force the Header to be sent, in chunking mode, not counting the length
+ rw.Write([]byte{'x'})
+ rw.Write([]byte{'y'})
+ rw.Write([]byte{'z'})
+ }))
+ <-conn.closec
+ if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) {
+ t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q",
+ conn.writeBuf.Bytes())
+ }
+}
+
+// Tests that the server flushes its response headers out when it's
+// ignoring the response body and waits a bit before forcefully
+// closing the TCP connection, causing the client to get a RST.
+// See http://golang.org/issue/3595
+func TestServerGracefulClose(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ Error(w, "bye", StatusUnauthorized)
+ }))
+ defer ts.Close()
+
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+ const bodySize = 5 << 20
+ req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize))
+ for i := 0; i < bodySize; i++ {
+ req = append(req, 'x')
+ }
+ writeErr := make(chan error)
+ go func() {
+ _, err := conn.Write(req)
+ writeErr <- err
+ }()
+ br := bufio.NewReader(conn)
+ lineNum := 0
+ for {
+ line, err := br.ReadString('\n')
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ t.Fatalf("ReadLine: %v", err)
+ }
+ lineNum++
+ if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") {
+ t.Errorf("Response line = %q; want a 401", line)
+ }
+ }
+ // Wait for write to finish. This is a broken pipe on both
+ // Darwin and Linux, but checking this isn't the point of
+ // the test.
+ <-writeErr
+}
+
+func TestCaseSensitiveMethod(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method != "get" {
+ t.Errorf(`Got method %q; want "get"`, r.Method)
+ }
+ }))
+ defer ts.Close()
+ req, _ := NewRequest("get", ts.URL, nil)
+ res, err := DefaultClient.Do(req)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ res.Body.Close()
+}
+
+// TestContentLengthZero tests that for both an HTTP/1.0 and HTTP/1.1
+// request (both keep-alive), when a Handler never writes any
+// response, the net/http package adds a "Content-Length: 0" response
+// header.
+func TestContentLengthZero(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {}))
+ defer ts.Close()
+
+ for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} {
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("error dialing: %v", err)
+ }
+ _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version)
+ if err != nil {
+ t.Fatalf("error writing: %v", err)
+ }
+ req, _ := NewRequest("GET", "/", nil)
+ res, err := ReadResponse(bufio.NewReader(conn), req)
+ if err != nil {
+ t.Fatalf("error reading response: %v", err)
+ }
+ if te := res.TransferEncoding; len(te) > 0 {
+ t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te)
+ }
+ if cl := res.ContentLength; cl != 0 {
+ t.Errorf("For version %q, Content-Length = %v; want 0", version, cl)
+ }
+ conn.Close()
+ }
+}
+
+func TestCloseNotifier(t *testing.T) {
+ defer afterTest(t)
+ gotReq := make(chan bool, 1)
+ sawClose := make(chan bool, 1)
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ gotReq <- true
+ cc := rw.(CloseNotifier).CloseNotify()
+ <-cc
+ sawClose <- true
+ }))
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("error dialing: %v", err)
+ }
+ diec := make(chan bool)
+ go func() {
+ _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
+ if err != nil {
+ t.Fatal(err)
+ }
+ <-diec
+ conn.Close()
+ }()
+For:
+ for {
+ select {
+ case <-gotReq:
+ diec <- true
+ case <-sawClose:
+ break For
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout")
+ }
+ }
+ ts.Close()
+}
+
+func TestCloseNotifierChanLeak(t *testing.T) {
+ defer afterTest(t)
+ req := reqBytes("GET / HTTP/1.0\nHost: golang.org")
+ for i := 0; i < 20; i++ {
+ var output bytes.Buffer
+ conn := &rwTestConn{
+ Reader: bytes.NewReader(req),
+ Writer: &output,
+ closec: make(chan bool, 1),
+ }
+ ln := &oneConnListener{conn: conn}
+ handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
+ // Ignore the return value and never read from
+ // it, testing that we don't leak goroutines
+ // on the sending side:
+ _ = rw.(CloseNotifier).CloseNotify()
+ })
+ go Serve(ln, handler)
+ <-conn.closec
+ }
+}
+
+func TestOptions(t *testing.T) {
+ uric := make(chan string, 2) // only expect 1, but leave space for 2
+ mux := NewServeMux()
+ mux.HandleFunc("/", func(w ResponseWriter, r *Request) {
+ uric <- r.RequestURI
+ })
+ ts := httptest.NewServer(mux)
+ defer ts.Close()
+
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+
+ // An OPTIONS * request should succeed.
+ _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ br := bufio.NewReader(conn)
+ res, err := ReadResponse(br, &Request{Method: "OPTIONS"})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 200 {
+ t.Errorf("Got non-200 response to OPTIONS *: %#v", res)
+ }
+
+ // A GET * request on a ServeMux should fail.
+ _, err = conn.Write([]byte("GET * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ res, err = ReadResponse(br, &Request{Method: "GET"})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 400 {
+ t.Errorf("Got non-400 response to GET *: %#v", res)
+ }
+
+ res, err = Get(ts.URL + "/second")
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if got := <-uric; got != "/second" {
+ t.Errorf("Handler saw request for %q; want /second", got)
+ }
+}
+
+// Tests regarding the ordering of Write, WriteHeader, Header, and
+// Flush calls. In Go 1.0, rw.WriteHeader immediately flushed the
+// (*response).header to the wire. In Go 1.1, the actual wire flush is
+// delayed, so we could maybe tack on a Content-Length and better
+// Content-Type after we see more (or all) of the output. To preserve
+// compatibility with Go 1, we need to be careful to track which
+// headers were live at the time of WriteHeader, so we write the same
+// ones, even if the handler modifies them (~erroneously) after the
+// first Write.
+func TestHeaderToWire(t *testing.T) {
+ tests := []struct {
+ name string
+ handler func(ResponseWriter, *Request)
+ check func(output string) error
+ }{
+ {
+ name: "write without Header",
+ handler: func(rw ResponseWriter, r *Request) {
+ rw.Write([]byte("hello world"))
+ },
+ check: func(got string) error {
+ if !strings.Contains(got, "Content-Length:") {
+ return errors.New("no content-length")
+ }
+ if !strings.Contains(got, "Content-Type: text/plain") {
+ return errors.New("no content-length")
+ }
+ return nil
+ },
+ },
+ {
+ name: "Header mutation before write",
+ handler: func(rw ResponseWriter, r *Request) {
+ h := rw.Header()
+ h.Set("Content-Type", "some/type")
+ rw.Write([]byte("hello world"))
+ h.Set("Too-Late", "bogus")
+ },
+ check: func(got string) error {
+ if !strings.Contains(got, "Content-Length:") {
+ return errors.New("no content-length")
+ }
+ if !strings.Contains(got, "Content-Type: some/type") {
+ return errors.New("wrong content-type")
+ }
+ if strings.Contains(got, "Too-Late") {
+ return errors.New("don't want too-late header")
+ }
+ return nil
+ },
+ },
+ {
+ name: "write then useless Header mutation",
+ handler: func(rw ResponseWriter, r *Request) {
+ rw.Write([]byte("hello world"))
+ rw.Header().Set("Too-Late", "Write already wrote headers")
+ },
+ check: func(got string) error {
+ if strings.Contains(got, "Too-Late") {
+ return errors.New("header appeared from after WriteHeader")
+ }
+ return nil
+ },
+ },
+ {
+ name: "flush then write",
+ handler: func(rw ResponseWriter, r *Request) {
+ rw.(Flusher).Flush()
+ rw.Write([]byte("post-flush"))
+ rw.Header().Set("Too-Late", "Write already wrote headers")
+ },
+ check: func(got string) error {
+ if !strings.Contains(got, "Transfer-Encoding: chunked") {
+ return errors.New("not chunked")
+ }
+ if strings.Contains(got, "Too-Late") {
+ return errors.New("header appeared from after WriteHeader")
+ }
+ return nil
+ },
+ },
+ {
+ name: "header then flush",
+ handler: func(rw ResponseWriter, r *Request) {
+ rw.Header().Set("Content-Type", "some/type")
+ rw.(Flusher).Flush()
+ rw.Write([]byte("post-flush"))
+ rw.Header().Set("Too-Late", "Write already wrote headers")
+ },
+ check: func(got string) error {
+ if !strings.Contains(got, "Transfer-Encoding: chunked") {
+ return errors.New("not chunked")
+ }
+ if strings.Contains(got, "Too-Late") {
+ return errors.New("header appeared from after WriteHeader")
+ }
+ if !strings.Contains(got, "Content-Type: some/type") {
+ return errors.New("wrong content-length")
+ }
+ return nil
+ },
+ },
+ {
+ name: "sniff-on-first-write content-type",
+ handler: func(rw ResponseWriter, r *Request) {
+ rw.Write([]byte("<html><head></head><body>some html</body></html>"))
+ rw.Header().Set("Content-Type", "x/wrong")
+ },
+ check: func(got string) error {
+ if !strings.Contains(got, "Content-Type: text/html") {
+ return errors.New("wrong content-length; want html")
+ }
+ return nil
+ },
+ },
+ {
+ name: "explicit content-type wins",
+ handler: func(rw ResponseWriter, r *Request) {
+ rw.Header().Set("Content-Type", "some/type")
+ rw.Write([]byte("<html><head></head><body>some html</body></html>"))
+ },
+ check: func(got string) error {
+ if !strings.Contains(got, "Content-Type: some/type") {
+ return errors.New("wrong content-length; want html")
+ }
+ return nil
+ },
+ },
+ {
+ name: "empty handler",
+ handler: func(rw ResponseWriter, r *Request) {
+ },
+ check: func(got string) error {
+ if !strings.Contains(got, "Content-Type: text/plain") {
+ return errors.New("wrong content-length; want text/plain")
+ }
+ if !strings.Contains(got, "Content-Length: 0") {
+ return errors.New("want 0 content-length")
+ }
+ return nil
+ },
+ },
+ {
+ name: "only Header, no write",
+ handler: func(rw ResponseWriter, r *Request) {
+ rw.Header().Set("Some-Header", "some-value")
+ },
+ check: func(got string) error {
+ if !strings.Contains(got, "Some-Header") {
+ return errors.New("didn't get header")
+ }
+ return nil
+ },
+ },
+ {
+ name: "WriteHeader call",
+ handler: func(rw ResponseWriter, r *Request) {
+ rw.WriteHeader(404)
+ rw.Header().Set("Too-Late", "some-value")
+ },
+ check: func(got string) error {
+ if !strings.Contains(got, "404") {
+ return errors.New("wrong status")
+ }
+ if strings.Contains(got, "Some-Header") {
+ return errors.New("shouldn't have seen Too-Late")
+ }
+ return nil
+ },
+ },
+ }
+ for _, tc := range tests {
+ ht := newHandlerTest(HandlerFunc(tc.handler))
+ got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
+ if err := tc.check(got); err != nil {
+ t.Errorf("%s: %v\nGot response:\n%s", tc.name, err, got)
+ }
+ }
+}
+
+// goTimeout runs f, failing t if f takes more than ns to complete.
+func goTimeout(t *testing.T, d time.Duration, f func()) {
+ ch := make(chan bool, 2)
+ timer := time.AfterFunc(d, func() {
+ t.Errorf("Timeout expired after %v", d)
+ ch <- true
+ })
+ defer timer.Stop()
+ go func() {
+ defer func() { ch <- true }()
+ f()
+ }()
+ <-ch
+}
+
+type errorListener struct {
+ errs []error
+}
+
+func (l *errorListener) Accept() (c net.Conn, err error) {
+ if len(l.errs) == 0 {
+ return nil, io.EOF
+ }
+ err = l.errs[0]
+ l.errs = l.errs[1:]
+ return
+}
+
+func (l *errorListener) Close() error {
+ return nil
+}
+
+func (l *errorListener) Addr() net.Addr {
+ return dummyAddr("test-address")
+}
+
+func TestAcceptMaxFds(t *testing.T) {
+ log.SetOutput(ioutil.Discard) // is noisy otherwise
+ defer log.SetOutput(os.Stderr)
+
+ ln := &errorListener{[]error{
+ &net.OpError{
+ Op: "accept",
+ Err: syscall.EMFILE,
+ }}}
+ err := Serve(ln, HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})))
+ if err != io.EOF {
+ t.Errorf("got error %v, want EOF", err)
+ }
+}
+
+func TestWriteAfterHijack(t *testing.T) {
+ req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
+ var buf bytes.Buffer
+ wrotec := make(chan bool, 1)
+ conn := &rwTestConn{
+ Reader: bytes.NewReader(req),
+ Writer: &buf,
+ closec: make(chan bool, 1),
+ }
+ handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
+ conn, bufrw, err := rw.(Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ go func() {
+ bufrw.Write([]byte("[hijack-to-bufw]"))
+ bufrw.Flush()
+ conn.Write([]byte("[hijack-to-conn]"))
+ conn.Close()
+ wrotec <- true
+ }()
+ })
+ ln := &oneConnListener{conn: conn}
+ go Serve(ln, handler)
+ <-conn.closec
+ <-wrotec
+ if g, w := buf.String(), "[hijack-to-bufw][hijack-to-conn]"; g != w {
+ t.Errorf("wrote %q; want %q", g, w)
+ }
+}
+
+func TestDoubleHijack(t *testing.T) {
+ req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
+ var buf bytes.Buffer
+ conn := &rwTestConn{
+ Reader: bytes.NewReader(req),
+ Writer: &buf,
+ closec: make(chan bool, 1),
+ }
+ handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
+ conn, _, err := rw.(Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ _, _, err = rw.(Hijacker).Hijack()
+ if err == nil {
+ t.Errorf("got err = nil; want err != nil")
+ }
+ conn.Close()
+ })
+ ln := &oneConnListener{conn: conn}
+ go Serve(ln, handler)
+ <-conn.closec
+}
+
+// http://code.google.com/p/go/issues/detail?id=5955
+// Note that this does not test the "request too large"
+// exit path from the http server. This is intentional;
+// not sending Connection: close is just a minor wire
+// optimization and is pointless if dealing with a
+// badly behaved client.
+func TestHTTP10ConnectionHeader(t *testing.T) {
+ defer afterTest(t)
+
+ mux := NewServeMux()
+ mux.Handle("/", HandlerFunc(func(resp ResponseWriter, req *Request) {}))
+ ts := httptest.NewServer(mux)
+ defer ts.Close()
+
+ // net/http uses HTTP/1.1 for requests, so write requests manually
+ tests := []struct {
+ req string // raw http request
+ expect []string // expected Connection header(s)
+ }{
+ {
+ req: "GET / HTTP/1.0\r\n\r\n",
+ expect: nil,
+ },
+ {
+ req: "OPTIONS * HTTP/1.0\r\n\r\n",
+ expect: nil,
+ },
+ {
+ req: "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n",
+ expect: []string{"keep-alive"},
+ },
+ }
+
+ for _, tt := range tests {
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal("dial err:", err)
+ }
+
+ _, err = fmt.Fprint(conn, tt.req)
+ if err != nil {
+ t.Fatal("conn write err:", err)
+ }
+
+ resp, err := ReadResponse(bufio.NewReader(conn), &Request{Method: "GET"})
+ if err != nil {
+ t.Fatal("ReadResponse err:", err)
+ }
+ conn.Close()
+ resp.Body.Close()
+
+ got := resp.Header["Connection"]
+ if !reflect.DeepEqual(got, tt.expect) {
+ t.Errorf("wrong Connection headers for request %q. Got %q expect %q", tt.req, got, tt.expect)
+ }
+ }
+}
+
+// See golang.org/issue/5660
+func TestServerReaderFromOrder(t *testing.T) {
+ defer afterTest(t)
+ pr, pw := io.Pipe()
+ const size = 3 << 20
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ rw.Header().Set("Content-Type", "text/plain") // prevent sniffing path
+ done := make(chan bool)
+ go func() {
+ io.Copy(rw, pr)
+ close(done)
+ }()
+ time.Sleep(25 * time.Millisecond) // give Copy a chance to break things
+ n, err := io.Copy(ioutil.Discard, req.Body)
+ if err != nil {
+ t.Errorf("handler Copy: %v", err)
+ return
+ }
+ if n != size {
+ t.Errorf("handler Copy = %d; want %d", n, size)
+ }
+ pw.Write([]byte("hi"))
+ pw.Close()
+ <-done
+ }))
+ defer ts.Close()
+
+ req, err := NewRequest("POST", ts.URL, io.LimitReader(neverEnding('a'), size))
+ if err != nil {
+ t.Fatal(err)
+ }
+ res, err := DefaultClient.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ all, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if string(all) != "hi" {
+ t.Errorf("Body = %q; want hi", all)
+ }
+}
+
+// Issue 6157, Issue 6685
+func TestCodesPreventingContentTypeAndBody(t *testing.T) {
+ for _, code := range []int{StatusNotModified, StatusNoContent, StatusContinue} {
+ ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.URL.Path == "/header" {
+ w.Header().Set("Content-Length", "123")
+ }
+ w.WriteHeader(code)
+ if r.URL.Path == "/more" {
+ w.Write([]byte("stuff"))
+ }
+ }))
+ for _, req := range []string{
+ "GET / HTTP/1.0",
+ "GET /header HTTP/1.0",
+ "GET /more HTTP/1.0",
+ "GET / HTTP/1.1",
+ "GET /header HTTP/1.1",
+ "GET /more HTTP/1.1",
+ } {
+ got := ht.rawResponse(req)
+ wantStatus := fmt.Sprintf("%d %s", code, StatusText(code))
+ if !strings.Contains(got, wantStatus) {
+ t.Errorf("Code %d: Wanted %q Modified for %q: %s", code, wantStatus, req, got)
+ } else if strings.Contains(got, "Content-Length") {
+ t.Errorf("Code %d: Got a Content-Length from %q: %s", code, req, got)
+ } else if strings.Contains(got, "stuff") {
+ t.Errorf("Code %d: Response contains a body from %q: %s", code, req, got)
+ }
+ }
+ }
+}
+
+func TestContentTypeOkayOn204(t *testing.T) {
+ ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "123") // suppressed
+ w.Header().Set("Content-Type", "foo/bar")
+ w.WriteHeader(204)
+ }))
+ got := ht.rawResponse("GET / HTTP/1.1")
+ if !strings.Contains(got, "Content-Type: foo/bar") {
+ t.Errorf("Response = %q; want Content-Type: foo/bar", got)
+ }
+ if strings.Contains(got, "Content-Length: 123") {
+ t.Errorf("Response = %q; don't want a Content-Length", got)
+ }
+}
+
+// Issue 6995
+// A server Handler can receive a Request, and then turn around and
+// give a copy of that Request.Body out to the Transport (e.g. any
+// proxy). So then two people own that Request.Body (both the server
+// and the http client), and both think they can close it on failure.
+// Therefore, all incoming server requests Bodies need to be thread-safe.
+func TestTransportAndServerSharedBodyRace(t *testing.T) {
+ defer afterTest(t)
+
+ const bodySize = 1 << 20
+
+ unblockBackend := make(chan bool)
+ backend := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ io.CopyN(rw, req.Body, bodySize/2)
+ <-unblockBackend
+ }))
+ defer backend.Close()
+
+ backendRespc := make(chan *Response, 1)
+ proxy := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ if req.RequestURI == "/foo" {
+ rw.Write([]byte("bar"))
+ return
+ }
+ req2, _ := NewRequest("POST", backend.URL, req.Body)
+ req2.ContentLength = bodySize
+
+ bresp, err := DefaultClient.Do(req2)
+ if err != nil {
+ t.Errorf("Proxy outbound request: %v", err)
+ return
+ }
+ _, err = io.CopyN(ioutil.Discard, bresp.Body, bodySize/4)
+ if err != nil {
+ t.Errorf("Proxy copy error: %v", err)
+ return
+ }
+ backendRespc <- bresp // to close later
+
+ // Try to cause a race: Both the DefaultTransport and the proxy handler's Server
+ // will try to read/close req.Body (aka req2.Body)
+ DefaultTransport.(*Transport).CancelRequest(req2)
+ rw.Write([]byte("OK"))
+ }))
+ defer proxy.Close()
+
+ req, _ := NewRequest("POST", proxy.URL, io.LimitReader(neverEnding('a'), bodySize))
+ res, err := DefaultClient.Do(req)
+ if err != nil {
+ t.Fatalf("Original request: %v", err)
+ }
+
+ // Cleanup, so we don't leak goroutines.
+ res.Body.Close()
+ close(unblockBackend)
+ (<-backendRespc).Body.Close()
+}
+
+// Test that a hanging Request.Body.Read from another goroutine can't
+// cause the Handler goroutine's Request.Body.Close to block.
+func TestRequestBodyCloseDoesntBlock(t *testing.T) {
+ t.Skipf("Skipping known issue; see golang.org/issue/7121")
+ if testing.Short() {
+ t.Skip("skipping in -short mode")
+ }
+ defer afterTest(t)
+
+ readErrCh := make(chan error, 1)
+ errCh := make(chan error, 2)
+
+ server := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ go func(body io.Reader) {
+ _, err := body.Read(make([]byte, 100))
+ readErrCh <- err
+ }(req.Body)
+ time.Sleep(500 * time.Millisecond)
+ }))
+ defer server.Close()
+
+ closeConn := make(chan bool)
+ defer close(closeConn)
+ go func() {
+ conn, err := net.Dial("tcp", server.Listener.Addr().String())
+ if err != nil {
+ errCh <- err
+ return
+ }
+ defer conn.Close()
+ _, err = conn.Write([]byte("POST / HTTP/1.1\r\nConnection: close\r\nHost: foo\r\nContent-Length: 100000\r\n\r\n"))
+ if err != nil {
+ errCh <- err
+ return
+ }
+ // And now just block, making the server block on our
+ // 100000 bytes of body that will never arrive.
+ <-closeConn
+ }()
+ select {
+ case err := <-readErrCh:
+ if err == nil {
+ t.Error("Read was nil. Expected error.")
+ }
+ case err := <-errCh:
+ t.Error(err)
+ case <-time.After(5 * time.Second):
+ t.Error("timeout")
+ }
+}
+
+func TestResponseWriterWriteStringAllocs(t *testing.T) {
+ ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.URL.Path == "/s" {
+ io.WriteString(w, "Hello world")
+ } else {
+ w.Write([]byte("Hello world"))
+ }
+ }))
+ before := testing.AllocsPerRun(50, func() { ht.rawResponse("GET / HTTP/1.0") })
+ after := testing.AllocsPerRun(50, func() { ht.rawResponse("GET /s HTTP/1.0") })
+ if int(after) >= int(before) {
+ t.Errorf("WriteString allocs of %v >= Write allocs of %v", after, before)
+ }
+}
+
+func TestAppendTime(t *testing.T) {
+ var b [len(TimeFormat)]byte
+ t1 := time.Date(2013, 9, 21, 15, 41, 0, 0, time.FixedZone("CEST", 2*60*60))
+ res := ExportAppendTime(b[:0], t1)
+ t2, err := ParseTime(string(res))
+ if err != nil {
+ t.Fatalf("Error parsing time: %s", err)
+ }
+ if !t1.Equal(t2) {
+ t.Fatalf("Times differ; expected: %v, got %v (%s)", t1, t2, string(res))
+ }
+}
+
+func TestServerConnState(t *testing.T) {
+ defer afterTest(t)
+ handler := map[string]func(w ResponseWriter, r *Request){
+ "/": func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "Hello.")
+ },
+ "/close": func(w ResponseWriter, r *Request) {
+ w.Header().Set("Connection", "close")
+ fmt.Fprintf(w, "Hello.")
+ },
+ "/hijack": func(w ResponseWriter, r *Request) {
+ c, _, _ := w.(Hijacker).Hijack()
+ c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
+ c.Close()
+ },
+ "/hijack-panic": func(w ResponseWriter, r *Request) {
+ c, _, _ := w.(Hijacker).Hijack()
+ c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
+ c.Close()
+ panic("intentional panic")
+ },
+ }
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ handler[r.URL.Path](w, r)
+ }))
+ defer ts.Close()
+
+ var mu sync.Mutex // guard stateLog and connID
+ var stateLog = map[int][]ConnState{}
+ var connID = map[net.Conn]int{}
+
+ ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
+ ts.Config.ConnState = func(c net.Conn, state ConnState) {
+ if c == nil {
+ t.Errorf("nil conn seen in state %s", state)
+ return
+ }
+ mu.Lock()
+ defer mu.Unlock()
+ id, ok := connID[c]
+ if !ok {
+ id = len(connID) + 1
+ connID[c] = id
+ }
+ stateLog[id] = append(stateLog[id], state)
+ }
+ ts.Start()
+
+ mustGet(t, ts.URL+"/")
+ mustGet(t, ts.URL+"/close")
+
+ mustGet(t, ts.URL+"/")
+ mustGet(t, ts.URL+"/", "Connection", "close")
+
+ mustGet(t, ts.URL+"/hijack")
+ mustGet(t, ts.URL+"/hijack-panic")
+
+ // New->Closed
+ {
+ c, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.Close()
+ }
+
+ // New->Active->Closed
+ {
+ c, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil {
+ t.Fatal(err)
+ }
+ c.Close()
+ }
+
+ // New->Idle->Closed
+ {
+ c, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
+ t.Fatal(err)
+ }
+ res, err := ReadResponse(bufio.NewReader(c), nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := io.Copy(ioutil.Discard, res.Body); err != nil {
+ t.Fatal(err)
+ }
+ c.Close()
+ }
+
+ want := map[int][]ConnState{
+ 1: {StateNew, StateActive, StateIdle, StateActive, StateClosed},
+ 2: {StateNew, StateActive, StateIdle, StateActive, StateClosed},
+ 3: {StateNew, StateActive, StateHijacked},
+ 4: {StateNew, StateActive, StateHijacked},
+ 5: {StateNew, StateClosed},
+ 6: {StateNew, StateActive, StateClosed},
+ 7: {StateNew, StateActive, StateIdle, StateClosed},
+ }
+ logString := func(m map[int][]ConnState) string {
+ var b bytes.Buffer
+ for id, l := range m {
+ fmt.Fprintf(&b, "Conn %d: ", id)
+ for _, s := range l {
+ fmt.Fprintf(&b, "%s ", s)
+ }
+ b.WriteString("\n")
+ }
+ return b.String()
+ }
+
+ for i := 0; i < 5; i++ {
+ time.Sleep(time.Duration(i) * 50 * time.Millisecond)
+ mu.Lock()
+ match := reflect.DeepEqual(stateLog, want)
+ mu.Unlock()
+ if match {
+ return
+ }
+ }
+
+ mu.Lock()
+ t.Errorf("Unexpected events.\nGot log: %s\n Want: %s\n", logString(stateLog), logString(want))
+ mu.Unlock()
+}
+
+func mustGet(t *testing.T, url string, headers ...string) {
+ req, err := NewRequest("GET", url, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for len(headers) > 0 {
+ req.Header.Add(headers[0], headers[1])
+ headers = headers[2:]
+ }
+ res, err := DefaultClient.Do(req)
+ if err != nil {
+ t.Errorf("Error fetching %s: %v", url, err)
+ return
+ }
+ _, err = ioutil.ReadAll(res.Body)
+ defer res.Body.Close()
+ if err != nil {
+ t.Errorf("Error reading %s: %v", url, err)
+ }
+}
+
+func TestServerKeepAlivesEnabled(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
+ ts.Config.SetKeepAlivesEnabled(false)
+ ts.Start()
+ defer ts.Close()
+ res, err := Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if !res.Close {
+ t.Errorf("Body.Close == false; want true")
+ }
+}
+
+// golang.org/issue/7856
+func TestServerEmptyBodyRace(t *testing.T) {
+ defer afterTest(t)
+ var n int32
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ atomic.AddInt32(&n, 1)
+ }))
+ defer ts.Close()
+ var wg sync.WaitGroup
+ const reqs = 20
+ for i := 0; i < reqs; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ res, err := Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer res.Body.Close()
+ _, err = io.Copy(ioutil.Discard, res.Body)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ }()
+ }
+ wg.Wait()
+ if got := atomic.LoadInt32(&n); got != reqs {
+ t.Errorf("handler ran %d times; want %d", got, reqs)
+ }
+}
+
+func TestServerConnStateNew(t *testing.T) {
+ sawNew := false // if the test is buggy, we'll race on this variable.
+ srv := &Server{
+ ConnState: func(c net.Conn, state ConnState) {
+ if state == StateNew {
+ sawNew = true // testing that this write isn't racy
+ }
+ },
+ Handler: HandlerFunc(func(w ResponseWriter, r *Request) {}), // irrelevant
+ }
+ srv.Serve(&oneConnListener{
+ conn: &rwTestConn{
+ Reader: strings.NewReader("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"),
+ Writer: ioutil.Discard,
+ },
+ })
+ if !sawNew { // testing that this read isn't racy
+ t.Error("StateNew not seen")
+ }
+}
+
+type closeWriteTestConn struct {
+ rwTestConn
+ didCloseWrite bool
+}
+
+func (c *closeWriteTestConn) CloseWrite() error {
+ c.didCloseWrite = true
+ return nil
+}
+
+func TestCloseWrite(t *testing.T) {
+ var srv Server
+ var testConn closeWriteTestConn
+ c, err := ExportServerNewConn(&srv, &testConn)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ExportCloseWriteAndWait(c)
+ if !testConn.didCloseWrite {
+ t.Error("didn't see CloseWrite call")
+ }
+}
+
+// This verifies that a handler can Flush and then Hijack.
+//
+// An similar test crashed once during development, but it was only
+// testing this tangentially and temporarily until another TODO was
+// fixed.
+//
+// So add an explicit test for this.
+func TestServerFlushAndHijack(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.WriteString(w, "Hello, ")
+ w.(Flusher).Flush()
+ conn, buf, _ := w.(Hijacker).Hijack()
+ buf.WriteString("6\r\nworld!\r\n0\r\n\r\n")
+ if err := buf.Flush(); err != nil {
+ t.Error(err)
+ }
+ if err := conn.Close(); err != nil {
+ t.Error(err)
+ }
+ }))
+ defer ts.Close()
+ res, err := Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ all, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if want := "Hello, world!"; string(all) != want {
+ t.Errorf("Got %q; want %q", all, want)
+ }
+}
+
+// golang.org/issue/8534 -- the Server shouldn't reuse a connection
+// for keep-alive after it's seen any Write error (e.g. a timeout) on
+// that net.Conn.
+//
+// To test, verify we don't timeout or see fewer unique client
+// addresses (== unique connections) than requests.
+func TestServerKeepAliveAfterWriteError(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in -short mode")
+ }
+ defer afterTest(t)
+ const numReq = 3
+ addrc := make(chan string, numReq)
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ addrc <- r.RemoteAddr
+ time.Sleep(500 * time.Millisecond)
+ w.(Flusher).Flush()
+ }))
+ ts.Config.WriteTimeout = 250 * time.Millisecond
+ ts.Start()
+ defer ts.Close()
+
+ errc := make(chan error, numReq)
+ go func() {
+ defer close(errc)
+ for i := 0; i < numReq; i++ {
+ res, err := Get(ts.URL)
+ if res != nil {
+ res.Body.Close()
+ }
+ errc <- err
+ }
+ }()
+
+ timeout := time.NewTimer(numReq * 2 * time.Second) // 4x overkill
+ defer timeout.Stop()
+ addrSeen := map[string]bool{}
+ numOkay := 0
+ for {
+ select {
+ case v := <-addrc:
+ addrSeen[v] = true
+ case err, ok := <-errc:
+ if !ok {
+ if len(addrSeen) != numReq {
+ t.Errorf("saw %d unique client addresses; want %d", len(addrSeen), numReq)
+ }
+ if numOkay != 0 {
+ t.Errorf("got %d successful client requests; want 0", numOkay)
+ }
+ return
+ }
+ if err == nil {
+ numOkay++
+ }
+ case <-timeout.C:
+ t.Fatal("timeout waiting for requests to complete")
+ }
+ }
+}
+
+func BenchmarkClientServer(b *testing.B) {
+ b.ReportAllocs()
+ b.StopTimer()
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
+ fmt.Fprintf(rw, "Hello world.\n")
+ }))
+ defer ts.Close()
+ b.StartTimer()
+
+ for i := 0; i < b.N; i++ {
+ res, err := Get(ts.URL)
+ if err != nil {
+ b.Fatal("Get:", err)
+ }
+ all, err := ioutil.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ b.Fatal("ReadAll:", err)
+ }
+ body := string(all)
+ if body != "Hello world.\n" {
+ b.Fatal("Got body:", body)
+ }
+ }
+
+ b.StopTimer()
+}
+
+func BenchmarkClientServerParallel4(b *testing.B) {
+ benchmarkClientServerParallel(b, 4, false)
+}
+
+func BenchmarkClientServerParallel64(b *testing.B) {
+ benchmarkClientServerParallel(b, 64, false)
+}
+
+func BenchmarkClientServerParallelTLS4(b *testing.B) {
+ benchmarkClientServerParallel(b, 4, true)
+}
+
+func BenchmarkClientServerParallelTLS64(b *testing.B) {
+ benchmarkClientServerParallel(b, 64, true)
+}
+
+func benchmarkClientServerParallel(b *testing.B, parallelism int, useTLS bool) {
+ b.ReportAllocs()
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
+ fmt.Fprintf(rw, "Hello world.\n")
+ }))
+ if useTLS {
+ ts.StartTLS()
+ } else {
+ ts.Start()
+ }
+ defer ts.Close()
+ b.ResetTimer()
+ b.SetParallelism(parallelism)
+ b.RunParallel(func(pb *testing.PB) {
+ noVerifyTransport := &Transport{
+ TLSClientConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ }
+ defer noVerifyTransport.CloseIdleConnections()
+ client := &Client{Transport: noVerifyTransport}
+ for pb.Next() {
+ res, err := client.Get(ts.URL)
+ if err != nil {
+ b.Logf("Get: %v", err)
+ continue
+ }
+ all, err := ioutil.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ b.Logf("ReadAll: %v", err)
+ continue
+ }
+ body := string(all)
+ if body != "Hello world.\n" {
+ panic("Got body: " + body)
+ }
+ }
+ })
+}
+
+// A benchmark for profiling the server without the HTTP client code.
+// The client code runs in a subprocess.
+//
+// For use like:
+// $ go test -c
+// $ ./http.test -test.run=XX -test.bench=BenchmarkServer -test.benchtime=15s -test.cpuprofile=http.prof
+// $ go tool pprof http.test http.prof
+// (pprof) web
+func BenchmarkServer(b *testing.B) {
+ b.ReportAllocs()
+ // Child process mode;
+ if url := os.Getenv("TEST_BENCH_SERVER_URL"); url != "" {
+ n, err := strconv.Atoi(os.Getenv("TEST_BENCH_CLIENT_N"))
+ if err != nil {
+ panic(err)
+ }
+ for i := 0; i < n; i++ {
+ res, err := Get(url)
+ if err != nil {
+ log.Panicf("Get: %v", err)
+ }
+ all, err := ioutil.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ log.Panicf("ReadAll: %v", err)
+ }
+ body := string(all)
+ if body != "Hello world.\n" {
+ log.Panicf("Got body: %q", body)
+ }
+ }
+ os.Exit(0)
+ return
+ }
+
+ var res = []byte("Hello world.\n")
+ b.StopTimer()
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
+ rw.Header().Set("Content-Type", "text/html; charset=utf-8")
+ rw.Write(res)
+ }))
+ defer ts.Close()
+ b.StartTimer()
+
+ cmd := exec.Command(os.Args[0], "-test.run=XXXX", "-test.bench=BenchmarkServer")
+ cmd.Env = append([]string{
+ fmt.Sprintf("TEST_BENCH_CLIENT_N=%d", b.N),
+ fmt.Sprintf("TEST_BENCH_SERVER_URL=%s", ts.URL),
+ }, os.Environ()...)
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ b.Errorf("Test failure: %v, with output: %s", err, out)
+ }
+}
+
+func BenchmarkServerFakeConnNoKeepAlive(b *testing.B) {
+ b.ReportAllocs()
+ req := reqBytes(`GET / HTTP/1.0
+Host: golang.org
+Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
+User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
+Accept-Encoding: gzip,deflate,sdch
+Accept-Language: en-US,en;q=0.8
+Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
+`)
+ res := []byte("Hello world!\n")
+
+ conn := &testConn{
+ // testConn.Close will not push into the channel
+ // if it's full.
+ closec: make(chan bool, 1),
+ }
+ handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
+ rw.Header().Set("Content-Type", "text/html; charset=utf-8")
+ rw.Write(res)
+ })
+ ln := new(oneConnListener)
+ for i := 0; i < b.N; i++ {
+ conn.readBuf.Reset()
+ conn.writeBuf.Reset()
+ conn.readBuf.Write(req)
+ ln.conn = conn
+ Serve(ln, handler)
+ <-conn.closec
+ }
+}
+
+// repeatReader reads content count times, then EOFs.
+type repeatReader struct {
+ content []byte
+ count int
+ off int
+}
+
+func (r *repeatReader) Read(p []byte) (n int, err error) {
+ if r.count <= 0 {
+ return 0, io.EOF
+ }
+ n = copy(p, r.content[r.off:])
+ r.off += n
+ if r.off == len(r.content) {
+ r.count--
+ r.off = 0
+ }
+ return
+}
+
+func BenchmarkServerFakeConnWithKeepAlive(b *testing.B) {
+ b.ReportAllocs()
+
+ req := reqBytes(`GET / HTTP/1.1
+Host: golang.org
+Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
+User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
+Accept-Encoding: gzip,deflate,sdch
+Accept-Language: en-US,en;q=0.8
+Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
+`)
+ res := []byte("Hello world!\n")
+
+ conn := &rwTestConn{
+ Reader: &repeatReader{content: req, count: b.N},
+ Writer: ioutil.Discard,
+ closec: make(chan bool, 1),
+ }
+ handled := 0
+ handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
+ handled++
+ rw.Header().Set("Content-Type", "text/html; charset=utf-8")
+ rw.Write(res)
+ })
+ ln := &oneConnListener{conn: conn}
+ go Serve(ln, handler)
+ <-conn.closec
+ if b.N != handled {
+ b.Errorf("b.N=%d but handled %d", b.N, handled)
+ }
+}
+
+// same as above, but representing the most simple possible request
+// and handler. Notably: the handler does not call rw.Header().
+func BenchmarkServerFakeConnWithKeepAliveLite(b *testing.B) {
+ b.ReportAllocs()
+
+ req := reqBytes(`GET / HTTP/1.1
+Host: golang.org
+`)
+ res := []byte("Hello world!\n")
+
+ conn := &rwTestConn{
+ Reader: &repeatReader{content: req, count: b.N},
+ Writer: ioutil.Discard,
+ closec: make(chan bool, 1),
+ }
+ handled := 0
+ handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
+ handled++
+ rw.Write(res)
+ })
+ ln := &oneConnListener{conn: conn}
+ go Serve(ln, handler)
+ <-conn.closec
+ if b.N != handled {
+ b.Errorf("b.N=%d but handled %d", b.N, handled)
+ }
+}
+
+const someResponse = "<html>some response</html>"
+
+// A Response that's just no bigger than 2KB, the buffer-before-chunking threshold.
+var response = bytes.Repeat([]byte(someResponse), 2<<10/len(someResponse))
+
+// Both Content-Type and Content-Length set. Should be no buffering.
+func BenchmarkServerHandlerTypeLen(b *testing.B) {
+ benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Type", "text/html")
+ w.Header().Set("Content-Length", strconv.Itoa(len(response)))
+ w.Write(response)
+ }))
+}
+
+// A Content-Type is set, but no length. No sniffing, but will count the Content-Length.
+func BenchmarkServerHandlerNoLen(b *testing.B) {
+ benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Type", "text/html")
+ w.Write(response)
+ }))
+}
+
+// A Content-Length is set, but the Content-Type will be sniffed.
+func BenchmarkServerHandlerNoType(b *testing.B) {
+ benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", strconv.Itoa(len(response)))
+ w.Write(response)
+ }))
+}
+
+// Neither a Content-Type or Content-Length, so sniffed and counted.
+func BenchmarkServerHandlerNoHeader(b *testing.B) {
+ benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Write(response)
+ }))
+}
+
+func benchmarkHandler(b *testing.B, h Handler) {
+ b.ReportAllocs()
+ req := reqBytes(`GET / HTTP/1.1
+Host: golang.org
+`)
+ conn := &rwTestConn{
+ Reader: &repeatReader{content: req, count: b.N},
+ Writer: ioutil.Discard,
+ closec: make(chan bool, 1),
+ }
+ handled := 0
+ handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
+ handled++
+ h.ServeHTTP(rw, r)
+ })
+ ln := &oneConnListener{conn: conn}
+ go Serve(ln, handler)
+ <-conn.closec
+ if b.N != handled {
+ b.Errorf("b.N=%d but handled %d", b.N, handled)
+ }
+}
+
+func BenchmarkServerHijack(b *testing.B) {
+ b.ReportAllocs()
+ req := reqBytes(`GET / HTTP/1.1
+Host: golang.org
+`)
+ h := HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ panic(err)
+ }
+ conn.Close()
+ })
+ conn := &rwTestConn{
+ Writer: ioutil.Discard,
+ closec: make(chan bool, 1),
+ }
+ ln := &oneConnListener{conn: conn}
+ for i := 0; i < b.N; i++ {
+ conn.Reader = bytes.NewReader(req)
+ ln.conn = conn
+ Serve(ln, h)
+ <-conn.closec
+ }
+}
diff --git a/src/net/http/server.go b/src/net/http/server.go
new file mode 100644
index 000000000..008d5aa7a
--- /dev/null
+++ b/src/net/http/server.go
@@ -0,0 +1,2096 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP server. See RFC 2616.
+
+package http
+
+import (
+ "bufio"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "net"
+ "net/url"
+ "os"
+ "path"
+ "runtime"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+// Errors introduced by the HTTP server.
+var (
+ ErrWriteAfterFlush = errors.New("Conn.Write called after Flush")
+ ErrBodyNotAllowed = errors.New("http: request method or response status code does not allow body")
+ ErrHijacked = errors.New("Conn has been hijacked")
+ ErrContentLength = errors.New("Conn.Write wrote more than the declared Content-Length")
+)
+
+// Objects implementing the Handler interface can be
+// registered to serve a particular path or subtree
+// in the HTTP server.
+//
+// ServeHTTP should write reply headers and data to the ResponseWriter
+// and then return. Returning signals that the request is finished
+// and that the HTTP server can move on to the next request on
+// the connection.
+//
+// If ServeHTTP panics, the server (the caller of ServeHTTP) assumes
+// that the effect of the panic was isolated to the active request.
+// It recovers the panic, logs a stack trace to the server error log,
+// and hangs up the connection.
+//
+type Handler interface {
+ ServeHTTP(ResponseWriter, *Request)
+}
+
+// A ResponseWriter interface is used by an HTTP handler to
+// construct an HTTP response.
+type ResponseWriter interface {
+ // Header returns the header map that will be sent by WriteHeader.
+ // Changing the header after a call to WriteHeader (or Write) has
+ // no effect.
+ Header() Header
+
+ // Write writes the data to the connection as part of an HTTP reply.
+ // If WriteHeader has not yet been called, Write calls WriteHeader(http.StatusOK)
+ // before writing the data. If the Header does not contain a
+ // Content-Type line, Write adds a Content-Type set to the result of passing
+ // the initial 512 bytes of written data to DetectContentType.
+ Write([]byte) (int, error)
+
+ // WriteHeader sends an HTTP response header with status code.
+ // If WriteHeader is not called explicitly, the first call to Write
+ // will trigger an implicit WriteHeader(http.StatusOK).
+ // Thus explicit calls to WriteHeader are mainly used to
+ // send error codes.
+ WriteHeader(int)
+}
+
+// The Flusher interface is implemented by ResponseWriters that allow
+// an HTTP handler to flush buffered data to the client.
+//
+// Note that even for ResponseWriters that support Flush,
+// if the client is connected through an HTTP proxy,
+// the buffered data may not reach the client until the response
+// completes.
+type Flusher interface {
+ // Flush sends any buffered data to the client.
+ Flush()
+}
+
+// The Hijacker interface is implemented by ResponseWriters that allow
+// an HTTP handler to take over the connection.
+type Hijacker interface {
+ // Hijack lets the caller take over the connection.
+ // After a call to Hijack(), the HTTP server library
+ // will not do anything else with the connection.
+ // It becomes the caller's responsibility to manage
+ // and close the connection.
+ Hijack() (net.Conn, *bufio.ReadWriter, error)
+}
+
+// The CloseNotifier interface is implemented by ResponseWriters which
+// allow detecting when the underlying connection has gone away.
+//
+// This mechanism can be used to cancel long operations on the server
+// if the client has disconnected before the response is ready.
+type CloseNotifier interface {
+ // CloseNotify returns a channel that receives a single value
+ // when the client connection has gone away.
+ CloseNotify() <-chan bool
+}
+
+// A conn represents the server side of an HTTP connection.
+type conn struct {
+ remoteAddr string // network address of remote side
+ server *Server // the Server on which the connection arrived
+ rwc net.Conn // i/o connection
+ w io.Writer // checkConnErrorWriter's copy of wrc, not zeroed on Hijack
+ werr error // any errors writing to w
+ sr liveSwitchReader // where the LimitReader reads from; usually the rwc
+ lr *io.LimitedReader // io.LimitReader(sr)
+ buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->sr->rwc
+ tlsState *tls.ConnectionState // or nil when not using TLS
+
+ mu sync.Mutex // guards the following
+ clientGone bool // if client has disconnected mid-request
+ closeNotifyc chan bool // made lazily
+ hijackedv bool // connection has been hijacked by handler
+}
+
+func (c *conn) hijacked() bool {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ return c.hijackedv
+}
+
+func (c *conn) hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.hijackedv {
+ return nil, nil, ErrHijacked
+ }
+ if c.closeNotifyc != nil {
+ return nil, nil, errors.New("http: Hijack is incompatible with use of CloseNotifier")
+ }
+ c.hijackedv = true
+ rwc = c.rwc
+ buf = c.buf
+ c.rwc = nil
+ c.buf = nil
+ c.setState(rwc, StateHijacked)
+ return
+}
+
+func (c *conn) closeNotify() <-chan bool {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.closeNotifyc == nil {
+ c.closeNotifyc = make(chan bool, 1)
+ if c.hijackedv {
+ // to obey the function signature, even though
+ // it'll never receive a value.
+ return c.closeNotifyc
+ }
+ pr, pw := io.Pipe()
+
+ readSource := c.sr.r
+ c.sr.Lock()
+ c.sr.r = pr
+ c.sr.Unlock()
+ go func() {
+ _, err := io.Copy(pw, readSource)
+ if err == nil {
+ err = io.EOF
+ }
+ pw.CloseWithError(err)
+ c.noteClientGone()
+ }()
+ }
+ return c.closeNotifyc
+}
+
+func (c *conn) noteClientGone() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.closeNotifyc != nil && !c.clientGone {
+ c.closeNotifyc <- true
+ }
+ c.clientGone = true
+}
+
+// A switchReader can have its Reader changed at runtime.
+// It's not safe for concurrent Reads and switches.
+type switchReader struct {
+ io.Reader
+}
+
+// A switchWriter can have its Writer changed at runtime.
+// It's not safe for concurrent Writes and switches.
+type switchWriter struct {
+ io.Writer
+}
+
+// A liveSwitchReader is a switchReader that's safe for concurrent
+// reads and switches, if its mutex is held.
+type liveSwitchReader struct {
+ sync.Mutex
+ r io.Reader
+}
+
+func (sr *liveSwitchReader) Read(p []byte) (n int, err error) {
+ sr.Lock()
+ r := sr.r
+ sr.Unlock()
+ return r.Read(p)
+}
+
+// This should be >= 512 bytes for DetectContentType,
+// but otherwise it's somewhat arbitrary.
+const bufferBeforeChunkingSize = 2048
+
+// chunkWriter writes to a response's conn buffer, and is the writer
+// wrapped by the response.bufw buffered writer.
+//
+// chunkWriter also is responsible for finalizing the Header, including
+// conditionally setting the Content-Type and setting a Content-Length
+// in cases where the handler's final output is smaller than the buffer
+// size. It also conditionally adds chunk headers, when in chunking mode.
+//
+// See the comment above (*response).Write for the entire write flow.
+type chunkWriter struct {
+ res *response
+
+ // header is either nil or a deep clone of res.handlerHeader
+ // at the time of res.WriteHeader, if res.WriteHeader is
+ // called and extra buffering is being done to calculate
+ // Content-Type and/or Content-Length.
+ header Header
+
+ // wroteHeader tells whether the header's been written to "the
+ // wire" (or rather: w.conn.buf). this is unlike
+ // (*response).wroteHeader, which tells only whether it was
+ // logically written.
+ wroteHeader bool
+
+ // set by the writeHeader method:
+ chunking bool // using chunked transfer encoding for reply body
+}
+
+var (
+ crlf = []byte("\r\n")
+ colonSpace = []byte(": ")
+)
+
+func (cw *chunkWriter) Write(p []byte) (n int, err error) {
+ if !cw.wroteHeader {
+ cw.writeHeader(p)
+ }
+ if cw.res.req.Method == "HEAD" {
+ // Eat writes.
+ return len(p), nil
+ }
+ if cw.chunking {
+ _, err = fmt.Fprintf(cw.res.conn.buf, "%x\r\n", len(p))
+ if err != nil {
+ cw.res.conn.rwc.Close()
+ return
+ }
+ }
+ n, err = cw.res.conn.buf.Write(p)
+ if cw.chunking && err == nil {
+ _, err = cw.res.conn.buf.Write(crlf)
+ }
+ if err != nil {
+ cw.res.conn.rwc.Close()
+ }
+ return
+}
+
+func (cw *chunkWriter) flush() {
+ if !cw.wroteHeader {
+ cw.writeHeader(nil)
+ }
+ cw.res.conn.buf.Flush()
+}
+
+func (cw *chunkWriter) close() {
+ if !cw.wroteHeader {
+ cw.writeHeader(nil)
+ }
+ if cw.chunking {
+ // zero EOF chunk, trailer key/value pairs (currently
+ // unsupported in Go's server), followed by a blank
+ // line.
+ cw.res.conn.buf.WriteString("0\r\n\r\n")
+ }
+}
+
+// A response represents the server side of an HTTP response.
+type response struct {
+ conn *conn
+ req *Request // request for this response
+ wroteHeader bool // reply header has been (logically) written
+ wroteContinue bool // 100 Continue response was written
+
+ w *bufio.Writer // buffers output in chunks to chunkWriter
+ cw chunkWriter
+ sw *switchWriter // of the bufio.Writer, for return to putBufioWriter
+
+ // handlerHeader is the Header that Handlers get access to,
+ // which may be retained and mutated even after WriteHeader.
+ // handlerHeader is copied into cw.header at WriteHeader
+ // time, and privately mutated thereafter.
+ handlerHeader Header
+ calledHeader bool // handler accessed handlerHeader via Header
+
+ written int64 // number of bytes written in body
+ contentLength int64 // explicitly-declared Content-Length; or -1
+ status int // status code passed to WriteHeader
+
+ // close connection after this reply. set on request and
+ // updated after response from handler if there's a
+ // "Connection: keep-alive" response header and a
+ // Content-Length.
+ closeAfterReply bool
+
+ // requestBodyLimitHit is set by requestTooLarge when
+ // maxBytesReader hits its max size. It is checked in
+ // WriteHeader, to make sure we don't consume the
+ // remaining request body to try to advance to the next HTTP
+ // request. Instead, when this is set, we stop reading
+ // subsequent requests on this connection and stop reading
+ // input from it.
+ requestBodyLimitHit bool
+
+ handlerDone bool // set true when the handler exits
+
+ // Buffers for Date and Content-Length
+ dateBuf [len(TimeFormat)]byte
+ clenBuf [10]byte
+}
+
+// requestTooLarge is called by maxBytesReader when too much input has
+// been read from the client.
+func (w *response) requestTooLarge() {
+ w.closeAfterReply = true
+ w.requestBodyLimitHit = true
+ if !w.wroteHeader {
+ w.Header().Set("Connection", "close")
+ }
+}
+
+// needsSniff reports whether a Content-Type still needs to be sniffed.
+func (w *response) needsSniff() bool {
+ _, haveType := w.handlerHeader["Content-Type"]
+ return !w.cw.wroteHeader && !haveType && w.written < sniffLen
+}
+
+// writerOnly hides an io.Writer value's optional ReadFrom method
+// from io.Copy.
+type writerOnly struct {
+ io.Writer
+}
+
+func srcIsRegularFile(src io.Reader) (isRegular bool, err error) {
+ switch v := src.(type) {
+ case *os.File:
+ fi, err := v.Stat()
+ if err != nil {
+ return false, err
+ }
+ return fi.Mode().IsRegular(), nil
+ case *io.LimitedReader:
+ return srcIsRegularFile(v.R)
+ default:
+ return
+ }
+}
+
+// ReadFrom is here to optimize copying from an *os.File regular file
+// to a *net.TCPConn with sendfile.
+func (w *response) ReadFrom(src io.Reader) (n int64, err error) {
+ // Our underlying w.conn.rwc is usually a *TCPConn (with its
+ // own ReadFrom method). If not, or if our src isn't a regular
+ // file, just fall back to the normal copy method.
+ rf, ok := w.conn.rwc.(io.ReaderFrom)
+ regFile, err := srcIsRegularFile(src)
+ if err != nil {
+ return 0, err
+ }
+ if !ok || !regFile {
+ return io.Copy(writerOnly{w}, src)
+ }
+
+ // sendfile path:
+
+ if !w.wroteHeader {
+ w.WriteHeader(StatusOK)
+ }
+
+ if w.needsSniff() {
+ n0, err := io.Copy(writerOnly{w}, io.LimitReader(src, sniffLen))
+ n += n0
+ if err != nil {
+ return n, err
+ }
+ }
+
+ w.w.Flush() // get rid of any previous writes
+ w.cw.flush() // make sure Header is written; flush data to rwc
+
+ // Now that cw has been flushed, its chunking field is guaranteed initialized.
+ if !w.cw.chunking && w.bodyAllowed() {
+ n0, err := rf.ReadFrom(src)
+ n += n0
+ w.written += n0
+ return n, err
+ }
+
+ n0, err := io.Copy(writerOnly{w}, src)
+ n += n0
+ return n, err
+}
+
+// noLimit is an effective infinite upper bound for io.LimitedReader
+const noLimit int64 = (1 << 63) - 1
+
+// debugServerConnections controls whether all server connections are wrapped
+// with a verbose logging wrapper.
+const debugServerConnections = false
+
+// Create new connection from rwc.
+func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) {
+ c = new(conn)
+ c.remoteAddr = rwc.RemoteAddr().String()
+ c.server = srv
+ c.rwc = rwc
+ c.w = rwc
+ if debugServerConnections {
+ c.rwc = newLoggingConn("server", c.rwc)
+ }
+ c.sr = liveSwitchReader{r: c.rwc}
+ c.lr = io.LimitReader(&c.sr, noLimit).(*io.LimitedReader)
+ br := newBufioReader(c.lr)
+ bw := newBufioWriterSize(checkConnErrorWriter{c}, 4<<10)
+ c.buf = bufio.NewReadWriter(br, bw)
+ return c, nil
+}
+
+var (
+ bufioReaderPool sync.Pool
+ bufioWriter2kPool sync.Pool
+ bufioWriter4kPool sync.Pool
+)
+
+func bufioWriterPool(size int) *sync.Pool {
+ switch size {
+ case 2 << 10:
+ return &bufioWriter2kPool
+ case 4 << 10:
+ return &bufioWriter4kPool
+ }
+ return nil
+}
+
+func newBufioReader(r io.Reader) *bufio.Reader {
+ if v := bufioReaderPool.Get(); v != nil {
+ br := v.(*bufio.Reader)
+ br.Reset(r)
+ return br
+ }
+ return bufio.NewReader(r)
+}
+
+func putBufioReader(br *bufio.Reader) {
+ br.Reset(nil)
+ bufioReaderPool.Put(br)
+}
+
+func newBufioWriterSize(w io.Writer, size int) *bufio.Writer {
+ pool := bufioWriterPool(size)
+ if pool != nil {
+ if v := pool.Get(); v != nil {
+ bw := v.(*bufio.Writer)
+ bw.Reset(w)
+ return bw
+ }
+ }
+ return bufio.NewWriterSize(w, size)
+}
+
+func putBufioWriter(bw *bufio.Writer) {
+ bw.Reset(nil)
+ if pool := bufioWriterPool(bw.Available()); pool != nil {
+ pool.Put(bw)
+ }
+}
+
+// DefaultMaxHeaderBytes is the maximum permitted size of the headers
+// in an HTTP request.
+// This can be overridden by setting Server.MaxHeaderBytes.
+const DefaultMaxHeaderBytes = 1 << 20 // 1 MB
+
+func (srv *Server) maxHeaderBytes() int {
+ if srv.MaxHeaderBytes > 0 {
+ return srv.MaxHeaderBytes
+ }
+ return DefaultMaxHeaderBytes
+}
+
+func (srv *Server) initialLimitedReaderSize() int64 {
+ return int64(srv.maxHeaderBytes()) + 4096 // bufio slop
+}
+
+// wrapper around io.ReaderCloser which on first read, sends an
+// HTTP/1.1 100 Continue header
+type expectContinueReader struct {
+ resp *response
+ readCloser io.ReadCloser
+ closed bool
+}
+
+func (ecr *expectContinueReader) Read(p []byte) (n int, err error) {
+ if ecr.closed {
+ return 0, ErrBodyReadAfterClose
+ }
+ if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked() {
+ ecr.resp.wroteContinue = true
+ ecr.resp.conn.buf.WriteString("HTTP/1.1 100 Continue\r\n\r\n")
+ ecr.resp.conn.buf.Flush()
+ }
+ return ecr.readCloser.Read(p)
+}
+
+func (ecr *expectContinueReader) Close() error {
+ ecr.closed = true
+ return ecr.readCloser.Close()
+}
+
+// TimeFormat is the time format to use with
+// time.Parse and time.Time.Format when parsing
+// or generating times in HTTP headers.
+// It is like time.RFC1123 but hard codes GMT as the time zone.
+const TimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT"
+
+// appendTime is a non-allocating version of []byte(t.UTC().Format(TimeFormat))
+func appendTime(b []byte, t time.Time) []byte {
+ const days = "SunMonTueWedThuFriSat"
+ const months = "JanFebMarAprMayJunJulAugSepOctNovDec"
+
+ t = t.UTC()
+ yy, mm, dd := t.Date()
+ hh, mn, ss := t.Clock()
+ day := days[3*t.Weekday():]
+ mon := months[3*(mm-1):]
+
+ return append(b,
+ day[0], day[1], day[2], ',', ' ',
+ byte('0'+dd/10), byte('0'+dd%10), ' ',
+ mon[0], mon[1], mon[2], ' ',
+ byte('0'+yy/1000), byte('0'+(yy/100)%10), byte('0'+(yy/10)%10), byte('0'+yy%10), ' ',
+ byte('0'+hh/10), byte('0'+hh%10), ':',
+ byte('0'+mn/10), byte('0'+mn%10), ':',
+ byte('0'+ss/10), byte('0'+ss%10), ' ',
+ 'G', 'M', 'T')
+}
+
+var errTooLarge = errors.New("http: request too large")
+
+// Read next request from connection.
+func (c *conn) readRequest() (w *response, err error) {
+ if c.hijacked() {
+ return nil, ErrHijacked
+ }
+
+ if d := c.server.ReadTimeout; d != 0 {
+ c.rwc.SetReadDeadline(time.Now().Add(d))
+ }
+ if d := c.server.WriteTimeout; d != 0 {
+ defer func() {
+ c.rwc.SetWriteDeadline(time.Now().Add(d))
+ }()
+ }
+
+ c.lr.N = c.server.initialLimitedReaderSize()
+ var req *Request
+ if req, err = ReadRequest(c.buf.Reader); err != nil {
+ if c.lr.N == 0 {
+ return nil, errTooLarge
+ }
+ return nil, err
+ }
+ c.lr.N = noLimit
+
+ req.RemoteAddr = c.remoteAddr
+ req.TLS = c.tlsState
+
+ w = &response{
+ conn: c,
+ req: req,
+ handlerHeader: make(Header),
+ contentLength: -1,
+ }
+ w.cw.res = w
+ w.w = newBufioWriterSize(&w.cw, bufferBeforeChunkingSize)
+ return w, nil
+}
+
+func (w *response) Header() Header {
+ if w.cw.header == nil && w.wroteHeader && !w.cw.wroteHeader {
+ // Accessing the header between logically writing it
+ // and physically writing it means we need to allocate
+ // a clone to snapshot the logically written state.
+ w.cw.header = w.handlerHeader.clone()
+ }
+ w.calledHeader = true
+ return w.handlerHeader
+}
+
+// maxPostHandlerReadBytes is the max number of Request.Body bytes not
+// consumed by a handler that the server will read from the client
+// in order to keep a connection alive. If there are more bytes than
+// this then the server to be paranoid instead sends a "Connection:
+// close" response.
+//
+// This number is approximately what a typical machine's TCP buffer
+// size is anyway. (if we have the bytes on the machine, we might as
+// well read them)
+const maxPostHandlerReadBytes = 256 << 10
+
+func (w *response) WriteHeader(code int) {
+ if w.conn.hijacked() {
+ w.conn.server.logf("http: response.WriteHeader on hijacked connection")
+ return
+ }
+ if w.wroteHeader {
+ w.conn.server.logf("http: multiple response.WriteHeader calls")
+ return
+ }
+ w.wroteHeader = true
+ w.status = code
+
+ if w.calledHeader && w.cw.header == nil {
+ w.cw.header = w.handlerHeader.clone()
+ }
+
+ if cl := w.handlerHeader.get("Content-Length"); cl != "" {
+ v, err := strconv.ParseInt(cl, 10, 64)
+ if err == nil && v >= 0 {
+ w.contentLength = v
+ } else {
+ w.conn.server.logf("http: invalid Content-Length of %q", cl)
+ w.handlerHeader.Del("Content-Length")
+ }
+ }
+}
+
+// extraHeader is the set of headers sometimes added by chunkWriter.writeHeader.
+// This type is used to avoid extra allocations from cloning and/or populating
+// the response Header map and all its 1-element slices.
+type extraHeader struct {
+ contentType string
+ connection string
+ transferEncoding string
+ date []byte // written if not nil
+ contentLength []byte // written if not nil
+}
+
+// Sorted the same as extraHeader.Write's loop.
+var extraHeaderKeys = [][]byte{
+ []byte("Content-Type"),
+ []byte("Connection"),
+ []byte("Transfer-Encoding"),
+}
+
+var (
+ headerContentLength = []byte("Content-Length: ")
+ headerDate = []byte("Date: ")
+)
+
+// Write writes the headers described in h to w.
+//
+// This method has a value receiver, despite the somewhat large size
+// of h, because it prevents an allocation. The escape analysis isn't
+// smart enough to realize this function doesn't mutate h.
+func (h extraHeader) Write(w *bufio.Writer) {
+ if h.date != nil {
+ w.Write(headerDate)
+ w.Write(h.date)
+ w.Write(crlf)
+ }
+ if h.contentLength != nil {
+ w.Write(headerContentLength)
+ w.Write(h.contentLength)
+ w.Write(crlf)
+ }
+ for i, v := range []string{h.contentType, h.connection, h.transferEncoding} {
+ if v != "" {
+ w.Write(extraHeaderKeys[i])
+ w.Write(colonSpace)
+ w.WriteString(v)
+ w.Write(crlf)
+ }
+ }
+}
+
+// writeHeader finalizes the header sent to the client and writes it
+// to cw.res.conn.buf.
+//
+// p is not written by writeHeader, but is the first chunk of the body
+// that will be written. It is sniffed for a Content-Type if none is
+// set explicitly. It's also used to set the Content-Length, if the
+// total body size was small and the handler has already finished
+// running.
+func (cw *chunkWriter) writeHeader(p []byte) {
+ if cw.wroteHeader {
+ return
+ }
+ cw.wroteHeader = true
+
+ w := cw.res
+ keepAlivesEnabled := w.conn.server.doKeepAlives()
+ isHEAD := w.req.Method == "HEAD"
+
+ // header is written out to w.conn.buf below. Depending on the
+ // state of the handler, we either own the map or not. If we
+ // don't own it, the exclude map is created lazily for
+ // WriteSubset to remove headers. The setHeader struct holds
+ // headers we need to add.
+ header := cw.header
+ owned := header != nil
+ if !owned {
+ header = w.handlerHeader
+ }
+ var excludeHeader map[string]bool
+ delHeader := func(key string) {
+ if owned {
+ header.Del(key)
+ return
+ }
+ if _, ok := header[key]; !ok {
+ return
+ }
+ if excludeHeader == nil {
+ excludeHeader = make(map[string]bool)
+ }
+ excludeHeader[key] = true
+ }
+ var setHeader extraHeader
+
+ // If the handler is done but never sent a Content-Length
+ // response header and this is our first (and last) write, set
+ // it, even to zero. This helps HTTP/1.0 clients keep their
+ // "keep-alive" connections alive.
+ // Exceptions: 304/204/1xx responses never get Content-Length, and if
+ // it was a HEAD request, we don't know the difference between
+ // 0 actual bytes and 0 bytes because the handler noticed it
+ // was a HEAD request and chose not to write anything. So for
+ // HEAD, the handler should either write the Content-Length or
+ // write non-zero bytes. If it's actually 0 bytes and the
+ // handler never looked at the Request.Method, we just don't
+ // send a Content-Length header.
+ if w.handlerDone && bodyAllowedForStatus(w.status) && header.get("Content-Length") == "" && (!isHEAD || len(p) > 0) {
+ w.contentLength = int64(len(p))
+ setHeader.contentLength = strconv.AppendInt(cw.res.clenBuf[:0], int64(len(p)), 10)
+ }
+
+ // If this was an HTTP/1.0 request with keep-alive and we sent a
+ // Content-Length back, we can make this a keep-alive response ...
+ if w.req.wantsHttp10KeepAlive() && keepAlivesEnabled {
+ sentLength := header.get("Content-Length") != ""
+ if sentLength && header.get("Connection") == "keep-alive" {
+ w.closeAfterReply = false
+ }
+ }
+
+ // Check for a explicit (and valid) Content-Length header.
+ hasCL := w.contentLength != -1
+
+ if w.req.wantsHttp10KeepAlive() && (isHEAD || hasCL) {
+ _, connectionHeaderSet := header["Connection"]
+ if !connectionHeaderSet {
+ setHeader.connection = "keep-alive"
+ }
+ } else if !w.req.ProtoAtLeast(1, 1) || w.req.wantsClose() {
+ w.closeAfterReply = true
+ }
+
+ if header.get("Connection") == "close" || !keepAlivesEnabled {
+ w.closeAfterReply = true
+ }
+
+ // Per RFC 2616, we should consume the request body before
+ // replying, if the handler hasn't already done so. But we
+ // don't want to do an unbounded amount of reading here for
+ // DoS reasons, so we only try up to a threshold.
+ if w.req.ContentLength != 0 && !w.closeAfterReply {
+ ecr, isExpecter := w.req.Body.(*expectContinueReader)
+ if !isExpecter || ecr.resp.wroteContinue {
+ n, _ := io.CopyN(ioutil.Discard, w.req.Body, maxPostHandlerReadBytes+1)
+ if n >= maxPostHandlerReadBytes {
+ w.requestTooLarge()
+ delHeader("Connection")
+ setHeader.connection = "close"
+ } else {
+ w.req.Body.Close()
+ }
+ }
+ }
+
+ code := w.status
+ if bodyAllowedForStatus(code) {
+ // If no content type, apply sniffing algorithm to body.
+ _, haveType := header["Content-Type"]
+ if !haveType {
+ setHeader.contentType = DetectContentType(p)
+ }
+ } else {
+ for _, k := range suppressedHeaders(code) {
+ delHeader(k)
+ }
+ }
+
+ if _, ok := header["Date"]; !ok {
+ setHeader.date = appendTime(cw.res.dateBuf[:0], time.Now())
+ }
+
+ te := header.get("Transfer-Encoding")
+ hasTE := te != ""
+ if hasCL && hasTE && te != "identity" {
+ // TODO: return an error if WriteHeader gets a return parameter
+ // For now just ignore the Content-Length.
+ w.conn.server.logf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d",
+ te, w.contentLength)
+ delHeader("Content-Length")
+ hasCL = false
+ }
+
+ if w.req.Method == "HEAD" || !bodyAllowedForStatus(code) {
+ // do nothing
+ } else if code == StatusNoContent {
+ delHeader("Transfer-Encoding")
+ } else if hasCL {
+ delHeader("Transfer-Encoding")
+ } else if w.req.ProtoAtLeast(1, 1) {
+ // HTTP/1.1 or greater: Transfer-Encoding has been set to identity, and no
+ // content-length has been provided. The connection must be closed after the
+ // reply is written, and no chunking is to be done. This is the setup
+ // recommended in the Server-Sent Events candidate recommendation 11,
+ // section 8.
+ if hasTE && te == "identity" {
+ cw.chunking = false
+ w.closeAfterReply = true
+ } else {
+ // HTTP/1.1 or greater: use chunked transfer encoding
+ // to avoid closing the connection at EOF.
+ cw.chunking = true
+ setHeader.transferEncoding = "chunked"
+ }
+ } else {
+ // HTTP version < 1.1: cannot do chunked transfer
+ // encoding and we don't know the Content-Length so
+ // signal EOF by closing connection.
+ w.closeAfterReply = true
+ delHeader("Transfer-Encoding") // in case already set
+ }
+
+ // Cannot use Content-Length with non-identity Transfer-Encoding.
+ if cw.chunking {
+ delHeader("Content-Length")
+ }
+ if !w.req.ProtoAtLeast(1, 0) {
+ return
+ }
+
+ if w.closeAfterReply && (!keepAlivesEnabled || !hasToken(cw.header.get("Connection"), "close")) {
+ delHeader("Connection")
+ if w.req.ProtoAtLeast(1, 1) {
+ setHeader.connection = "close"
+ }
+ }
+
+ w.conn.buf.WriteString(statusLine(w.req, code))
+ cw.header.WriteSubset(w.conn.buf, excludeHeader)
+ setHeader.Write(w.conn.buf.Writer)
+ w.conn.buf.Write(crlf)
+}
+
+// statusLines is a cache of Status-Line strings, keyed by code (for
+// HTTP/1.1) or negative code (for HTTP/1.0). This is faster than a
+// map keyed by struct of two fields. This map's max size is bounded
+// by 2*len(statusText), two protocol types for each known official
+// status code in the statusText map.
+var (
+ statusMu sync.RWMutex
+ statusLines = make(map[int]string)
+)
+
+// statusLine returns a response Status-Line (RFC 2616 Section 6.1)
+// for the given request and response status code.
+func statusLine(req *Request, code int) string {
+ // Fast path:
+ key := code
+ proto11 := req.ProtoAtLeast(1, 1)
+ if !proto11 {
+ key = -key
+ }
+ statusMu.RLock()
+ line, ok := statusLines[key]
+ statusMu.RUnlock()
+ if ok {
+ return line
+ }
+
+ // Slow path:
+ proto := "HTTP/1.0"
+ if proto11 {
+ proto = "HTTP/1.1"
+ }
+ codestring := strconv.Itoa(code)
+ text, ok := statusText[code]
+ if !ok {
+ text = "status code " + codestring
+ }
+ line = proto + " " + codestring + " " + text + "\r\n"
+ if ok {
+ statusMu.Lock()
+ defer statusMu.Unlock()
+ statusLines[key] = line
+ }
+ return line
+}
+
+// bodyAllowed returns true if a Write is allowed for this response type.
+// It's illegal to call this before the header has been flushed.
+func (w *response) bodyAllowed() bool {
+ if !w.wroteHeader {
+ panic("")
+ }
+ return bodyAllowedForStatus(w.status)
+}
+
+// The Life Of A Write is like this:
+//
+// Handler starts. No header has been sent. The handler can either
+// write a header, or just start writing. Writing before sending a header
+// sends an implicitly empty 200 OK header.
+//
+// If the handler didn't declare a Content-Length up front, we either
+// go into chunking mode or, if the handler finishes running before
+// the chunking buffer size, we compute a Content-Length and send that
+// in the header instead.
+//
+// Likewise, if the handler didn't set a Content-Type, we sniff that
+// from the initial chunk of output.
+//
+// The Writers are wired together like:
+//
+// 1. *response (the ResponseWriter) ->
+// 2. (*response).w, a *bufio.Writer of bufferBeforeChunkingSize bytes
+// 3. chunkWriter.Writer (whose writeHeader finalizes Content-Length/Type)
+// and which writes the chunk headers, if needed.
+// 4. conn.buf, a bufio.Writer of default (4kB) bytes, writing to ->
+// 5. checkConnErrorWriter{c}, which notes any non-nil error on Write
+// and populates c.werr with it if so. but otherwise writes to:
+// 6. the rwc, the net.Conn.
+//
+// TODO(bradfitz): short-circuit some of the buffering when the
+// initial header contains both a Content-Type and Content-Length.
+// Also short-circuit in (1) when the header's been sent and not in
+// chunking mode, writing directly to (4) instead, if (2) has no
+// buffered data. More generally, we could short-circuit from (1) to
+// (3) even in chunking mode if the write size from (1) is over some
+// threshold and nothing is in (2). The answer might be mostly making
+// bufferBeforeChunkingSize smaller and having bufio's fast-paths deal
+// with this instead.
+func (w *response) Write(data []byte) (n int, err error) {
+ return w.write(len(data), data, "")
+}
+
+func (w *response) WriteString(data string) (n int, err error) {
+ return w.write(len(data), nil, data)
+}
+
+// either dataB or dataS is non-zero.
+func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err error) {
+ if w.conn.hijacked() {
+ w.conn.server.logf("http: response.Write on hijacked connection")
+ return 0, ErrHijacked
+ }
+ if !w.wroteHeader {
+ w.WriteHeader(StatusOK)
+ }
+ if lenData == 0 {
+ return 0, nil
+ }
+ if !w.bodyAllowed() {
+ return 0, ErrBodyNotAllowed
+ }
+
+ w.written += int64(lenData) // ignoring errors, for errorKludge
+ if w.contentLength != -1 && w.written > w.contentLength {
+ return 0, ErrContentLength
+ }
+ if dataB != nil {
+ return w.w.Write(dataB)
+ } else {
+ return w.w.WriteString(dataS)
+ }
+}
+
+func (w *response) finishRequest() {
+ w.handlerDone = true
+
+ if !w.wroteHeader {
+ w.WriteHeader(StatusOK)
+ }
+
+ w.w.Flush()
+ putBufioWriter(w.w)
+ w.cw.close()
+ w.conn.buf.Flush()
+
+ // Close the body (regardless of w.closeAfterReply) so we can
+ // re-use its bufio.Reader later safely.
+ w.req.Body.Close()
+
+ if w.req.MultipartForm != nil {
+ w.req.MultipartForm.RemoveAll()
+ }
+
+ if w.req.Method != "HEAD" && w.contentLength != -1 && w.bodyAllowed() && w.contentLength != w.written {
+ // Did not write enough. Avoid getting out of sync.
+ w.closeAfterReply = true
+ }
+
+ // There was some error writing to the underlying connection
+ // during the request, so don't re-use this conn.
+ if w.conn.werr != nil {
+ w.closeAfterReply = true
+ }
+}
+
+func (w *response) Flush() {
+ if !w.wroteHeader {
+ w.WriteHeader(StatusOK)
+ }
+ w.w.Flush()
+ w.cw.flush()
+}
+
+func (c *conn) finalFlush() {
+ if c.buf != nil {
+ c.buf.Flush()
+
+ // Steal the bufio.Reader (~4KB worth of memory) and its associated
+ // reader for a future connection.
+ putBufioReader(c.buf.Reader)
+
+ // Steal the bufio.Writer (~4KB worth of memory) and its associated
+ // writer for a future connection.
+ putBufioWriter(c.buf.Writer)
+
+ c.buf = nil
+ }
+}
+
+// Close the connection.
+func (c *conn) close() {
+ c.finalFlush()
+ if c.rwc != nil {
+ c.rwc.Close()
+ c.rwc = nil
+ }
+}
+
+// rstAvoidanceDelay is the amount of time we sleep after closing the
+// write side of a TCP connection before closing the entire socket.
+// By sleeping, we increase the chances that the client sees our FIN
+// and processes its final data before they process the subsequent RST
+// from closing a connection with known unread data.
+// This RST seems to occur mostly on BSD systems. (And Windows?)
+// This timeout is somewhat arbitrary (~latency around the planet).
+const rstAvoidanceDelay = 500 * time.Millisecond
+
+type closeWriter interface {
+ CloseWrite() error
+}
+
+var _ closeWriter = (*net.TCPConn)(nil)
+
+// closeWrite flushes any outstanding data and sends a FIN packet (if
+// client is connected via TCP), signalling that we're done. We then
+// pause for a bit, hoping the client processes it before any
+// subsequent RST.
+//
+// See http://golang.org/issue/3595
+func (c *conn) closeWriteAndWait() {
+ c.finalFlush()
+ if tcp, ok := c.rwc.(closeWriter); ok {
+ tcp.CloseWrite()
+ }
+ time.Sleep(rstAvoidanceDelay)
+}
+
+// validNPN reports whether the proto is not a blacklisted Next
+// Protocol Negotiation protocol. Empty and built-in protocol types
+// are blacklisted and can't be overridden with alternate
+// implementations.
+func validNPN(proto string) bool {
+ switch proto {
+ case "", "http/1.1", "http/1.0":
+ return false
+ }
+ return true
+}
+
+func (c *conn) setState(nc net.Conn, state ConnState) {
+ if hook := c.server.ConnState; hook != nil {
+ hook(nc, state)
+ }
+}
+
+// Serve a new connection.
+func (c *conn) serve() {
+ origConn := c.rwc // copy it before it's set nil on Close or Hijack
+ defer func() {
+ if err := recover(); err != nil {
+ const size = 64 << 10
+ buf := make([]byte, size)
+ buf = buf[:runtime.Stack(buf, false)]
+ c.server.logf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf)
+ }
+ if !c.hijacked() {
+ c.close()
+ c.setState(origConn, StateClosed)
+ }
+ }()
+
+ if tlsConn, ok := c.rwc.(*tls.Conn); ok {
+ if d := c.server.ReadTimeout; d != 0 {
+ c.rwc.SetReadDeadline(time.Now().Add(d))
+ }
+ if d := c.server.WriteTimeout; d != 0 {
+ c.rwc.SetWriteDeadline(time.Now().Add(d))
+ }
+ if err := tlsConn.Handshake(); err != nil {
+ c.server.logf("http: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err)
+ return
+ }
+ c.tlsState = new(tls.ConnectionState)
+ *c.tlsState = tlsConn.ConnectionState()
+ if proto := c.tlsState.NegotiatedProtocol; validNPN(proto) {
+ if fn := c.server.TLSNextProto[proto]; fn != nil {
+ h := initNPNRequest{tlsConn, serverHandler{c.server}}
+ fn(c.server, tlsConn, h)
+ }
+ return
+ }
+ }
+
+ for {
+ w, err := c.readRequest()
+ if c.lr.N != c.server.initialLimitedReaderSize() {
+ // If we read any bytes off the wire, we're active.
+ c.setState(c.rwc, StateActive)
+ }
+ if err != nil {
+ if err == errTooLarge {
+ // Their HTTP client may or may not be
+ // able to read this if we're
+ // responding to them and hanging up
+ // while they're still writing their
+ // request. Undefined behavior.
+ io.WriteString(c.rwc, "HTTP/1.1 413 Request Entity Too Large\r\n\r\n")
+ c.closeWriteAndWait()
+ break
+ } else if err == io.EOF {
+ break // Don't reply
+ } else if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
+ break // Don't reply
+ }
+ io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\n\r\n")
+ break
+ }
+
+ // Expect 100 Continue support
+ req := w.req
+ if req.expectsContinue() {
+ if req.ProtoAtLeast(1, 1) && req.ContentLength != 0 {
+ // Wrap the Body reader with one that replies on the connection
+ req.Body = &expectContinueReader{readCloser: req.Body, resp: w}
+ }
+ req.Header.Del("Expect")
+ } else if req.Header.get("Expect") != "" {
+ w.sendExpectationFailed()
+ break
+ }
+
+ // HTTP cannot have multiple simultaneous active requests.[*]
+ // Until the server replies to this request, it can't read another,
+ // so we might as well run the handler in this goroutine.
+ // [*] Not strictly true: HTTP pipelining. We could let them all process
+ // in parallel even if their responses need to be serialized.
+ serverHandler{c.server}.ServeHTTP(w, w.req)
+ if c.hijacked() {
+ return
+ }
+ w.finishRequest()
+ if w.closeAfterReply {
+ if w.requestBodyLimitHit {
+ c.closeWriteAndWait()
+ }
+ break
+ }
+ c.setState(c.rwc, StateIdle)
+ }
+}
+
+func (w *response) sendExpectationFailed() {
+ // TODO(bradfitz): let ServeHTTP handlers handle
+ // requests with non-standard expectation[s]? Seems
+ // theoretical at best, and doesn't fit into the
+ // current ServeHTTP model anyway. We'd need to
+ // make the ResponseWriter an optional
+ // "ExpectReplier" interface or something.
+ //
+ // For now we'll just obey RFC 2616 14.20 which says
+ // "If a server receives a request containing an
+ // Expect field that includes an expectation-
+ // extension that it does not support, it MUST
+ // respond with a 417 (Expectation Failed) status."
+ w.Header().Set("Connection", "close")
+ w.WriteHeader(StatusExpectationFailed)
+ w.finishRequest()
+}
+
+// Hijack implements the Hijacker.Hijack method. Our response is both a ResponseWriter
+// and a Hijacker.
+func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
+ if w.wroteHeader {
+ w.cw.flush()
+ }
+ // Release the bufioWriter that writes to the chunk writer, it is not
+ // used after a connection has been hijacked.
+ rwc, buf, err = w.conn.hijack()
+ if err == nil {
+ putBufioWriter(w.w)
+ w.w = nil
+ }
+ return rwc, buf, err
+}
+
+func (w *response) CloseNotify() <-chan bool {
+ return w.conn.closeNotify()
+}
+
+// The HandlerFunc type is an adapter to allow the use of
+// ordinary functions as HTTP handlers. If f is a function
+// with the appropriate signature, HandlerFunc(f) is a
+// Handler object that calls f.
+type HandlerFunc func(ResponseWriter, *Request)
+
+// ServeHTTP calls f(w, r).
+func (f HandlerFunc) ServeHTTP(w ResponseWriter, r *Request) {
+ f(w, r)
+}
+
+// Helper handlers
+
+// Error replies to the request with the specified error message and HTTP code.
+// The error message should be plain text.
+func Error(w ResponseWriter, error string, code int) {
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+ w.WriteHeader(code)
+ fmt.Fprintln(w, error)
+}
+
+// NotFound replies to the request with an HTTP 404 not found error.
+func NotFound(w ResponseWriter, r *Request) { Error(w, "404 page not found", StatusNotFound) }
+
+// NotFoundHandler returns a simple request handler
+// that replies to each request with a ``404 page not found'' reply.
+func NotFoundHandler() Handler { return HandlerFunc(NotFound) }
+
+// StripPrefix returns a handler that serves HTTP requests
+// by removing the given prefix from the request URL's Path
+// and invoking the handler h. StripPrefix handles a
+// request for a path that doesn't begin with prefix by
+// replying with an HTTP 404 not found error.
+func StripPrefix(prefix string, h Handler) Handler {
+ if prefix == "" {
+ return h
+ }
+ return HandlerFunc(func(w ResponseWriter, r *Request) {
+ if p := strings.TrimPrefix(r.URL.Path, prefix); len(p) < len(r.URL.Path) {
+ r.URL.Path = p
+ h.ServeHTTP(w, r)
+ } else {
+ NotFound(w, r)
+ }
+ })
+}
+
+// Redirect replies to the request with a redirect to url,
+// which may be a path relative to the request path.
+func Redirect(w ResponseWriter, r *Request, urlStr string, code int) {
+ if u, err := url.Parse(urlStr); err == nil {
+ // If url was relative, make absolute by
+ // combining with request path.
+ // The browser would probably do this for us,
+ // but doing it ourselves is more reliable.
+
+ // NOTE(rsc): RFC 2616 says that the Location
+ // line must be an absolute URI, like
+ // "http://www.google.com/redirect/",
+ // not a path like "/redirect/".
+ // Unfortunately, we don't know what to
+ // put in the host name section to get the
+ // client to connect to us again, so we can't
+ // know the right absolute URI to send back.
+ // Because of this problem, no one pays attention
+ // to the RFC; they all send back just a new path.
+ // So do we.
+ oldpath := r.URL.Path
+ if oldpath == "" { // should not happen, but avoid a crash if it does
+ oldpath = "/"
+ }
+ if u.Scheme == "" {
+ // no leading http://server
+ if urlStr == "" || urlStr[0] != '/' {
+ // make relative path absolute
+ olddir, _ := path.Split(oldpath)
+ urlStr = olddir + urlStr
+ }
+
+ var query string
+ if i := strings.Index(urlStr, "?"); i != -1 {
+ urlStr, query = urlStr[:i], urlStr[i:]
+ }
+
+ // clean up but preserve trailing slash
+ trailing := strings.HasSuffix(urlStr, "/")
+ urlStr = path.Clean(urlStr)
+ if trailing && !strings.HasSuffix(urlStr, "/") {
+ urlStr += "/"
+ }
+ urlStr += query
+ }
+ }
+
+ w.Header().Set("Location", urlStr)
+ w.WriteHeader(code)
+
+ // RFC2616 recommends that a short note "SHOULD" be included in the
+ // response because older user agents may not understand 301/307.
+ // Shouldn't send the response for POST or HEAD; that leaves GET.
+ if r.Method == "GET" {
+ note := "<a href=\"" + htmlEscape(urlStr) + "\">" + statusText[code] + "</a>.\n"
+ fmt.Fprintln(w, note)
+ }
+}
+
+var htmlReplacer = strings.NewReplacer(
+ "&", "&amp;",
+ "<", "&lt;",
+ ">", "&gt;",
+ // "&#34;" is shorter than "&quot;".
+ `"`, "&#34;",
+ // "&#39;" is shorter than "&apos;" and apos was not in HTML until HTML5.
+ "'", "&#39;",
+)
+
+func htmlEscape(s string) string {
+ return htmlReplacer.Replace(s)
+}
+
+// Redirect to a fixed URL
+type redirectHandler struct {
+ url string
+ code int
+}
+
+func (rh *redirectHandler) ServeHTTP(w ResponseWriter, r *Request) {
+ Redirect(w, r, rh.url, rh.code)
+}
+
+// RedirectHandler returns a request handler that redirects
+// each request it receives to the given url using the given
+// status code.
+func RedirectHandler(url string, code int) Handler {
+ return &redirectHandler{url, code}
+}
+
+// ServeMux is an HTTP request multiplexer.
+// It matches the URL of each incoming request against a list of registered
+// patterns and calls the handler for the pattern that
+// most closely matches the URL.
+//
+// Patterns name fixed, rooted paths, like "/favicon.ico",
+// or rooted subtrees, like "/images/" (note the trailing slash).
+// Longer patterns take precedence over shorter ones, so that
+// if there are handlers registered for both "/images/"
+// and "/images/thumbnails/", the latter handler will be
+// called for paths beginning "/images/thumbnails/" and the
+// former will receive requests for any other paths in the
+// "/images/" subtree.
+//
+// Note that since a pattern ending in a slash names a rooted subtree,
+// the pattern "/" matches all paths not matched by other registered
+// patterns, not just the URL with Path == "/".
+//
+// Patterns may optionally begin with a host name, restricting matches to
+// URLs on that host only. Host-specific patterns take precedence over
+// general patterns, so that a handler might register for the two patterns
+// "/codesearch" and "codesearch.google.com/" without also taking over
+// requests for "http://www.google.com/".
+//
+// ServeMux also takes care of sanitizing the URL request path,
+// redirecting any request containing . or .. elements to an
+// equivalent .- and ..-free URL.
+type ServeMux struct {
+ mu sync.RWMutex
+ m map[string]muxEntry
+ hosts bool // whether any patterns contain hostnames
+}
+
+type muxEntry struct {
+ explicit bool
+ h Handler
+ pattern string
+}
+
+// NewServeMux allocates and returns a new ServeMux.
+func NewServeMux() *ServeMux { return &ServeMux{m: make(map[string]muxEntry)} }
+
+// DefaultServeMux is the default ServeMux used by Serve.
+var DefaultServeMux = NewServeMux()
+
+// Does path match pattern?
+func pathMatch(pattern, path string) bool {
+ if len(pattern) == 0 {
+ // should not happen
+ return false
+ }
+ n := len(pattern)
+ if pattern[n-1] != '/' {
+ return pattern == path
+ }
+ return len(path) >= n && path[0:n] == pattern
+}
+
+// Return the canonical path for p, eliminating . and .. elements.
+func cleanPath(p string) string {
+ if p == "" {
+ return "/"
+ }
+ if p[0] != '/' {
+ p = "/" + p
+ }
+ np := path.Clean(p)
+ // path.Clean removes trailing slash except for root;
+ // put the trailing slash back if necessary.
+ if p[len(p)-1] == '/' && np != "/" {
+ np += "/"
+ }
+ return np
+}
+
+// Find a handler on a handler map given a path string
+// Most-specific (longest) pattern wins
+func (mux *ServeMux) match(path string) (h Handler, pattern string) {
+ var n = 0
+ for k, v := range mux.m {
+ if !pathMatch(k, path) {
+ continue
+ }
+ if h == nil || len(k) > n {
+ n = len(k)
+ h = v.h
+ pattern = v.pattern
+ }
+ }
+ return
+}
+
+// Handler returns the handler to use for the given request,
+// consulting r.Method, r.Host, and r.URL.Path. It always returns
+// a non-nil handler. If the path is not in its canonical form, the
+// handler will be an internally-generated handler that redirects
+// to the canonical path.
+//
+// Handler also returns the registered pattern that matches the
+// request or, in the case of internally-generated redirects,
+// the pattern that will match after following the redirect.
+//
+// If there is no registered handler that applies to the request,
+// Handler returns a ``page not found'' handler and an empty pattern.
+func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) {
+ if r.Method != "CONNECT" {
+ if p := cleanPath(r.URL.Path); p != r.URL.Path {
+ _, pattern = mux.handler(r.Host, p)
+ url := *r.URL
+ url.Path = p
+ return RedirectHandler(url.String(), StatusMovedPermanently), pattern
+ }
+ }
+
+ return mux.handler(r.Host, r.URL.Path)
+}
+
+// handler is the main implementation of Handler.
+// The path is known to be in canonical form, except for CONNECT methods.
+func (mux *ServeMux) handler(host, path string) (h Handler, pattern string) {
+ mux.mu.RLock()
+ defer mux.mu.RUnlock()
+
+ // Host-specific pattern takes precedence over generic ones
+ if mux.hosts {
+ h, pattern = mux.match(host + path)
+ }
+ if h == nil {
+ h, pattern = mux.match(path)
+ }
+ if h == nil {
+ h, pattern = NotFoundHandler(), ""
+ }
+ return
+}
+
+// ServeHTTP dispatches the request to the handler whose
+// pattern most closely matches the request URL.
+func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) {
+ if r.RequestURI == "*" {
+ if r.ProtoAtLeast(1, 1) {
+ w.Header().Set("Connection", "close")
+ }
+ w.WriteHeader(StatusBadRequest)
+ return
+ }
+ h, _ := mux.Handler(r)
+ h.ServeHTTP(w, r)
+}
+
+// Handle registers the handler for the given pattern.
+// If a handler already exists for pattern, Handle panics.
+func (mux *ServeMux) Handle(pattern string, handler Handler) {
+ mux.mu.Lock()
+ defer mux.mu.Unlock()
+
+ if pattern == "" {
+ panic("http: invalid pattern " + pattern)
+ }
+ if handler == nil {
+ panic("http: nil handler")
+ }
+ if mux.m[pattern].explicit {
+ panic("http: multiple registrations for " + pattern)
+ }
+
+ mux.m[pattern] = muxEntry{explicit: true, h: handler, pattern: pattern}
+
+ if pattern[0] != '/' {
+ mux.hosts = true
+ }
+
+ // Helpful behavior:
+ // If pattern is /tree/, insert an implicit permanent redirect for /tree.
+ // It can be overridden by an explicit registration.
+ n := len(pattern)
+ if n > 0 && pattern[n-1] == '/' && !mux.m[pattern[0:n-1]].explicit {
+ // If pattern contains a host name, strip it and use remaining
+ // path for redirect.
+ path := pattern
+ if pattern[0] != '/' {
+ // In pattern, at least the last character is a '/', so
+ // strings.Index can't be -1.
+ path = pattern[strings.Index(pattern, "/"):]
+ }
+ mux.m[pattern[0:n-1]] = muxEntry{h: RedirectHandler(path, StatusMovedPermanently), pattern: pattern}
+ }
+}
+
+// HandleFunc registers the handler function for the given pattern.
+func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Request)) {
+ mux.Handle(pattern, HandlerFunc(handler))
+}
+
+// Handle registers the handler for the given pattern
+// in the DefaultServeMux.
+// The documentation for ServeMux explains how patterns are matched.
+func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) }
+
+// HandleFunc registers the handler function for the given pattern
+// in the DefaultServeMux.
+// The documentation for ServeMux explains how patterns are matched.
+func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) {
+ DefaultServeMux.HandleFunc(pattern, handler)
+}
+
+// Serve accepts incoming HTTP connections on the listener l,
+// creating a new service goroutine for each. The service goroutines
+// read requests and then call handler to reply to them.
+// Handler is typically nil, in which case the DefaultServeMux is used.
+func Serve(l net.Listener, handler Handler) error {
+ srv := &Server{Handler: handler}
+ return srv.Serve(l)
+}
+
+// A Server defines parameters for running an HTTP server.
+// The zero value for Server is a valid configuration.
+type Server struct {
+ Addr string // TCP address to listen on, ":http" if empty
+ Handler Handler // handler to invoke, http.DefaultServeMux if nil
+ ReadTimeout time.Duration // maximum duration before timing out read of the request
+ WriteTimeout time.Duration // maximum duration before timing out write of the response
+ MaxHeaderBytes int // maximum size of request headers, DefaultMaxHeaderBytes if 0
+ TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS
+
+ // TLSNextProto optionally specifies a function to take over
+ // ownership of the provided TLS connection when an NPN
+ // protocol upgrade has occurred. The map key is the protocol
+ // name negotiated. The Handler argument should be used to
+ // handle HTTP requests and will initialize the Request's TLS
+ // and RemoteAddr if not already set. The connection is
+ // automatically closed when the function returns.
+ TLSNextProto map[string]func(*Server, *tls.Conn, Handler)
+
+ // ConnState specifies an optional callback function that is
+ // called when a client connection changes state. See the
+ // ConnState type and associated constants for details.
+ ConnState func(net.Conn, ConnState)
+
+ // ErrorLog specifies an optional logger for errors accepting
+ // connections and unexpected behavior from handlers.
+ // If nil, logging goes to os.Stderr via the log package's
+ // standard logger.
+ ErrorLog *log.Logger
+
+ disableKeepAlives int32 // accessed atomically.
+}
+
+// A ConnState represents the state of a client connection to a server.
+// It's used by the optional Server.ConnState hook.
+type ConnState int
+
+const (
+ // StateNew represents a new connection that is expected to
+ // send a request immediately. Connections begin at this
+ // state and then transition to either StateActive or
+ // StateClosed.
+ StateNew ConnState = iota
+
+ // StateActive represents a connection that has read 1 or more
+ // bytes of a request. The Server.ConnState hook for
+ // StateActive fires before the request has entered a handler
+ // and doesn't fire again until the request has been
+ // handled. After the request is handled, the state
+ // transitions to StateClosed, StateHijacked, or StateIdle.
+ StateActive
+
+ // StateIdle represents a connection that has finished
+ // handling a request and is in the keep-alive state, waiting
+ // for a new request. Connections transition from StateIdle
+ // to either StateActive or StateClosed.
+ StateIdle
+
+ // StateHijacked represents a hijacked connection.
+ // This is a terminal state. It does not transition to StateClosed.
+ StateHijacked
+
+ // StateClosed represents a closed connection.
+ // This is a terminal state. Hijacked connections do not
+ // transition to StateClosed.
+ StateClosed
+)
+
+var stateName = map[ConnState]string{
+ StateNew: "new",
+ StateActive: "active",
+ StateIdle: "idle",
+ StateHijacked: "hijacked",
+ StateClosed: "closed",
+}
+
+func (c ConnState) String() string {
+ return stateName[c]
+}
+
+// serverHandler delegates to either the server's Handler or
+// DefaultServeMux and also handles "OPTIONS *" requests.
+type serverHandler struct {
+ srv *Server
+}
+
+func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) {
+ handler := sh.srv.Handler
+ if handler == nil {
+ handler = DefaultServeMux
+ }
+ if req.RequestURI == "*" && req.Method == "OPTIONS" {
+ handler = globalOptionsHandler{}
+ }
+ handler.ServeHTTP(rw, req)
+}
+
+// ListenAndServe listens on the TCP network address srv.Addr and then
+// calls Serve to handle requests on incoming connections. If
+// srv.Addr is blank, ":http" is used.
+func (srv *Server) ListenAndServe() error {
+ addr := srv.Addr
+ if addr == "" {
+ addr = ":http"
+ }
+ ln, err := net.Listen("tcp", addr)
+ if err != nil {
+ return err
+ }
+ return srv.Serve(tcpKeepAliveListener{ln.(*net.TCPListener)})
+}
+
+// Serve accepts incoming connections on the Listener l, creating a
+// new service goroutine for each. The service goroutines read requests and
+// then call srv.Handler to reply to them.
+func (srv *Server) Serve(l net.Listener) error {
+ defer l.Close()
+ var tempDelay time.Duration // how long to sleep on accept failure
+ for {
+ rw, e := l.Accept()
+ if e != nil {
+ if ne, ok := e.(net.Error); ok && ne.Temporary() {
+ if tempDelay == 0 {
+ tempDelay = 5 * time.Millisecond
+ } else {
+ tempDelay *= 2
+ }
+ if max := 1 * time.Second; tempDelay > max {
+ tempDelay = max
+ }
+ srv.logf("http: Accept error: %v; retrying in %v", e, tempDelay)
+ time.Sleep(tempDelay)
+ continue
+ }
+ return e
+ }
+ tempDelay = 0
+ c, err := srv.newConn(rw)
+ if err != nil {
+ continue
+ }
+ c.setState(c.rwc, StateNew) // before Serve can return
+ go c.serve()
+ }
+}
+
+func (s *Server) doKeepAlives() bool {
+ return atomic.LoadInt32(&s.disableKeepAlives) == 0
+}
+
+// SetKeepAlivesEnabled controls whether HTTP keep-alives are enabled.
+// By default, keep-alives are always enabled. Only very
+// resource-constrained environments or servers in the process of
+// shutting down should disable them.
+func (s *Server) SetKeepAlivesEnabled(v bool) {
+ if v {
+ atomic.StoreInt32(&s.disableKeepAlives, 0)
+ } else {
+ atomic.StoreInt32(&s.disableKeepAlives, 1)
+ }
+}
+
+func (s *Server) logf(format string, args ...interface{}) {
+ if s.ErrorLog != nil {
+ s.ErrorLog.Printf(format, args...)
+ } else {
+ log.Printf(format, args...)
+ }
+}
+
+// ListenAndServe listens on the TCP network address addr
+// and then calls Serve with handler to handle requests
+// on incoming connections. Handler is typically nil,
+// in which case the DefaultServeMux is used.
+//
+// A trivial example server is:
+//
+// package main
+//
+// import (
+// "io"
+// "net/http"
+// "log"
+// )
+//
+// // hello world, the web server
+// func HelloServer(w http.ResponseWriter, req *http.Request) {
+// io.WriteString(w, "hello, world!\n")
+// }
+//
+// func main() {
+// http.HandleFunc("/hello", HelloServer)
+// err := http.ListenAndServe(":12345", nil)
+// if err != nil {
+// log.Fatal("ListenAndServe: ", err)
+// }
+// }
+func ListenAndServe(addr string, handler Handler) error {
+ server := &Server{Addr: addr, Handler: handler}
+ return server.ListenAndServe()
+}
+
+// ListenAndServeTLS acts identically to ListenAndServe, except that it
+// expects HTTPS connections. Additionally, files containing a certificate and
+// matching private key for the server must be provided. If the certificate
+// is signed by a certificate authority, the certFile should be the concatenation
+// of the server's certificate followed by the CA's certificate.
+//
+// A trivial example server is:
+//
+// import (
+// "log"
+// "net/http"
+// )
+//
+// func handler(w http.ResponseWriter, req *http.Request) {
+// w.Header().Set("Content-Type", "text/plain")
+// w.Write([]byte("This is an example server.\n"))
+// }
+//
+// func main() {
+// http.HandleFunc("/", handler)
+// log.Printf("About to listen on 10443. Go to https://127.0.0.1:10443/")
+// err := http.ListenAndServeTLS(":10443", "cert.pem", "key.pem", nil)
+// if err != nil {
+// log.Fatal(err)
+// }
+// }
+//
+// One can use generate_cert.go in crypto/tls to generate cert.pem and key.pem.
+func ListenAndServeTLS(addr string, certFile string, keyFile string, handler Handler) error {
+ server := &Server{Addr: addr, Handler: handler}
+ return server.ListenAndServeTLS(certFile, keyFile)
+}
+
+// ListenAndServeTLS listens on the TCP network address srv.Addr and
+// then calls Serve to handle requests on incoming TLS connections.
+//
+// Filenames containing a certificate and matching private key for
+// the server must be provided. If the certificate is signed by a
+// certificate authority, the certFile should be the concatenation
+// of the server's certificate followed by the CA's certificate.
+//
+// If srv.Addr is blank, ":https" is used.
+func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error {
+ addr := srv.Addr
+ if addr == "" {
+ addr = ":https"
+ }
+ config := &tls.Config{}
+ if srv.TLSConfig != nil {
+ *config = *srv.TLSConfig
+ }
+ if config.NextProtos == nil {
+ config.NextProtos = []string{"http/1.1"}
+ }
+
+ var err error
+ config.Certificates = make([]tls.Certificate, 1)
+ config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
+ if err != nil {
+ return err
+ }
+
+ ln, err := net.Listen("tcp", addr)
+ if err != nil {
+ return err
+ }
+
+ tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config)
+ return srv.Serve(tlsListener)
+}
+
+// TimeoutHandler returns a Handler that runs h with the given time limit.
+//
+// The new Handler calls h.ServeHTTP to handle each request, but if a
+// call runs for longer than its time limit, the handler responds with
+// a 503 Service Unavailable error and the given message in its body.
+// (If msg is empty, a suitable default message will be sent.)
+// After such a timeout, writes by h to its ResponseWriter will return
+// ErrHandlerTimeout.
+func TimeoutHandler(h Handler, dt time.Duration, msg string) Handler {
+ f := func() <-chan time.Time {
+ return time.After(dt)
+ }
+ return &timeoutHandler{h, f, msg}
+}
+
+// ErrHandlerTimeout is returned on ResponseWriter Write calls
+// in handlers which have timed out.
+var ErrHandlerTimeout = errors.New("http: Handler timeout")
+
+type timeoutHandler struct {
+ handler Handler
+ timeout func() <-chan time.Time // returns channel producing a timeout
+ body string
+}
+
+func (h *timeoutHandler) errorBody() string {
+ if h.body != "" {
+ return h.body
+ }
+ return "<html><head><title>Timeout</title></head><body><h1>Timeout</h1></body></html>"
+}
+
+func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) {
+ done := make(chan bool, 1)
+ tw := &timeoutWriter{w: w}
+ go func() {
+ h.handler.ServeHTTP(tw, r)
+ done <- true
+ }()
+ select {
+ case <-done:
+ return
+ case <-h.timeout():
+ tw.mu.Lock()
+ defer tw.mu.Unlock()
+ if !tw.wroteHeader {
+ tw.w.WriteHeader(StatusServiceUnavailable)
+ tw.w.Write([]byte(h.errorBody()))
+ }
+ tw.timedOut = true
+ }
+}
+
+type timeoutWriter struct {
+ w ResponseWriter
+
+ mu sync.Mutex
+ timedOut bool
+ wroteHeader bool
+}
+
+func (tw *timeoutWriter) Header() Header {
+ return tw.w.Header()
+}
+
+func (tw *timeoutWriter) Write(p []byte) (int, error) {
+ tw.mu.Lock()
+ defer tw.mu.Unlock()
+ tw.wroteHeader = true // implicitly at least
+ if tw.timedOut {
+ return 0, ErrHandlerTimeout
+ }
+ return tw.w.Write(p)
+}
+
+func (tw *timeoutWriter) WriteHeader(code int) {
+ tw.mu.Lock()
+ defer tw.mu.Unlock()
+ if tw.timedOut || tw.wroteHeader {
+ return
+ }
+ tw.wroteHeader = true
+ tw.w.WriteHeader(code)
+}
+
+// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
+// connections. It's used by ListenAndServe and ListenAndServeTLS so
+// dead TCP connections (e.g. closing laptop mid-download) eventually
+// go away.
+type tcpKeepAliveListener struct {
+ *net.TCPListener
+}
+
+func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
+ tc, err := ln.AcceptTCP()
+ if err != nil {
+ return
+ }
+ tc.SetKeepAlive(true)
+ tc.SetKeepAlivePeriod(3 * time.Minute)
+ return tc, nil
+}
+
+// globalOptionsHandler responds to "OPTIONS *" requests.
+type globalOptionsHandler struct{}
+
+func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "0")
+ if r.ContentLength != 0 {
+ // Read up to 4KB of OPTIONS body (as mentioned in the
+ // spec as being reserved for future use), but anything
+ // over that is considered a waste of server resources
+ // (or an attack) and we abort and close the connection,
+ // courtesy of MaxBytesReader's EOF behavior.
+ mb := MaxBytesReader(w, r.Body, 4<<10)
+ io.Copy(ioutil.Discard, mb)
+ }
+}
+
+type eofReaderWithWriteTo struct{}
+
+func (eofReaderWithWriteTo) WriteTo(io.Writer) (int64, error) { return 0, nil }
+func (eofReaderWithWriteTo) Read([]byte) (int, error) { return 0, io.EOF }
+
+// eofReader is a non-nil io.ReadCloser that always returns EOF.
+// It has a WriteTo method so io.Copy won't need a buffer.
+var eofReader = &struct {
+ eofReaderWithWriteTo
+ io.Closer
+}{
+ eofReaderWithWriteTo{},
+ ioutil.NopCloser(nil),
+}
+
+// Verify that an io.Copy from an eofReader won't require a buffer.
+var _ io.WriterTo = eofReader
+
+// initNPNRequest is an HTTP handler that initializes certain
+// uninitialized fields in its *Request. Such partially-initialized
+// Requests come from NPN protocol handlers.
+type initNPNRequest struct {
+ c *tls.Conn
+ h serverHandler
+}
+
+func (h initNPNRequest) ServeHTTP(rw ResponseWriter, req *Request) {
+ if req.TLS == nil {
+ req.TLS = &tls.ConnectionState{}
+ *req.TLS = h.c.ConnectionState()
+ }
+ if req.Body == nil {
+ req.Body = eofReader
+ }
+ if req.RemoteAddr == "" {
+ req.RemoteAddr = h.c.RemoteAddr().String()
+ }
+ h.h.ServeHTTP(rw, req)
+}
+
+// loggingConn is used for debugging.
+type loggingConn struct {
+ name string
+ net.Conn
+}
+
+var (
+ uniqNameMu sync.Mutex
+ uniqNameNext = make(map[string]int)
+)
+
+func newLoggingConn(baseName string, c net.Conn) net.Conn {
+ uniqNameMu.Lock()
+ defer uniqNameMu.Unlock()
+ uniqNameNext[baseName]++
+ return &loggingConn{
+ name: fmt.Sprintf("%s-%d", baseName, uniqNameNext[baseName]),
+ Conn: c,
+ }
+}
+
+func (c *loggingConn) Write(p []byte) (n int, err error) {
+ log.Printf("%s.Write(%d) = ....", c.name, len(p))
+ n, err = c.Conn.Write(p)
+ log.Printf("%s.Write(%d) = %d, %v", c.name, len(p), n, err)
+ return
+}
+
+func (c *loggingConn) Read(p []byte) (n int, err error) {
+ log.Printf("%s.Read(%d) = ....", c.name, len(p))
+ n, err = c.Conn.Read(p)
+ log.Printf("%s.Read(%d) = %d, %v", c.name, len(p), n, err)
+ return
+}
+
+func (c *loggingConn) Close() (err error) {
+ log.Printf("%s.Close() = ...", c.name)
+ err = c.Conn.Close()
+ log.Printf("%s.Close() = %v", c.name, err)
+ return
+}
+
+// checkConnErrorWriter writes to c.rwc and records any write errors to c.werr.
+// It only contains one field (and a pointer field at that), so it
+// fits in an interface value without an extra allocation.
+type checkConnErrorWriter struct {
+ c *conn
+}
+
+func (w checkConnErrorWriter) Write(p []byte) (n int, err error) {
+ n, err = w.c.w.Write(p) // c.w == c.rwc, except after a hijack, when rwc is nil.
+ if err != nil && w.c.werr == nil {
+ w.c.werr = err
+ }
+ return
+}
diff --git a/src/net/http/sniff.go b/src/net/http/sniff.go
new file mode 100644
index 000000000..68f519b05
--- /dev/null
+++ b/src/net/http/sniff.go
@@ -0,0 +1,214 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bytes"
+ "encoding/binary"
+)
+
+// The algorithm uses at most sniffLen bytes to make its decision.
+const sniffLen = 512
+
+// DetectContentType implements the algorithm described
+// at http://mimesniff.spec.whatwg.org/ to determine the
+// Content-Type of the given data. It considers at most the
+// first 512 bytes of data. DetectContentType always returns
+// a valid MIME type: if it cannot determine a more specific one, it
+// returns "application/octet-stream".
+func DetectContentType(data []byte) string {
+ if len(data) > sniffLen {
+ data = data[:sniffLen]
+ }
+
+ // Index of the first non-whitespace byte in data.
+ firstNonWS := 0
+ for ; firstNonWS < len(data) && isWS(data[firstNonWS]); firstNonWS++ {
+ }
+
+ for _, sig := range sniffSignatures {
+ if ct := sig.match(data, firstNonWS); ct != "" {
+ return ct
+ }
+ }
+
+ return "application/octet-stream" // fallback
+}
+
+func isWS(b byte) bool {
+ return bytes.IndexByte([]byte("\t\n\x0C\r "), b) != -1
+}
+
+type sniffSig interface {
+ // match returns the MIME type of the data, or "" if unknown.
+ match(data []byte, firstNonWS int) string
+}
+
+// Data matching the table in section 6.
+var sniffSignatures = []sniffSig{
+ htmlSig("<!DOCTYPE HTML"),
+ htmlSig("<HTML"),
+ htmlSig("<HEAD"),
+ htmlSig("<SCRIPT"),
+ htmlSig("<IFRAME"),
+ htmlSig("<H1"),
+ htmlSig("<DIV"),
+ htmlSig("<FONT"),
+ htmlSig("<TABLE"),
+ htmlSig("<A"),
+ htmlSig("<STYLE"),
+ htmlSig("<TITLE"),
+ htmlSig("<B"),
+ htmlSig("<BODY"),
+ htmlSig("<BR"),
+ htmlSig("<P"),
+ htmlSig("<!--"),
+
+ &maskedSig{mask: []byte("\xFF\xFF\xFF\xFF\xFF"), pat: []byte("<?xml"), skipWS: true, ct: "text/xml; charset=utf-8"},
+
+ &exactSig{[]byte("%PDF-"), "application/pdf"},
+ &exactSig{[]byte("%!PS-Adobe-"), "application/postscript"},
+
+ // UTF BOMs.
+ &maskedSig{mask: []byte("\xFF\xFF\x00\x00"), pat: []byte("\xFE\xFF\x00\x00"), ct: "text/plain; charset=utf-16be"},
+ &maskedSig{mask: []byte("\xFF\xFF\x00\x00"), pat: []byte("\xFF\xFE\x00\x00"), ct: "text/plain; charset=utf-16le"},
+ &maskedSig{mask: []byte("\xFF\xFF\xFF\x00"), pat: []byte("\xEF\xBB\xBF\x00"), ct: "text/plain; charset=utf-8"},
+
+ &exactSig{[]byte("GIF87a"), "image/gif"},
+ &exactSig{[]byte("GIF89a"), "image/gif"},
+ &exactSig{[]byte("\x89\x50\x4E\x47\x0D\x0A\x1A\x0A"), "image/png"},
+ &exactSig{[]byte("\xFF\xD8\xFF"), "image/jpeg"},
+ &exactSig{[]byte("BM"), "image/bmp"},
+ &maskedSig{
+ mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF\xFF\xFF"),
+ pat: []byte("RIFF\x00\x00\x00\x00WEBPVP"),
+ ct: "image/webp",
+ },
+ &exactSig{[]byte("\x00\x00\x01\x00"), "image/vnd.microsoft.icon"},
+ &exactSig{[]byte("\x4F\x67\x67\x53\x00"), "application/ogg"},
+ &maskedSig{
+ mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF"),
+ pat: []byte("RIFF\x00\x00\x00\x00WAVE"),
+ ct: "audio/wave",
+ },
+ &exactSig{[]byte("\x1A\x45\xDF\xA3"), "video/webm"},
+ &exactSig{[]byte("\x52\x61\x72\x20\x1A\x07\x00"), "application/x-rar-compressed"},
+ &exactSig{[]byte("\x50\x4B\x03\x04"), "application/zip"},
+ &exactSig{[]byte("\x1F\x8B\x08"), "application/x-gzip"},
+
+ // TODO(dsymonds): Re-enable this when the spec is sorted w.r.t. MP4.
+ //mp4Sig(0),
+
+ textSig(0), // should be last
+}
+
+type exactSig struct {
+ sig []byte
+ ct string
+}
+
+func (e *exactSig) match(data []byte, firstNonWS int) string {
+ if bytes.HasPrefix(data, e.sig) {
+ return e.ct
+ }
+ return ""
+}
+
+type maskedSig struct {
+ mask, pat []byte
+ skipWS bool
+ ct string
+}
+
+func (m *maskedSig) match(data []byte, firstNonWS int) string {
+ if m.skipWS {
+ data = data[firstNonWS:]
+ }
+ if len(data) < len(m.mask) {
+ return ""
+ }
+ for i, mask := range m.mask {
+ db := data[i] & mask
+ if db != m.pat[i] {
+ return ""
+ }
+ }
+ return m.ct
+}
+
+type htmlSig []byte
+
+func (h htmlSig) match(data []byte, firstNonWS int) string {
+ data = data[firstNonWS:]
+ if len(data) < len(h)+1 {
+ return ""
+ }
+ for i, b := range h {
+ db := data[i]
+ if 'A' <= b && b <= 'Z' {
+ db &= 0xDF
+ }
+ if b != db {
+ return ""
+ }
+ }
+ // Next byte must be space or right angle bracket.
+ if db := data[len(h)]; db != ' ' && db != '>' {
+ return ""
+ }
+ return "text/html; charset=utf-8"
+}
+
+type mp4Sig int
+
+func (mp4Sig) match(data []byte, firstNonWS int) string {
+ // c.f. section 6.1.
+ if len(data) < 8 {
+ return ""
+ }
+ boxSize := int(binary.BigEndian.Uint32(data[:4]))
+ if boxSize%4 != 0 || len(data) < boxSize {
+ return ""
+ }
+ if !bytes.Equal(data[4:8], []byte("ftyp")) {
+ return ""
+ }
+ for st := 8; st < boxSize; st += 4 {
+ if st == 12 {
+ // minor version number
+ continue
+ }
+ seg := string(data[st : st+3])
+ switch seg {
+ case "mp4", "iso", "M4V", "M4P", "M4B":
+ return "video/mp4"
+ /* The remainder are not in the spec.
+ case "M4A":
+ return "audio/mp4"
+ case "3gp":
+ return "video/3gpp"
+ case "jp2":
+ return "image/jp2" // JPEG 2000
+ */
+ }
+ }
+ return ""
+}
+
+type textSig int
+
+func (textSig) match(data []byte, firstNonWS int) string {
+ // c.f. section 5, step 4.
+ for _, b := range data[firstNonWS:] {
+ switch {
+ case 0x00 <= b && b <= 0x08,
+ b == 0x0B,
+ 0x0E <= b && b <= 0x1A,
+ 0x1C <= b && b <= 0x1F:
+ return ""
+ }
+ }
+ return "text/plain; charset=utf-8"
+}
diff --git a/src/net/http/sniff_test.go b/src/net/http/sniff_test.go
new file mode 100644
index 000000000..24ca27afc
--- /dev/null
+++ b/src/net/http/sniff_test.go
@@ -0,0 +1,171 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ . "net/http"
+ "net/http/httptest"
+ "reflect"
+ "strconv"
+ "strings"
+ "testing"
+)
+
+var sniffTests = []struct {
+ desc string
+ data []byte
+ contentType string
+}{
+ // Some nonsense.
+ {"Empty", []byte{}, "text/plain; charset=utf-8"},
+ {"Binary", []byte{1, 2, 3}, "application/octet-stream"},
+
+ {"HTML document #1", []byte(`<HtMl><bOdY>blah blah blah</body></html>`), "text/html; charset=utf-8"},
+ {"HTML document #2", []byte(`<HTML></HTML>`), "text/html; charset=utf-8"},
+ {"HTML document #3 (leading whitespace)", []byte(` <!DOCTYPE HTML>...`), "text/html; charset=utf-8"},
+ {"HTML document #4 (leading CRLF)", []byte("\r\n<html>..."), "text/html; charset=utf-8"},
+
+ {"Plain text", []byte(`This is not HTML. It has โ˜ƒ though.`), "text/plain; charset=utf-8"},
+
+ {"XML", []byte("\n<?xml!"), "text/xml; charset=utf-8"},
+
+ // Image types.
+ {"GIF 87a", []byte(`GIF87a`), "image/gif"},
+ {"GIF 89a", []byte(`GIF89a...`), "image/gif"},
+
+ // TODO(dsymonds): Re-enable this when the spec is sorted w.r.t. MP4.
+ //{"MP4 video", []byte("\x00\x00\x00\x18ftypmp42\x00\x00\x00\x00mp42isom<\x06t\xbfmdat"), "video/mp4"},
+ //{"MP4 audio", []byte("\x00\x00\x00\x20ftypM4A \x00\x00\x00\x00M4A mp42isom\x00\x00\x00\x00"), "audio/mp4"},
+}
+
+func TestDetectContentType(t *testing.T) {
+ for _, tt := range sniffTests {
+ ct := DetectContentType(tt.data)
+ if ct != tt.contentType {
+ t.Errorf("%v: DetectContentType = %q, want %q", tt.desc, ct, tt.contentType)
+ }
+ }
+}
+
+func TestServerContentType(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ i, _ := strconv.Atoi(r.FormValue("i"))
+ tt := sniffTests[i]
+ n, err := w.Write(tt.data)
+ if n != len(tt.data) || err != nil {
+ log.Fatalf("%v: Write(%q) = %v, %v want %d, nil", tt.desc, tt.data, n, err, len(tt.data))
+ }
+ }))
+ defer ts.Close()
+
+ for i, tt := range sniffTests {
+ resp, err := Get(ts.URL + "/?i=" + strconv.Itoa(i))
+ if err != nil {
+ t.Errorf("%v: %v", tt.desc, err)
+ continue
+ }
+ if ct := resp.Header.Get("Content-Type"); ct != tt.contentType {
+ t.Errorf("%v: Content-Type = %q, want %q", tt.desc, ct, tt.contentType)
+ }
+ data, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ t.Errorf("%v: reading body: %v", tt.desc, err)
+ } else if !bytes.Equal(data, tt.data) {
+ t.Errorf("%v: data is %q, want %q", tt.desc, data, tt.data)
+ }
+ resp.Body.Close()
+ }
+}
+
+// Issue 5953: shouldn't sniff if the handler set a Content-Type header,
+// even if it's the empty string.
+func TestServerIssue5953(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header()["Content-Type"] = []string{""}
+ fmt.Fprintf(w, "<html><head></head><body>hi</body></html>")
+ }))
+ defer ts.Close()
+
+ resp, err := Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ got := resp.Header["Content-Type"]
+ want := []string{""}
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("Content-Type = %q; want %q", got, want)
+ }
+ resp.Body.Close()
+}
+
+func TestContentTypeWithCopy(t *testing.T) {
+ defer afterTest(t)
+
+ const (
+ input = "\n<html>\n\t<head>\n"
+ expected = "text/html; charset=utf-8"
+ )
+
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ // Use io.Copy from a bytes.Buffer to trigger ReadFrom.
+ buf := bytes.NewBuffer([]byte(input))
+ n, err := io.Copy(w, buf)
+ if int(n) != len(input) || err != nil {
+ t.Errorf("io.Copy(w, %q) = %v, %v want %d, nil", input, n, err, len(input))
+ }
+ }))
+ defer ts.Close()
+
+ resp, err := Get(ts.URL)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ if ct := resp.Header.Get("Content-Type"); ct != expected {
+ t.Errorf("Content-Type = %q, want %q", ct, expected)
+ }
+ data, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ t.Errorf("reading body: %v", err)
+ } else if !bytes.Equal(data, []byte(input)) {
+ t.Errorf("data is %q, want %q", data, input)
+ }
+ resp.Body.Close()
+}
+
+func TestSniffWriteSize(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ size, _ := strconv.Atoi(r.FormValue("size"))
+ written, err := io.WriteString(w, strings.Repeat("a", size))
+ if err != nil {
+ t.Errorf("write of %d bytes: %v", size, err)
+ return
+ }
+ if written != size {
+ t.Errorf("write of %d bytes wrote %d bytes", size, written)
+ }
+ }))
+ defer ts.Close()
+ for _, size := range []int{0, 1, 200, 600, 999, 1000, 1023, 1024, 512 << 10, 1 << 20} {
+ res, err := Get(fmt.Sprintf("%s/?size=%d", ts.URL, size))
+ if err != nil {
+ t.Fatalf("size %d: %v", size, err)
+ }
+ if _, err := io.Copy(ioutil.Discard, res.Body); err != nil {
+ t.Fatalf("size %d: io.Copy of body = %v", size, err)
+ }
+ if err := res.Body.Close(); err != nil {
+ t.Fatalf("size %d: body Close = %v", size, err)
+ }
+ }
+}
diff --git a/src/net/http/status.go b/src/net/http/status.go
new file mode 100644
index 000000000..d253bd5cb
--- /dev/null
+++ b/src/net/http/status.go
@@ -0,0 +1,120 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+// HTTP status codes, defined in RFC 2616.
+const (
+ StatusContinue = 100
+ StatusSwitchingProtocols = 101
+
+ StatusOK = 200
+ StatusCreated = 201
+ StatusAccepted = 202
+ StatusNonAuthoritativeInfo = 203
+ StatusNoContent = 204
+ StatusResetContent = 205
+ StatusPartialContent = 206
+
+ StatusMultipleChoices = 300
+ StatusMovedPermanently = 301
+ StatusFound = 302
+ StatusSeeOther = 303
+ StatusNotModified = 304
+ StatusUseProxy = 305
+ StatusTemporaryRedirect = 307
+
+ StatusBadRequest = 400
+ StatusUnauthorized = 401
+ StatusPaymentRequired = 402
+ StatusForbidden = 403
+ StatusNotFound = 404
+ StatusMethodNotAllowed = 405
+ StatusNotAcceptable = 406
+ StatusProxyAuthRequired = 407
+ StatusRequestTimeout = 408
+ StatusConflict = 409
+ StatusGone = 410
+ StatusLengthRequired = 411
+ StatusPreconditionFailed = 412
+ StatusRequestEntityTooLarge = 413
+ StatusRequestURITooLong = 414
+ StatusUnsupportedMediaType = 415
+ StatusRequestedRangeNotSatisfiable = 416
+ StatusExpectationFailed = 417
+ StatusTeapot = 418
+
+ StatusInternalServerError = 500
+ StatusNotImplemented = 501
+ StatusBadGateway = 502
+ StatusServiceUnavailable = 503
+ StatusGatewayTimeout = 504
+ StatusHTTPVersionNotSupported = 505
+
+ // New HTTP status codes from RFC 6585. Not exported yet in Go 1.1.
+ // See discussion at https://codereview.appspot.com/7678043/
+ statusPreconditionRequired = 428
+ statusTooManyRequests = 429
+ statusRequestHeaderFieldsTooLarge = 431
+ statusNetworkAuthenticationRequired = 511
+)
+
+var statusText = map[int]string{
+ StatusContinue: "Continue",
+ StatusSwitchingProtocols: "Switching Protocols",
+
+ StatusOK: "OK",
+ StatusCreated: "Created",
+ StatusAccepted: "Accepted",
+ StatusNonAuthoritativeInfo: "Non-Authoritative Information",
+ StatusNoContent: "No Content",
+ StatusResetContent: "Reset Content",
+ StatusPartialContent: "Partial Content",
+
+ StatusMultipleChoices: "Multiple Choices",
+ StatusMovedPermanently: "Moved Permanently",
+ StatusFound: "Found",
+ StatusSeeOther: "See Other",
+ StatusNotModified: "Not Modified",
+ StatusUseProxy: "Use Proxy",
+ StatusTemporaryRedirect: "Temporary Redirect",
+
+ StatusBadRequest: "Bad Request",
+ StatusUnauthorized: "Unauthorized",
+ StatusPaymentRequired: "Payment Required",
+ StatusForbidden: "Forbidden",
+ StatusNotFound: "Not Found",
+ StatusMethodNotAllowed: "Method Not Allowed",
+ StatusNotAcceptable: "Not Acceptable",
+ StatusProxyAuthRequired: "Proxy Authentication Required",
+ StatusRequestTimeout: "Request Timeout",
+ StatusConflict: "Conflict",
+ StatusGone: "Gone",
+ StatusLengthRequired: "Length Required",
+ StatusPreconditionFailed: "Precondition Failed",
+ StatusRequestEntityTooLarge: "Request Entity Too Large",
+ StatusRequestURITooLong: "Request URI Too Long",
+ StatusUnsupportedMediaType: "Unsupported Media Type",
+ StatusRequestedRangeNotSatisfiable: "Requested Range Not Satisfiable",
+ StatusExpectationFailed: "Expectation Failed",
+ StatusTeapot: "I'm a teapot",
+
+ StatusInternalServerError: "Internal Server Error",
+ StatusNotImplemented: "Not Implemented",
+ StatusBadGateway: "Bad Gateway",
+ StatusServiceUnavailable: "Service Unavailable",
+ StatusGatewayTimeout: "Gateway Timeout",
+ StatusHTTPVersionNotSupported: "HTTP Version Not Supported",
+
+ statusPreconditionRequired: "Precondition Required",
+ statusTooManyRequests: "Too Many Requests",
+ statusRequestHeaderFieldsTooLarge: "Request Header Fields Too Large",
+ statusNetworkAuthenticationRequired: "Network Authentication Required",
+}
+
+// StatusText returns a text for the HTTP status code. It returns the empty
+// string if the code is unknown.
+func StatusText(code int) string {
+ return statusText[code]
+}
diff --git a/src/net/http/testdata/file b/src/net/http/testdata/file
new file mode 100644
index 000000000..11f11f9be
--- /dev/null
+++ b/src/net/http/testdata/file
@@ -0,0 +1 @@
+0123456789
diff --git a/src/net/http/testdata/index.html b/src/net/http/testdata/index.html
new file mode 100644
index 000000000..da8e1e93d
--- /dev/null
+++ b/src/net/http/testdata/index.html
@@ -0,0 +1 @@
+index.html says hello
diff --git a/src/net/http/testdata/style.css b/src/net/http/testdata/style.css
new file mode 100644
index 000000000..208d16d42
--- /dev/null
+++ b/src/net/http/testdata/style.css
@@ -0,0 +1 @@
+body {}
diff --git a/src/net/http/transfer.go b/src/net/http/transfer.go
new file mode 100644
index 000000000..520500330
--- /dev/null
+++ b/src/net/http/transfer.go
@@ -0,0 +1,737 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http/internal"
+ "net/textproto"
+ "sort"
+ "strconv"
+ "strings"
+ "sync"
+)
+
+// ErrLineTooLong is returned when reading request or response bodies
+// with malformed chunked encoding.
+var ErrLineTooLong = internal.ErrLineTooLong
+
+type errorReader struct {
+ err error
+}
+
+func (r *errorReader) Read(p []byte) (n int, err error) {
+ return 0, r.err
+}
+
+// transferWriter inspects the fields of a user-supplied Request or Response,
+// sanitizes them without changing the user object and provides methods for
+// writing the respective header, body and trailer in wire format.
+type transferWriter struct {
+ Method string
+ Body io.Reader
+ BodyCloser io.Closer
+ ResponseToHEAD bool
+ ContentLength int64 // -1 means unknown, 0 means exactly none
+ Close bool
+ TransferEncoding []string
+ Trailer Header
+}
+
+func newTransferWriter(r interface{}) (t *transferWriter, err error) {
+ t = &transferWriter{}
+
+ // Extract relevant fields
+ atLeastHTTP11 := false
+ switch rr := r.(type) {
+ case *Request:
+ if rr.ContentLength != 0 && rr.Body == nil {
+ return nil, fmt.Errorf("http: Request.ContentLength=%d with nil Body", rr.ContentLength)
+ }
+ t.Method = rr.Method
+ t.Body = rr.Body
+ t.BodyCloser = rr.Body
+ t.ContentLength = rr.ContentLength
+ t.Close = rr.Close
+ t.TransferEncoding = rr.TransferEncoding
+ t.Trailer = rr.Trailer
+ atLeastHTTP11 = rr.ProtoAtLeast(1, 1)
+ if t.Body != nil && len(t.TransferEncoding) == 0 && atLeastHTTP11 {
+ if t.ContentLength == 0 {
+ // Test to see if it's actually zero or just unset.
+ var buf [1]byte
+ n, rerr := io.ReadFull(t.Body, buf[:])
+ if rerr != nil && rerr != io.EOF {
+ t.ContentLength = -1
+ t.Body = &errorReader{rerr}
+ } else if n == 1 {
+ // Oh, guess there is data in this Body Reader after all.
+ // The ContentLength field just wasn't set.
+ // Stich the Body back together again, re-attaching our
+ // consumed byte.
+ t.ContentLength = -1
+ t.Body = io.MultiReader(bytes.NewReader(buf[:]), t.Body)
+ } else {
+ // Body is actually empty.
+ t.Body = nil
+ t.BodyCloser = nil
+ }
+ }
+ if t.ContentLength < 0 {
+ t.TransferEncoding = []string{"chunked"}
+ }
+ }
+ case *Response:
+ if rr.Request != nil {
+ t.Method = rr.Request.Method
+ }
+ t.Body = rr.Body
+ t.BodyCloser = rr.Body
+ t.ContentLength = rr.ContentLength
+ t.Close = rr.Close
+ t.TransferEncoding = rr.TransferEncoding
+ t.Trailer = rr.Trailer
+ atLeastHTTP11 = rr.ProtoAtLeast(1, 1)
+ t.ResponseToHEAD = noBodyExpected(t.Method)
+ }
+
+ // Sanitize Body,ContentLength,TransferEncoding
+ if t.ResponseToHEAD {
+ t.Body = nil
+ if chunked(t.TransferEncoding) {
+ t.ContentLength = -1
+ }
+ } else {
+ if !atLeastHTTP11 || t.Body == nil {
+ t.TransferEncoding = nil
+ }
+ if chunked(t.TransferEncoding) {
+ t.ContentLength = -1
+ } else if t.Body == nil { // no chunking, no body
+ t.ContentLength = 0
+ }
+ }
+
+ // Sanitize Trailer
+ if !chunked(t.TransferEncoding) {
+ t.Trailer = nil
+ }
+
+ return t, nil
+}
+
+func noBodyExpected(requestMethod string) bool {
+ return requestMethod == "HEAD"
+}
+
+func (t *transferWriter) shouldSendContentLength() bool {
+ if chunked(t.TransferEncoding) {
+ return false
+ }
+ if t.ContentLength > 0 {
+ return true
+ }
+ // Many servers expect a Content-Length for these methods
+ if t.Method == "POST" || t.Method == "PUT" {
+ return true
+ }
+ if t.ContentLength == 0 && isIdentity(t.TransferEncoding) {
+ return true
+ }
+
+ return false
+}
+
+func (t *transferWriter) WriteHeader(w io.Writer) error {
+ if t.Close {
+ if _, err := io.WriteString(w, "Connection: close\r\n"); err != nil {
+ return err
+ }
+ }
+
+ // Write Content-Length and/or Transfer-Encoding whose values are a
+ // function of the sanitized field triple (Body, ContentLength,
+ // TransferEncoding)
+ if t.shouldSendContentLength() {
+ if _, err := io.WriteString(w, "Content-Length: "); err != nil {
+ return err
+ }
+ if _, err := io.WriteString(w, strconv.FormatInt(t.ContentLength, 10)+"\r\n"); err != nil {
+ return err
+ }
+ } else if chunked(t.TransferEncoding) {
+ if _, err := io.WriteString(w, "Transfer-Encoding: chunked\r\n"); err != nil {
+ return err
+ }
+ }
+
+ // Write Trailer header
+ if t.Trailer != nil {
+ keys := make([]string, 0, len(t.Trailer))
+ for k := range t.Trailer {
+ k = CanonicalHeaderKey(k)
+ switch k {
+ case "Transfer-Encoding", "Trailer", "Content-Length":
+ return &badStringError{"invalid Trailer key", k}
+ }
+ keys = append(keys, k)
+ }
+ if len(keys) > 0 {
+ sort.Strings(keys)
+ // TODO: could do better allocation-wise here, but trailers are rare,
+ // so being lazy for now.
+ if _, err := io.WriteString(w, "Trailer: "+strings.Join(keys, ",")+"\r\n"); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+func (t *transferWriter) WriteBody(w io.Writer) error {
+ var err error
+ var ncopy int64
+
+ // Write body
+ if t.Body != nil {
+ if chunked(t.TransferEncoding) {
+ cw := internal.NewChunkedWriter(w)
+ _, err = io.Copy(cw, t.Body)
+ if err == nil {
+ err = cw.Close()
+ }
+ } else if t.ContentLength == -1 {
+ ncopy, err = io.Copy(w, t.Body)
+ } else {
+ ncopy, err = io.Copy(w, io.LimitReader(t.Body, t.ContentLength))
+ if err != nil {
+ return err
+ }
+ var nextra int64
+ nextra, err = io.Copy(ioutil.Discard, t.Body)
+ ncopy += nextra
+ }
+ if err != nil {
+ return err
+ }
+ if err = t.BodyCloser.Close(); err != nil {
+ return err
+ }
+ }
+
+ if !t.ResponseToHEAD && t.ContentLength != -1 && t.ContentLength != ncopy {
+ return fmt.Errorf("http: ContentLength=%d with Body length %d",
+ t.ContentLength, ncopy)
+ }
+
+ // TODO(petar): Place trailer writer code here.
+ if chunked(t.TransferEncoding) {
+ // Write Trailer header
+ if t.Trailer != nil {
+ if err := t.Trailer.Write(w); err != nil {
+ return err
+ }
+ }
+ // Last chunk, empty trailer
+ _, err = io.WriteString(w, "\r\n")
+ }
+ return err
+}
+
+type transferReader struct {
+ // Input
+ Header Header
+ StatusCode int
+ RequestMethod string
+ ProtoMajor int
+ ProtoMinor int
+ // Output
+ Body io.ReadCloser
+ ContentLength int64
+ TransferEncoding []string
+ Close bool
+ Trailer Header
+}
+
+// bodyAllowedForStatus reports whether a given response status code
+// permits a body. See RFC2616, section 4.4.
+func bodyAllowedForStatus(status int) bool {
+ switch {
+ case status >= 100 && status <= 199:
+ return false
+ case status == 204:
+ return false
+ case status == 304:
+ return false
+ }
+ return true
+}
+
+var (
+ suppressedHeaders304 = []string{"Content-Type", "Content-Length", "Transfer-Encoding"}
+ suppressedHeadersNoBody = []string{"Content-Length", "Transfer-Encoding"}
+)
+
+func suppressedHeaders(status int) []string {
+ switch {
+ case status == 304:
+ // RFC 2616 section 10.3.5: "the response MUST NOT include other entity-headers"
+ return suppressedHeaders304
+ case !bodyAllowedForStatus(status):
+ return suppressedHeadersNoBody
+ }
+ return nil
+}
+
+// msg is *Request or *Response.
+func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
+ t := &transferReader{RequestMethod: "GET"}
+
+ // Unify input
+ isResponse := false
+ switch rr := msg.(type) {
+ case *Response:
+ t.Header = rr.Header
+ t.StatusCode = rr.StatusCode
+ t.ProtoMajor = rr.ProtoMajor
+ t.ProtoMinor = rr.ProtoMinor
+ t.Close = shouldClose(t.ProtoMajor, t.ProtoMinor, t.Header, true)
+ isResponse = true
+ if rr.Request != nil {
+ t.RequestMethod = rr.Request.Method
+ }
+ case *Request:
+ t.Header = rr.Header
+ t.ProtoMajor = rr.ProtoMajor
+ t.ProtoMinor = rr.ProtoMinor
+ // Transfer semantics for Requests are exactly like those for
+ // Responses with status code 200, responding to a GET method
+ t.StatusCode = 200
+ default:
+ panic("unexpected type")
+ }
+
+ // Default to HTTP/1.1
+ if t.ProtoMajor == 0 && t.ProtoMinor == 0 {
+ t.ProtoMajor, t.ProtoMinor = 1, 1
+ }
+
+ // Transfer encoding, content length
+ t.TransferEncoding, err = fixTransferEncoding(t.RequestMethod, t.Header)
+ if err != nil {
+ return err
+ }
+
+ realLength, err := fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.TransferEncoding)
+ if err != nil {
+ return err
+ }
+ if isResponse && t.RequestMethod == "HEAD" {
+ if n, err := parseContentLength(t.Header.get("Content-Length")); err != nil {
+ return err
+ } else {
+ t.ContentLength = n
+ }
+ } else {
+ t.ContentLength = realLength
+ }
+
+ // Trailer
+ t.Trailer, err = fixTrailer(t.Header, t.TransferEncoding)
+ if err != nil {
+ return err
+ }
+
+ // If there is no Content-Length or chunked Transfer-Encoding on a *Response
+ // and the status is not 1xx, 204 or 304, then the body is unbounded.
+ // See RFC2616, section 4.4.
+ switch msg.(type) {
+ case *Response:
+ if realLength == -1 &&
+ !chunked(t.TransferEncoding) &&
+ bodyAllowedForStatus(t.StatusCode) {
+ // Unbounded body.
+ t.Close = true
+ }
+ }
+
+ // Prepare body reader. ContentLength < 0 means chunked encoding
+ // or close connection when finished, since multipart is not supported yet
+ switch {
+ case chunked(t.TransferEncoding):
+ if noBodyExpected(t.RequestMethod) {
+ t.Body = eofReader
+ } else {
+ t.Body = &body{src: internal.NewChunkedReader(r), hdr: msg, r: r, closing: t.Close}
+ }
+ case realLength == 0:
+ t.Body = eofReader
+ case realLength > 0:
+ t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close}
+ default:
+ // realLength < 0, i.e. "Content-Length" not mentioned in header
+ if t.Close {
+ // Close semantics (i.e. HTTP/1.0)
+ t.Body = &body{src: r, closing: t.Close}
+ } else {
+ // Persistent connection (i.e. HTTP/1.1)
+ t.Body = eofReader
+ }
+ }
+
+ // Unify output
+ switch rr := msg.(type) {
+ case *Request:
+ rr.Body = t.Body
+ rr.ContentLength = t.ContentLength
+ rr.TransferEncoding = t.TransferEncoding
+ rr.Close = t.Close
+ rr.Trailer = t.Trailer
+ case *Response:
+ rr.Body = t.Body
+ rr.ContentLength = t.ContentLength
+ rr.TransferEncoding = t.TransferEncoding
+ rr.Close = t.Close
+ rr.Trailer = t.Trailer
+ }
+
+ return nil
+}
+
+// Checks whether chunked is part of the encodings stack
+func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" }
+
+// Checks whether the encoding is explicitly "identity".
+func isIdentity(te []string) bool { return len(te) == 1 && te[0] == "identity" }
+
+// Sanitize transfer encoding
+func fixTransferEncoding(requestMethod string, header Header) ([]string, error) {
+ raw, present := header["Transfer-Encoding"]
+ if !present {
+ return nil, nil
+ }
+
+ delete(header, "Transfer-Encoding")
+
+ encodings := strings.Split(raw[0], ",")
+ te := make([]string, 0, len(encodings))
+ // TODO: Even though we only support "identity" and "chunked"
+ // encodings, the loop below is designed with foresight. One
+ // invariant that must be maintained is that, if present,
+ // chunked encoding must always come first.
+ for _, encoding := range encodings {
+ encoding = strings.ToLower(strings.TrimSpace(encoding))
+ // "identity" encoding is not recorded
+ if encoding == "identity" {
+ break
+ }
+ if encoding != "chunked" {
+ return nil, &badStringError{"unsupported transfer encoding", encoding}
+ }
+ te = te[0 : len(te)+1]
+ te[len(te)-1] = encoding
+ }
+ if len(te) > 1 {
+ return nil, &badStringError{"too many transfer encodings", strings.Join(te, ",")}
+ }
+ if len(te) > 0 {
+ // Chunked encoding trumps Content-Length. See RFC 2616
+ // Section 4.4. Currently len(te) > 0 implies chunked
+ // encoding.
+ delete(header, "Content-Length")
+ return te, nil
+ }
+
+ return nil, nil
+}
+
+// Determine the expected body length, using RFC 2616 Section 4.4. This
+// function is not a method, because ultimately it should be shared by
+// ReadResponse and ReadRequest.
+func fixLength(isResponse bool, status int, requestMethod string, header Header, te []string) (int64, error) {
+
+ // Logic based on response type or status
+ if noBodyExpected(requestMethod) {
+ return 0, nil
+ }
+ if status/100 == 1 {
+ return 0, nil
+ }
+ switch status {
+ case 204, 304:
+ return 0, nil
+ }
+
+ // Logic based on Transfer-Encoding
+ if chunked(te) {
+ return -1, nil
+ }
+
+ // Logic based on Content-Length
+ cl := strings.TrimSpace(header.get("Content-Length"))
+ if cl != "" {
+ n, err := parseContentLength(cl)
+ if err != nil {
+ return -1, err
+ }
+ return n, nil
+ } else {
+ header.Del("Content-Length")
+ }
+
+ if !isResponse && requestMethod == "GET" {
+ // RFC 2616 doesn't explicitly permit nor forbid an
+ // entity-body on a GET request so we permit one if
+ // declared, but we default to 0 here (not -1 below)
+ // if there's no mention of a body.
+ return 0, nil
+ }
+
+ // Body-EOF logic based on other methods (like closing, or chunked coding)
+ return -1, nil
+}
+
+// Determine whether to hang up after sending a request and body, or
+// receiving a response and body
+// 'header' is the request headers
+func shouldClose(major, minor int, header Header, removeCloseHeader bool) bool {
+ if major < 1 {
+ return true
+ } else if major == 1 && minor == 0 {
+ if !strings.Contains(strings.ToLower(header.get("Connection")), "keep-alive") {
+ return true
+ }
+ return false
+ } else {
+ // TODO: Should split on commas, toss surrounding white space,
+ // and check each field.
+ if strings.ToLower(header.get("Connection")) == "close" {
+ if removeCloseHeader {
+ header.Del("Connection")
+ }
+ return true
+ }
+ }
+ return false
+}
+
+// Parse the trailer header
+func fixTrailer(header Header, te []string) (Header, error) {
+ raw := header.get("Trailer")
+ if raw == "" {
+ return nil, nil
+ }
+
+ header.Del("Trailer")
+ trailer := make(Header)
+ keys := strings.Split(raw, ",")
+ for _, key := range keys {
+ key = CanonicalHeaderKey(strings.TrimSpace(key))
+ switch key {
+ case "Transfer-Encoding", "Trailer", "Content-Length":
+ return nil, &badStringError{"bad trailer key", key}
+ }
+ trailer[key] = nil
+ }
+ if len(trailer) == 0 {
+ return nil, nil
+ }
+ if !chunked(te) {
+ // Trailer and no chunking
+ return nil, ErrUnexpectedTrailer
+ }
+ return trailer, nil
+}
+
+// body turns a Reader into a ReadCloser.
+// Close ensures that the body has been fully read
+// and then reads the trailer if necessary.
+type body struct {
+ src io.Reader
+ hdr interface{} // non-nil (Response or Request) value means read trailer
+ r *bufio.Reader // underlying wire-format reader for the trailer
+ closing bool // is the connection to be closed after reading body?
+
+ mu sync.Mutex // guards closed, and calls to Read and Close
+ closed bool
+}
+
+// ErrBodyReadAfterClose is returned when reading a Request or Response
+// Body after the body has been closed. This typically happens when the body is
+// read after an HTTP Handler calls WriteHeader or Write on its
+// ResponseWriter.
+var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed Body")
+
+func (b *body) Read(p []byte) (n int, err error) {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+ if b.closed {
+ return 0, ErrBodyReadAfterClose
+ }
+ return b.readLocked(p)
+}
+
+// Must hold b.mu.
+func (b *body) readLocked(p []byte) (n int, err error) {
+ n, err = b.src.Read(p)
+
+ if err == io.EOF {
+ // Chunked case. Read the trailer.
+ if b.hdr != nil {
+ if e := b.readTrailer(); e != nil {
+ err = e
+ }
+ b.hdr = nil
+ } else {
+ // If the server declared the Content-Length, our body is a LimitedReader
+ // and we need to check whether this EOF arrived early.
+ if lr, ok := b.src.(*io.LimitedReader); ok && lr.N > 0 {
+ err = io.ErrUnexpectedEOF
+ }
+ }
+ }
+
+ // If we can return an EOF here along with the read data, do
+ // so. This is optional per the io.Reader contract, but doing
+ // so helps the HTTP transport code recycle its connection
+ // earlier (since it will see this EOF itself), even if the
+ // client doesn't do future reads or Close.
+ if err == nil && n > 0 {
+ if lr, ok := b.src.(*io.LimitedReader); ok && lr.N == 0 {
+ err = io.EOF
+ }
+ }
+
+ return n, err
+}
+
+var (
+ singleCRLF = []byte("\r\n")
+ doubleCRLF = []byte("\r\n\r\n")
+)
+
+func seeUpcomingDoubleCRLF(r *bufio.Reader) bool {
+ for peekSize := 4; ; peekSize++ {
+ // This loop stops when Peek returns an error,
+ // which it does when r's buffer has been filled.
+ buf, err := r.Peek(peekSize)
+ if bytes.HasSuffix(buf, doubleCRLF) {
+ return true
+ }
+ if err != nil {
+ break
+ }
+ }
+ return false
+}
+
+var errTrailerEOF = errors.New("http: unexpected EOF reading trailer")
+
+func (b *body) readTrailer() error {
+ // The common case, since nobody uses trailers.
+ buf, err := b.r.Peek(2)
+ if bytes.Equal(buf, singleCRLF) {
+ b.r.ReadByte()
+ b.r.ReadByte()
+ return nil
+ }
+ if len(buf) < 2 {
+ return errTrailerEOF
+ }
+ if err != nil {
+ return err
+ }
+
+ // Make sure there's a header terminator coming up, to prevent
+ // a DoS with an unbounded size Trailer. It's not easy to
+ // slip in a LimitReader here, as textproto.NewReader requires
+ // a concrete *bufio.Reader. Also, we can't get all the way
+ // back up to our conn's LimitedReader that *might* be backing
+ // this bufio.Reader. Instead, a hack: we iteratively Peek up
+ // to the bufio.Reader's max size, looking for a double CRLF.
+ // This limits the trailer to the underlying buffer size, typically 4kB.
+ if !seeUpcomingDoubleCRLF(b.r) {
+ return errors.New("http: suspiciously long trailer after chunked body")
+ }
+
+ hdr, err := textproto.NewReader(b.r).ReadMIMEHeader()
+ if err != nil {
+ if err == io.EOF {
+ return errTrailerEOF
+ }
+ return err
+ }
+ switch rr := b.hdr.(type) {
+ case *Request:
+ mergeSetHeader(&rr.Trailer, Header(hdr))
+ case *Response:
+ mergeSetHeader(&rr.Trailer, Header(hdr))
+ }
+ return nil
+}
+
+func mergeSetHeader(dst *Header, src Header) {
+ if *dst == nil {
+ *dst = src
+ return
+ }
+ for k, vv := range src {
+ (*dst)[k] = vv
+ }
+}
+
+func (b *body) Close() error {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+ if b.closed {
+ return nil
+ }
+ var err error
+ switch {
+ case b.hdr == nil && b.closing:
+ // no trailer and closing the connection next.
+ // no point in reading to EOF.
+ default:
+ // Fully consume the body, which will also lead to us reading
+ // the trailer headers after the body, if present.
+ _, err = io.Copy(ioutil.Discard, bodyLocked{b})
+ }
+ b.closed = true
+ return err
+}
+
+// bodyLocked is a io.Reader reading from a *body when its mutex is
+// already held.
+type bodyLocked struct {
+ b *body
+}
+
+func (bl bodyLocked) Read(p []byte) (n int, err error) {
+ if bl.b.closed {
+ return 0, ErrBodyReadAfterClose
+ }
+ return bl.b.readLocked(p)
+}
+
+// parseContentLength trims whitespace from s and returns -1 if no value
+// is set, or the value if it's >= 0.
+func parseContentLength(cl string) (int64, error) {
+ cl = strings.TrimSpace(cl)
+ if cl == "" {
+ return -1, nil
+ }
+ n, err := strconv.ParseInt(cl, 10, 64)
+ if err != nil || n < 0 {
+ return 0, &badStringError{"bad Content-Length", cl}
+ }
+ return n, nil
+
+}
diff --git a/src/net/http/transfer_test.go b/src/net/http/transfer_test.go
new file mode 100644
index 000000000..48cd540b9
--- /dev/null
+++ b/src/net/http/transfer_test.go
@@ -0,0 +1,64 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bufio"
+ "io"
+ "strings"
+ "testing"
+)
+
+func TestBodyReadBadTrailer(t *testing.T) {
+ b := &body{
+ src: strings.NewReader("foobar"),
+ hdr: true, // force reading the trailer
+ r: bufio.NewReader(strings.NewReader("")),
+ }
+ buf := make([]byte, 7)
+ n, err := b.Read(buf[:3])
+ got := string(buf[:n])
+ if got != "foo" || err != nil {
+ t.Fatalf(`first Read = %d (%q), %v; want 3 ("foo")`, n, got, err)
+ }
+
+ n, err = b.Read(buf[:])
+ got = string(buf[:n])
+ if got != "bar" || err != nil {
+ t.Fatalf(`second Read = %d (%q), %v; want 3 ("bar")`, n, got, err)
+ }
+
+ n, err = b.Read(buf[:])
+ got = string(buf[:n])
+ if err == nil {
+ t.Errorf("final Read was successful (%q), expected error from trailer read", got)
+ }
+}
+
+func TestFinalChunkedBodyReadEOF(t *testing.T) {
+ res, err := ReadResponse(bufio.NewReader(strings.NewReader(
+ "HTTP/1.1 200 OK\r\n"+
+ "Transfer-Encoding: chunked\r\n"+
+ "\r\n"+
+ "0a\r\n"+
+ "Body here\n\r\n"+
+ "09\r\n"+
+ "continued\r\n"+
+ "0\r\n"+
+ "\r\n")), nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ want := "Body here\ncontinued"
+ buf := make([]byte, len(want))
+ n, err := res.Body.Read(buf)
+ if n != len(want) || err != io.EOF {
+ t.Logf("body = %#v", res.Body)
+ t.Errorf("Read = %v, %v; want %d, EOF", n, err, len(want))
+ }
+ if string(buf) != want {
+ t.Errorf("buf = %q; want %q", buf, want)
+ }
+}
diff --git a/src/net/http/transport.go b/src/net/http/transport.go
new file mode 100644
index 000000000..782f7cd39
--- /dev/null
+++ b/src/net/http/transport.go
@@ -0,0 +1,1275 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP client implementation. See RFC 2616.
+//
+// This is the low-level Transport implementation of RoundTripper.
+// The high-level interface is in client.go.
+
+package http
+
+import (
+ "bufio"
+ "compress/gzip"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "net/url"
+ "os"
+ "strings"
+ "sync"
+ "time"
+)
+
+// DefaultTransport is the default implementation of Transport and is
+// used by DefaultClient. It establishes network connections as needed
+// and caches them for reuse by subsequent calls. It uses HTTP proxies
+// as directed by the $HTTP_PROXY and $NO_PROXY (or $http_proxy and
+// $no_proxy) environment variables.
+var DefaultTransport RoundTripper = &Transport{
+ Proxy: ProxyFromEnvironment,
+ Dial: (&net.Dialer{
+ Timeout: 30 * time.Second,
+ KeepAlive: 30 * time.Second,
+ }).Dial,
+ TLSHandshakeTimeout: 10 * time.Second,
+}
+
+// DefaultMaxIdleConnsPerHost is the default value of Transport's
+// MaxIdleConnsPerHost.
+const DefaultMaxIdleConnsPerHost = 2
+
+// Transport is an implementation of RoundTripper that supports HTTP,
+// HTTPS, and HTTP proxies (for either HTTP or HTTPS with CONNECT).
+// Transport can also cache connections for future re-use.
+type Transport struct {
+ idleMu sync.Mutex
+ wantIdle bool // user has requested to close all idle conns
+ idleConn map[connectMethodKey][]*persistConn
+ idleConnCh map[connectMethodKey]chan *persistConn
+
+ reqMu sync.Mutex
+ reqCanceler map[*Request]func()
+
+ altMu sync.RWMutex
+ altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper
+
+ // Proxy specifies a function to return a proxy for a given
+ // Request. If the function returns a non-nil error, the
+ // request is aborted with the provided error.
+ // If Proxy is nil or returns a nil *URL, no proxy is used.
+ Proxy func(*Request) (*url.URL, error)
+
+ // Dial specifies the dial function for creating unencrypted
+ // TCP connections.
+ // If Dial is nil, net.Dial is used.
+ Dial func(network, addr string) (net.Conn, error)
+
+ // DialTLS specifies an optional dial function for creating
+ // TLS connections for non-proxied HTTPS requests.
+ //
+ // If DialTLS is nil, Dial and TLSClientConfig are used.
+ //
+ // If DialTLS is set, the Dial hook is not used for HTTPS
+ // requests and the TLSClientConfig and TLSHandshakeTimeout
+ // are ignored. The returned net.Conn is assumed to already be
+ // past the TLS handshake.
+ DialTLS func(network, addr string) (net.Conn, error)
+
+ // TLSClientConfig specifies the TLS configuration to use with
+ // tls.Client. If nil, the default configuration is used.
+ TLSClientConfig *tls.Config
+
+ // TLSHandshakeTimeout specifies the maximum amount of time waiting to
+ // wait for a TLS handshake. Zero means no timeout.
+ TLSHandshakeTimeout time.Duration
+
+ // DisableKeepAlives, if true, prevents re-use of TCP connections
+ // between different HTTP requests.
+ DisableKeepAlives bool
+
+ // DisableCompression, if true, prevents the Transport from
+ // requesting compression with an "Accept-Encoding: gzip"
+ // request header when the Request contains no existing
+ // Accept-Encoding value. If the Transport requests gzip on
+ // its own and gets a gzipped response, it's transparently
+ // decoded in the Response.Body. However, if the user
+ // explicitly requested gzip it is not automatically
+ // uncompressed.
+ DisableCompression bool
+
+ // MaxIdleConnsPerHost, if non-zero, controls the maximum idle
+ // (keep-alive) to keep per-host. If zero,
+ // DefaultMaxIdleConnsPerHost is used.
+ MaxIdleConnsPerHost int
+
+ // ResponseHeaderTimeout, if non-zero, specifies the amount of
+ // time to wait for a server's response headers after fully
+ // writing the request (including its body, if any). This
+ // time does not include the time to read the response body.
+ ResponseHeaderTimeout time.Duration
+
+ // TODO: tunable on global max cached connections
+ // TODO: tunable on timeout on cached connections
+}
+
+// ProxyFromEnvironment returns the URL of the proxy to use for a
+// given request, as indicated by the environment variables
+// HTTP_PROXY, HTTPS_PROXY and NO_PROXY (or the lowercase versions
+// thereof). HTTPS_PROXY takes precedence over HTTP_PROXY for https
+// requests.
+//
+// The environment values may be either a complete URL or a
+// "host[:port]", in which case the "http" scheme is assumed.
+// An error is returned if the value is a different form.
+//
+// A nil URL and nil error are returned if no proxy is defined in the
+// environment, or a proxy should not be used for the given request,
+// as defined by NO_PROXY.
+//
+// As a special case, if req.URL.Host is "localhost" (with or without
+// a port number), then a nil URL and nil error will be returned.
+func ProxyFromEnvironment(req *Request) (*url.URL, error) {
+ var proxy string
+ if req.URL.Scheme == "https" {
+ proxy = httpsProxyEnv.Get()
+ }
+ if proxy == "" {
+ proxy = httpProxyEnv.Get()
+ }
+ if proxy == "" {
+ return nil, nil
+ }
+ if !useProxy(canonicalAddr(req.URL)) {
+ return nil, nil
+ }
+ proxyURL, err := url.Parse(proxy)
+ if err != nil || !strings.HasPrefix(proxyURL.Scheme, "http") {
+ // proxy was bogus. Try prepending "http://" to it and
+ // see if that parses correctly. If not, we fall
+ // through and complain about the original one.
+ if proxyURL, err := url.Parse("http://" + proxy); err == nil {
+ return proxyURL, nil
+ }
+ }
+ if err != nil {
+ return nil, fmt.Errorf("invalid proxy address %q: %v", proxy, err)
+ }
+ return proxyURL, nil
+}
+
+// ProxyURL returns a proxy function (for use in a Transport)
+// that always returns the same URL.
+func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) {
+ return func(*Request) (*url.URL, error) {
+ return fixedURL, nil
+ }
+}
+
+// transportRequest is a wrapper around a *Request that adds
+// optional extra headers to write.
+type transportRequest struct {
+ *Request // original request, not to be mutated
+ extra Header // extra headers to write, or nil
+}
+
+func (tr *transportRequest) extraHeaders() Header {
+ if tr.extra == nil {
+ tr.extra = make(Header)
+ }
+ return tr.extra
+}
+
+// RoundTrip implements the RoundTripper interface.
+//
+// For higher-level HTTP client support (such as handling of cookies
+// and redirects), see Get, Post, and the Client type.
+func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) {
+ if req.URL == nil {
+ req.closeBody()
+ return nil, errors.New("http: nil Request.URL")
+ }
+ if req.Header == nil {
+ req.closeBody()
+ return nil, errors.New("http: nil Request.Header")
+ }
+ if req.URL.Scheme != "http" && req.URL.Scheme != "https" {
+ t.altMu.RLock()
+ var rt RoundTripper
+ if t.altProto != nil {
+ rt = t.altProto[req.URL.Scheme]
+ }
+ t.altMu.RUnlock()
+ if rt == nil {
+ req.closeBody()
+ return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme}
+ }
+ return rt.RoundTrip(req)
+ }
+ if req.URL.Host == "" {
+ req.closeBody()
+ return nil, errors.New("http: no Host in request URL")
+ }
+ treq := &transportRequest{Request: req}
+ cm, err := t.connectMethodForRequest(treq)
+ if err != nil {
+ req.closeBody()
+ return nil, err
+ }
+
+ // Get the cached or newly-created connection to either the
+ // host (for http or https), the http proxy, or the http proxy
+ // pre-CONNECTed to https server. In any case, we'll be ready
+ // to send it requests.
+ pconn, err := t.getConn(req, cm)
+ if err != nil {
+ t.setReqCanceler(req, nil)
+ req.closeBody()
+ return nil, err
+ }
+
+ return pconn.roundTrip(treq)
+}
+
+// RegisterProtocol registers a new protocol with scheme.
+// The Transport will pass requests using the given scheme to rt.
+// It is rt's responsibility to simulate HTTP request semantics.
+//
+// RegisterProtocol can be used by other packages to provide
+// implementations of protocol schemes like "ftp" or "file".
+func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) {
+ if scheme == "http" || scheme == "https" {
+ panic("protocol " + scheme + " already registered")
+ }
+ t.altMu.Lock()
+ defer t.altMu.Unlock()
+ if t.altProto == nil {
+ t.altProto = make(map[string]RoundTripper)
+ }
+ if _, exists := t.altProto[scheme]; exists {
+ panic("protocol " + scheme + " already registered")
+ }
+ t.altProto[scheme] = rt
+}
+
+// CloseIdleConnections closes any connections which were previously
+// connected from previous requests but are now sitting idle in
+// a "keep-alive" state. It does not interrupt any connections currently
+// in use.
+func (t *Transport) CloseIdleConnections() {
+ t.idleMu.Lock()
+ m := t.idleConn
+ t.idleConn = nil
+ t.idleConnCh = nil
+ t.wantIdle = true
+ t.idleMu.Unlock()
+ for _, conns := range m {
+ for _, pconn := range conns {
+ pconn.close()
+ }
+ }
+}
+
+// CancelRequest cancels an in-flight request by closing its
+// connection.
+func (t *Transport) CancelRequest(req *Request) {
+ t.reqMu.Lock()
+ cancel := t.reqCanceler[req]
+ t.reqMu.Unlock()
+ if cancel != nil {
+ cancel()
+ }
+}
+
+//
+// Private implementation past this point.
+//
+
+var (
+ httpProxyEnv = &envOnce{
+ names: []string{"HTTP_PROXY", "http_proxy"},
+ }
+ httpsProxyEnv = &envOnce{
+ names: []string{"HTTPS_PROXY", "https_proxy"},
+ }
+ noProxyEnv = &envOnce{
+ names: []string{"NO_PROXY", "no_proxy"},
+ }
+)
+
+// envOnce looks up an environment variable (optionally by multiple
+// names) once. It mitigates expensive lookups on some platforms
+// (e.g. Windows).
+type envOnce struct {
+ names []string
+ once sync.Once
+ val string
+}
+
+func (e *envOnce) Get() string {
+ e.once.Do(e.init)
+ return e.val
+}
+
+func (e *envOnce) init() {
+ for _, n := range e.names {
+ e.val = os.Getenv(n)
+ if e.val != "" {
+ return
+ }
+ }
+}
+
+// reset is used by tests
+func (e *envOnce) reset() {
+ e.once = sync.Once{}
+ e.val = ""
+}
+
+func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) {
+ cm.targetScheme = treq.URL.Scheme
+ cm.targetAddr = canonicalAddr(treq.URL)
+ if t.Proxy != nil {
+ cm.proxyURL, err = t.Proxy(treq.Request)
+ }
+ return cm, err
+}
+
+// proxyAuth returns the Proxy-Authorization header to set
+// on requests, if applicable.
+func (cm *connectMethod) proxyAuth() string {
+ if cm.proxyURL == nil {
+ return ""
+ }
+ if u := cm.proxyURL.User; u != nil {
+ username := u.Username()
+ password, _ := u.Password()
+ return "Basic " + basicAuth(username, password)
+ }
+ return ""
+}
+
+// putIdleConn adds pconn to the list of idle persistent connections awaiting
+// a new request.
+// If pconn is no longer needed or not in a good state, putIdleConn
+// returns false.
+func (t *Transport) putIdleConn(pconn *persistConn) bool {
+ if t.DisableKeepAlives || t.MaxIdleConnsPerHost < 0 {
+ pconn.close()
+ return false
+ }
+ if pconn.isBroken() {
+ return false
+ }
+ key := pconn.cacheKey
+ max := t.MaxIdleConnsPerHost
+ if max == 0 {
+ max = DefaultMaxIdleConnsPerHost
+ }
+ t.idleMu.Lock()
+
+ waitingDialer := t.idleConnCh[key]
+ select {
+ case waitingDialer <- pconn:
+ // We're done with this pconn and somebody else is
+ // currently waiting for a conn of this type (they're
+ // actively dialing, but this conn is ready
+ // first). Chrome calls this socket late binding. See
+ // https://insouciant.org/tech/connection-management-in-chromium/
+ t.idleMu.Unlock()
+ return true
+ default:
+ if waitingDialer != nil {
+ // They had populated this, but their dial won
+ // first, so we can clean up this map entry.
+ delete(t.idleConnCh, key)
+ }
+ }
+ if t.wantIdle {
+ t.idleMu.Unlock()
+ pconn.close()
+ return false
+ }
+ if t.idleConn == nil {
+ t.idleConn = make(map[connectMethodKey][]*persistConn)
+ }
+ if len(t.idleConn[key]) >= max {
+ t.idleMu.Unlock()
+ pconn.close()
+ return false
+ }
+ for _, exist := range t.idleConn[key] {
+ if exist == pconn {
+ log.Fatalf("dup idle pconn %p in freelist", pconn)
+ }
+ }
+ t.idleConn[key] = append(t.idleConn[key], pconn)
+ t.idleMu.Unlock()
+ return true
+}
+
+// getIdleConnCh returns a channel to receive and return idle
+// persistent connection for the given connectMethod.
+// It may return nil, if persistent connections are not being used.
+func (t *Transport) getIdleConnCh(cm connectMethod) chan *persistConn {
+ if t.DisableKeepAlives {
+ return nil
+ }
+ key := cm.key()
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
+ t.wantIdle = false
+ if t.idleConnCh == nil {
+ t.idleConnCh = make(map[connectMethodKey]chan *persistConn)
+ }
+ ch, ok := t.idleConnCh[key]
+ if !ok {
+ ch = make(chan *persistConn)
+ t.idleConnCh[key] = ch
+ }
+ return ch
+}
+
+func (t *Transport) getIdleConn(cm connectMethod) (pconn *persistConn) {
+ key := cm.key()
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
+ if t.idleConn == nil {
+ return nil
+ }
+ for {
+ pconns, ok := t.idleConn[key]
+ if !ok {
+ return nil
+ }
+ if len(pconns) == 1 {
+ pconn = pconns[0]
+ delete(t.idleConn, key)
+ } else {
+ // 2 or more cached connections; pop last
+ // TODO: queue?
+ pconn = pconns[len(pconns)-1]
+ t.idleConn[key] = pconns[:len(pconns)-1]
+ }
+ if !pconn.isBroken() {
+ return
+ }
+ }
+}
+
+func (t *Transport) setReqCanceler(r *Request, fn func()) {
+ t.reqMu.Lock()
+ defer t.reqMu.Unlock()
+ if t.reqCanceler == nil {
+ t.reqCanceler = make(map[*Request]func())
+ }
+ if fn != nil {
+ t.reqCanceler[r] = fn
+ } else {
+ delete(t.reqCanceler, r)
+ }
+}
+
+func (t *Transport) dial(network, addr string) (c net.Conn, err error) {
+ if t.Dial != nil {
+ return t.Dial(network, addr)
+ }
+ return net.Dial(network, addr)
+}
+
+// Testing hooks:
+var prePendingDial, postPendingDial func()
+
+// getConn dials and creates a new persistConn to the target as
+// specified in the connectMethod. This includes doing a proxy CONNECT
+// and/or setting up TLS. If this doesn't return an error, the persistConn
+// is ready to write requests to.
+func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error) {
+ if pc := t.getIdleConn(cm); pc != nil {
+ return pc, nil
+ }
+
+ type dialRes struct {
+ pc *persistConn
+ err error
+ }
+ dialc := make(chan dialRes)
+
+ handlePendingDial := func() {
+ if prePendingDial != nil {
+ prePendingDial()
+ }
+ go func() {
+ if v := <-dialc; v.err == nil {
+ t.putIdleConn(v.pc)
+ }
+ if postPendingDial != nil {
+ postPendingDial()
+ }
+ }()
+ }
+
+ cancelc := make(chan struct{})
+ t.setReqCanceler(req, func() { close(cancelc) })
+
+ go func() {
+ pc, err := t.dialConn(cm)
+ dialc <- dialRes{pc, err}
+ }()
+
+ idleConnCh := t.getIdleConnCh(cm)
+ select {
+ case v := <-dialc:
+ // Our dial finished.
+ return v.pc, v.err
+ case pc := <-idleConnCh:
+ // Another request finished first and its net.Conn
+ // became available before our dial. Or somebody
+ // else's dial that they didn't use.
+ // But our dial is still going, so give it away
+ // when it finishes:
+ handlePendingDial()
+ return pc, nil
+ case <-cancelc:
+ handlePendingDial()
+ return nil, errors.New("net/http: request canceled while waiting for connection")
+ }
+}
+
+func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
+ pconn := &persistConn{
+ t: t,
+ cacheKey: cm.key(),
+ reqch: make(chan requestAndChan, 1),
+ writech: make(chan writeRequest, 1),
+ closech: make(chan struct{}),
+ writeErrCh: make(chan error, 1),
+ }
+ tlsDial := t.DialTLS != nil && cm.targetScheme == "https" && cm.proxyURL == nil
+ if tlsDial {
+ var err error
+ pconn.conn, err = t.DialTLS("tcp", cm.addr())
+ if err != nil {
+ return nil, err
+ }
+ if tc, ok := pconn.conn.(*tls.Conn); ok {
+ cs := tc.ConnectionState()
+ pconn.tlsState = &cs
+ }
+ } else {
+ conn, err := t.dial("tcp", cm.addr())
+ if err != nil {
+ if cm.proxyURL != nil {
+ err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err)
+ }
+ return nil, err
+ }
+ pconn.conn = conn
+ }
+
+ // Proxy setup.
+ switch {
+ case cm.proxyURL == nil:
+ // Do nothing. Not using a proxy.
+ case cm.targetScheme == "http":
+ pconn.isProxy = true
+ if pa := cm.proxyAuth(); pa != "" {
+ pconn.mutateHeaderFunc = func(h Header) {
+ h.Set("Proxy-Authorization", pa)
+ }
+ }
+ case cm.targetScheme == "https":
+ conn := pconn.conn
+ connectReq := &Request{
+ Method: "CONNECT",
+ URL: &url.URL{Opaque: cm.targetAddr},
+ Host: cm.targetAddr,
+ Header: make(Header),
+ }
+ if pa := cm.proxyAuth(); pa != "" {
+ connectReq.Header.Set("Proxy-Authorization", pa)
+ }
+ connectReq.Write(conn)
+
+ // Read response.
+ // Okay to use and discard buffered reader here, because
+ // TLS server will not speak until spoken to.
+ br := bufio.NewReader(conn)
+ resp, err := ReadResponse(br, connectReq)
+ if err != nil {
+ conn.Close()
+ return nil, err
+ }
+ if resp.StatusCode != 200 {
+ f := strings.SplitN(resp.Status, " ", 2)
+ conn.Close()
+ return nil, errors.New(f[1])
+ }
+ }
+
+ if cm.targetScheme == "https" && !tlsDial {
+ // Initiate TLS and check remote host name against certificate.
+ cfg := t.TLSClientConfig
+ if cfg == nil || cfg.ServerName == "" {
+ host := cm.tlsHost()
+ if cfg == nil {
+ cfg = &tls.Config{ServerName: host}
+ } else {
+ clone := *cfg // shallow clone
+ clone.ServerName = host
+ cfg = &clone
+ }
+ }
+ plainConn := pconn.conn
+ tlsConn := tls.Client(plainConn, cfg)
+ errc := make(chan error, 2)
+ var timer *time.Timer // for canceling TLS handshake
+ if d := t.TLSHandshakeTimeout; d != 0 {
+ timer = time.AfterFunc(d, func() {
+ errc <- tlsHandshakeTimeoutError{}
+ })
+ }
+ go func() {
+ err := tlsConn.Handshake()
+ if timer != nil {
+ timer.Stop()
+ }
+ errc <- err
+ }()
+ if err := <-errc; err != nil {
+ plainConn.Close()
+ return nil, err
+ }
+ if !cfg.InsecureSkipVerify {
+ if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
+ plainConn.Close()
+ return nil, err
+ }
+ }
+ cs := tlsConn.ConnectionState()
+ pconn.tlsState = &cs
+ pconn.conn = tlsConn
+ }
+
+ pconn.br = bufio.NewReader(noteEOFReader{pconn.conn, &pconn.sawEOF})
+ pconn.bw = bufio.NewWriter(pconn.conn)
+ go pconn.readLoop()
+ go pconn.writeLoop()
+ return pconn, nil
+}
+
+// useProxy returns true if requests to addr should use a proxy,
+// according to the NO_PROXY or no_proxy environment variable.
+// addr is always a canonicalAddr with a host and port.
+func useProxy(addr string) bool {
+ if len(addr) == 0 {
+ return true
+ }
+ host, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ return false
+ }
+ if host == "localhost" {
+ return false
+ }
+ if ip := net.ParseIP(host); ip != nil {
+ if ip.IsLoopback() {
+ return false
+ }
+ }
+
+ no_proxy := noProxyEnv.Get()
+ if no_proxy == "*" {
+ return false
+ }
+
+ addr = strings.ToLower(strings.TrimSpace(addr))
+ if hasPort(addr) {
+ addr = addr[:strings.LastIndex(addr, ":")]
+ }
+
+ for _, p := range strings.Split(no_proxy, ",") {
+ p = strings.ToLower(strings.TrimSpace(p))
+ if len(p) == 0 {
+ continue
+ }
+ if hasPort(p) {
+ p = p[:strings.LastIndex(p, ":")]
+ }
+ if addr == p {
+ return false
+ }
+ if p[0] == '.' && (strings.HasSuffix(addr, p) || addr == p[1:]) {
+ // no_proxy ".foo.com" matches "bar.foo.com" or "foo.com"
+ return false
+ }
+ if p[0] != '.' && strings.HasSuffix(addr, p) && addr[len(addr)-len(p)-1] == '.' {
+ // no_proxy "foo.com" matches "bar.foo.com"
+ return false
+ }
+ }
+ return true
+}
+
+// connectMethod is the map key (in its String form) for keeping persistent
+// TCP connections alive for subsequent HTTP requests.
+//
+// A connect method may be of the following types:
+//
+// Cache key form Description
+// ----------------- -------------------------
+// |http|foo.com http directly to server, no proxy
+// |https|foo.com https directly to server, no proxy
+// http://proxy.com|https|foo.com http to proxy, then CONNECT to foo.com
+// http://proxy.com|http http to proxy, http to anywhere after that
+//
+// Note: no support to https to the proxy yet.
+//
+type connectMethod struct {
+ proxyURL *url.URL // nil for no proxy, else full proxy URL
+ targetScheme string // "http" or "https"
+ targetAddr string // Not used if proxy + http targetScheme (4th example in table)
+}
+
+func (cm *connectMethod) key() connectMethodKey {
+ proxyStr := ""
+ targetAddr := cm.targetAddr
+ if cm.proxyURL != nil {
+ proxyStr = cm.proxyURL.String()
+ if cm.targetScheme == "http" {
+ targetAddr = ""
+ }
+ }
+ return connectMethodKey{
+ proxy: proxyStr,
+ scheme: cm.targetScheme,
+ addr: targetAddr,
+ }
+}
+
+// addr returns the first hop "host:port" to which we need to TCP connect.
+func (cm *connectMethod) addr() string {
+ if cm.proxyURL != nil {
+ return canonicalAddr(cm.proxyURL)
+ }
+ return cm.targetAddr
+}
+
+// tlsHost returns the host name to match against the peer's
+// TLS certificate.
+func (cm *connectMethod) tlsHost() string {
+ h := cm.targetAddr
+ if hasPort(h) {
+ h = h[:strings.LastIndex(h, ":")]
+ }
+ return h
+}
+
+// connectMethodKey is the map key version of connectMethod, with a
+// stringified proxy URL (or the empty string) instead of a pointer to
+// a URL.
+type connectMethodKey struct {
+ proxy, scheme, addr string
+}
+
+func (k connectMethodKey) String() string {
+ // Only used by tests.
+ return fmt.Sprintf("%s|%s|%s", k.proxy, k.scheme, k.addr)
+}
+
+// persistConn wraps a connection, usually a persistent one
+// (but may be used for non-keep-alive requests as well)
+type persistConn struct {
+ t *Transport
+ cacheKey connectMethodKey
+ conn net.Conn
+ tlsState *tls.ConnectionState
+ br *bufio.Reader // from conn
+ sawEOF bool // whether we've seen EOF from conn; owned by readLoop
+ bw *bufio.Writer // to conn
+ reqch chan requestAndChan // written by roundTrip; read by readLoop
+ writech chan writeRequest // written by roundTrip; read by writeLoop
+ closech chan struct{} // closed when conn closed
+ isProxy bool
+ // writeErrCh passes the request write error (usually nil)
+ // from the writeLoop goroutine to the readLoop which passes
+ // it off to the res.Body reader, which then uses it to decide
+ // whether or not a connection can be reused. Issue 7569.
+ writeErrCh chan error
+
+ lk sync.Mutex // guards following fields
+ numExpectedResponses int
+ closed bool // whether conn has been closed
+ broken bool // an error has happened on this connection; marked broken so it's not reused.
+ // mutateHeaderFunc is an optional func to modify extra
+ // headers on each outbound request before it's written. (the
+ // original Request given to RoundTrip is not modified)
+ mutateHeaderFunc func(Header)
+}
+
+// isBroken reports whether this connection is in a known broken state.
+func (pc *persistConn) isBroken() bool {
+ pc.lk.Lock()
+ b := pc.broken
+ pc.lk.Unlock()
+ return b
+}
+
+func (pc *persistConn) cancelRequest() {
+ pc.conn.Close()
+}
+
+var remoteSideClosedFunc func(error) bool // or nil to use default
+
+func remoteSideClosed(err error) bool {
+ if err == io.EOF {
+ return true
+ }
+ if remoteSideClosedFunc != nil {
+ return remoteSideClosedFunc(err)
+ }
+ return false
+}
+
+func (pc *persistConn) readLoop() {
+ alive := true
+
+ for alive {
+ pb, err := pc.br.Peek(1)
+
+ pc.lk.Lock()
+ if pc.numExpectedResponses == 0 {
+ if !pc.closed {
+ pc.closeLocked()
+ if len(pb) > 0 {
+ log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v",
+ string(pb), err)
+ }
+ }
+ pc.lk.Unlock()
+ return
+ }
+ pc.lk.Unlock()
+
+ rc := <-pc.reqch
+
+ var resp *Response
+ if err == nil {
+ resp, err = ReadResponse(pc.br, rc.req)
+ if err == nil && resp.StatusCode == 100 {
+ // Skip any 100-continue for now.
+ // TODO(bradfitz): if rc.req had "Expect: 100-continue",
+ // actually block the request body write and signal the
+ // writeLoop now to begin sending it. (Issue 2184) For now we
+ // eat it, since we're never expecting one.
+ resp, err = ReadResponse(pc.br, rc.req)
+ }
+ }
+
+ if resp != nil {
+ resp.TLS = pc.tlsState
+ }
+
+ hasBody := resp != nil && rc.req.Method != "HEAD" && resp.ContentLength != 0
+
+ if err != nil {
+ pc.close()
+ } else {
+ if rc.addedGzip && hasBody && resp.Header.Get("Content-Encoding") == "gzip" {
+ resp.Header.Del("Content-Encoding")
+ resp.Header.Del("Content-Length")
+ resp.ContentLength = -1
+ resp.Body = &gzipReader{body: resp.Body}
+ }
+ resp.Body = &bodyEOFSignal{body: resp.Body}
+ }
+
+ if err != nil || resp.Close || rc.req.Close || resp.StatusCode <= 199 {
+ // Don't do keep-alive on error if either party requested a close
+ // or we get an unexpected informational (1xx) response.
+ // StatusCode 100 is already handled above.
+ alive = false
+ }
+
+ var waitForBodyRead chan bool
+ if hasBody {
+ waitForBodyRead = make(chan bool, 2)
+ resp.Body.(*bodyEOFSignal).earlyCloseFn = func() error {
+ // Sending false here sets alive to
+ // false and closes the connection
+ // below.
+ waitForBodyRead <- false
+ return nil
+ }
+ resp.Body.(*bodyEOFSignal).fn = func(err error) {
+ waitForBodyRead <- alive &&
+ err == nil &&
+ !pc.sawEOF &&
+ pc.wroteRequest() &&
+ pc.t.putIdleConn(pc)
+ }
+ }
+
+ if alive && !hasBody {
+ alive = !pc.sawEOF &&
+ pc.wroteRequest() &&
+ pc.t.putIdleConn(pc)
+ }
+
+ rc.ch <- responseAndError{resp, err}
+
+ // Wait for the just-returned response body to be fully consumed
+ // before we race and peek on the underlying bufio reader.
+ if waitForBodyRead != nil {
+ select {
+ case alive = <-waitForBodyRead:
+ case <-pc.closech:
+ alive = false
+ }
+ }
+
+ pc.t.setReqCanceler(rc.req, nil)
+
+ if !alive {
+ pc.close()
+ }
+ }
+}
+
+func (pc *persistConn) writeLoop() {
+ for {
+ select {
+ case wr := <-pc.writech:
+ if pc.isBroken() {
+ wr.ch <- errors.New("http: can't write HTTP request on broken connection")
+ continue
+ }
+ err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra)
+ if err == nil {
+ err = pc.bw.Flush()
+ }
+ if err != nil {
+ pc.markBroken()
+ wr.req.Request.closeBody()
+ }
+ pc.writeErrCh <- err // to the body reader, which might recycle us
+ wr.ch <- err // to the roundTrip function
+ case <-pc.closech:
+ return
+ }
+ }
+}
+
+// wroteRequest is a check before recycling a connection that the previous write
+// (from writeLoop above) happened and was successful.
+func (pc *persistConn) wroteRequest() bool {
+ select {
+ case err := <-pc.writeErrCh:
+ // Common case: the write happened well before the response, so
+ // avoid creating a timer.
+ return err == nil
+ default:
+ // Rare case: the request was written in writeLoop above but
+ // before it could send to pc.writeErrCh, the reader read it
+ // all, processed it, and called us here. In this case, give the
+ // write goroutine a bit of time to finish its send.
+ //
+ // Less rare case: We also get here in the legitimate case of
+ // Issue 7569, where the writer is still writing (or stalled),
+ // but the server has already replied. In this case, we don't
+ // want to wait too long, and we want to return false so this
+ // connection isn't re-used.
+ select {
+ case err := <-pc.writeErrCh:
+ return err == nil
+ case <-time.After(50 * time.Millisecond):
+ return false
+ }
+ }
+}
+
+type responseAndError struct {
+ res *Response
+ err error
+}
+
+type requestAndChan struct {
+ req *Request
+ ch chan responseAndError
+
+ // did the Transport (as opposed to the client code) add an
+ // Accept-Encoding gzip header? only if it we set it do
+ // we transparently decode the gzip.
+ addedGzip bool
+}
+
+// A writeRequest is sent by the readLoop's goroutine to the
+// writeLoop's goroutine to write a request while the read loop
+// concurrently waits on both the write response and the server's
+// reply.
+type writeRequest struct {
+ req *transportRequest
+ ch chan<- error
+}
+
+type httpError struct {
+ err string
+ timeout bool
+}
+
+func (e *httpError) Error() string { return e.err }
+func (e *httpError) Timeout() bool { return e.timeout }
+func (e *httpError) Temporary() bool { return true }
+
+var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true}
+var errClosed error = &httpError{err: "net/http: transport closed before response was received"}
+
+func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
+ pc.t.setReqCanceler(req.Request, pc.cancelRequest)
+ pc.lk.Lock()
+ pc.numExpectedResponses++
+ headerFn := pc.mutateHeaderFunc
+ pc.lk.Unlock()
+
+ if headerFn != nil {
+ headerFn(req.extraHeaders())
+ }
+
+ // Ask for a compressed version if the caller didn't set their
+ // own value for Accept-Encoding. We only attempt to
+ // uncompress the gzip stream if we were the layer that
+ // requested it.
+ requestedGzip := false
+ if !pc.t.DisableCompression &&
+ req.Header.Get("Accept-Encoding") == "" &&
+ req.Header.Get("Range") == "" &&
+ req.Method != "HEAD" {
+ // Request gzip only, not deflate. Deflate is ambiguous and
+ // not as universally supported anyway.
+ // See: http://www.gzip.org/zlib/zlib_faq.html#faq38
+ //
+ // Note that we don't request this for HEAD requests,
+ // due to a bug in nginx:
+ // http://trac.nginx.org/nginx/ticket/358
+ // http://golang.org/issue/5522
+ //
+ // We don't request gzip if the request is for a range, since
+ // auto-decoding a portion of a gzipped document will just fail
+ // anyway. See http://golang.org/issue/8923
+ requestedGzip = true
+ req.extraHeaders().Set("Accept-Encoding", "gzip")
+ }
+
+ // Write the request concurrently with waiting for a response,
+ // in case the server decides to reply before reading our full
+ // request body.
+ writeErrCh := make(chan error, 1)
+ pc.writech <- writeRequest{req, writeErrCh}
+
+ resc := make(chan responseAndError, 1)
+ pc.reqch <- requestAndChan{req.Request, resc, requestedGzip}
+
+ var re responseAndError
+ var pconnDeadCh = pc.closech
+ var failTicker <-chan time.Time
+ var respHeaderTimer <-chan time.Time
+WaitResponse:
+ for {
+ select {
+ case err := <-writeErrCh:
+ if err != nil {
+ re = responseAndError{nil, err}
+ pc.close()
+ break WaitResponse
+ }
+ if d := pc.t.ResponseHeaderTimeout; d > 0 {
+ respHeaderTimer = time.After(d)
+ }
+ case <-pconnDeadCh:
+ // The persist connection is dead. This shouldn't
+ // usually happen (only with Connection: close responses
+ // with no response bodies), but if it does happen it
+ // means either a) the remote server hung up on us
+ // prematurely, or b) the readLoop sent us a response &
+ // closed its closech at roughly the same time, and we
+ // selected this case first, in which case a response
+ // might still be coming soon.
+ //
+ // We can't avoid the select race in b) by using a unbuffered
+ // resc channel instead, because then goroutines can
+ // leak if we exit due to other errors.
+ pconnDeadCh = nil // avoid spinning
+ failTicker = time.After(100 * time.Millisecond) // arbitrary time to wait for resc
+ case <-failTicker:
+ re = responseAndError{err: errClosed}
+ break WaitResponse
+ case <-respHeaderTimer:
+ pc.close()
+ re = responseAndError{err: errTimeout}
+ break WaitResponse
+ case re = <-resc:
+ break WaitResponse
+ }
+ }
+
+ pc.lk.Lock()
+ pc.numExpectedResponses--
+ pc.lk.Unlock()
+
+ if re.err != nil {
+ pc.t.setReqCanceler(req.Request, nil)
+ }
+ return re.res, re.err
+}
+
+// markBroken marks a connection as broken (so it's not reused).
+// It differs from close in that it doesn't close the underlying
+// connection for use when it's still being read.
+func (pc *persistConn) markBroken() {
+ pc.lk.Lock()
+ defer pc.lk.Unlock()
+ pc.broken = true
+}
+
+func (pc *persistConn) close() {
+ pc.lk.Lock()
+ defer pc.lk.Unlock()
+ pc.closeLocked()
+}
+
+func (pc *persistConn) closeLocked() {
+ pc.broken = true
+ if !pc.closed {
+ pc.conn.Close()
+ pc.closed = true
+ close(pc.closech)
+ }
+ pc.mutateHeaderFunc = nil
+}
+
+var portMap = map[string]string{
+ "http": "80",
+ "https": "443",
+}
+
+// canonicalAddr returns url.Host but always with a ":port" suffix
+func canonicalAddr(url *url.URL) string {
+ addr := url.Host
+ if !hasPort(addr) {
+ return addr + ":" + portMap[url.Scheme]
+ }
+ return addr
+}
+
+// bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most
+// once, right before its final (error-producing) Read or Close call
+// returns. If earlyCloseFn is non-nil and Close is called before
+// io.EOF is seen, earlyCloseFn is called instead of fn, and its
+// return value is the return value from Close.
+type bodyEOFSignal struct {
+ body io.ReadCloser
+ mu sync.Mutex // guards following 4 fields
+ closed bool // whether Close has been called
+ rerr error // sticky Read error
+ fn func(error) // error will be nil on Read io.EOF
+ earlyCloseFn func() error // optional alt Close func used if io.EOF not seen
+}
+
+func (es *bodyEOFSignal) Read(p []byte) (n int, err error) {
+ es.mu.Lock()
+ closed, rerr := es.closed, es.rerr
+ es.mu.Unlock()
+ if closed {
+ return 0, errors.New("http: read on closed response body")
+ }
+ if rerr != nil {
+ return 0, rerr
+ }
+
+ n, err = es.body.Read(p)
+ if err != nil {
+ es.mu.Lock()
+ defer es.mu.Unlock()
+ if es.rerr == nil {
+ es.rerr = err
+ }
+ es.condfn(err)
+ }
+ return
+}
+
+func (es *bodyEOFSignal) Close() error {
+ es.mu.Lock()
+ defer es.mu.Unlock()
+ if es.closed {
+ return nil
+ }
+ es.closed = true
+ if es.earlyCloseFn != nil && es.rerr != io.EOF {
+ return es.earlyCloseFn()
+ }
+ err := es.body.Close()
+ es.condfn(err)
+ return err
+}
+
+// caller must hold es.mu.
+func (es *bodyEOFSignal) condfn(err error) {
+ if es.fn == nil {
+ return
+ }
+ if err == io.EOF {
+ err = nil
+ }
+ es.fn(err)
+ es.fn = nil
+}
+
+// gzipReader wraps a response body so it can lazily
+// call gzip.NewReader on the first call to Read
+type gzipReader struct {
+ body io.ReadCloser // underlying Response.Body
+ zr io.Reader // lazily-initialized gzip reader
+}
+
+func (gz *gzipReader) Read(p []byte) (n int, err error) {
+ if gz.zr == nil {
+ gz.zr, err = gzip.NewReader(gz.body)
+ if err != nil {
+ return 0, err
+ }
+ }
+ return gz.zr.Read(p)
+}
+
+func (gz *gzipReader) Close() error {
+ return gz.body.Close()
+}
+
+type readerAndCloser struct {
+ io.Reader
+ io.Closer
+}
+
+type tlsHandshakeTimeoutError struct{}
+
+func (tlsHandshakeTimeoutError) Timeout() bool { return true }
+func (tlsHandshakeTimeoutError) Temporary() bool { return true }
+func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
+
+type noteEOFReader struct {
+ r io.Reader
+ sawEOF *bool
+}
+
+func (nr noteEOFReader) Read(p []byte) (n int, err error) {
+ n, err = nr.r.Read(p)
+ if err == io.EOF {
+ *nr.sawEOF = true
+ }
+ return
+}
diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go
new file mode 100644
index 000000000..defa63370
--- /dev/null
+++ b/src/net/http/transport_test.go
@@ -0,0 +1,2324 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests for transport.go
+
+package http_test
+
+import (
+ "bufio"
+ "bytes"
+ "compress/gzip"
+ "crypto/rand"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "net"
+ "net/http"
+ . "net/http"
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "runtime"
+ "strconv"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+)
+
+// TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close
+// and then verify that the final 2 responses get errors back.
+
+// hostPortHandler writes back the client's "host:port".
+var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.FormValue("close") == "true" {
+ w.Header().Set("Connection", "close")
+ }
+ w.Write([]byte(r.RemoteAddr))
+})
+
+// testCloseConn is a net.Conn tracked by a testConnSet.
+type testCloseConn struct {
+ net.Conn
+ set *testConnSet
+}
+
+func (c *testCloseConn) Close() error {
+ c.set.remove(c)
+ return c.Conn.Close()
+}
+
+// testConnSet tracks a set of TCP connections and whether they've
+// been closed.
+type testConnSet struct {
+ t *testing.T
+ mu sync.Mutex // guards closed and list
+ closed map[net.Conn]bool
+ list []net.Conn // in order created
+}
+
+func (tcs *testConnSet) insert(c net.Conn) {
+ tcs.mu.Lock()
+ defer tcs.mu.Unlock()
+ tcs.closed[c] = false
+ tcs.list = append(tcs.list, c)
+}
+
+func (tcs *testConnSet) remove(c net.Conn) {
+ tcs.mu.Lock()
+ defer tcs.mu.Unlock()
+ tcs.closed[c] = true
+}
+
+// some tests use this to manage raw tcp connections for later inspection
+func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
+ connSet := &testConnSet{
+ t: t,
+ closed: make(map[net.Conn]bool),
+ }
+ dial := func(n, addr string) (net.Conn, error) {
+ c, err := net.Dial(n, addr)
+ if err != nil {
+ return nil, err
+ }
+ tc := &testCloseConn{c, connSet}
+ connSet.insert(tc)
+ return tc, nil
+ }
+ return connSet, dial
+}
+
+func (tcs *testConnSet) check(t *testing.T) {
+ tcs.mu.Lock()
+ defer tcs.mu.Unlock()
+ for i := 4; i >= 0; i-- {
+ for i, c := range tcs.list {
+ if tcs.closed[c] {
+ continue
+ }
+ if i != 0 {
+ tcs.mu.Unlock()
+ time.Sleep(50 * time.Millisecond)
+ tcs.mu.Lock()
+ continue
+ }
+ t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
+ }
+ }
+}
+
+// Two subsequent requests and verify their response is the same.
+// The response from the server is our own IP:port
+func TestTransportKeepAlives(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(hostPortHandler)
+ defer ts.Close()
+
+ for _, disableKeepAlive := range []bool{false, true} {
+ tr := &Transport{DisableKeepAlives: disableKeepAlive}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ fetch := func(n int) string {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
+ }
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
+ }
+ return string(body)
+ }
+
+ body1 := fetch(1)
+ body2 := fetch(2)
+
+ bodiesDiffer := body1 != body2
+ if bodiesDiffer != disableKeepAlive {
+ t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
+ disableKeepAlive, bodiesDiffer, body1, body2)
+ }
+ }
+}
+
+func TestTransportConnectionCloseOnResponse(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(hostPortHandler)
+ defer ts.Close()
+
+ connSet, testDial := makeTestDial(t)
+
+ for _, connectionClose := range []bool{false, true} {
+ tr := &Transport{
+ Dial: testDial,
+ }
+ c := &Client{Transport: tr}
+
+ fetch := func(n int) string {
+ req := new(Request)
+ var err error
+ req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose))
+ if err != nil {
+ t.Fatalf("URL parse error: %v", err)
+ }
+ req.Method = "GET"
+ req.Proto = "HTTP/1.1"
+ req.ProtoMajor = 1
+ req.ProtoMinor = 1
+
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
+ }
+ defer res.Body.Close()
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
+ }
+ return string(body)
+ }
+
+ body1 := fetch(1)
+ body2 := fetch(2)
+ bodiesDiffer := body1 != body2
+ if bodiesDiffer != connectionClose {
+ t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
+ connectionClose, bodiesDiffer, body1, body2)
+ }
+
+ tr.CloseIdleConnections()
+ }
+
+ connSet.check(t)
+}
+
+func TestTransportConnectionCloseOnRequest(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(hostPortHandler)
+ defer ts.Close()
+
+ connSet, testDial := makeTestDial(t)
+
+ for _, connectionClose := range []bool{false, true} {
+ tr := &Transport{
+ Dial: testDial,
+ }
+ c := &Client{Transport: tr}
+
+ fetch := func(n int) string {
+ req := new(Request)
+ var err error
+ req.URL, err = url.Parse(ts.URL)
+ if err != nil {
+ t.Fatalf("URL parse error: %v", err)
+ }
+ req.Method = "GET"
+ req.Proto = "HTTP/1.1"
+ req.ProtoMajor = 1
+ req.ProtoMinor = 1
+ req.Close = connectionClose
+
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
+ }
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
+ }
+ return string(body)
+ }
+
+ body1 := fetch(1)
+ body2 := fetch(2)
+ bodiesDiffer := body1 != body2
+ if bodiesDiffer != connectionClose {
+ t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
+ connectionClose, bodiesDiffer, body1, body2)
+ }
+
+ tr.CloseIdleConnections()
+ }
+
+ connSet.check(t)
+}
+
+func TestTransportIdleCacheKeys(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(hostPortHandler)
+ defer ts.Close()
+
+ tr := &Transport{DisableKeepAlives: false}
+ c := &Client{Transport: tr}
+
+ if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
+ t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
+ }
+
+ resp, err := c.Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ ioutil.ReadAll(resp.Body)
+
+ keys := tr.IdleConnKeysForTesting()
+ if e, g := 1, len(keys); e != g {
+ t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
+ }
+
+ if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
+ t.Errorf("Expected idle cache key %q; got %q", e, keys[0])
+ }
+
+ tr.CloseIdleConnections()
+ if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
+ t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
+ }
+}
+
+// Tests that the HTTP transport re-uses connections when a client
+// reads to the end of a response Body without closing it.
+func TestTransportReadToEndReusesConn(t *testing.T) {
+ defer afterTest(t)
+ const msg = "foobar"
+
+ var addrSeen map[string]int
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ addrSeen[r.RemoteAddr]++
+ if r.URL.Path == "/chunked/" {
+ w.WriteHeader(200)
+ w.(http.Flusher).Flush()
+ } else {
+ w.Header().Set("Content-Type", strconv.Itoa(len(msg)))
+ w.WriteHeader(200)
+ }
+ w.Write([]byte(msg))
+ }))
+ defer ts.Close()
+
+ buf := make([]byte, len(msg))
+
+ for pi, path := range []string{"/content-length/", "/chunked/"} {
+ wantLen := []int{len(msg), -1}[pi]
+ addrSeen = make(map[string]int)
+ for i := 0; i < 3; i++ {
+ res, err := http.Get(ts.URL + path)
+ if err != nil {
+ t.Errorf("Get %s: %v", path, err)
+ continue
+ }
+ // We want to close this body eventually (before the
+ // defer afterTest at top runs), but not before the
+ // len(addrSeen) check at the bottom of this test,
+ // since Closing this early in the loop would risk
+ // making connections be re-used for the wrong reason.
+ defer res.Body.Close()
+
+ if res.ContentLength != int64(wantLen) {
+ t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen)
+ }
+ n, err := res.Body.Read(buf)
+ if n != len(msg) || err != io.EOF {
+ t.Errorf("%s Read = %v, %v; want %d, EOF", path, n, err, len(msg))
+ }
+ }
+ if len(addrSeen) != 1 {
+ t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen))
+ }
+ }
+}
+
+func TestTransportMaxPerHostIdleConns(t *testing.T) {
+ defer afterTest(t)
+ resch := make(chan string)
+ gotReq := make(chan bool)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ gotReq <- true
+ msg := <-resch
+ _, err := w.Write([]byte(msg))
+ if err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+ }))
+ defer ts.Close()
+ maxIdleConns := 2
+ tr := &Transport{DisableKeepAlives: false, MaxIdleConnsPerHost: maxIdleConns}
+ c := &Client{Transport: tr}
+
+ // Start 3 outstanding requests and wait for the server to get them.
+ // Their responses will hang until we write to resch, though.
+ donech := make(chan bool)
+ doReq := func() {
+ resp, err := c.Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ if _, err := ioutil.ReadAll(resp.Body); err != nil {
+ t.Errorf("ReadAll: %v", err)
+ return
+ }
+ donech <- true
+ }
+ go doReq()
+ <-gotReq
+ go doReq()
+ <-gotReq
+ go doReq()
+ <-gotReq
+
+ if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
+ t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g)
+ }
+
+ resch <- "res1"
+ <-donech
+ keys := tr.IdleConnKeysForTesting()
+ if e, g := 1, len(keys); e != g {
+ t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
+ }
+ cacheKey := "|http|" + ts.Listener.Addr().String()
+ if keys[0] != cacheKey {
+ t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
+ }
+ if e, g := 1, tr.IdleConnCountForTesting(cacheKey); e != g {
+ t.Errorf("after first response, expected %d idle conns; got %d", e, g)
+ }
+
+ resch <- "res2"
+ <-donech
+ if e, g := 2, tr.IdleConnCountForTesting(cacheKey); e != g {
+ t.Errorf("after second response, expected %d idle conns; got %d", e, g)
+ }
+
+ resch <- "res3"
+ <-donech
+ if e, g := maxIdleConns, tr.IdleConnCountForTesting(cacheKey); e != g {
+ t.Errorf("after third response, still expected %d idle conns; got %d", e, g)
+ }
+}
+
+func TestTransportServerClosingUnexpectedly(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(hostPortHandler)
+ defer ts.Close()
+
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+
+ fetch := func(n, retries int) string {
+ condFatalf := func(format string, arg ...interface{}) {
+ if retries <= 0 {
+ t.Fatalf(format, arg...)
+ }
+ t.Logf("retrying shortly after expected error: "+format, arg...)
+ time.Sleep(time.Second / time.Duration(retries))
+ }
+ for retries >= 0 {
+ retries--
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ condFatalf("error in req #%d, GET: %v", n, err)
+ continue
+ }
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ condFatalf("error in req #%d, ReadAll: %v", n, err)
+ continue
+ }
+ res.Body.Close()
+ return string(body)
+ }
+ panic("unreachable")
+ }
+
+ body1 := fetch(1, 0)
+ body2 := fetch(2, 0)
+
+ ts.CloseClientConnections() // surprise!
+
+ // This test has an expected race. Sleeping for 25 ms prevents
+ // it on most fast machines, causing the next fetch() call to
+ // succeed quickly. But if we do get errors, fetch() will retry 5
+ // times with some delays between.
+ time.Sleep(25 * time.Millisecond)
+
+ body3 := fetch(3, 5)
+
+ if body1 != body2 {
+ t.Errorf("expected body1 and body2 to be equal")
+ }
+ if body2 == body3 {
+ t.Errorf("expected body2 and body3 to be different")
+ }
+}
+
+// Test for http://golang.org/issue/2616 (appropriate issue number)
+// This fails pretty reliably with GOMAXPROCS=100 or something high.
+func TestStressSurpriseServerCloses(t *testing.T) {
+ defer afterTest(t)
+ if testing.Short() {
+ t.Skip("skipping test in short mode")
+ }
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "5")
+ w.Header().Set("Content-Type", "text/plain")
+ w.Write([]byte("Hello"))
+ w.(Flusher).Flush()
+ conn, buf, _ := w.(Hijacker).Hijack()
+ buf.Flush()
+ conn.Close()
+ }))
+ defer ts.Close()
+
+ tr := &Transport{DisableKeepAlives: false}
+ c := &Client{Transport: tr}
+
+ // Do a bunch of traffic from different goroutines. Send to activityc
+ // after each request completes, regardless of whether it failed.
+ const (
+ numClients = 50
+ reqsPerClient = 250
+ )
+ activityc := make(chan bool)
+ for i := 0; i < numClients; i++ {
+ go func() {
+ for i := 0; i < reqsPerClient; i++ {
+ res, err := c.Get(ts.URL)
+ if err == nil {
+ // We expect errors since the server is
+ // hanging up on us after telling us to
+ // send more requests, so we don't
+ // actually care what the error is.
+ // But we want to close the body in cases
+ // where we won the race.
+ res.Body.Close()
+ }
+ activityc <- true
+ }
+ }()
+ }
+
+ // Make sure all the request come back, one way or another.
+ for i := 0; i < numClients*reqsPerClient; i++ {
+ select {
+ case <-activityc:
+ case <-time.After(5 * time.Second):
+ t.Fatalf("presumed deadlock; no HTTP client activity seen in awhile")
+ }
+ }
+}
+
+// TestTransportHeadResponses verifies that we deal with Content-Lengths
+// with no bodies properly
+func TestTransportHeadResponses(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method != "HEAD" {
+ panic("expected HEAD; got " + r.Method)
+ }
+ w.Header().Set("Content-Length", "123")
+ w.WriteHeader(200)
+ }))
+ defer ts.Close()
+
+ tr := &Transport{DisableKeepAlives: false}
+ c := &Client{Transport: tr}
+ for i := 0; i < 2; i++ {
+ res, err := c.Head(ts.URL)
+ if err != nil {
+ t.Errorf("error on loop %d: %v", i, err)
+ continue
+ }
+ if e, g := "123", res.Header.Get("Content-Length"); e != g {
+ t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
+ }
+ if e, g := int64(123), res.ContentLength; e != g {
+ t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
+ }
+ if all, err := ioutil.ReadAll(res.Body); err != nil {
+ t.Errorf("loop %d: Body ReadAll: %v", i, err)
+ } else if len(all) != 0 {
+ t.Errorf("Bogus body %q", all)
+ }
+ }
+}
+
+// TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding
+// on responses to HEAD requests.
+func TestTransportHeadChunkedResponse(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method != "HEAD" {
+ panic("expected HEAD; got " + r.Method)
+ }
+ w.Header().Set("Transfer-Encoding", "chunked") // client should ignore
+ w.Header().Set("x-client-ipport", r.RemoteAddr)
+ w.WriteHeader(200)
+ }))
+ defer ts.Close()
+
+ tr := &Transport{DisableKeepAlives: false}
+ c := &Client{Transport: tr}
+
+ res1, err := c.Head(ts.URL)
+ if err != nil {
+ t.Fatalf("request 1 error: %v", err)
+ }
+ res2, err := c.Head(ts.URL)
+ if err != nil {
+ t.Fatalf("request 2 error: %v", err)
+ }
+ if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 {
+ t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2)
+ }
+}
+
+var roundTripTests = []struct {
+ accept string
+ expectAccept string
+ compressed bool
+}{
+ // Requests with no accept-encoding header use transparent compression
+ {"", "gzip", false},
+ // Requests with other accept-encoding should pass through unmodified
+ {"foo", "foo", false},
+ // Requests with accept-encoding == gzip should be passed through
+ {"gzip", "gzip", true},
+}
+
+// Test that the modification made to the Request by the RoundTripper is cleaned up
+func TestRoundTripGzip(t *testing.T) {
+ defer afterTest(t)
+ const responseBody = "test response body"
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ accept := req.Header.Get("Accept-Encoding")
+ if expect := req.FormValue("expect_accept"); accept != expect {
+ t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
+ req.FormValue("testnum"), accept, expect)
+ }
+ if accept == "gzip" {
+ rw.Header().Set("Content-Encoding", "gzip")
+ gz := gzip.NewWriter(rw)
+ gz.Write([]byte(responseBody))
+ gz.Close()
+ } else {
+ rw.Header().Set("Content-Encoding", accept)
+ rw.Write([]byte(responseBody))
+ }
+ }))
+ defer ts.Close()
+
+ for i, test := range roundTripTests {
+ // Test basic request (no accept-encoding)
+ req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
+ if test.accept != "" {
+ req.Header.Set("Accept-Encoding", test.accept)
+ }
+ res, err := DefaultTransport.RoundTrip(req)
+ var body []byte
+ if test.compressed {
+ var r *gzip.Reader
+ r, err = gzip.NewReader(res.Body)
+ if err != nil {
+ t.Errorf("%d. gzip NewReader: %v", i, err)
+ continue
+ }
+ body, err = ioutil.ReadAll(r)
+ res.Body.Close()
+ } else {
+ body, err = ioutil.ReadAll(res.Body)
+ }
+ if err != nil {
+ t.Errorf("%d. Error: %q", i, err)
+ continue
+ }
+ if g, e := string(body), responseBody; g != e {
+ t.Errorf("%d. body = %q; want %q", i, g, e)
+ }
+ if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
+ t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
+ }
+ if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
+ t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
+ }
+ }
+
+}
+
+func TestTransportGzip(t *testing.T) {
+ defer afterTest(t)
+ const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+ const nRandBytes = 1024 * 1024
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ if req.Method == "HEAD" {
+ if g := req.Header.Get("Accept-Encoding"); g != "" {
+ t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g)
+ }
+ return
+ }
+ if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
+ t.Errorf("Accept-Encoding = %q, want %q", g, e)
+ }
+ rw.Header().Set("Content-Encoding", "gzip")
+
+ var w io.Writer = rw
+ var buf bytes.Buffer
+ if req.FormValue("chunked") == "0" {
+ w = &buf
+ defer io.Copy(rw, &buf)
+ defer func() {
+ rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
+ }()
+ }
+ gz := gzip.NewWriter(w)
+ gz.Write([]byte(testString))
+ if req.FormValue("body") == "large" {
+ io.CopyN(gz, rand.Reader, nRandBytes)
+ }
+ gz.Close()
+ }))
+ defer ts.Close()
+
+ for _, chunked := range []string{"1", "0"} {
+ c := &Client{Transport: &Transport{}}
+
+ // First fetch something large, but only read some of it.
+ res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
+ if err != nil {
+ t.Fatalf("large get: %v", err)
+ }
+ buf := make([]byte, len(testString))
+ n, err := io.ReadFull(res.Body, buf)
+ if err != nil {
+ t.Fatalf("partial read of large response: size=%d, %v", n, err)
+ }
+ if e, g := testString, string(buf); e != g {
+ t.Errorf("partial read got %q, expected %q", g, e)
+ }
+ res.Body.Close()
+ // Read on the body, even though it's closed
+ n, err = res.Body.Read(buf)
+ if n != 0 || err == nil {
+ t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
+ }
+
+ // Then something small.
+ res, err = c.Get(ts.URL + "/?chunked=" + chunked)
+ if err != nil {
+ t.Fatal(err)
+ }
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if g, e := string(body), testString; g != e {
+ t.Fatalf("body = %q; want %q", g, e)
+ }
+ if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
+ t.Fatalf("Content-Encoding = %q; want %q", g, e)
+ }
+
+ // Read on the body after it's been fully read:
+ n, err = res.Body.Read(buf)
+ if n != 0 || err == nil {
+ t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
+ }
+ res.Body.Close()
+ n, err = res.Body.Read(buf)
+ if n != 0 || err == nil {
+ t.Errorf("expected Read error after Close; got %d, %v", n, err)
+ }
+ }
+
+ // And a HEAD request too, because they're always weird.
+ c := &Client{Transport: &Transport{}}
+ res, err := c.Head(ts.URL)
+ if err != nil {
+ t.Fatalf("Head: %v", err)
+ }
+ if res.StatusCode != 200 {
+ t.Errorf("Head status=%d; want=200", res.StatusCode)
+ }
+}
+
+func TestTransportProxy(t *testing.T) {
+ defer afterTest(t)
+ ch := make(chan string, 1)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ ch <- "real server"
+ }))
+ defer ts.Close()
+ proxy := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ ch <- "proxy for " + r.URL.String()
+ }))
+ defer proxy.Close()
+
+ pu, err := url.Parse(proxy.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ c := &Client{Transport: &Transport{Proxy: ProxyURL(pu)}}
+ c.Head(ts.URL)
+ got := <-ch
+ want := "proxy for " + ts.URL + "/"
+ if got != want {
+ t.Errorf("want %q, got %q", want, got)
+ }
+}
+
+// TestTransportGzipRecursive sends a gzip quine and checks that the
+// client gets the same value back. This is more cute than anything,
+// but checks that we don't recurse forever, and checks that
+// Content-Encoding is removed.
+func TestTransportGzipRecursive(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Encoding", "gzip")
+ w.Write(rgz)
+ }))
+ defer ts.Close()
+
+ c := &Client{Transport: &Transport{}}
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(body, rgz) {
+ t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x",
+ body, rgz)
+ }
+ if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
+ t.Fatalf("Content-Encoding = %q; want %q", g, e)
+ }
+}
+
+// golang.org/issue/7750: request fails when server replies with
+// a short gzip body
+func TestTransportGzipShort(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Encoding", "gzip")
+ w.Write([]byte{0x1f, 0x8b})
+ }))
+ defer ts.Close()
+
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ _, err = ioutil.ReadAll(res.Body)
+ if err == nil {
+ t.Fatal("Expect an error from reading a body.")
+ }
+ if err != io.ErrUnexpectedEOF {
+ t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err)
+ }
+}
+
+// tests that persistent goroutine connections shut down when no longer desired.
+func TestTransportPersistConnLeak(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("skipping test; see http://golang.org/issue/7237")
+ }
+ defer afterTest(t)
+ gotReqCh := make(chan bool)
+ unblockCh := make(chan bool)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ gotReqCh <- true
+ <-unblockCh
+ w.Header().Set("Content-Length", "0")
+ w.WriteHeader(204)
+ }))
+ defer ts.Close()
+
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+
+ n0 := runtime.NumGoroutine()
+
+ const numReq = 25
+ didReqCh := make(chan bool)
+ for i := 0; i < numReq; i++ {
+ go func() {
+ res, err := c.Get(ts.URL)
+ didReqCh <- true
+ if err != nil {
+ t.Errorf("client fetch error: %v", err)
+ return
+ }
+ res.Body.Close()
+ }()
+ }
+
+ // Wait for all goroutines to be stuck in the Handler.
+ for i := 0; i < numReq; i++ {
+ <-gotReqCh
+ }
+
+ nhigh := runtime.NumGoroutine()
+
+ // Tell all handlers to unblock and reply.
+ for i := 0; i < numReq; i++ {
+ unblockCh <- true
+ }
+
+ // Wait for all HTTP clients to be done.
+ for i := 0; i < numReq; i++ {
+ <-didReqCh
+ }
+
+ tr.CloseIdleConnections()
+ time.Sleep(100 * time.Millisecond)
+ runtime.GC()
+ runtime.GC() // even more.
+ nfinal := runtime.NumGoroutine()
+
+ growth := nfinal - n0
+
+ // We expect 0 or 1 extra goroutine, empirically. Allow up to 5.
+ // Previously we were leaking one per numReq.
+ if int(growth) > 5 {
+ t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
+ t.Error("too many new goroutines")
+ }
+}
+
+// golang.org/issue/4531: Transport leaks goroutines when
+// request.ContentLength is explicitly short
+func TestTransportPersistConnLeakShortBody(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("skipping test; see http://golang.org/issue/7237")
+ }
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ }))
+ defer ts.Close()
+
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+
+ n0 := runtime.NumGoroutine()
+ body := []byte("Hello")
+ for i := 0; i < 20; i++ {
+ req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.ContentLength = int64(len(body) - 2) // explicitly short
+ _, err = c.Do(req)
+ if err == nil {
+ t.Fatal("Expect an error from writing too long of a body.")
+ }
+ }
+ nhigh := runtime.NumGoroutine()
+ tr.CloseIdleConnections()
+ time.Sleep(400 * time.Millisecond)
+ runtime.GC()
+ nfinal := runtime.NumGoroutine()
+
+ growth := nfinal - n0
+
+ // We expect 0 or 1 extra goroutine, empirically. Allow up to 5.
+ // Previously we were leaking one per numReq.
+ t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
+ if int(growth) > 5 {
+ t.Error("too many new goroutines")
+ }
+}
+
+// This used to crash; http://golang.org/issue/3266
+func TestTransportIdleConnCrash(t *testing.T) {
+ defer afterTest(t)
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+
+ unblockCh := make(chan bool, 1)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ <-unblockCh
+ tr.CloseIdleConnections()
+ }))
+ defer ts.Close()
+
+ didreq := make(chan bool)
+ go func() {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ } else {
+ res.Body.Close() // returns idle conn
+ }
+ didreq <- true
+ }()
+ unblockCh <- true
+ <-didreq
+}
+
+// Test that the transport doesn't close the TCP connection early,
+// before the response body has been read. This was a regression
+// which sadly lacked a triggering test. The large response body made
+// the old race easier to trigger.
+func TestIssue3644(t *testing.T) {
+ defer afterTest(t)
+ const numFoos = 5000
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Connection", "close")
+ for i := 0; i < numFoos; i++ {
+ w.Write([]byte("foo "))
+ }
+ }))
+ defer ts.Close()
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ bs, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(bs) != numFoos*len("foo ") {
+ t.Errorf("unexpected response length")
+ }
+}
+
+// Test that a client receives a server's reply, even if the server doesn't read
+// the entire request body.
+func TestIssue3595(t *testing.T) {
+ defer afterTest(t)
+ const deniedMsg = "sorry, denied."
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ Error(w, deniedMsg, StatusUnauthorized)
+ }))
+ defer ts.Close()
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+ res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
+ if err != nil {
+ t.Errorf("Post: %v", err)
+ return
+ }
+ got, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("Body ReadAll: %v", err)
+ }
+ if !strings.Contains(string(got), deniedMsg) {
+ t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
+ }
+}
+
+// From http://golang.org/issue/4454 ,
+// "client fails to handle requests with no body and chunked encoding"
+func TestChunkedNoContent(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.WriteHeader(StatusNoContent)
+ }))
+ defer ts.Close()
+
+ for _, closeBody := range []bool{true, false} {
+ c := &Client{Transport: &Transport{}}
+ const n = 4
+ for i := 1; i <= n; i++ {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
+ } else {
+ if closeBody {
+ res.Body.Close()
+ }
+ }
+ }
+ }
+}
+
+func TestTransportConcurrency(t *testing.T) {
+ defer afterTest(t)
+ maxProcs, numReqs := 16, 500
+ if testing.Short() {
+ maxProcs, numReqs = 4, 50
+ }
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "%v", r.FormValue("echo"))
+ }))
+ defer ts.Close()
+
+ var wg sync.WaitGroup
+ wg.Add(numReqs)
+
+ // Due to the Transport's "socket late binding" (see
+ // idleConnCh in transport.go), the numReqs HTTP requests
+ // below can finish with a dial still outstanding. To keep
+ // the leak checker happy, keep track of pending dials and
+ // wait for them to finish (and be closed or returned to the
+ // idle pool) before we close idle connections.
+ SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
+ defer SetPendingDialHooks(nil, nil)
+
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+
+ c := &Client{Transport: tr}
+ reqs := make(chan string)
+ defer close(reqs)
+
+ for i := 0; i < maxProcs*2; i++ {
+ go func() {
+ for req := range reqs {
+ res, err := c.Get(ts.URL + "/?echo=" + req)
+ if err != nil {
+ t.Errorf("error on req %s: %v", req, err)
+ wg.Done()
+ continue
+ }
+ all, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Errorf("read error on req %s: %v", req, err)
+ wg.Done()
+ continue
+ }
+ if string(all) != req {
+ t.Errorf("body of req %s = %q; want %q", req, all, req)
+ }
+ res.Body.Close()
+ wg.Done()
+ }
+ }()
+ }
+ for i := 0; i < numReqs; i++ {
+ reqs <- fmt.Sprintf("request-%d", i)
+ }
+ wg.Wait()
+}
+
+func TestIssue4191_InfiniteGetTimeout(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("skipping test; see http://golang.org/issue/7237")
+ }
+ defer afterTest(t)
+ const debug = false
+ mux := NewServeMux()
+ mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
+ io.Copy(w, neverEnding('a'))
+ })
+ ts := httptest.NewServer(mux)
+ timeout := 100 * time.Millisecond
+
+ client := &Client{
+ Transport: &Transport{
+ Dial: func(n, addr string) (net.Conn, error) {
+ conn, err := net.Dial(n, addr)
+ if err != nil {
+ return nil, err
+ }
+ conn.SetDeadline(time.Now().Add(timeout))
+ if debug {
+ conn = NewLoggingConn("client", conn)
+ }
+ return conn, nil
+ },
+ DisableKeepAlives: true,
+ },
+ }
+
+ getFailed := false
+ nRuns := 5
+ if testing.Short() {
+ nRuns = 1
+ }
+ for i := 0; i < nRuns; i++ {
+ if debug {
+ println("run", i+1, "of", nRuns)
+ }
+ sres, err := client.Get(ts.URL + "/get")
+ if err != nil {
+ if !getFailed {
+ // Make the timeout longer, once.
+ getFailed = true
+ t.Logf("increasing timeout")
+ i--
+ timeout *= 10
+ continue
+ }
+ t.Errorf("Error issuing GET: %v", err)
+ break
+ }
+ _, err = io.Copy(ioutil.Discard, sres.Body)
+ if err == nil {
+ t.Errorf("Unexpected successful copy")
+ break
+ }
+ }
+ if debug {
+ println("tests complete; waiting for handlers to finish")
+ }
+ ts.Close()
+}
+
+func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("skipping test; see http://golang.org/issue/7237")
+ }
+ defer afterTest(t)
+ const debug = false
+ mux := NewServeMux()
+ mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
+ io.Copy(w, neverEnding('a'))
+ })
+ mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
+ defer r.Body.Close()
+ io.Copy(ioutil.Discard, r.Body)
+ })
+ ts := httptest.NewServer(mux)
+ timeout := 100 * time.Millisecond
+
+ client := &Client{
+ Transport: &Transport{
+ Dial: func(n, addr string) (net.Conn, error) {
+ conn, err := net.Dial(n, addr)
+ if err != nil {
+ return nil, err
+ }
+ conn.SetDeadline(time.Now().Add(timeout))
+ if debug {
+ conn = NewLoggingConn("client", conn)
+ }
+ return conn, nil
+ },
+ DisableKeepAlives: true,
+ },
+ }
+
+ getFailed := false
+ nRuns := 5
+ if testing.Short() {
+ nRuns = 1
+ }
+ for i := 0; i < nRuns; i++ {
+ if debug {
+ println("run", i+1, "of", nRuns)
+ }
+ sres, err := client.Get(ts.URL + "/get")
+ if err != nil {
+ if !getFailed {
+ // Make the timeout longer, once.
+ getFailed = true
+ t.Logf("increasing timeout")
+ i--
+ timeout *= 10
+ continue
+ }
+ t.Errorf("Error issuing GET: %v", err)
+ break
+ }
+ req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
+ _, err = client.Do(req)
+ if err == nil {
+ sres.Body.Close()
+ t.Errorf("Unexpected successful PUT")
+ break
+ }
+ sres.Body.Close()
+ }
+ if debug {
+ println("tests complete; waiting for handlers to finish")
+ }
+ ts.Close()
+}
+
+func TestTransportResponseHeaderTimeout(t *testing.T) {
+ defer afterTest(t)
+ if testing.Short() {
+ t.Skip("skipping timeout test in -short mode")
+ }
+ inHandler := make(chan bool, 1)
+ mux := NewServeMux()
+ mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {
+ inHandler <- true
+ })
+ mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
+ inHandler <- true
+ time.Sleep(2 * time.Second)
+ })
+ ts := httptest.NewServer(mux)
+ defer ts.Close()
+
+ tr := &Transport{
+ ResponseHeaderTimeout: 500 * time.Millisecond,
+ }
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ tests := []struct {
+ path string
+ want int
+ wantErr string
+ }{
+ {path: "/fast", want: 200},
+ {path: "/slow", wantErr: "timeout awaiting response headers"},
+ {path: "/fast", want: 200},
+ }
+ for i, tt := range tests {
+ res, err := c.Get(ts.URL + tt.path)
+ select {
+ case <-inHandler:
+ case <-time.After(5 * time.Second):
+ t.Errorf("never entered handler for test index %d, %s", i, tt.path)
+ continue
+ }
+ if err != nil {
+ uerr, ok := err.(*url.Error)
+ if !ok {
+ t.Errorf("error is not an url.Error; got: %#v", err)
+ continue
+ }
+ nerr, ok := uerr.Err.(net.Error)
+ if !ok {
+ t.Errorf("error does not satisfy net.Error interface; got: %#v", err)
+ continue
+ }
+ if !nerr.Timeout() {
+ t.Errorf("want timeout error; got: %q", nerr)
+ continue
+ }
+ if strings.Contains(err.Error(), tt.wantErr) {
+ continue
+ }
+ t.Errorf("%d. unexpected error: %v", i, err)
+ continue
+ }
+ if tt.wantErr != "" {
+ t.Errorf("%d. no error. expected error: %v", i, tt.wantErr)
+ continue
+ }
+ if res.StatusCode != tt.want {
+ t.Errorf("%d for path %q status = %d; want %d", i, tt.path, res.StatusCode, tt.want)
+ }
+ }
+}
+
+func TestTransportCancelRequest(t *testing.T) {
+ defer afterTest(t)
+ if testing.Short() {
+ t.Skip("skipping test in -short mode")
+ }
+ unblockc := make(chan bool)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "Hello")
+ w.(Flusher).Flush() // send headers and some body
+ <-unblockc
+ }))
+ defer ts.Close()
+ defer close(unblockc)
+
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ go func() {
+ time.Sleep(1 * time.Second)
+ tr.CancelRequest(req)
+ }()
+ t0 := time.Now()
+ body, err := ioutil.ReadAll(res.Body)
+ d := time.Since(t0)
+
+ if err == nil {
+ t.Error("expected an error reading the body")
+ }
+ if string(body) != "Hello" {
+ t.Errorf("Body = %q; want Hello", body)
+ }
+ if d < 500*time.Millisecond {
+ t.Errorf("expected ~1 second delay; got %v", d)
+ }
+ // Verify no outstanding requests after readLoop/writeLoop
+ // goroutines shut down.
+ for tries := 3; tries > 0; tries-- {
+ n := tr.NumPendingRequestsForTesting()
+ if n == 0 {
+ break
+ }
+ time.Sleep(100 * time.Millisecond)
+ if tries == 1 {
+ t.Errorf("pending requests = %d; want 0", n)
+ }
+ }
+}
+
+func TestTransportCancelRequestInDial(t *testing.T) {
+ defer afterTest(t)
+ if testing.Short() {
+ t.Skip("skipping test in -short mode")
+ }
+ var logbuf bytes.Buffer
+ eventLog := log.New(&logbuf, "", 0)
+
+ unblockDial := make(chan bool)
+ defer close(unblockDial)
+
+ inDial := make(chan bool)
+ tr := &Transport{
+ Dial: func(network, addr string) (net.Conn, error) {
+ eventLog.Println("dial: blocking")
+ inDial <- true
+ <-unblockDial
+ return nil, errors.New("nope")
+ },
+ }
+ cl := &Client{Transport: tr}
+ gotres := make(chan bool)
+ req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
+ go func() {
+ _, err := cl.Do(req)
+ eventLog.Printf("Get = %v", err)
+ gotres <- true
+ }()
+
+ select {
+ case <-inDial:
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout; never saw blocking dial")
+ }
+
+ eventLog.Printf("canceling")
+ tr.CancelRequest(req)
+
+ select {
+ case <-gotres:
+ case <-time.After(5 * time.Second):
+ panic("hang. events are: " + logbuf.String())
+ }
+
+ got := logbuf.String()
+ want := `dial: blocking
+canceling
+Get = Get http://something.no-network.tld/: net/http: request canceled while waiting for connection
+`
+ if got != want {
+ t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
+ }
+}
+
+// golang.org/issue/3672 -- Client can't close HTTP stream
+// Calling Close on a Response.Body used to just read until EOF.
+// Now it actually closes the TCP connection.
+func TestTransportCloseResponseBody(t *testing.T) {
+ defer afterTest(t)
+ writeErr := make(chan error, 1)
+ msg := []byte("young\n")
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ for {
+ _, err := w.Write(msg)
+ if err != nil {
+ writeErr <- err
+ return
+ }
+ w.(Flusher).Flush()
+ }
+ }))
+ defer ts.Close()
+
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ defer tr.CancelRequest(req)
+
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ const repeats = 3
+ buf := make([]byte, len(msg)*repeats)
+ want := bytes.Repeat(msg, repeats)
+
+ _, err = io.ReadFull(res.Body, buf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("read %q; want %q", buf, want)
+ }
+ didClose := make(chan error, 1)
+ go func() {
+ didClose <- res.Body.Close()
+ }()
+ select {
+ case err := <-didClose:
+ if err != nil {
+ t.Errorf("Close = %v", err)
+ }
+ case <-time.After(10 * time.Second):
+ t.Fatal("too long waiting for close")
+ }
+ select {
+ case err := <-writeErr:
+ if err == nil {
+ t.Errorf("expected non-nil write error")
+ }
+ case <-time.After(10 * time.Second):
+ t.Fatal("too long waiting for write error")
+ }
+}
+
+type fooProto struct{}
+
+func (fooProto) RoundTrip(req *Request) (*Response, error) {
+ res := &Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Header: make(Header),
+ Body: ioutil.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
+ }
+ return res, nil
+}
+
+func TestTransportAltProto(t *testing.T) {
+ defer afterTest(t)
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+ tr.RegisterProtocol("foo", fooProto{})
+ res, err := c.Get("foo://bar.com/path")
+ if err != nil {
+ t.Fatal(err)
+ }
+ bodyb, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ body := string(bodyb)
+ if e := "You wanted foo://bar.com/path"; body != e {
+ t.Errorf("got response %q, want %q", body, e)
+ }
+}
+
+func TestTransportNoHost(t *testing.T) {
+ defer afterTest(t)
+ tr := &Transport{}
+ _, err := tr.RoundTrip(&Request{
+ Header: make(Header),
+ URL: &url.URL{
+ Scheme: "http",
+ },
+ })
+ want := "http: no Host in request URL"
+ if got := fmt.Sprint(err); got != want {
+ t.Errorf("error = %v; want %q", err, want)
+ }
+}
+
+func TestTransportSocketLateBinding(t *testing.T) {
+ defer afterTest(t)
+
+ mux := NewServeMux()
+ fooGate := make(chan bool, 1)
+ mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) {
+ w.Header().Set("foo-ipport", r.RemoteAddr)
+ w.(Flusher).Flush()
+ <-fooGate
+ })
+ mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) {
+ w.Header().Set("bar-ipport", r.RemoteAddr)
+ })
+ ts := httptest.NewServer(mux)
+ defer ts.Close()
+
+ dialGate := make(chan bool, 1)
+ tr := &Transport{
+ Dial: func(n, addr string) (net.Conn, error) {
+ if <-dialGate {
+ return net.Dial(n, addr)
+ }
+ return nil, errors.New("manually closed")
+ },
+ DisableKeepAlives: false,
+ }
+ defer tr.CloseIdleConnections()
+ c := &Client{
+ Transport: tr,
+ }
+
+ dialGate <- true // only allow one dial
+ fooRes, err := c.Get(ts.URL + "/foo")
+ if err != nil {
+ t.Fatal(err)
+ }
+ fooAddr := fooRes.Header.Get("foo-ipport")
+ if fooAddr == "" {
+ t.Fatal("No addr on /foo request")
+ }
+ time.AfterFunc(200*time.Millisecond, func() {
+ // let the foo response finish so we can use its
+ // connection for /bar
+ fooGate <- true
+ io.Copy(ioutil.Discard, fooRes.Body)
+ fooRes.Body.Close()
+ })
+
+ barRes, err := c.Get(ts.URL + "/bar")
+ if err != nil {
+ t.Fatal(err)
+ }
+ barAddr := barRes.Header.Get("bar-ipport")
+ if barAddr != fooAddr {
+ t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr)
+ }
+ barRes.Body.Close()
+ dialGate <- false
+}
+
+// Issue 2184
+func TestTransportReading100Continue(t *testing.T) {
+ defer afterTest(t)
+
+ const numReqs = 5
+ reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) }
+ reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) }
+
+ send100Response := func(w *io.PipeWriter, r *io.PipeReader) {
+ defer w.Close()
+ defer r.Close()
+ br := bufio.NewReader(r)
+ n := 0
+ for {
+ n++
+ req, err := ReadRequest(br)
+ if err == io.EOF {
+ return
+ }
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ slurp, err := ioutil.ReadAll(req.Body)
+ if err != nil {
+ t.Errorf("Server request body slurp: %v", err)
+ return
+ }
+ id := req.Header.Get("Request-Id")
+ resCode := req.Header.Get("X-Want-Response-Code")
+ if resCode == "" {
+ resCode = "100 Continue"
+ if string(slurp) != reqBody(n) {
+ t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n))
+ }
+ }
+ body := fmt.Sprintf("Response number %d", n)
+ v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s
+Date: Thu, 28 Feb 2013 17:55:41 GMT
+
+HTTP/1.1 200 OK
+Content-Type: text/html
+Echo-Request-Id: %s
+Content-Length: %d
+
+%s`, resCode, id, len(body), body), "\n", "\r\n", -1))
+ w.Write(v)
+ if id == reqID(numReqs) {
+ return
+ }
+ }
+
+ }
+
+ tr := &Transport{
+ Dial: func(n, addr string) (net.Conn, error) {
+ sr, sw := io.Pipe() // server read/write
+ cr, cw := io.Pipe() // client read/write
+ conn := &rwTestConn{
+ Reader: cr,
+ Writer: sw,
+ closeFunc: func() error {
+ sw.Close()
+ cw.Close()
+ return nil
+ },
+ }
+ go send100Response(cw, sr)
+ return conn, nil
+ },
+ DisableKeepAlives: false,
+ }
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ testResponse := func(req *Request, name string, wantCode int) {
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatalf("%s: Do: %v", name, err)
+ }
+ if res.StatusCode != wantCode {
+ t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode)
+ }
+ if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack {
+ t.Errorf("%s: response id %q != request id %q", name, idBack, id)
+ }
+ _, err = ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("%s: Slurp error: %v", name, err)
+ }
+ }
+
+ // Few 100 responses, making sure we're not off-by-one.
+ for i := 1; i <= numReqs; i++ {
+ req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i)))
+ req.Header.Set("Request-Id", reqID(i))
+ testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200)
+ }
+
+ // And some other informational 1xx but non-100 responses, to test
+ // we return them but don't re-use the connection.
+ for i := 1; i <= numReqs; i++ {
+ req, _ := NewRequest("POST", "http://other.tld/", strings.NewReader(reqBody(i)))
+ req.Header.Set("X-Want-Response-Code", "123 Sesame Street")
+ testResponse(req, fmt.Sprintf("123, %d/%d", i, numReqs), 123)
+ }
+}
+
+type proxyFromEnvTest struct {
+ req string // URL to fetch; blank means "http://example.com"
+
+ env string // HTTP_PROXY
+ httpsenv string // HTTPS_PROXY
+ noenv string // NO_RPXY
+
+ want string
+ wanterr error
+}
+
+func (t proxyFromEnvTest) String() string {
+ var buf bytes.Buffer
+ space := func() {
+ if buf.Len() > 0 {
+ buf.WriteByte(' ')
+ }
+ }
+ if t.env != "" {
+ fmt.Fprintf(&buf, "http_proxy=%q", t.env)
+ }
+ if t.httpsenv != "" {
+ space()
+ fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv)
+ }
+ if t.noenv != "" {
+ space()
+ fmt.Fprintf(&buf, "no_proxy=%q", t.noenv)
+ }
+ req := "http://example.com"
+ if t.req != "" {
+ req = t.req
+ }
+ space()
+ fmt.Fprintf(&buf, "req=%q", req)
+ return strings.TrimSpace(buf.String())
+}
+
+var proxyFromEnvTests = []proxyFromEnvTest{
+ {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
+ {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
+ {env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
+ {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
+ {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
+ {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
+
+ // Don't use secure for http
+ {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"},
+ // Use secure for https.
+ {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"},
+ {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"},
+
+ {want: "<nil>"},
+
+ {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
+ {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
+ {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
+ {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
+ {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
+}
+
+func TestProxyFromEnvironment(t *testing.T) {
+ ResetProxyEnv()
+ for _, tt := range proxyFromEnvTests {
+ os.Setenv("HTTP_PROXY", tt.env)
+ os.Setenv("HTTPS_PROXY", tt.httpsenv)
+ os.Setenv("NO_PROXY", tt.noenv)
+ ResetCachedEnvironment()
+ reqURL := tt.req
+ if reqURL == "" {
+ reqURL = "http://example.com"
+ }
+ req, _ := NewRequest("GET", reqURL, nil)
+ url, err := ProxyFromEnvironment(req)
+ if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
+ t.Errorf("%v: got error = %q, want %q", tt, g, e)
+ continue
+ }
+ if got := fmt.Sprintf("%s", url); got != tt.want {
+ t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
+ }
+ }
+}
+
+func TestIdleConnChannelLeak(t *testing.T) {
+ var mu sync.Mutex
+ var n int
+
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ mu.Lock()
+ n++
+ mu.Unlock()
+ }))
+ defer ts.Close()
+
+ tr := &Transport{
+ Dial: func(netw, addr string) (net.Conn, error) {
+ return net.Dial(netw, ts.Listener.Addr().String())
+ },
+ }
+ defer tr.CloseIdleConnections()
+
+ c := &Client{Transport: tr}
+
+ // First, without keep-alives.
+ for _, disableKeep := range []bool{true, false} {
+ tr.DisableKeepAlives = disableKeep
+ for i := 0; i < 5; i++ {
+ _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i))
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+ if got := tr.IdleConnChMapSizeForTesting(); got != 0 {
+ t.Fatalf("ForDisableKeepAlives = %v, map size = %d; want 0", disableKeep, got)
+ }
+ }
+}
+
+// Verify the status quo: that the Client.Post function coerces its
+// body into a ReadCloser if it's a Closer, and that the Transport
+// then closes it.
+func TestTransportClosesRequestBody(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(http.HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.Copy(ioutil.Discard, r.Body)
+ }))
+ defer ts.Close()
+
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+ cl := &Client{Transport: tr}
+
+ closes := 0
+
+ res, err := cl.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if closes != 1 {
+ t.Errorf("closes = %d; want 1", closes)
+ }
+}
+
+func TestTransportTLSHandshakeTimeout(t *testing.T) {
+ defer afterTest(t)
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ ln := newLocalListener(t)
+ defer ln.Close()
+ testdonec := make(chan struct{})
+ defer close(testdonec)
+
+ go func() {
+ c, err := ln.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ <-testdonec
+ c.Close()
+ }()
+
+ getdonec := make(chan struct{})
+ go func() {
+ defer close(getdonec)
+ tr := &Transport{
+ Dial: func(_, _ string) (net.Conn, error) {
+ return net.Dial("tcp", ln.Addr().String())
+ },
+ TLSHandshakeTimeout: 250 * time.Millisecond,
+ }
+ cl := &Client{Transport: tr}
+ _, err := cl.Get("https://dummy.tld/")
+ if err == nil {
+ t.Error("expected error")
+ return
+ }
+ ue, ok := err.(*url.Error)
+ if !ok {
+ t.Errorf("expected url.Error; got %#v", err)
+ return
+ }
+ ne, ok := ue.Err.(net.Error)
+ if !ok {
+ t.Errorf("expected net.Error; got %#v", err)
+ return
+ }
+ if !ne.Timeout() {
+ t.Errorf("expected timeout error; got %v", err)
+ }
+ if !strings.Contains(err.Error(), "handshake timeout") {
+ t.Errorf("expected 'handshake timeout' in error; got %v", err)
+ }
+ }()
+ select {
+ case <-getdonec:
+ case <-time.After(5 * time.Second):
+ t.Error("test timeout; TLS handshake hung?")
+ }
+}
+
+// Trying to repro golang.org/issue/3514
+func TestTLSServerClosesConnection(t *testing.T) {
+ defer afterTest(t)
+ if runtime.GOOS == "windows" {
+ t.Skip("skipping flaky test on Windows; golang.org/issue/7634")
+ }
+ closedc := make(chan bool, 1)
+ ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if strings.Contains(r.URL.Path, "/keep-alive-then-die") {
+ conn, _, _ := w.(Hijacker).Hijack()
+ conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
+ conn.Close()
+ closedc <- true
+ return
+ }
+ fmt.Fprintf(w, "hello")
+ }))
+ defer ts.Close()
+ tr := &Transport{
+ TLSClientConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ }
+ defer tr.CloseIdleConnections()
+ client := &Client{Transport: tr}
+
+ var nSuccess = 0
+ var errs []error
+ const trials = 20
+ for i := 0; i < trials; i++ {
+ tr.CloseIdleConnections()
+ res, err := client.Get(ts.URL + "/keep-alive-then-die")
+ if err != nil {
+ t.Fatal(err)
+ }
+ <-closedc
+ slurp, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(slurp) != "foo" {
+ t.Errorf("Got %q, want foo", slurp)
+ }
+
+ // Now try again and see if we successfully
+ // pick a new connection.
+ res, err = client.Get(ts.URL + "/")
+ if err != nil {
+ errs = append(errs, err)
+ continue
+ }
+ slurp, err = ioutil.ReadAll(res.Body)
+ if err != nil {
+ errs = append(errs, err)
+ continue
+ }
+ nSuccess++
+ }
+ if nSuccess > 0 {
+ t.Logf("successes = %d of %d", nSuccess, trials)
+ } else {
+ t.Errorf("All runs failed:")
+ }
+ for _, err := range errs {
+ t.Logf(" err: %v", err)
+ }
+}
+
+// byteFromChanReader is an io.Reader that reads a single byte at a
+// time from the channel. When the channel is closed, the reader
+// returns io.EOF.
+type byteFromChanReader chan byte
+
+func (c byteFromChanReader) Read(p []byte) (n int, err error) {
+ if len(p) == 0 {
+ return
+ }
+ b, ok := <-c
+ if !ok {
+ return 0, io.EOF
+ }
+ p[0] = b
+ return 1, nil
+}
+
+// Verifies that the Transport doesn't reuse a connection in the case
+// where the server replies before the request has been fully
+// written. We still honor that reply (see TestIssue3595), but don't
+// send future requests on the connection because it's then in a
+// questionable state.
+// golang.org/issue/7569
+func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
+ defer afterTest(t)
+ var sconn struct {
+ sync.Mutex
+ c net.Conn
+ }
+ var getOkay bool
+ closeConn := func() {
+ sconn.Lock()
+ defer sconn.Unlock()
+ if sconn.c != nil {
+ sconn.c.Close()
+ sconn.c = nil
+ if !getOkay {
+ t.Logf("Closed server connection")
+ }
+ }
+ }
+ defer closeConn()
+
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method == "GET" {
+ io.WriteString(w, "bar")
+ return
+ }
+ conn, _, _ := w.(Hijacker).Hijack()
+ sconn.Lock()
+ sconn.c = conn
+ sconn.Unlock()
+ conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive
+ go io.Copy(ioutil.Discard, conn)
+ }))
+ defer ts.Close()
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+ client := &Client{Transport: tr}
+
+ const bodySize = 256 << 10
+ finalBit := make(byteFromChanReader, 1)
+ req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
+ req.ContentLength = bodySize
+ res, err := client.Do(req)
+ if err := wantBody(res, err, "foo"); err != nil {
+ t.Errorf("POST response: %v", err)
+ }
+ donec := make(chan bool)
+ go func() {
+ defer close(donec)
+ res, err = client.Get(ts.URL)
+ if err := wantBody(res, err, "bar"); err != nil {
+ t.Errorf("GET response: %v", err)
+ return
+ }
+ getOkay = true // suppress test noise
+ }()
+ time.AfterFunc(5*time.Second, closeConn)
+ select {
+ case <-donec:
+ finalBit <- 'x' // unblock the writeloop of the first Post
+ close(finalBit)
+ case <-time.After(7 * time.Second):
+ t.Fatal("timeout waiting for GET request to finish")
+ }
+}
+
+type errorReader struct {
+ err error
+}
+
+func (e errorReader) Read(p []byte) (int, error) { return 0, e.err }
+
+type closerFunc func() error
+
+func (f closerFunc) Close() error { return f() }
+
+// Issue 6981
+func TestTransportClosesBodyOnError(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("skipping test; see http://golang.org/issue/7782")
+ }
+ defer afterTest(t)
+ readBody := make(chan error, 1)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ _, err := ioutil.ReadAll(r.Body)
+ readBody <- err
+ }))
+ defer ts.Close()
+ fakeErr := errors.New("fake error")
+ didClose := make(chan bool, 1)
+ req, _ := NewRequest("POST", ts.URL, struct {
+ io.Reader
+ io.Closer
+ }{
+ io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), errorReader{fakeErr}),
+ closerFunc(func() error {
+ select {
+ case didClose <- true:
+ default:
+ }
+ return nil
+ }),
+ })
+ res, err := DefaultClient.Do(req)
+ if res != nil {
+ defer res.Body.Close()
+ }
+ if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) {
+ t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error())
+ }
+ select {
+ case err := <-readBody:
+ if err == nil {
+ t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'")
+ }
+ case <-time.After(5 * time.Second):
+ t.Error("timeout waiting for server handler to complete")
+ }
+ select {
+ case <-didClose:
+ default:
+ t.Errorf("didn't see Body.Close")
+ }
+}
+
+func TestTransportDialTLS(t *testing.T) {
+ var mu sync.Mutex // guards following
+ var gotReq, didDial bool
+
+ ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ mu.Lock()
+ gotReq = true
+ mu.Unlock()
+ }))
+ defer ts.Close()
+ tr := &Transport{
+ DialTLS: func(netw, addr string) (net.Conn, error) {
+ mu.Lock()
+ didDial = true
+ mu.Unlock()
+ c, err := tls.Dial(netw, addr, &tls.Config{
+ InsecureSkipVerify: true,
+ })
+ if err != nil {
+ return nil, err
+ }
+ return c, c.Handshake()
+ },
+ }
+ defer tr.CloseIdleConnections()
+ client := &Client{Transport: tr}
+ res, err := client.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ mu.Lock()
+ if !gotReq {
+ t.Error("didn't get request")
+ }
+ if !didDial {
+ t.Error("didn't use dial hook")
+ }
+}
+
+// Test for issue 8755
+// Ensure that if a proxy returns an error, it is exposed by RoundTrip
+func TestRoundTripReturnsProxyError(t *testing.T) {
+ badProxy := func(*http.Request) (*url.URL, error) {
+ return nil, errors.New("errorMessage")
+ }
+
+ tr := &Transport{Proxy: badProxy}
+
+ req, _ := http.NewRequest("GET", "http://example.com", nil)
+
+ _, err := tr.RoundTrip(req)
+
+ if err == nil {
+ t.Error("Expected proxy error to be returned by RoundTrip")
+ }
+}
+
+// tests that putting an idle conn after a call to CloseIdleConns does return it
+func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
+ tr := &Transport{}
+ wantIdle := func(when string, n int) bool {
+ got := tr.IdleConnCountForTesting("|http|example.com") // key used by PutIdleTestConn
+ if got == n {
+ return true
+ }
+ t.Errorf("%s: idle conns = %d; want %d", when, got, n)
+ return false
+ }
+ wantIdle("start", 0)
+ if !tr.PutIdleTestConn() {
+ t.Fatal("put failed")
+ }
+ if !tr.PutIdleTestConn() {
+ t.Fatal("second put failed")
+ }
+ wantIdle("after put", 2)
+ tr.CloseIdleConnections()
+ if !tr.IsIdleForTesting() {
+ t.Error("should be idle after CloseIdleConnections")
+ }
+ wantIdle("after close idle", 0)
+ if tr.PutIdleTestConn() {
+ t.Fatal("put didn't fail")
+ }
+ wantIdle("after second put", 0)
+
+ tr.RequestIdleConnChForTesting() // should toggle the transport out of idle mode
+ if tr.IsIdleForTesting() {
+ t.Error("shouldn't be idle after RequestIdleConnChForTesting")
+ }
+ if !tr.PutIdleTestConn() {
+ t.Fatal("after re-activation")
+ }
+ wantIdle("after final put", 1)
+}
+
+// This tests that an client requesting a content range won't also
+// implicitly ask for gzip support. If they want that, they need to do it
+// on their own.
+// golang.org/issue/8923
+func TestTransportRangeAndGzip(t *testing.T) {
+ defer afterTest(t)
+ reqc := make(chan *Request, 1)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ reqc <- r
+ }))
+ defer ts.Close()
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ req.Header.Set("Range", "bytes=7-11")
+ res, err := DefaultClient.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ select {
+ case r := <-reqc:
+ if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
+ t.Error("Transport advertised gzip support in the Accept header")
+ }
+ if r.Header.Get("Range") == "" {
+ t.Error("no Range in request")
+ }
+ case <-time.After(10 * time.Second):
+ t.Fatal("timeout")
+ }
+ res.Body.Close()
+}
+
+func wantBody(res *http.Response, err error, want string) error {
+ if err != nil {
+ return err
+ }
+ slurp, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ return fmt.Errorf("error reading body: %v", err)
+ }
+ if string(slurp) != want {
+ return fmt.Errorf("body = %q; want %q", slurp, want)
+ }
+ if err := res.Body.Close(); err != nil {
+ return fmt.Errorf("body Close = %v", err)
+ }
+ return nil
+}
+
+func newLocalListener(t *testing.T) net.Listener {
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ ln, err = net.Listen("tcp6", "[::1]:0")
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ return ln
+}
+
+type countCloseReader struct {
+ n *int
+ io.Reader
+}
+
+func (cr countCloseReader) Close() error {
+ (*cr.n)++
+ return nil
+}
+
+// rgz is a gzip quine that uncompresses to itself.
+var rgz = []byte{
+ 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73,
+ 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0,
+ 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2,
+ 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17,
+ 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60,
+ 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2,
+ 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00,
+ 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00,
+ 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16,
+ 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05,
+ 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff,
+ 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00,
+ 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00,
+ 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
+ 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88,
+ 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff,
+ 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00,
+ 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00,
+ 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
+ 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00,
+ 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff,
+ 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00,
+ 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00,
+ 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16,
+ 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08,
+ 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa,
+ 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06,
+ 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00,
+ 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
+ 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
+ 0x00, 0x00,
+}
diff --git a/src/net/http/triv.go b/src/net/http/triv.go
new file mode 100644
index 000000000..232d65089
--- /dev/null
+++ b/src/net/http/triv.go
@@ -0,0 +1,141 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build ignore
+
+package main
+
+import (
+ "bytes"
+ "expvar"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "os"
+ "os/exec"
+ "strconv"
+ "sync"
+)
+
+// hello world, the web server
+var helloRequests = expvar.NewInt("hello-requests")
+
+func HelloServer(w http.ResponseWriter, req *http.Request) {
+ helloRequests.Add(1)
+ io.WriteString(w, "hello, world!\n")
+}
+
+// Simple counter server. POSTing to it will set the value.
+type Counter struct {
+ mu sync.Mutex // protects n
+ n int
+}
+
+// This makes Counter satisfy the expvar.Var interface, so we can export
+// it directly.
+func (ctr *Counter) String() string {
+ ctr.mu.Lock()
+ defer ctr.mu.Unlock()
+ return fmt.Sprintf("%d", ctr.n)
+}
+
+func (ctr *Counter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ ctr.mu.Lock()
+ defer ctr.mu.Unlock()
+ switch req.Method {
+ case "GET":
+ ctr.n++
+ case "POST":
+ buf := new(bytes.Buffer)
+ io.Copy(buf, req.Body)
+ body := buf.String()
+ if n, err := strconv.Atoi(body); err != nil {
+ fmt.Fprintf(w, "bad POST: %v\nbody: [%v]\n", err, body)
+ } else {
+ ctr.n = n
+ fmt.Fprint(w, "counter reset\n")
+ }
+ }
+ fmt.Fprintf(w, "counter = %d\n", ctr.n)
+}
+
+// simple flag server
+var booleanflag = flag.Bool("boolean", true, "another flag for testing")
+
+func FlagServer(w http.ResponseWriter, req *http.Request) {
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+ fmt.Fprint(w, "Flags:\n")
+ flag.VisitAll(func(f *flag.Flag) {
+ if f.Value.String() != f.DefValue {
+ fmt.Fprintf(w, "%s = %s [default = %s]\n", f.Name, f.Value.String(), f.DefValue)
+ } else {
+ fmt.Fprintf(w, "%s = %s\n", f.Name, f.Value.String())
+ }
+ })
+}
+
+// simple argument server
+func ArgServer(w http.ResponseWriter, req *http.Request) {
+ for _, s := range os.Args {
+ fmt.Fprint(w, s, " ")
+ }
+}
+
+// a channel (just for the fun of it)
+type Chan chan int
+
+func ChanCreate() Chan {
+ c := make(Chan)
+ go func(c Chan) {
+ for x := 0; ; x++ {
+ c <- x
+ }
+ }(c)
+ return c
+}
+
+func (ch Chan) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ io.WriteString(w, fmt.Sprintf("channel send #%d\n", <-ch))
+}
+
+// exec a program, redirecting output
+func DateServer(rw http.ResponseWriter, req *http.Request) {
+ rw.Header().Set("Content-Type", "text/plain; charset=utf-8")
+
+ date, err := exec.Command("/bin/date").Output()
+ if err != nil {
+ http.Error(rw, err.Error(), 500)
+ return
+ }
+ rw.Write(date)
+}
+
+func Logger(w http.ResponseWriter, req *http.Request) {
+ log.Print(req.URL)
+ http.Error(w, "oops", 404)
+}
+
+var webroot = flag.String("root", os.Getenv("HOME"), "web root directory")
+
+func main() {
+ flag.Parse()
+
+ // The counter is published as a variable directly.
+ ctr := new(Counter)
+ expvar.Publish("counter", ctr)
+ http.Handle("/counter", ctr)
+ http.Handle("/", http.HandlerFunc(Logger))
+ http.Handle("/go/", http.StripPrefix("/go/", http.FileServer(http.Dir(*webroot))))
+ http.Handle("/chan", ChanCreate())
+ http.HandleFunc("/flags", FlagServer)
+ http.HandleFunc("/args", ArgServer)
+ http.HandleFunc("/go/hello", HelloServer)
+ http.HandleFunc("/date", DateServer)
+ err := http.ListenAndServe(":12345", nil)
+ if err != nil {
+ log.Panicln("ListenAndServe:", err)
+ }
+}